forked from Github-Mirrors/canaille
Authorization code flow unit tests
This commit is contained in:
parent
43e230750a
commit
61f941c319
8 changed files with 95 additions and 30 deletions
|
@ -10,6 +10,7 @@ skipsdist=True
|
||||||
install_command = pip install {packages}
|
install_command = pip install {packages}
|
||||||
commands = {envbindir}/pytest --showlocals --full-trace {posargs}
|
commands = {envbindir}/pytest --showlocals --full-trace {posargs}
|
||||||
deps =
|
deps =
|
||||||
|
flask-webtest
|
||||||
pytest
|
pytest
|
||||||
pdbpp
|
pdbpp
|
||||||
--requirement requirements.txt
|
--requirement requirements.txt
|
||||||
|
@ -17,6 +18,7 @@ deps =
|
||||||
[testenv:coverage]
|
[testenv:coverage]
|
||||||
skip_install = true
|
skip_install = true
|
||||||
deps =
|
deps =
|
||||||
|
flask-webtest
|
||||||
pdbpp
|
pdbpp
|
||||||
pytest
|
pytest
|
||||||
pytest-coverage
|
pytest-coverage
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
import base64
|
||||||
|
|
||||||
|
|
||||||
|
def client_credentials(client):
|
||||||
|
return base64.b64encode(
|
||||||
|
client.oauthClientID.encode("utf-8")
|
||||||
|
+ b":"
|
||||||
|
+ client.oauthClientSecret.encode("utf-8")
|
||||||
|
).decode("utf-8")
|
|
@ -3,6 +3,7 @@ import ldap.ldapobject
|
||||||
import os
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
import slapdtest
|
import slapdtest
|
||||||
|
from flask_webtest import TestApp
|
||||||
from werkzeug.security import gen_salt
|
from werkzeug.security import gen_salt
|
||||||
from web import create_app
|
from web import create_app
|
||||||
from web.models import User, Client, Token, AuthorizationCode
|
from web.models import User, Client, Token, AuthorizationCode
|
||||||
|
@ -50,6 +51,15 @@ def slapd_server():
|
||||||
+ "\n"
|
+ "\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
conn = ldap.ldapobject.SimpleLDAPObject(slapd.ldap_uri)
|
||||||
|
conn.simple_bind_s(slapd.root_dn, slapd.root_pw)
|
||||||
|
LDAPObjectHelper.root_dn = slapd.suffix
|
||||||
|
Client.initialize(conn)
|
||||||
|
User.initialize(conn)
|
||||||
|
Token.initialize(conn)
|
||||||
|
AuthorizationCode.initialize(conn)
|
||||||
|
conn.unbind_s()
|
||||||
|
|
||||||
yield slapd
|
yield slapd
|
||||||
finally:
|
finally:
|
||||||
slapd.stop()
|
slapd.stop()
|
||||||
|
@ -65,23 +75,18 @@ def slapd_connection(slapd_server):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def app(slapd_server, slapd_connection):
|
def app(slapd_server):
|
||||||
os.environ["AUTHLIB_INSECURE_TRANSPORT"] = "true"
|
os.environ["AUTHLIB_INSECURE_TRANSPORT"] = "true"
|
||||||
|
|
||||||
LDAPObjectHelper.root_dn = slapd_server.suffix
|
|
||||||
Client.initialize(slapd_connection)
|
|
||||||
User.initialize(slapd_connection)
|
|
||||||
Token.initialize(slapd_connection)
|
|
||||||
AuthorizationCode.initialize(slapd_connection)
|
|
||||||
|
|
||||||
app = create_app(
|
app = create_app(
|
||||||
{
|
{
|
||||||
|
"SECRET_KEY": gen_salt(24),
|
||||||
"LDAP": {
|
"LDAP": {
|
||||||
"ROOT_DN": slapd_server.suffix,
|
"ROOT_DN": slapd_server.suffix,
|
||||||
"URI": slapd_server.ldap_uri,
|
"URI": slapd_server.ldap_uri,
|
||||||
"BIND_DN": slapd_server.root_dn,
|
"BIND_DN": slapd_server.root_dn,
|
||||||
"BIND_PW": slapd_server.root_pw,
|
"BIND_PW": slapd_server.root_pw,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return app
|
return app
|
||||||
|
@ -90,9 +95,7 @@ def app(slapd_server, slapd_connection):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def testclient(app):
|
def testclient(app):
|
||||||
app.config["TESTING"] = True
|
app.config["TESTING"] = True
|
||||||
|
return TestApp(app)
|
||||||
with app.test_client() as client:
|
|
||||||
yield client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -118,6 +121,7 @@ def client(app, slapd_connection):
|
||||||
oauthTokenEndpointAuthMethod="client_secret_basic",
|
oauthTokenEndpointAuthMethod="client_secret_basic",
|
||||||
)
|
)
|
||||||
c.save(slapd_connection)
|
c.save(slapd_connection)
|
||||||
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
|
53
tests/test_authorization_code_flow.py
Normal file
53
tests/test_authorization_code_flow.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
from . import client_credentials
|
||||||
|
from urllib.parse import urlsplit, parse_qs
|
||||||
|
from web.models import AuthorizationCode, Token
|
||||||
|
|
||||||
|
|
||||||
|
def test_success(testclient, slapd_connection, user, client):
|
||||||
|
res = testclient.get(
|
||||||
|
"/oauth/authorize",
|
||||||
|
params=dict(
|
||||||
|
response_type="code",
|
||||||
|
client_id=client.oauthClientID,
|
||||||
|
scope="profile",
|
||||||
|
nonce="somenonce",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert 200 == res.status_code
|
||||||
|
|
||||||
|
res.form["login"] = user.name
|
||||||
|
res.form["password"] = "valid"
|
||||||
|
res = res.form.submit()
|
||||||
|
assert 302 == res.status_code
|
||||||
|
|
||||||
|
res = res.follow()
|
||||||
|
assert 200 == res.status_code
|
||||||
|
|
||||||
|
res = res.forms["accept"].submit()
|
||||||
|
assert 302 == res.status_code
|
||||||
|
|
||||||
|
assert res.location.startswith(client.oauthRedirectURIs[0])
|
||||||
|
params = parse_qs(urlsplit(res.location).query)
|
||||||
|
code = params["code"][0]
|
||||||
|
authcode = AuthorizationCode.get(code, slapd_connection)
|
||||||
|
assert authcode is not None
|
||||||
|
|
||||||
|
res = testclient.post(
|
||||||
|
"/oauth/token",
|
||||||
|
params=dict(
|
||||||
|
grant_type="authorization_code",
|
||||||
|
code=code,
|
||||||
|
scope="profile",
|
||||||
|
redirect_uri=client.oauthRedirectURIs[0],
|
||||||
|
),
|
||||||
|
headers={"Authorization": f"Basic {client_credentials(client)}"},
|
||||||
|
)
|
||||||
|
assert 200 == res.status_code
|
||||||
|
access_token = res.json["access_token"]
|
||||||
|
|
||||||
|
token = Token.get(access_token, slapd_connection)
|
||||||
|
assert token is not None
|
||||||
|
|
||||||
|
res = testclient.get("/api/me", headers={"Authorization": f"Bearer {access_token}"})
|
||||||
|
assert 200 == res.status_code
|
||||||
|
assert {"foo": "bar"} == res.json
|
|
@ -1,22 +1,17 @@
|
||||||
import base64
|
from . import client_credentials
|
||||||
|
from web.models import Token
|
||||||
|
|
||||||
|
|
||||||
def test_success(testclient, user, client):
|
def test_success(testclient, slapd_connection, user, client):
|
||||||
client_credentials = base64.b64encode(
|
|
||||||
client.oauthClientID.encode("utf-8")
|
|
||||||
+ b":"
|
|
||||||
+ client.oauthClientSecret.encode("utf-8")
|
|
||||||
).decode("utf-8")
|
|
||||||
|
|
||||||
res = testclient.post(
|
res = testclient.post(
|
||||||
"/oauth/token",
|
"/oauth/token",
|
||||||
data=dict(
|
params=dict(
|
||||||
grant_type="password",
|
grant_type="password",
|
||||||
username=user.name,
|
username=user.name,
|
||||||
password="valid",
|
password="valid",
|
||||||
scope="profile",
|
scope="profile",
|
||||||
),
|
),
|
||||||
headers={"Authorization": f"Basic {client_credentials}"},
|
headers={"Authorization": f"Basic {client_credentials(client)}"},
|
||||||
)
|
)
|
||||||
assert 200 == res.status_code
|
assert 200 == res.status_code
|
||||||
|
|
||||||
|
@ -24,8 +19,9 @@ def test_success(testclient, user, client):
|
||||||
assert res.json["token_type"] == "Bearer"
|
assert res.json["token_type"] == "Bearer"
|
||||||
access_token = res.json["access_token"]
|
access_token = res.json["access_token"]
|
||||||
|
|
||||||
res = testclient.get(
|
token = Token.get(access_token, slapd_connection)
|
||||||
"/api/me", headers={"Authorization": f"Bearer {access_token}"}
|
assert token is not None
|
||||||
)
|
|
||||||
|
res = testclient.get("/api/me", headers={"Authorization": f"Bearer {access_token}"})
|
||||||
assert 200 == res.status_code
|
assert 200 == res.status_code
|
||||||
assert {"foo": "bar"} == res.json
|
assert {"foo": "bar"} == res.json
|
||||||
|
|
|
@ -133,10 +133,11 @@ class LDAPObjectHelper:
|
||||||
conn.add_s(self.dn, attributes)
|
conn.add_s(self.dn, attributes)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, dn):
|
def get(cls, dn, conn=None):
|
||||||
|
conn = conn or cls.ldap()
|
||||||
if "=" not in dn:
|
if "=" not in dn:
|
||||||
dn = f"{cls.id}={dn},{cls.base},{cls.root_dn}"
|
dn = f"{cls.id}={dn},{cls.base},{cls.root_dn}"
|
||||||
result = cls.ldap().search_s(dn, ldap.SCOPE_SUBTREE)
|
result = conn.search_s(dn, ldap.SCOPE_SUBTREE)
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -53,7 +53,7 @@ class AuthorizationCodeGrant(_AuthorizationCodeGrant):
|
||||||
def create_authorization_code(self, client, grant_user, request):
|
def create_authorization_code(self, client, grant_user, request):
|
||||||
return create_authorization_code(client, grant_user, request)
|
return create_authorization_code(client, grant_user, request)
|
||||||
|
|
||||||
def parse_authorization_code(self, code, client):
|
def query_authorization_code(self, code, client):
|
||||||
item = AuthorizationCode.filter(
|
item = AuthorizationCode.filter(
|
||||||
oauthCode=code, oauthClientID=client.oauthClientID
|
oauthCode=code, oauthClientID=client.oauthClientID
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,17 +14,17 @@
|
||||||
<p>{{ gettext('from: %(user)s', user=user.name) }}</p>
|
<p>{{ gettext('from: %(user)s', user=user.name) }}</p>
|
||||||
|
|
||||||
<div class="ui buttons">
|
<div class="ui buttons">
|
||||||
<form action="{{ request.url }}" method="post">
|
<form action="{{ request.url }}" method="post" id="deny">
|
||||||
<input type="hidden" name="answer" value="deny" />
|
<input type="hidden" name="answer" value="deny" />
|
||||||
<input type="submit" class="ui negative button" value="{% trans %}Deny{% endtrans %}" />
|
<input type="submit" class="ui negative button" value="{% trans %}Deny{% endtrans %}" />
|
||||||
</form>
|
</form>
|
||||||
|
|
||||||
<form action="{{ request.url }}" method="post">
|
<form action="{{ request.url }}" method="post" id="logout">
|
||||||
<input type="hidden" name="answer" value="logout" />
|
<input type="hidden" name="answer" value="logout" />
|
||||||
<input type="submit" class="ui button" value="{% trans %}Switch user{% endtrans %}" />
|
<input type="submit" class="ui button" value="{% trans %}Switch user{% endtrans %}" />
|
||||||
</form>
|
</form>
|
||||||
|
|
||||||
<form action="{{ request.url }}" method="post">
|
<form action="{{ request.url }}" method="post" id="accept">
|
||||||
<input type="hidden" name="answer" value="accept" />
|
<input type="hidden" name="answer" value="accept" />
|
||||||
<input type="submit" class="ui positive button" value="{% trans %}Accept{% endtrans %}" />
|
<input type="submit" class="ui positive button" value="{% trans %}Accept{% endtrans %}" />
|
||||||
</form>
|
</form>
|
||||||
|
|
Loading…
Reference in a new issue