forked from Github-Mirrors/canaille
Authorization code flow unit tests
This commit is contained in:
parent
43e230750a
commit
61f941c319
8 changed files with 95 additions and 30 deletions
|
@ -10,6 +10,7 @@ skipsdist=True
|
|||
install_command = pip install {packages}
|
||||
commands = {envbindir}/pytest --showlocals --full-trace {posargs}
|
||||
deps =
|
||||
flask-webtest
|
||||
pytest
|
||||
pdbpp
|
||||
--requirement requirements.txt
|
||||
|
@ -17,6 +18,7 @@ deps =
|
|||
[testenv:coverage]
|
||||
skip_install = true
|
||||
deps =
|
||||
flask-webtest
|
||||
pdbpp
|
||||
pytest
|
||||
pytest-coverage
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
import base64
|
||||
|
||||
|
||||
def client_credentials(client):
|
||||
return base64.b64encode(
|
||||
client.oauthClientID.encode("utf-8")
|
||||
+ b":"
|
||||
+ client.oauthClientSecret.encode("utf-8")
|
||||
).decode("utf-8")
|
|
@ -3,6 +3,7 @@ import ldap.ldapobject
|
|||
import os
|
||||
import pytest
|
||||
import slapdtest
|
||||
from flask_webtest import TestApp
|
||||
from werkzeug.security import gen_salt
|
||||
from web import create_app
|
||||
from web.models import User, Client, Token, AuthorizationCode
|
||||
|
@ -50,6 +51,15 @@ def slapd_server():
|
|||
+ "\n"
|
||||
)
|
||||
|
||||
conn = ldap.ldapobject.SimpleLDAPObject(slapd.ldap_uri)
|
||||
conn.simple_bind_s(slapd.root_dn, slapd.root_pw)
|
||||
LDAPObjectHelper.root_dn = slapd.suffix
|
||||
Client.initialize(conn)
|
||||
User.initialize(conn)
|
||||
Token.initialize(conn)
|
||||
AuthorizationCode.initialize(conn)
|
||||
conn.unbind_s()
|
||||
|
||||
yield slapd
|
||||
finally:
|
||||
slapd.stop()
|
||||
|
@ -65,23 +75,18 @@ def slapd_connection(slapd_server):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def app(slapd_server, slapd_connection):
|
||||
def app(slapd_server):
|
||||
os.environ["AUTHLIB_INSECURE_TRANSPORT"] = "true"
|
||||
|
||||
LDAPObjectHelper.root_dn = slapd_server.suffix
|
||||
Client.initialize(slapd_connection)
|
||||
User.initialize(slapd_connection)
|
||||
Token.initialize(slapd_connection)
|
||||
AuthorizationCode.initialize(slapd_connection)
|
||||
|
||||
app = create_app(
|
||||
{
|
||||
"SECRET_KEY": gen_salt(24),
|
||||
"LDAP": {
|
||||
"ROOT_DN": slapd_server.suffix,
|
||||
"URI": slapd_server.ldap_uri,
|
||||
"BIND_DN": slapd_server.root_dn,
|
||||
"BIND_PW": slapd_server.root_pw,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
return app
|
||||
|
@ -90,9 +95,7 @@ def app(slapd_server, slapd_connection):
|
|||
@pytest.fixture
|
||||
def testclient(app):
|
||||
app.config["TESTING"] = True
|
||||
|
||||
with app.test_client() as client:
|
||||
yield client
|
||||
return TestApp(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -118,6 +121,7 @@ def client(app, slapd_connection):
|
|||
oauthTokenEndpointAuthMethod="client_secret_basic",
|
||||
)
|
||||
c.save(slapd_connection)
|
||||
|
||||
return c
|
||||
|
||||
|
||||
|
|
53
tests/test_authorization_code_flow.py
Normal file
53
tests/test_authorization_code_flow.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
from . import client_credentials
|
||||
from urllib.parse import urlsplit, parse_qs
|
||||
from web.models import AuthorizationCode, Token
|
||||
|
||||
|
||||
def test_success(testclient, slapd_connection, user, client):
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
params=dict(
|
||||
response_type="code",
|
||||
client_id=client.oauthClientID,
|
||||
scope="profile",
|
||||
nonce="somenonce",
|
||||
),
|
||||
)
|
||||
assert 200 == res.status_code
|
||||
|
||||
res.form["login"] = user.name
|
||||
res.form["password"] = "valid"
|
||||
res = res.form.submit()
|
||||
assert 302 == res.status_code
|
||||
|
||||
res = res.follow()
|
||||
assert 200 == res.status_code
|
||||
|
||||
res = res.forms["accept"].submit()
|
||||
assert 302 == res.status_code
|
||||
|
||||
assert res.location.startswith(client.oauthRedirectURIs[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code, slapd_connection)
|
||||
assert authcode is not None
|
||||
|
||||
res = testclient.post(
|
||||
"/oauth/token",
|
||||
params=dict(
|
||||
grant_type="authorization_code",
|
||||
code=code,
|
||||
scope="profile",
|
||||
redirect_uri=client.oauthRedirectURIs[0],
|
||||
),
|
||||
headers={"Authorization": f"Basic {client_credentials(client)}"},
|
||||
)
|
||||
assert 200 == res.status_code
|
||||
access_token = res.json["access_token"]
|
||||
|
||||
token = Token.get(access_token, slapd_connection)
|
||||
assert token is not None
|
||||
|
||||
res = testclient.get("/api/me", headers={"Authorization": f"Bearer {access_token}"})
|
||||
assert 200 == res.status_code
|
||||
assert {"foo": "bar"} == res.json
|
|
@ -1,22 +1,17 @@
|
|||
import base64
|
||||
from . import client_credentials
|
||||
from web.models import Token
|
||||
|
||||
|
||||
def test_success(testclient, user, client):
|
||||
client_credentials = base64.b64encode(
|
||||
client.oauthClientID.encode("utf-8")
|
||||
+ b":"
|
||||
+ client.oauthClientSecret.encode("utf-8")
|
||||
).decode("utf-8")
|
||||
|
||||
def test_success(testclient, slapd_connection, user, client):
|
||||
res = testclient.post(
|
||||
"/oauth/token",
|
||||
data=dict(
|
||||
params=dict(
|
||||
grant_type="password",
|
||||
username=user.name,
|
||||
password="valid",
|
||||
scope="profile",
|
||||
),
|
||||
headers={"Authorization": f"Basic {client_credentials}"},
|
||||
headers={"Authorization": f"Basic {client_credentials(client)}"},
|
||||
)
|
||||
assert 200 == res.status_code
|
||||
|
||||
|
@ -24,8 +19,9 @@ def test_success(testclient, user, client):
|
|||
assert res.json["token_type"] == "Bearer"
|
||||
access_token = res.json["access_token"]
|
||||
|
||||
res = testclient.get(
|
||||
"/api/me", headers={"Authorization": f"Bearer {access_token}"}
|
||||
)
|
||||
token = Token.get(access_token, slapd_connection)
|
||||
assert token is not None
|
||||
|
||||
res = testclient.get("/api/me", headers={"Authorization": f"Bearer {access_token}"})
|
||||
assert 200 == res.status_code
|
||||
assert {"foo": "bar"} == res.json
|
||||
|
|
|
@ -133,10 +133,11 @@ class LDAPObjectHelper:
|
|||
conn.add_s(self.dn, attributes)
|
||||
|
||||
@classmethod
|
||||
def get(cls, dn):
|
||||
def get(cls, dn, conn=None):
|
||||
conn = conn or cls.ldap()
|
||||
if "=" not in dn:
|
||||
dn = f"{cls.id}={dn},{cls.base},{cls.root_dn}"
|
||||
result = cls.ldap().search_s(dn, ldap.SCOPE_SUBTREE)
|
||||
result = conn.search_s(dn, ldap.SCOPE_SUBTREE)
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
|
|
@ -53,7 +53,7 @@ class AuthorizationCodeGrant(_AuthorizationCodeGrant):
|
|||
def create_authorization_code(self, client, grant_user, request):
|
||||
return create_authorization_code(client, grant_user, request)
|
||||
|
||||
def parse_authorization_code(self, code, client):
|
||||
def query_authorization_code(self, code, client):
|
||||
item = AuthorizationCode.filter(
|
||||
oauthCode=code, oauthClientID=client.oauthClientID
|
||||
)
|
||||
|
|
|
@ -14,17 +14,17 @@
|
|||
<p>{{ gettext('from: %(user)s', user=user.name) }}</p>
|
||||
|
||||
<div class="ui buttons">
|
||||
<form action="{{ request.url }}" method="post">
|
||||
<form action="{{ request.url }}" method="post" id="deny">
|
||||
<input type="hidden" name="answer" value="deny" />
|
||||
<input type="submit" class="ui negative button" value="{% trans %}Deny{% endtrans %}" />
|
||||
</form>
|
||||
|
||||
<form action="{{ request.url }}" method="post">
|
||||
<form action="{{ request.url }}" method="post" id="logout">
|
||||
<input type="hidden" name="answer" value="logout" />
|
||||
<input type="submit" class="ui button" value="{% trans %}Switch user{% endtrans %}" />
|
||||
</form>
|
||||
|
||||
<form action="{{ request.url }}" method="post">
|
||||
<form action="{{ request.url }}" method="post" id="accept">
|
||||
<input type="hidden" name="answer" value="accept" />
|
||||
<input type="submit" class="ui positive button" value="{% trans %}Accept{% endtrans %}" />
|
||||
</form>
|
||||
|
|
Loading…
Reference in a new issue