diff --git a/canaille/core/models.py b/canaille/core/models.py index 7ea4f2c7..2ab57ac5 100644 --- a/canaille/core/models.py +++ b/canaille/core/models.py @@ -504,37 +504,39 @@ class User(Model): def propagate_scim_changes(self): for client in self.get_clients(): - tokens = Backend.instance.query(models.Token, subject=self, client=client) - if tokens: - token = tokens[0] - client = httpx_client( - base_url=client.client_uri, - headers={"Authorization": f"Bearer {token.access_token}"}, - ) - scim = SyncSCIMClient(client) - scim.discover() - User = scim.get_resource_model("User") - EnterpriseUser = User.get_extension_model("EnterpriseUser") - user = user_from_canaille_to_scim_for_client(self, User, EnterpriseUser) + client_httpx = httpx_client( + base_url=client.client_uri, + headers={ + "Authorization": "Bearer NXeEceY820rnlzoh0FUxc4TFVKO5aAqikopiPEQacvL81ukk" + }, + ) + scim = SyncSCIMClient(client_httpx) + scim.discover() + User = scim.get_resource_model("User") + EnterpriseUser = User.get_extension_model("EnterpriseUser") + user = user_from_canaille_to_scim_for_client(self, User, EnterpriseUser) - req = SearchRequest(filter=f'userName eq "{self.user_name}"') - response = scim.query(User, search_request=req) + req = SearchRequest(filter=f'userName eq "{self.user_name}"') + response = scim.query(User, search_request=req) - if not response: - try: - scim.create(user) - except: - current_app.logger.warning( - f"SCIM User {self.user_name} creation for client {client.client_name} failed" - ) - else: - user.id = response.id - try: - scim.replace(user) - except: - current_app.logger.warning( - f"SCIM User {self.user_name} update for client {client.client_name} failed" - ) + if not response.resources: + try: + scim.create(user) + except Exception: + current_app.logger.warning( + f"SCIM User {self.user_name} creation for client {client.client_name} failed" + ) + else: + user.id = response.resources[0].id + try: + scim.replace(user) + except: + current_app.logger.warning( + f"SCIM User {self.user_name} update for client {client.client_name} failed" + ) + req = SearchRequest(filter=f'userName eq "{self.user_name}"') + response = scim.query(User, search_request=req) + print("response:", response) def propagate_scim_delete(self): client = httpx_client( @@ -550,8 +552,10 @@ class User(Model): current_app.logger.warning(f"SCIM User {self.user_name} delete failed") def get_clients(self): - consents = Backend.instance.query(models.Consent, subject=self) - return {t.client for t in consents} + if self.id: + consents = Backend.instance.query(models.Consent, subject=self) + return {t.client for t in consents} + return [] class Group(Model): diff --git a/demo/client/__init__.py b/demo/client/__init__.py index bee28fb9..204942c6 100644 --- a/demo/client/__init__.py +++ b/demo/client/__init__.py @@ -1,20 +1,123 @@ +import datetime +import json +import uuid +from http import HTTPStatus from urllib.parse import urlsplit from urllib.parse import urlunsplit from authlib.common.errors import AuthlibBaseError from authlib.integrations.flask_client import OAuth +from authlib.integrations.flask_oauth2 import ResourceProtector +from authlib.integrations.flask_oauth2.errors import ( + _HTTPException as AuthlibHTTPException, +) +from authlib.oauth2.rfc6750 import BearerTokenValidator from authlib.oidc.discovery import get_well_known_url +from flask import Blueprint from flask import Flask +from flask import Response +from flask import abort from flask import current_app from flask import flash from flask import redirect from flask import render_template +from flask import request from flask import session from flask import url_for +from pydantic import ValidationError +from scim2_models import Context +from scim2_models import EnterpriseUser +from scim2_models import Error +from scim2_models import Group +from scim2_models import ListResponse +from scim2_models import ResourceType +from scim2_models import Schema +from scim2_models import SearchRequest +from scim2_models import User +from werkzeug.exceptions import HTTPException +from canaille import csrf +from canaille.oidc.models import Token +from canaille.scim.models import get_resource_types +from canaille.scim.models import get_schemas +from canaille.scim.models import get_service_provider_config + +bp = Blueprint("scim", __name__) oauth = OAuth() +class SCIMBearerTokenValidator(BearerTokenValidator): + def authenticate_token(self, token_string: str): + token = Token() + token.token_id = "bidulos" + token.access_token = "NXeEceY820rnlzoh0FUxc4TFVKO5aAqikopiPEQacvL81ukk" + token.issue_date = datetime.datetime.now(datetime.timezone.utc) + token.lifetime = 600 + token.revokation_date = None + token.scope = "openid profile email groups address phone" + backend.add_token(token) + return backend.get_token(token_string) + + +require_oauth = ResourceProtector() +require_oauth.register_token_validator(SCIMBearerTokenValidator()) + + +class ClientBackend: + users: list[User[EnterpriseUser]] = [] + groups: list[Group] = [] + tokens: list = [] + + def get_user_by_username(self, username): + for user in self.users: + if user.user_name == username: + return user + return None + + def add_user(self, user): + user.id = str(uuid.uuid4()) + self.users.append(user) + + def replace_user(self, user): + for saved_user in self.users: + if saved_user.userName == user.userName: + saved_user = user + break + + def delete_user(self, user): + for saved_user in self.users: + if saved_user.user_name == user.userName: + self.users.remove(saved_user) + break + + def add_group(self, group): + self.groups.append(group) + + def replace_group(self, group): + for saved_group in self.groups: + if saved_group.groupName == group.groupName: + saved_group = group + break + + def delete_group(self, group): + for saved_group in self.groups: + if saved_group.group_name == group.groupName: + self.groups.remove(saved_group) + break + + def add_token(self, token): + self.tokens.append(token) + + def get_token(self, access_token): + for token in self.tokens: + if token.access_token == access_token: + return token + return None + + +backend = ClientBackend() + + def setup_routes(app): @app.route("/") @app.route("/tos") @@ -84,6 +187,206 @@ def setup_routes(app): flash("You have been successfully logged out", "success") return redirect(url_for("index")) + @bp.route("/Users") + @csrf.exempt + @require_oauth() + def query_users(): + req = parse_search_request(request) + users = backend.users[req.start_index_0 : req.stop_index_0] + total = len(users) + list_response = ListResponse[User[EnterpriseUser]]( + start_index=req.start_index, + items_per_page=req.count, + total_results=total, + resources=users, + ) + payload = list_response.model_dump( + scim_ctx=Context.RESOURCE_QUERY_RESPONSE, + ) + return payload + + @bp.route("/Users/", methods=["GET"]) + @csrf.exempt + @require_oauth() + def query_user(username): + user = backend.get_user_by_username(username) + if user: + return user.model_dump( + scim_ctx=Context.RESOURCE_QUERY_RESPONSE, + ) + + @bp.route("/Schemas", methods=["GET"]) + @csrf.exempt + @require_oauth() + def query_schemas(): + req = parse_search_request(request) + schemas = list(get_schemas().values())[req.start_index_0 : req.stop_index_0] + response = ListResponse[Schema]( + total_results=len(schemas), + items_per_page=req.count or len(schemas), + start_index=req.start_index, + resources=schemas, + ) + return response.model_dump(scim_ctx=Context.RESOURCE_QUERY_RESPONSE) + + @bp.route("/Schemas/", methods=["GET"]) + @csrf.exempt + @require_oauth() + def query_schema(schema_id): + schema = get_schemas().get(schema_id) + if not schema: + abort(404) + + return schema.model_dump(scim_ctx=Context.RESOURCE_QUERY_RESPONSE) + + @bp.route("/ResourceTypes", methods=["GET"]) + @csrf.exempt + @require_oauth() + def query_resource_types(): + req = parse_search_request(request) + resource_types = list(get_resource_types().values())[ + req.start_index_0 : req.stop_index_0 + ] + response = ListResponse[ResourceType]( + total_results=len(resource_types), + items_per_page=req.count or len(resource_types), + start_index=req.start_index, + resources=resource_types, + ) + return response.model_dump(scim_ctx=Context.RESOURCE_QUERY_RESPONSE) + + @bp.route("/ResourceTypes/", methods=["GET"]) + @csrf.exempt + @require_oauth() + def query_resource_type(resource_type_name): + resource_type = get_resource_types().get(resource_type_name) + if not resource_type: + abort(404) + + return resource_type.model_dump(scim_ctx=Context.RESOURCE_QUERY_RESPONSE) + + @bp.route("/ServiceProviderConfig", methods=["GET"]) + @csrf.exempt + @require_oauth() + def query_service_provider_config(): + spc = get_service_provider_config() + return spc.model_dump(scim_ctx=Context.RESOURCE_QUERY_RESPONSE) + + @bp.route("/Users", methods=["POST"]) + @csrf.exempt + @require_oauth() + def create_user(): + user = User[EnterpriseUser].model_validate( + request.json, scim_ctx=Context.RESOURCE_CREATION_REQUEST + ) + backend.add_user(user) + payload = user.model_dump_json(scim_ctx=Context.RESOURCE_CREATION_RESPONSE) + return Response(payload, status=HTTPStatus.CREATED) + + @bp.route("/Users/", methods=["PUT"]) + @csrf.exempt + @require_oauth() + def replace_user(username): + user = backend.get_user_by_username(username) + request_scim_user = User[EnterpriseUser].model_validate( + request.json, + scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST, + original=user, + ) + backend.replace_user(request_scim_user) + payload = request_scim_user.model_dump( + scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE + ) + return payload + + @bp.route("/Users/", methods=["DELETE"]) + @csrf.exempt + @require_oauth() + def delete_user(username): + user = backend.get_user_by_username(username) + backend.delete_user(user) + return "", HTTPStatus.NO_CONTENT + + @bp.route("/Groups", methods=["GET"]) + @csrf.exempt + @require_oauth() + def query_groups(): + req = parse_search_request(request) + groups = backend.groups[req.start_index_0 : req.stop_index_0] + total = len(groups) + list_response = ListResponse[Group]( + start_index=req.start_index, + items_per_page=req.count, + total_results=total, + resources=groups, + ) + payload = list_response.model_dump( + scim_ctx=Context.RESOURCE_QUERY_RESPONSE, + ) + return payload + + @bp.route("/Groups/", methods=["GET"]) + @csrf.exempt + @require_oauth() + def query_group(groupName): + return groupName.model_dump( + scim_ctx=Context.RESOURCE_QUERY_RESPONSE, + ) + + @bp.route("/Groups/", methods=["PUT"]) + @csrf.exempt + @require_oauth() + def replace_group(groupName): + group = Group.model_validate( + request.json, + scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST, + original=groupName, + ) + backend.save_group(group) + payload = group.model_dump(scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE) + return payload + + @bp.route("/Groups/", methods=["DELETE"]) + @csrf.exempt + @require_oauth() + def delete_group(groupName): + backend.delete_group(groupName) + return "", HTTPStatus.NO_CONTENT + + @bp.after_request + def add_scim_content_type(response): + response.headers["Content-Type"] = "application/scim+json" + return response + + @bp.errorhandler(HTTPException) + def http_error_handler(error): + obj = Error(detail=str(error), status=error.code) + return obj.model_dump(), obj.status + + @bp.errorhandler(AuthlibHTTPException) + def oauth2_error(error): + body = json.loads(error.body) + obj = Error( + detail=f"{body['error']}: {body['error_description']}" + if "error_description" in body + else body["error"], + status=error.code, + ) + return obj.model_dump(), error.code + + @bp.errorhandler(ValidationError) + def scim_error_handler(error): + error_details = error.errors()[0] + obj = Error(status=400, detail=error_details["msg"]) + # TODO: maybe the Pydantic <=> SCIM error code mapping could go in scim2_models + obj.scim_type = ( + "invalidValue" if error_details["type"] == "required_error" else None + ) + + return obj.model_dump(), obj.status + + app.register_blueprint(bp) + def setup_oauth(app): oauth.init_app(app) @@ -119,3 +422,20 @@ def set_parameter_in_url_query(url, **kwargs): split[3] = parameters return urlunsplit(split) + + +def parse_search_request(request) -> SearchRequest: + """Create a SearchRequest object from the request arguments.""" + max_nb_items_per_page = 1000 + count = ( + min(request.args["count"], max_nb_items_per_page) + if request.args.get("count") + else None + ) + req = SearchRequest( + attributes=request.args.get("attributes"), + excluded_attributes=request.args.get("excludedAttributes"), + start_index=request.args.get("startIndex"), + count=count, + ) + return req