From 582ac90dab30eecc7e88754a25dc86fadecbfbda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89loi=20Rivard?= Date: Wed, 13 Oct 2021 11:52:02 +0200 Subject: [PATCH] tokens can have multiple audiences --- canaille/admin/clients.py | 26 ++++++++-- canaille/models.py | 1 - canaille/oauth2utils.py | 17 ++++++- schemas/oauth2-openldap.ldif | 17 +++++-- schemas/oauth2-openldap.schema | 17 +++++-- tests/conftest.py | 39 ++++++++++++++- tests/test_authorization_code_flow.py | 13 +++-- tests/test_client_admin.py | 15 ++++-- tests/test_hybrid_flow.py | 6 ++- tests/test_implicit_flow.py | 10 ++-- tests/test_token_introspection.py | 69 ++++++++++++++++++++++++++- 11 files changed, 200 insertions(+), 30 deletions(-) diff --git a/canaille/admin/clients.py b/canaille/admin/clients.py index 9d1f5cf9..221b5e2a 100644 --- a/canaille/admin/clients.py +++ b/canaille/admin/clients.py @@ -18,6 +18,12 @@ def index(user): return render_template("admin/client_list.html", clients=clients, menuitem="admin") +def client_audiences(): + return [ + (client.dn, client.oauthClientName) for client in Client.filter() + ] + + class ClientAdd(FlaskForm): oauthClientName = wtforms.StringField( _("Name"), @@ -73,6 +79,12 @@ class ClientAdd(FlaskForm): ], default="client_secret_basic", ) + oauthAudience = wtforms.SelectMultipleField( + _("Token audiences"), + validators=[wtforms.validators.Optional()], + choices=client_audiences, + validate_choice=False, + ) oauthLogoURI = wtforms.URLField( _("Logo URI"), validators=[wtforms.validators.Optional()], @@ -120,7 +132,8 @@ def add(user): if not form.validate(): flash( - _("The client has not been added. Please check your information."), "error", + _("The client has not been added. Please check your information."), + "error", ) return render_template("admin/client_add.html", form=form, menuitem="admin") @@ -148,9 +161,11 @@ def add(user): if form["oauthTokenEndpointAuthMethod"].data == "none" else gen_salt(48), ) + client.oauthAudience = [client.dn] client.save() flash( - _("The client has been created."), "success", + _("The client has been created."), + "success", ) return redirect(url_for("admin_clients.edit", client_id=client_id)) @@ -203,10 +218,12 @@ def client_edit(client_id): oauthSoftwareVersion=form["oauthSoftwareVersion"].data, oauthJWK=form["oauthJWK"].data, oauthJWKURI=form["oauthJWKURI"].data, + oauthAudience=form["oauthAudience"].data, ) client.save() flash( - _("The client has been edited."), "success", + _("The client has been edited."), + "success", ) return render_template( @@ -217,7 +234,8 @@ def client_edit(client_id): def client_delete(client_id): client = Client.get(client_id) or abort(404) flash( - _("The client has been deleted."), "success", + _("The client has been deleted."), + "success", ) client.delete() return redirect(url_for("admin_clients.index")) diff --git a/canaille/models.py b/canaille/models.py index 5924e9b3..a47191d6 100644 --- a/canaille/models.py +++ b/canaille/models.py @@ -1,7 +1,6 @@ import datetime import ldap import uuid -from authlib.common.encoding import json_loads, json_dumps from authlib.oauth2.rfc6749 import ( ClientMixin, TokenMixin, diff --git a/canaille/oauth2utils.py b/canaille/oauth2utils.py index 8d186c3c..ec87c024 100644 --- a/canaille/oauth2utils.py +++ b/canaille/oauth2utils.py @@ -125,6 +125,10 @@ class OpenIDCode(_OpenIDCode): def generate_user_info(self, user, scope): return generate_user_info(user, scope) + def get_audiences(self, request): + client = request.client + return [Client.get(aud).oauthClientID for aud in client.oauthAudience] + class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant): def authenticate_user(self, username, password): @@ -157,6 +161,10 @@ class OpenIDImplicitGrant(_OpenIDImplicitGrant): def generate_user_info(self, user, scope): return generate_user_info(user, scope) + def get_audiences(self, request): + client = request.client + return [Client.get(aud).oauthClientID for aud in client.oauthAudience] + class OpenIDHybridGrant(_OpenIDHybridGrant): def save_authorization_code(self, code, request): @@ -171,6 +179,11 @@ class OpenIDHybridGrant(_OpenIDHybridGrant): def generate_user_info(self, user, scope): return generate_user_info(user, scope) + def get_audiences(self, request): + client = request.client + print(client) + return [Client.get(aud).oauthClientID for aud in client.oauthAudience] + def query_client(client_id): return Client.get(client_id) @@ -187,6 +200,7 @@ def save_token(token, request): oauthClient=request.client.dn, oauthRefreshToken=token.get("refresh_token"), oauthSubject=request.user, + oauthAudience=request.client.oauthAudience, ) t.save() @@ -244,6 +258,7 @@ class IntrospectionEndpoint(_IntrospectionEndpoint): def introspect_token(self, token): client_id = Client.get(token.oauthClient).oauthClientID user = User.get(dn=token.oauthSubject) + audience = [Client.get(aud).oauthClientID for aud in token.oauthAudience] return { "active": True, "client_id": client_id, @@ -251,7 +266,7 @@ class IntrospectionEndpoint(_IntrospectionEndpoint): "username": user.name, "scope": token.get_scope(), "sub": user.uid[0], - "aud": client_id, + "aud": audience, "iss": authorization.metadata["issuer"], "exp": token.get_expires_at(), "iat": token.get_issued_at(), diff --git a/schemas/oauth2-openldap.ldif b/schemas/oauth2-openldap.ldif index 0dd3a44e..cb52b593 100644 --- a/schemas/oauth2-openldap.ldif +++ b/schemas/oauth2-openldap.ldif @@ -62,7 +62,7 @@ olcAttributeTypes: ( 1.3.6.1.4.1.56207.1.1.7 NAME 'oauthAuthorizationDate' USAGE userApplications X-ORIGIN 'OAuth 2.0' ) olcAttributeTypes: ( 1.3.6.1.4.1.56207.1.1.8 NAME 'oauthCodeChallenge' - DESC 'OAuth 2.0 nonce' + DESC 'OAuth 2.0 code challenge' EQUALITY caseExactMatch ORDERING caseExactOrderingMatch SUBSTR caseExactSubstringsMatch @@ -71,7 +71,7 @@ olcAttributeTypes: ( 1.3.6.1.4.1.56207.1.1.8 NAME 'oauthCodeChallenge' USAGE userApplications X-ORIGIN 'OAuth 2.0' ) olcAttributeTypes: ( 1.3.6.1.4.1.56207.1.1.9 NAME 'oauthCodeChallengeMethod' - DESC 'OAuth 2.0 nonce' + DESC 'OAuth 2.0 code challenge method' EQUALITY caseExactMatch ORDERING caseExactOrderingMatch SUBSTR caseExactSubstringsMatch @@ -295,6 +295,13 @@ olcAttributeTypes: ( 1.3.6.1.4.1.56207.1.1.34 NAME 'oauthClient' SINGLE-VALUE USAGE userApplications X-ORIGIN 'OAuth 2.0' ) +olcAttributeTypes: ( 1.3.6.1.4.1.56207.1.1.35 NAME 'oauthAudience' + DESC 'Token audience' + EQUALITY distinguishedNameMatch + SUBSTR caseExactSubstringsMatch + SYNTAX 1.3.6.1.4.1.1466.115.121.1.12 + USAGE userApplications + X-ORIGIN 'OAuth 2.0' ) olcObjectClasses: ( 1.3.6.1.4.1.56207.1.2.1 NAME 'oauthClient' DESC 'OAuth 2.0 Authorization Code' SUP top @@ -318,7 +325,8 @@ olcObjectClasses: ( 1.3.6.1.4.1.56207.1.2.1 NAME 'oauthClient' oauthJWK $ oauthTokenEndpointAuthMethod $ oauthSoftwareID $ - oauthSoftwareVersion ) + oauthSoftwareVersion $ + oauthAudience ) ) X-ORIGIN 'OAuth 2.0' ) olcObjectClasses: ( 1.3.6.1.4.1.56207.1.2.2 NAME 'oauthAuthorizationCode' @@ -352,7 +360,8 @@ olcObjectClasses: ( 1.3.6.1.4.1.56207.1.2.3 NAME 'oauthToken' oauthScope $ oauthIssueDate $ oauthTokenLifetime $ - oauthRevokationDate ) + oauthRevokationDate $ + oauthAudience ) X-ORIGIN 'OAuth 2.0' ) olcObjectClasses: ( 1.3.6.1.4.1.56207.1.2.4 NAME 'oauthConsent' DESC 'OAuth 2.0 User consents' diff --git a/schemas/oauth2-openldap.schema b/schemas/oauth2-openldap.schema index f2dd4345..6f291c37 100644 --- a/schemas/oauth2-openldap.schema +++ b/schemas/oauth2-openldap.schema @@ -59,7 +59,7 @@ attributetype ( 1.3.6.1.4.1.56207.1.1.7 NAME 'oauthAuthorizationDate' USAGE userApplications X-ORIGIN 'OAuth 2.0' ) attributetype ( 1.3.6.1.4.1.56207.1.1.8 NAME 'oauthCodeChallenge' - DESC 'OAuth 2.0 nonce' + DESC 'OAuth 2.0 code challenge' EQUALITY caseExactMatch ORDERING caseExactOrderingMatch SUBSTR caseExactSubstringsMatch @@ -68,7 +68,7 @@ attributetype ( 1.3.6.1.4.1.56207.1.1.8 NAME 'oauthCodeChallenge' USAGE userApplications X-ORIGIN 'OAuth 2.0' ) attributetype ( 1.3.6.1.4.1.56207.1.1.9 NAME 'oauthCodeChallengeMethod' - DESC 'OAuth 2.0 nonce' + DESC 'OAuth 2.0 code challenge method' EQUALITY caseExactMatch ORDERING caseExactOrderingMatch SUBSTR caseExactSubstringsMatch @@ -292,6 +292,13 @@ attributetypes ( 1.3.6.1.4.1.56207.1.1.34 NAME 'oauthClient' SINGLE-VALUE USAGE userApplications X-ORIGIN 'OAuth 2.0' ) +attributetypes ( 1.3.6.1.4.1.56207.1.1.35 NAME 'oauthAudience' + DESC 'Token Audience' + EQUALITY distinguishedNameMatch + SUBSTR caseExactSubstringsMatch + SYNTAX 1.3.6.1.4.1.1466.115.121.1.12 + USAGE userApplications + X-ORIGIN 'OAuth 2.0' ) objectclass ( 1.3.6.1.4.1.56207.1.2.1 NAME 'oauthClient' DESC 'OAuth 2.0 Authorization Code' SUP top @@ -315,7 +322,8 @@ objectclass ( 1.3.6.1.4.1.56207.1.2.1 NAME 'oauthClient' oauthJWK $ oauthTokenEndpointAuthMethod $ oauthSoftwareID $ - oauthSoftwareVersion ) + oauthSoftwareVersion $ + oauthAudience) ) X-ORIGIN 'OAuth 2.0' ) objectclass ( 1.3.6.1.4.1.56207.1.2.2 NAME 'oauthAuthorizationCode' @@ -349,7 +357,8 @@ objectclass ( 1.3.6.1.4.1.56207.1.2.3 NAME 'oauthToken' oauthScope $ oauthIssueDate $ oauthTokenLifetime $ - oauthRevokationDate ) + oauthRevokationDate $ + oauthAudience ) X-ORIGIN 'OAuth 2.0' ) objectclass ( 1.3.6.1.4.1.56207.1.2.4 NAME 'oauthConsent' DESC 'OAuth 2.0 User consents' diff --git a/tests/conftest.py b/tests/conftest.py index 6e2a229d..cebaaf52 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -197,7 +197,7 @@ def testclient(app): @pytest.fixture -def client(app, slapd_connection): +def client(app, slapd_connection, other_client): Client.ocs_by_name(slapd_connection) c = Client( oauthClientID=gen_salt(24), @@ -225,6 +225,42 @@ def client(app, slapd_connection): oauthJWKURI="https://mydomain.tld/jwk", oauthTokenEndpointAuthMethod="client_secret_basic", ) + c.oauthAudience = [c.dn, other_client.dn] + c.save(slapd_connection) + + return c + + +@pytest.fixture +def other_client(app, slapd_connection): + Client.ocs_by_name(slapd_connection) + c = Client( + oauthClientID=gen_salt(24), + oauthClientName="Some other client", + oauthClientContact="contact@myotherdomain.tld", + oauthClientURI="https://myotherdomain.tld", + oauthRedirectURIs=[ + "https://myotherdomain.tld/redirect1", + "https://myotherdomain.tld/redirect2", + ], + oauthLogoURI="https://myotherdomain.tld/logo.png", + oauthIssueDate=datetime.datetime.now().strftime("%Y%m%d%H%S%MZ"), + oauthClientSecret=gen_salt(48), + oauthGrantType=[ + "password", + "authorization_code", + "implicit", + "hybrid", + "refresh_token", + ], + oauthResponseType=["code", "token", "id_token"], + oauthScope=["openid", "profile", "groups"], + oauthTermsOfServiceURI="https://myotherdomain.tld/tos", + oauthPolicyURI="https://myotherdomain.tld/policy", + oauthJWKURI="https://myotherdomain.tld/jwk", + oauthTokenEndpointAuthMethod="client_secret_basic", + ) + c.oauthAudience = [c.dn] c.save(slapd_connection) return c @@ -301,6 +337,7 @@ def token(slapd_connection, client, user): Token.ocs_by_name(slapd_connection) t = Token( oauthAccessToken=gen_salt(48), + oauthAudience=[client.dn], oauthClient=client.dn, oauthSubject=user.dn, oauthTokenType=None, diff --git a/tests/test_authorization_code_flow.py b/tests/test_authorization_code_flow.py index 70f80dcc..b9ef8b1e 100644 --- a/tests/test_authorization_code_flow.py +++ b/tests/test_authorization_code_flow.py @@ -1,11 +1,12 @@ from . import client_credentials +from authlib.jose import jwt from authlib.oauth2.rfc7636 import create_s256_code_challenge from urllib.parse import urlsplit, parse_qs from canaille.models import AuthorizationCode, Token, Consent from werkzeug.security import gen_salt -def test_authorization_code_flow(testclient, slapd_connection, logged_user, client): +def test_authorization_code_flow(testclient, slapd_connection, logged_user, client, keypair, other_client): res = testclient.get( "/oauth/authorize", params=dict( @@ -36,12 +37,18 @@ def test_authorization_code_flow(testclient, slapd_connection, logged_user, clie headers={"Authorization": f"Basic {client_credentials(client)}"}, status=200, ) - access_token = res.json["access_token"] + access_token = res.json["access_token"] token = Token.get(access_token, conn=slapd_connection) assert token.oauthClient == client.dn assert token.oauthSubject == logged_user.dn + id_token = res.json["id_token"] + claims = jwt.decode(id_token, keypair[1]) + assert logged_user.uid[0] == claims["sub"] + assert logged_user.cn[0] == claims["name"] + assert [client.oauthClientID, other_client.oauthClientID] == claims["aud"] + res = testclient.get( "/oauth/userinfo", headers={"Authorization": f"Bearer {access_token}"}, @@ -99,8 +106,8 @@ def test_logout_login(testclient, slapd_connection, logged_user, client): headers={"Authorization": f"Basic {client_credentials(client)}"}, status=200, ) - access_token = res.json["access_token"] + access_token = res.json["access_token"] token = Token.get(access_token, conn=slapd_connection) assert token.oauthClient == client.dn assert token.oauthSubject == logged_user.dn diff --git a/tests/test_client_admin.py b/tests/test_client_admin.py index a550c970..a0b307a0 100644 --- a/tests/test_client_admin.py +++ b/tests/test_client_admin.py @@ -32,15 +32,17 @@ def test_client_add(testclient, logged_admin, slapd_connection): "oauthSoftwareVersion": "1", "oauthJWK": "jwk", "oauthJWKURI": "https://foo.bar/jwks.json", + "oauthAudience": [], } for k, v in data.items(): - res.form[k] = v + res.form[k].force_value(v) res = res.form.submit(status=302, name="action", value="edit") res = res.follow(status=200) client_id = res.forms["readonly"]["oauthClientID"].value client = Client.get(client_id, conn=slapd_connection) + data["oauthAudience"] = [client.dn] for k, v in data.items(): client_value = getattr(client, k) if k == "oauthScope": @@ -49,7 +51,7 @@ def test_client_add(testclient, logged_admin, slapd_connection): assert v == client_value -def test_client_edit(testclient, client, logged_admin, slapd_connection): +def test_client_edit(testclient, client, logged_admin, slapd_connection, other_client): res = testclient.get("/admin/client/edit/" + client.oauthClientID) data = { "oauthClientName": "foobar", @@ -67,12 +69,17 @@ def test_client_edit(testclient, client, logged_admin, slapd_connection): "oauthSoftwareVersion": "1", "oauthJWK": "jwk", "oauthJWKURI": "https://foo.bar/jwks.json", + "oauthAudience": [client.dn, other_client.dn], } for k, v in data.items(): - res.forms["clientadd"][k] = v + res.forms["clientadd"][k].force_value(v) res = res.forms["clientadd"].submit(status=200, name="action", value="edit") - client.reload(conn=slapd_connection) + assert ( + "The client has not been edited. Please check your information." not in res.text + ) + + client = Client.get(client.dn, conn=slapd_connection) for k, v in data.items(): client_value = getattr(client, k) if k == "oauthScope": diff --git a/tests/test_hybrid_flow.py b/tests/test_hybrid_flow.py index 35b8586c..ad54ef41 100644 --- a/tests/test_hybrid_flow.py +++ b/tests/test_hybrid_flow.py @@ -50,7 +50,9 @@ def test_oauth_hybrid(testclient, slapd_connection, user, client): } == res.json -def test_oidc_hybrid(testclient, slapd_connection, logged_user, client, keypair): +def test_oidc_hybrid( + testclient, slapd_connection, logged_user, client, keypair, other_client +): res = testclient.get( "/oauth/authorize", params=dict( @@ -79,7 +81,7 @@ def test_oidc_hybrid(testclient, slapd_connection, logged_user, client, keypair) claims = jwt.decode(id_token, keypair[1]) assert logged_user.uid[0] == claims["sub"] assert logged_user.cn[0] == claims["name"] - assert [client.oauthClientID] == claims["aud"] + assert [client.oauthClientID, other_client.oauthClientID] == claims["aud"] res = testclient.get( "/oauth/userinfo", diff --git a/tests/test_implicit_flow.py b/tests/test_implicit_flow.py index 944629f5..04110709 100644 --- a/tests/test_implicit_flow.py +++ b/tests/test_implicit_flow.py @@ -52,7 +52,9 @@ def test_oauth_implicit(testclient, slapd_connection, user, client): client.save(slapd_connection) -def test_oidc_implicit(testclient, keypair, slapd_connection, user, client): +def test_oidc_implicit( + testclient, keypair, slapd_connection, user, client, other_client +): client.oauthGrantType = ["token id_token"] client.oauthTokenEndpointAuthMethod = "none" @@ -89,7 +91,7 @@ def test_oidc_implicit(testclient, keypair, slapd_connection, user, client): claims = jwt.decode(id_token, keypair[1]) assert user.uid[0] == claims["sub"] assert user.cn[0] == claims["name"] - assert [client.oauthClientID] == claims["aud"] + assert [client.oauthClientID, other_client.oauthClientID] == claims["aud"] res = testclient.get( "/oauth/userinfo", @@ -110,7 +112,7 @@ def test_oidc_implicit(testclient, keypair, slapd_connection, user, client): def test_oidc_implicit_with_group( - testclient, keypair, slapd_connection, user, client, foo_group + testclient, keypair, slapd_connection, user, client, foo_group, other_client ): client.oauthGrantType = ["token id_token"] client.oauthTokenEndpointAuthMethod = "none" @@ -148,7 +150,7 @@ def test_oidc_implicit_with_group( claims = jwt.decode(id_token, keypair[1]) assert user.uid[0] == claims["sub"] assert user.cn[0] == claims["name"] - assert [client.oauthClientID] == claims["aud"] + assert [client.oauthClientID, other_client.oauthClientID] == claims["aud"] assert ["foo"] == claims["groups"] res = testclient.get( diff --git a/tests/test_token_introspection.py b/tests/test_token_introspection.py index ccf40fb1..3d04628e 100644 --- a/tests/test_token_introspection.py +++ b/tests/test_token_introspection.py @@ -1,10 +1,14 @@ +from canaille.models import AuthorizationCode, Token, Client +from urllib.parse import urlsplit, parse_qs from . import client_credentials def test_token_introspection(testclient, user, client, token): res = testclient.post( "/oauth/introspect", - params=dict(token=token.oauthAccessToken,), + params=dict( + token=token.oauthAccessToken, + ), headers={"Authorization": f"Basic {client_credentials(client)}"}, status=200, ) @@ -15,7 +19,7 @@ def test_token_introspection(testclient, user, client, token): "username": user.name, "scope": token.get_scope(), "sub": user.uid[0], - "aud": client.oauthClientID, + "aud": [client.oauthClientID], "iss": "https://mydomain.tld", "exp": token.get_expires_at(), "iat": token.get_issued_at(), @@ -30,3 +34,64 @@ def test_token_invalid(testclient, client): status=200, ) assert {"active": False} == res.json + + +def test_full_flow( + testclient, slapd_connection, logged_user, client, user, other_client +): + res = testclient.get( + "/oauth/authorize", + params=dict( + response_type="code", + client_id=client.oauthClientID, + scope="profile", + nonce="somenonce", + ), + status=200, + ) + + res = res.form.submit(name="answer", value="accept", status=302) + + assert res.location.startswith(client.oauthRedirectURIs[0]) + params = parse_qs(urlsplit(res.location).query) + code = params["code"][0] + authcode = AuthorizationCode.get(code, conn=slapd_connection) + assert authcode is not None + + res = testclient.post( + "/oauth/token", + params=dict( + grant_type="authorization_code", + code=code, + scope="profile", + redirect_uri=client.oauthRedirectURIs[0], + ), + headers={"Authorization": f"Basic {client_credentials(client)}"}, + status=200, + ) + access_token = res.json["access_token"] + + token = Token.get(access_token, conn=slapd_connection) + assert token.oauthClient == client.dn + assert token.oauthSubject == logged_user.dn + + res = testclient.post( + "/oauth/introspect", + params=dict( + token=token.oauthAccessToken, + ), + headers={"Authorization": f"Basic {client_credentials(client)}"}, + status=200, + ) + assert { + "aud": [client.oauthClientID, other_client.oauthClientID], + "active": True, + "client_id": client.oauthClientID, + "token_type": token.oauthTokenType, + "username": user.name, + "scope": token.get_scope(), + "sub": user.uid[0], + "iss": "https://mydomain.tld", + "exp": token.get_expires_at(), + "iat": token.get_issued_at(), + } == res.json