diff --git a/canaille/scim/endpoints.py b/canaille/scim/endpoints.py index 043593b3..37f56151 100644 --- a/canaille/scim/endpoints.py +++ b/canaille/scim/endpoints.py @@ -1,5 +1,7 @@ from http import HTTPStatus +from authlib.integrations.flask_oauth2 import ResourceProtector +from authlib.oauth2.rfc6750 import BearerTokenValidator from flask import Blueprint from flask import Response from flask import abort @@ -46,6 +48,16 @@ group_schema.attributes[0].required = Required.true Group = Resource.from_schema(group_schema) +class SCIMBearerTokenValidator(BearerTokenValidator): + def authenticate_token(self, token_string: str): + token = Backend.instance.get(models.Token, access_token=token_string) + return token if token and not token.subject else None + + +require_oauth = ResourceProtector() +require_oauth.register_token_validator(SCIMBearerTokenValidator()) + + @bp.after_request def add_scim_content_type(response): response.headers["Content-Type"] = "application/scim+json" @@ -291,6 +303,7 @@ def group_from_scim_to_canaille(scim_group: Group, group): @bp.route("/Users", methods=["GET"]) @csrf.exempt +@require_oauth() def query_users(): req = parse_search_request(request) start_index_1 = req.start_index or 1 @@ -313,6 +326,7 @@ def query_users(): @bp.route("/Users/", methods=["GET"]) @csrf.exempt +@require_oauth() def query_user(user): scim_user = user_from_canaille_to_scim(user) return scim_user.model_dump( @@ -322,6 +336,7 @@ def query_user(user): @bp.route("/Groups", methods=["GET"]) @csrf.exempt +@require_oauth() def query_groups(): req = parse_search_request(request) start_index_1 = req.start_index or 1 @@ -344,6 +359,7 @@ def query_groups(): @bp.route("/Groups/", methods=["GET"]) @csrf.exempt +@require_oauth() def query_group(group): scim_group = group_from_canaille_to_scim(group) return scim_group.model_dump( @@ -353,6 +369,7 @@ def query_group(group): @bp.route("/Schemas", methods=["GET"]) @csrf.exempt +@require_oauth() def query_schemas(): req = parse_search_request(request) start_index_1 = req.start_index or 1 @@ -370,6 +387,7 @@ def query_schemas(): @bp.route("/Schemas/", methods=["GET"]) @csrf.exempt +@require_oauth() def query_schema(schema_id): schema = get_schemas().get(schema_id) if not schema: @@ -380,6 +398,7 @@ def query_schema(schema_id): @bp.route("/ResourceTypes", methods=["GET"]) @csrf.exempt +@require_oauth() def query_resource_types(): req = parse_search_request(request) start_index_1 = req.start_index or 1 @@ -397,6 +416,7 @@ def query_resource_types(): @bp.route("/ResourceTypes/", methods=["GET"]) @csrf.exempt +@require_oauth() def query_resource_type(resource_type_name): resource_type = get_resource_types().get(resource_type_name) if not resource_type: @@ -407,6 +427,7 @@ def query_resource_type(resource_type_name): @bp.route("/ServiceProviderConfig", methods=["GET"]) @csrf.exempt +@require_oauth() def query_service_provider_config(): spc = ServiceProviderConfig( meta=Meta( @@ -436,6 +457,7 @@ def query_service_provider_config(): @bp.route("/Users", methods=["POST"]) @csrf.exempt +@require_oauth() def create_user(): request_user = User[EnterpriseUser].model_validate( request.json, scim_ctx=Context.RESOURCE_CREATION_REQUEST @@ -449,6 +471,7 @@ def create_user(): @bp.route("/Groups", methods=["POST"]) @csrf.exempt +@require_oauth() def create_group(): request_group = Group.model_validate( request.json, scim_ctx=Context.RESOURCE_CREATION_REQUEST @@ -464,6 +487,7 @@ def create_group(): @bp.route("/Users/", methods=["PUT"]) @csrf.exempt +@require_oauth() def replace_user(user): request_user = User[EnterpriseUser].model_validate( request.json, scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST @@ -477,6 +501,7 @@ def replace_user(user): @bp.route("/Groups/", methods=["PUT"]) @csrf.exempt +@require_oauth() def replace_group(group): request_group = Group.model_validate( request.json, scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST @@ -490,6 +515,7 @@ def replace_group(group): @bp.route("/Users/", methods=["DELETE"]) @csrf.exempt +@require_oauth() def delete_user(user): Backend.instance.delete(user) return "", HTTPStatus.NO_CONTENT @@ -497,6 +523,7 @@ def delete_user(user): @bp.route("/Groups/", methods=["DELETE"]) @csrf.exempt +@require_oauth() def delete_group(group): Backend.instance.delete(group) return "", HTTPStatus.NO_CONTENT diff --git a/pyproject.toml b/pyproject.toml index fbce15be..5a66b968 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ oidc = [ scim = [ "scim2-models>=0.2.2", + "authlib >= 1.3.0", ] ldap = [ diff --git a/tests/scim/conftest.py b/tests/scim/conftest.py index 856a37b9..bd939752 100644 --- a/tests/scim/conftest.py +++ b/tests/scim/conftest.py @@ -1,4 +1,12 @@ +import datetime + import pytest +from scim2_client.engines.werkzeug import TestSCIMClient +from werkzeug.security import gen_salt +from werkzeug.test import Client + +from canaille.app import models +from canaille.scim.endpoints import bp @pytest.fixture @@ -13,3 +21,54 @@ def configuration(configuration): "level": "INFO", } return configuration + + +@pytest.fixture +def oidc_client(testclient, backend): + c = models.Client( + client_id=gen_salt(24), + client_name="Some client", + contacts=["contact@mydomain.test"], + client_uri="https://mydomain.test", + redirect_uris=[ + "https://mydomain.test/redirect1", + ], + client_id_issued_at=datetime.datetime.now(datetime.timezone.utc), + client_secret=gen_salt(48), + grant_types=[ + "client_credentials", + ], + response_types=["code", "token", "id_token"], + scope=["openid", "email", "profile", "groups", "address", "phone"], + token_endpoint_auth_method="client_secret_basic", + ) + backend.save(c) + + yield c + backend.delete(c) + + +@pytest.fixture +def oidc_token(testclient, oidc_client, backend): + t = models.Token( + token_id=gen_salt(48), + access_token=gen_salt(48), + audience=[oidc_client], + client=oidc_client, + refresh_token=gen_salt(48), + scope=["openid", "profile"], + issue_date=datetime.datetime.now(datetime.timezone.utc), + lifetime=3600, + ) + backend.save(t) + yield t + backend.delete(t) + + +@pytest.fixture +def scim_client(app, oidc_client, oidc_token): + return TestSCIMClient( + Client(app), + scim_prefix=bp.url_prefix, + environ={"headers": {"Authorization": f"Bearer {oidc_token.access_token}"}}, + ) diff --git a/tests/scim/test_authentication.py b/tests/scim/test_authentication.py new file mode 100644 index 00000000..12a1e88f --- /dev/null +++ b/tests/scim/test_authentication.py @@ -0,0 +1,45 @@ +import datetime + +import pytest +from scim2_client import SCIMResponseErrorObject +from scim2_client.engines.werkzeug import TestSCIMClient +from werkzeug.security import gen_salt +from werkzeug.test import Client + +from canaille.app import models +from canaille.scim.endpoints import bp + + +def test_authentication_failure(app): + """Test authentication with an invalid token.""" + scim_client = TestSCIMClient( + Client(app), + scim_prefix=bp.url_prefix, + environ={"headers": {"Authorization": "Bearer invalid"}}, + ) + with pytest.raises(SCIMResponseErrorObject): + scim_client.discover() + + +def test_authentication_with_an_user_token(app, backend, oidc_client, user): + """Test authentication with an user token.""" + scim_token = models.Token( + token_id=gen_salt(48), + access_token=gen_salt(48), + subject=user, + audience=[oidc_client], + client=oidc_client, + refresh_token=gen_salt(48), + scope=["openid", "profile"], + issue_date=datetime.datetime.now(datetime.timezone.utc), + lifetime=3600, + ) + backend.save(scim_token) + + scim_client = TestSCIMClient( + Client(app), + scim_prefix=bp.url_prefix, + environ={"headers": {"Authorization": f"Bearer {scim_token.access_token}"}}, + ) + with pytest.raises(SCIMResponseErrorObject): + scim_client.discover() diff --git a/tests/scim/test_scim_tester.py b/tests/scim/test_scim_tester.py index df4f04b2..3693d708 100644 --- a/tests/scim/test_scim_tester.py +++ b/tests/scim/test_scim_tester.py @@ -1,11 +1,8 @@ import pytest -from scim2_client.engines.werkzeug import TestSCIMClient from scim2_tester import check_server -from canaille.scim.endpoints import bp - -def test_scim_tester(app, backend): +def test_scim_tester(scim_client, backend): # currently the tester create empty groups because it cannot handle references # but LDAP does not support empty groups # https://github.com/python-scim/scim2-tester/issues/15 @@ -13,5 +10,4 @@ def test_scim_tester(app, backend): if "ldap" in backend.__class__.__module__: pytest.skip() - client = TestSCIMClient(app, scim_prefix=bp.url_prefix) - check_server(client, raise_exceptions=True) + check_server(scim_client, raise_exceptions=True) diff --git a/uv.lock b/uv.lock index bab489ad..d17d33ce 100644 --- a/uv.lock +++ b/uv.lock @@ -165,6 +165,7 @@ postgresql = [ { name = "sqlalchemy-utils" }, ] scim = [ + { name = "authlib" }, { name = "scim2-models" }, ] sentry = [ @@ -222,6 +223,7 @@ doc = [ [package.metadata] requires-dist = [ { name = "authlib", marker = "extra == 'oidc'", specifier = ">=1.3.0" }, + { name = "authlib", marker = "extra == 'scim'", specifier = ">=1.3.0" }, { name = "email-validator", marker = "extra == 'front'", specifier = ">=2.0.0" }, { name = "flask", specifier = ">=3.0.0" }, { name = "flask-babel", marker = "extra == 'front'", specifier = ">=4.0.0" }, @@ -1634,14 +1636,14 @@ wheels = [ [[package]] name = "scim2-client" -version = "0.4.3" +version = "0.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "scim2-models" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/44/1b228a6a680ca96a1274f2ca1dd22aa3e61e656c5e829c348b27f793dc9d/scim2_client-0.4.3.tar.gz", hash = "sha256:69f55e1c296cb018cb4d71954485b6dab8153bb59935647b1e063a659c141ede", size = 85428 } +sdist = { url = "https://files.pythonhosted.org/packages/4f/d0/06a2a68c8b6a840fd8020ebfaf0141e1eadff0a24b4a2ba87c1d0fb9607d/scim2_client-0.5.0.tar.gz", hash = "sha256:f485864c0148cbbddd6a4120a4b3c2553ca89a8076d5cf7bdfa8ad6aba2c1e6e", size = 85783 } wheels = [ - { url = "https://files.pythonhosted.org/packages/fa/75/d56f022664d6db564b0e85265ab5e97b49133f796ff1d502bd652e7c075f/scim2_client-0.4.3-py3-none-any.whl", hash = "sha256:051578f4e56e57149b1b6ea06c30cd4d831b8f2278c301e92421b0ebf524b813", size = 22373 }, + { url = "https://files.pythonhosted.org/packages/4b/e3/195d64ace80effcb948773914b72a8705565afbd02471ac827b28dfa977a/scim2_client-0.5.0-py3-none-any.whl", hash = "sha256:9f290aafea88d4220372a4902a17b3e7ea4dbdae69dfe9489b938d8d7a7ac827", size = 22500 }, ] [[package]]