feat: scim endpoint authentication

This commit is contained in:
Éloi Rivard 2024-12-06 15:15:04 +01:00
parent a299bb92ba
commit 10abb2013a
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184
6 changed files with 139 additions and 9 deletions

View file

@ -1,5 +1,7 @@
from http import HTTPStatus from http import HTTPStatus
from authlib.integrations.flask_oauth2 import ResourceProtector
from authlib.oauth2.rfc6750 import BearerTokenValidator
from flask import Blueprint from flask import Blueprint
from flask import Response from flask import Response
from flask import abort from flask import abort
@ -46,6 +48,16 @@ group_schema.attributes[0].required = Required.true
Group = Resource.from_schema(group_schema) Group = Resource.from_schema(group_schema)
class SCIMBearerTokenValidator(BearerTokenValidator):
def authenticate_token(self, token_string: str):
token = Backend.instance.get(models.Token, access_token=token_string)
return token if token and not token.subject else None
require_oauth = ResourceProtector()
require_oauth.register_token_validator(SCIMBearerTokenValidator())
@bp.after_request @bp.after_request
def add_scim_content_type(response): def add_scim_content_type(response):
response.headers["Content-Type"] = "application/scim+json" response.headers["Content-Type"] = "application/scim+json"
@ -291,6 +303,7 @@ def group_from_scim_to_canaille(scim_group: Group, group):
@bp.route("/Users", methods=["GET"]) @bp.route("/Users", methods=["GET"])
@csrf.exempt @csrf.exempt
@require_oauth()
def query_users(): def query_users():
req = parse_search_request(request) req = parse_search_request(request)
start_index_1 = req.start_index or 1 start_index_1 = req.start_index or 1
@ -313,6 +326,7 @@ def query_users():
@bp.route("/Users/<user:user>", methods=["GET"]) @bp.route("/Users/<user:user>", methods=["GET"])
@csrf.exempt @csrf.exempt
@require_oauth()
def query_user(user): def query_user(user):
scim_user = user_from_canaille_to_scim(user) scim_user = user_from_canaille_to_scim(user)
return scim_user.model_dump( return scim_user.model_dump(
@ -322,6 +336,7 @@ def query_user(user):
@bp.route("/Groups", methods=["GET"]) @bp.route("/Groups", methods=["GET"])
@csrf.exempt @csrf.exempt
@require_oauth()
def query_groups(): def query_groups():
req = parse_search_request(request) req = parse_search_request(request)
start_index_1 = req.start_index or 1 start_index_1 = req.start_index or 1
@ -344,6 +359,7 @@ def query_groups():
@bp.route("/Groups/<group:group>", methods=["GET"]) @bp.route("/Groups/<group:group>", methods=["GET"])
@csrf.exempt @csrf.exempt
@require_oauth()
def query_group(group): def query_group(group):
scim_group = group_from_canaille_to_scim(group) scim_group = group_from_canaille_to_scim(group)
return scim_group.model_dump( return scim_group.model_dump(
@ -353,6 +369,7 @@ def query_group(group):
@bp.route("/Schemas", methods=["GET"]) @bp.route("/Schemas", methods=["GET"])
@csrf.exempt @csrf.exempt
@require_oauth()
def query_schemas(): def query_schemas():
req = parse_search_request(request) req = parse_search_request(request)
start_index_1 = req.start_index or 1 start_index_1 = req.start_index or 1
@ -370,6 +387,7 @@ def query_schemas():
@bp.route("/Schemas/<string:schema_id>", methods=["GET"]) @bp.route("/Schemas/<string:schema_id>", methods=["GET"])
@csrf.exempt @csrf.exempt
@require_oauth()
def query_schema(schema_id): def query_schema(schema_id):
schema = get_schemas().get(schema_id) schema = get_schemas().get(schema_id)
if not schema: if not schema:
@ -380,6 +398,7 @@ def query_schema(schema_id):
@bp.route("/ResourceTypes", methods=["GET"]) @bp.route("/ResourceTypes", methods=["GET"])
@csrf.exempt @csrf.exempt
@require_oauth()
def query_resource_types(): def query_resource_types():
req = parse_search_request(request) req = parse_search_request(request)
start_index_1 = req.start_index or 1 start_index_1 = req.start_index or 1
@ -397,6 +416,7 @@ def query_resource_types():
@bp.route("/ResourceTypes/<string:resource_type_name>", methods=["GET"]) @bp.route("/ResourceTypes/<string:resource_type_name>", methods=["GET"])
@csrf.exempt @csrf.exempt
@require_oauth()
def query_resource_type(resource_type_name): def query_resource_type(resource_type_name):
resource_type = get_resource_types().get(resource_type_name) resource_type = get_resource_types().get(resource_type_name)
if not resource_type: if not resource_type:
@ -407,6 +427,7 @@ def query_resource_type(resource_type_name):
@bp.route("/ServiceProviderConfig", methods=["GET"]) @bp.route("/ServiceProviderConfig", methods=["GET"])
@csrf.exempt @csrf.exempt
@require_oauth()
def query_service_provider_config(): def query_service_provider_config():
spc = ServiceProviderConfig( spc = ServiceProviderConfig(
meta=Meta( meta=Meta(
@ -436,6 +457,7 @@ def query_service_provider_config():
@bp.route("/Users", methods=["POST"]) @bp.route("/Users", methods=["POST"])
@csrf.exempt @csrf.exempt
@require_oauth()
def create_user(): def create_user():
request_user = User[EnterpriseUser].model_validate( request_user = User[EnterpriseUser].model_validate(
request.json, scim_ctx=Context.RESOURCE_CREATION_REQUEST request.json, scim_ctx=Context.RESOURCE_CREATION_REQUEST
@ -449,6 +471,7 @@ def create_user():
@bp.route("/Groups", methods=["POST"]) @bp.route("/Groups", methods=["POST"])
@csrf.exempt @csrf.exempt
@require_oauth()
def create_group(): def create_group():
request_group = Group.model_validate( request_group = Group.model_validate(
request.json, scim_ctx=Context.RESOURCE_CREATION_REQUEST request.json, scim_ctx=Context.RESOURCE_CREATION_REQUEST
@ -464,6 +487,7 @@ def create_group():
@bp.route("/Users/<user:user>", methods=["PUT"]) @bp.route("/Users/<user:user>", methods=["PUT"])
@csrf.exempt @csrf.exempt
@require_oauth()
def replace_user(user): def replace_user(user):
request_user = User[EnterpriseUser].model_validate( request_user = User[EnterpriseUser].model_validate(
request.json, scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST request.json, scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST
@ -477,6 +501,7 @@ def replace_user(user):
@bp.route("/Groups/<group:group>", methods=["PUT"]) @bp.route("/Groups/<group:group>", methods=["PUT"])
@csrf.exempt @csrf.exempt
@require_oauth()
def replace_group(group): def replace_group(group):
request_group = Group.model_validate( request_group = Group.model_validate(
request.json, scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST request.json, scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST
@ -490,6 +515,7 @@ def replace_group(group):
@bp.route("/Users/<user:user>", methods=["DELETE"]) @bp.route("/Users/<user:user>", methods=["DELETE"])
@csrf.exempt @csrf.exempt
@require_oauth()
def delete_user(user): def delete_user(user):
Backend.instance.delete(user) Backend.instance.delete(user)
return "", HTTPStatus.NO_CONTENT return "", HTTPStatus.NO_CONTENT
@ -497,6 +523,7 @@ def delete_user(user):
@bp.route("/Groups/<group:group>", methods=["DELETE"]) @bp.route("/Groups/<group:group>", methods=["DELETE"])
@csrf.exempt @csrf.exempt
@require_oauth()
def delete_group(group): def delete_group(group):
Backend.instance.delete(group) Backend.instance.delete(group)
return "", HTTPStatus.NO_CONTENT return "", HTTPStatus.NO_CONTENT

View file

@ -54,6 +54,7 @@ oidc = [
scim = [ scim = [
"scim2-models>=0.2.2", "scim2-models>=0.2.2",
"authlib >= 1.3.0",
] ]
ldap = [ ldap = [

View file

@ -1,4 +1,12 @@
import datetime
import pytest import pytest
from scim2_client.engines.werkzeug import TestSCIMClient
from werkzeug.security import gen_salt
from werkzeug.test import Client
from canaille.app import models
from canaille.scim.endpoints import bp
@pytest.fixture @pytest.fixture
@ -13,3 +21,54 @@ def configuration(configuration):
"level": "INFO", "level": "INFO",
} }
return configuration return configuration
@pytest.fixture
def oidc_client(testclient, backend):
c = models.Client(
client_id=gen_salt(24),
client_name="Some client",
contacts=["contact@mydomain.test"],
client_uri="https://mydomain.test",
redirect_uris=[
"https://mydomain.test/redirect1",
],
client_id_issued_at=datetime.datetime.now(datetime.timezone.utc),
client_secret=gen_salt(48),
grant_types=[
"client_credentials",
],
response_types=["code", "token", "id_token"],
scope=["openid", "email", "profile", "groups", "address", "phone"],
token_endpoint_auth_method="client_secret_basic",
)
backend.save(c)
yield c
backend.delete(c)
@pytest.fixture
def oidc_token(testclient, oidc_client, backend):
t = models.Token(
token_id=gen_salt(48),
access_token=gen_salt(48),
audience=[oidc_client],
client=oidc_client,
refresh_token=gen_salt(48),
scope=["openid", "profile"],
issue_date=datetime.datetime.now(datetime.timezone.utc),
lifetime=3600,
)
backend.save(t)
yield t
backend.delete(t)
@pytest.fixture
def scim_client(app, oidc_client, oidc_token):
return TestSCIMClient(
Client(app),
scim_prefix=bp.url_prefix,
environ={"headers": {"Authorization": f"Bearer {oidc_token.access_token}"}},
)

View file

@ -0,0 +1,45 @@
import datetime
import pytest
from scim2_client import SCIMResponseErrorObject
from scim2_client.engines.werkzeug import TestSCIMClient
from werkzeug.security import gen_salt
from werkzeug.test import Client
from canaille.app import models
from canaille.scim.endpoints import bp
def test_authentication_failure(app):
"""Test authentication with an invalid token."""
scim_client = TestSCIMClient(
Client(app),
scim_prefix=bp.url_prefix,
environ={"headers": {"Authorization": "Bearer invalid"}},
)
with pytest.raises(SCIMResponseErrorObject):
scim_client.discover()
def test_authentication_with_an_user_token(app, backend, oidc_client, user):
"""Test authentication with an user token."""
scim_token = models.Token(
token_id=gen_salt(48),
access_token=gen_salt(48),
subject=user,
audience=[oidc_client],
client=oidc_client,
refresh_token=gen_salt(48),
scope=["openid", "profile"],
issue_date=datetime.datetime.now(datetime.timezone.utc),
lifetime=3600,
)
backend.save(scim_token)
scim_client = TestSCIMClient(
Client(app),
scim_prefix=bp.url_prefix,
environ={"headers": {"Authorization": f"Bearer {scim_token.access_token}"}},
)
with pytest.raises(SCIMResponseErrorObject):
scim_client.discover()

View file

@ -1,11 +1,8 @@
import pytest import pytest
from scim2_client.engines.werkzeug import TestSCIMClient
from scim2_tester import check_server from scim2_tester import check_server
from canaille.scim.endpoints import bp
def test_scim_tester(scim_client, backend):
def test_scim_tester(app, backend):
# currently the tester create empty groups because it cannot handle references # currently the tester create empty groups because it cannot handle references
# but LDAP does not support empty groups # but LDAP does not support empty groups
# https://github.com/python-scim/scim2-tester/issues/15 # https://github.com/python-scim/scim2-tester/issues/15
@ -13,5 +10,4 @@ def test_scim_tester(app, backend):
if "ldap" in backend.__class__.__module__: if "ldap" in backend.__class__.__module__:
pytest.skip() pytest.skip()
client = TestSCIMClient(app, scim_prefix=bp.url_prefix) check_server(scim_client, raise_exceptions=True)
check_server(client, raise_exceptions=True)

View file

@ -165,6 +165,7 @@ postgresql = [
{ name = "sqlalchemy-utils" }, { name = "sqlalchemy-utils" },
] ]
scim = [ scim = [
{ name = "authlib" },
{ name = "scim2-models" }, { name = "scim2-models" },
] ]
sentry = [ sentry = [
@ -222,6 +223,7 @@ doc = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "authlib", marker = "extra == 'oidc'", specifier = ">=1.3.0" }, { name = "authlib", marker = "extra == 'oidc'", specifier = ">=1.3.0" },
{ name = "authlib", marker = "extra == 'scim'", specifier = ">=1.3.0" },
{ name = "email-validator", marker = "extra == 'front'", specifier = ">=2.0.0" }, { name = "email-validator", marker = "extra == 'front'", specifier = ">=2.0.0" },
{ name = "flask", specifier = ">=3.0.0" }, { name = "flask", specifier = ">=3.0.0" },
{ name = "flask-babel", marker = "extra == 'front'", specifier = ">=4.0.0" }, { name = "flask-babel", marker = "extra == 'front'", specifier = ">=4.0.0" },
@ -1634,14 +1636,14 @@ wheels = [
[[package]] [[package]]
name = "scim2-client" name = "scim2-client"
version = "0.4.3" version = "0.5.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "scim2-models" }, { name = "scim2-models" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/c9/44/1b228a6a680ca96a1274f2ca1dd22aa3e61e656c5e829c348b27f793dc9d/scim2_client-0.4.3.tar.gz", hash = "sha256:69f55e1c296cb018cb4d71954485b6dab8153bb59935647b1e063a659c141ede", size = 85428 } sdist = { url = "https://files.pythonhosted.org/packages/4f/d0/06a2a68c8b6a840fd8020ebfaf0141e1eadff0a24b4a2ba87c1d0fb9607d/scim2_client-0.5.0.tar.gz", hash = "sha256:f485864c0148cbbddd6a4120a4b3c2553ca89a8076d5cf7bdfa8ad6aba2c1e6e", size = 85783 }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/75/d56f022664d6db564b0e85265ab5e97b49133f796ff1d502bd652e7c075f/scim2_client-0.4.3-py3-none-any.whl", hash = "sha256:051578f4e56e57149b1b6ea06c30cd4d831b8f2278c301e92421b0ebf524b813", size = 22373 }, { url = "https://files.pythonhosted.org/packages/4b/e3/195d64ace80effcb948773914b72a8705565afbd02471ac827b28dfa977a/scim2_client-0.5.0-py3-none-any.whl", hash = "sha256:9f290aafea88d4220372a4902a17b3e7ea4dbdae69dfe9489b938d8d7a7ac827", size = 22500 },
] ]
[[package]] [[package]]