diff --git a/canaille/backends/sql/models.py b/canaille/backends/sql/models.py index 4d90a440..a693ca35 100644 --- a/canaille/backends/sql/models.py +++ b/canaille/backends/sql/models.py @@ -111,7 +111,6 @@ class User(canaille.core.models.User, Base, SqlAlchemyModel): one_time_password_emission_date: Mapped[datetime.datetime] = mapped_column( TZDateTime(timezone=True), nullable=True ) - scim_id: Mapped[str] = mapped_column(String, nullable=True, unique=True) @property def password_failure_timestamps(self): diff --git a/canaille/core/models.py b/canaille/core/models.py index 2ab57ac5..2f157a41 100644 --- a/canaille/core/models.py +++ b/canaille/core/models.py @@ -7,6 +7,7 @@ from flask import current_app from httpx import Client as httpx_client from scim2_client.engines.httpx import SyncSCIMClient from scim2_models import SearchRequest +from werkzeug.security import gen_salt from canaille.app import models from canaille.backends import Backend @@ -284,8 +285,6 @@ class User(Model): one_time_password_emission_date: datetime.datetime | None = None """A DateTime indicating when the user last emitted an email or sms one-time password.""" - scim_id: str | None = None - _readable_fields = None _writable_fields = None _permissions = None @@ -504,11 +503,33 @@ class User(Model): def propagate_scim_changes(self): for client in self.get_clients(): + scim_tokens = Backend.instance.query( + models.Token, client=client, subject=None + ) + valid_scim_tokens = [ + token + for token in scim_tokens + if not token.is_expired() and not token.is_revoked() + ] + if valid_scim_tokens: + scim_token = valid_scim_tokens[0] + else: + scim_token = models.Token( + token_id=gen_salt(48), + access_token=gen_salt(48), + subject=None, + audience=[client], + client=client, + refresh_token=gen_salt(48), + scope=["openid", "profile"], + issue_date=datetime.datetime.now(datetime.timezone.utc), + lifetime=3600, + ) + Backend.instance.save(scim_token) + client_httpx = httpx_client( base_url=client.client_uri, - headers={ - "Authorization": "Bearer NXeEceY820rnlzoh0FUxc4TFVKO5aAqikopiPEQacvL81ukk" - }, + headers={"Authorization": f"Bearer {scim_token.access_token}"}, ) scim = SyncSCIMClient(client_httpx) scim.discover() @@ -518,7 +539,6 @@ class User(Model): req = SearchRequest(filter=f'userName eq "{self.user_name}"') response = scim.query(User, search_request=req) - if not response.resources: try: scim.create(user) @@ -536,7 +556,6 @@ class User(Model): ) req = SearchRequest(filter=f'userName eq "{self.user_name}"') response = scim.query(User, search_request=req) - print("response:", response) def propagate_scim_delete(self): client = httpx_client( @@ -554,7 +573,13 @@ class User(Model): def get_clients(self): if self.id: consents = Backend.instance.query(models.Consent, subject=self) - return {t.client for t in consents} + consented_clients = {t.client for t in consents} + preconsented_clients = [ + client + for client in Backend.instance.query(models.Client) + if client.preconsent and client not in consented_clients + ] + return list(consented_clients) + list(preconsented_clients) return [] diff --git a/canaille/oidc/oauth.py b/canaille/oidc/oauth.py index c7f1a424..12177611 100644 --- a/canaille/oidc/oauth.py +++ b/canaille/oidc/oauth.py @@ -394,18 +394,20 @@ class IntrospectionEndpoint(_IntrospectionEndpoint): def introspect_token(self, token): audience = [aud.client_id for aud in token.audience] - return { + response = { "active": True, "client_id": token.client.client_id, "token_type": token.type, - "username": token.subject.formatted_name, "scope": token.get_scope(), - "sub": token.subject.user_name, "aud": audience, "iss": get_issuer(), "exp": token.get_expires_at(), "iat": token.get_issued_at(), } + if token.subject: + response["username"] = token.subject.formatted_name + response["sub"] = token.subject.user_name + return response class ClientManagementMixin: diff --git a/demo/client/__init__.py b/demo/client/__init__.py index 204942c6..7ccdefd6 100644 --- a/demo/client/__init__.py +++ b/demo/client/__init__.py @@ -1,17 +1,17 @@ -import datetime import json import uuid from http import HTTPStatus from urllib.parse import urlsplit from urllib.parse import urlunsplit +import requests from authlib.common.errors import AuthlibBaseError from authlib.integrations.flask_client import OAuth from authlib.integrations.flask_oauth2 import ResourceProtector from authlib.integrations.flask_oauth2.errors import ( _HTTPException as AuthlibHTTPException, ) -from authlib.oauth2.rfc6750 import BearerTokenValidator +from authlib.oauth2.rfc7662 import IntrospectTokenValidator from authlib.oidc.discovery import get_well_known_url from flask import Blueprint from flask import Flask @@ -37,7 +37,6 @@ from scim2_models import User from werkzeug.exceptions import HTTPException from canaille import csrf -from canaille.oidc.models import Token from canaille.scim.models import get_resource_types from canaille.scim.models import get_schemas from canaille.scim.models import get_service_provider_config @@ -46,17 +45,17 @@ bp = Blueprint("scim", __name__) oauth = OAuth() -class SCIMBearerTokenValidator(BearerTokenValidator): - def authenticate_token(self, token_string: str): - token = Token() - token.token_id = "bidulos" - token.access_token = "NXeEceY820rnlzoh0FUxc4TFVKO5aAqikopiPEQacvL81ukk" - token.issue_date = datetime.datetime.now(datetime.timezone.utc) - token.lifetime = 600 - token.revokation_date = None - token.scope = "openid profile email groups address phone" - backend.add_token(token) - return backend.get_token(token_string) +class SCIMBearerTokenValidator(IntrospectTokenValidator): + def introspect_token(self, token_string: str): + url = current_app.config["OAUTH_AUTH_SERVER"] + "/oauth/introspect" + data = {"token": token_string, "token_type_hint": "access_token"} + auth = ( + current_app.config["OAUTH_CLIENT_ID"], + current_app.config["OAUTH_CLIENT_SECRET"], + ) + resp = requests.post(url, data=data, auth=auth) + resp.raise_for_status() + return resp.json() require_oauth = ResourceProtector() @@ -68,9 +67,9 @@ class ClientBackend: groups: list[Group] = [] tokens: list = [] - def get_user_by_username(self, username): + def get_user_by_id(self, id): for user in self.users: - if user.user_name == username: + if user.id == id: return user return None @@ -79,9 +78,9 @@ class ClientBackend: self.users.append(user) def replace_user(self, user): - for saved_user in self.users: - if saved_user.userName == user.userName: - saved_user = user + for i, saved_user in enumerate(self.users): + if saved_user.id == user.id: + self.users[i] = user break def delete_user(self, user): @@ -105,15 +104,6 @@ class ClientBackend: self.groups.remove(saved_group) break - def add_token(self, token): - self.tokens.append(token) - - def get_token(self, access_token): - for token in self.tokens: - if token.access_token == access_token: - return token - return None - backend = ClientBackend() @@ -205,11 +195,11 @@ def setup_routes(app): ) return payload - @bp.route("/Users/", methods=["GET"]) + @bp.route("/Users/", methods=["GET"]) @csrf.exempt @require_oauth() - def query_user(username): - user = backend.get_user_by_username(username) + def query_user(id): + user = backend.get_user_by_id(id) if user: return user.model_dump( scim_ctx=Context.RESOURCE_QUERY_RESPONSE, @@ -283,27 +273,28 @@ def setup_routes(app): payload = user.model_dump_json(scim_ctx=Context.RESOURCE_CREATION_RESPONSE) return Response(payload, status=HTTPStatus.CREATED) - @bp.route("/Users/", methods=["PUT"]) + @bp.route("/Users/", methods=["PUT"]) @csrf.exempt @require_oauth() - def replace_user(username): - user = backend.get_user_by_username(username) + def replace_user(id): + user = backend.get_user_by_id(id) request_scim_user = User[EnterpriseUser].model_validate( request.json, scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST, original=user, ) + request_scim_user.id = user.id backend.replace_user(request_scim_user) payload = request_scim_user.model_dump( scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE ) return payload - @bp.route("/Users/", methods=["DELETE"]) + @bp.route("/Users/", methods=["DELETE"]) @csrf.exempt @require_oauth() - def delete_user(username): - user = backend.get_user_by_username(username) + def delete_user(id): + user = backend.get_user_by_id(id) backend.delete_user(user) return "", HTTPStatus.NO_CONTENT diff --git a/pyproject.toml b/pyproject.toml index 07ae80c7..58e2d7d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "flask-wtf >= 1.2.1", "httpx>=0.28.1", "pydantic-settings >= 2.0.3", + "q>=2.7", "requests>=2.32.3", "wtforms >= 3.1.1", ] diff --git a/uv.lock b/uv.lock index c8e48be0..5d022d0b 100644 --- a/uv.lock +++ b/uv.lock @@ -143,6 +143,7 @@ dependencies = [ { name = "flask-wtf" }, { name = "httpx" }, { name = "pydantic-settings" }, + { name = "q" }, { name = "requests" }, { name = "wtforms" }, ] @@ -255,6 +256,7 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.0.3" }, { name = "python-ldap", marker = "extra == 'ldap'", specifier = ">=3.4.0" }, { name = "pytz", marker = "extra == 'front'", specifier = ">=2022.7" }, + { name = "q", specifier = ">=2.7" }, { name = "qrcode", marker = "extra == 'otp'", specifier = ">=8.0" }, { name = "requests", specifier = ">=2.32.3" }, { name = "scim2-models", marker = "extra == 'scim'", specifier = ">=0.2.2" }, @@ -1659,6 +1661,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 }, ] +[[package]] +name = "q" +version = "2.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/15/90/2649ecc3b4b335e62de4a0c3762c7cd7b2f77a023c5c00649f549cebb56c/q-2.7.tar.gz", hash = "sha256:8e0b792f6658ab9e1133b5ea17af1b530530e60124cf9743bc0fa051b8c64f4e", size = 7946 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/75/f0/ae942c0530d02092702211fd36d9a465e203f732789c84d0b96fbebe3039/q-2.7-py2.py3-none-any.whl", hash = "sha256:8388a3ef7e79b3b6224189e44ddba8dc1a6e9ed3212ce96f83f6056fa532459c", size = 10390 }, +] + [[package]] name = "qrcode" version = "8.0"