feat: Achieved communication with SCIM client

This commit is contained in:
Félix Rohrlich 2024-12-17 18:11:08 +01:00
parent efe79505fd
commit 6e64f51ad4
2 changed files with 355 additions and 31 deletions

View file

@ -504,37 +504,39 @@ class User(Model):
def propagate_scim_changes(self): def propagate_scim_changes(self):
for client in self.get_clients(): for client in self.get_clients():
tokens = Backend.instance.query(models.Token, subject=self, client=client) client_httpx = httpx_client(
if tokens: base_url=client.client_uri,
token = tokens[0] headers={
client = httpx_client( "Authorization": "Bearer NXeEceY820rnlzoh0FUxc4TFVKO5aAqikopiPEQacvL81ukk"
base_url=client.client_uri, },
headers={"Authorization": f"Bearer {token.access_token}"}, )
) scim = SyncSCIMClient(client_httpx)
scim = SyncSCIMClient(client) scim.discover()
scim.discover() User = scim.get_resource_model("User")
User = scim.get_resource_model("User") EnterpriseUser = User.get_extension_model("EnterpriseUser")
EnterpriseUser = User.get_extension_model("EnterpriseUser") user = user_from_canaille_to_scim_for_client(self, User, EnterpriseUser)
user = user_from_canaille_to_scim_for_client(self, User, EnterpriseUser)
req = SearchRequest(filter=f'userName eq "{self.user_name}"') req = SearchRequest(filter=f'userName eq "{self.user_name}"')
response = scim.query(User, search_request=req) response = scim.query(User, search_request=req)
if not response: if not response.resources:
try: try:
scim.create(user) scim.create(user)
except: except Exception:
current_app.logger.warning( current_app.logger.warning(
f"SCIM User {self.user_name} creation for client {client.client_name} failed" f"SCIM User {self.user_name} creation for client {client.client_name} failed"
) )
else: else:
user.id = response.id user.id = response.resources[0].id
try: try:
scim.replace(user) scim.replace(user)
except: except:
current_app.logger.warning( current_app.logger.warning(
f"SCIM User {self.user_name} update for client {client.client_name} failed" f"SCIM User {self.user_name} update for client {client.client_name} failed"
) )
req = SearchRequest(filter=f'userName eq "{self.user_name}"')
response = scim.query(User, search_request=req)
print("response:", response)
def propagate_scim_delete(self): def propagate_scim_delete(self):
client = httpx_client( client = httpx_client(
@ -550,8 +552,10 @@ class User(Model):
current_app.logger.warning(f"SCIM User {self.user_name} delete failed") current_app.logger.warning(f"SCIM User {self.user_name} delete failed")
def get_clients(self): def get_clients(self):
consents = Backend.instance.query(models.Consent, subject=self) if self.id:
return {t.client for t in consents} consents = Backend.instance.query(models.Consent, subject=self)
return {t.client for t in consents}
return []
class Group(Model): class Group(Model):

View file

@ -1,20 +1,123 @@
import datetime
import json
import uuid
from http import HTTPStatus
from urllib.parse import urlsplit from urllib.parse import urlsplit
from urllib.parse import urlunsplit from urllib.parse import urlunsplit
from authlib.common.errors import AuthlibBaseError from authlib.common.errors import AuthlibBaseError
from authlib.integrations.flask_client import OAuth from authlib.integrations.flask_client import OAuth
from authlib.integrations.flask_oauth2 import ResourceProtector
from authlib.integrations.flask_oauth2.errors import (
_HTTPException as AuthlibHTTPException,
)
from authlib.oauth2.rfc6750 import BearerTokenValidator
from authlib.oidc.discovery import get_well_known_url from authlib.oidc.discovery import get_well_known_url
from flask import Blueprint
from flask import Flask from flask import Flask
from flask import Response
from flask import abort
from flask import current_app from flask import current_app
from flask import flash from flask import flash
from flask import redirect from flask import redirect
from flask import render_template from flask import render_template
from flask import request
from flask import session from flask import session
from flask import url_for from flask import url_for
from pydantic import ValidationError
from scim2_models import Context
from scim2_models import EnterpriseUser
from scim2_models import Error
from scim2_models import Group
from scim2_models import ListResponse
from scim2_models import ResourceType
from scim2_models import Schema
from scim2_models import SearchRequest
from scim2_models import User
from werkzeug.exceptions import HTTPException
from canaille import csrf
from canaille.oidc.models import Token
from canaille.scim.models import get_resource_types
from canaille.scim.models import get_schemas
from canaille.scim.models import get_service_provider_config
bp = Blueprint("scim", __name__)
oauth = OAuth() oauth = OAuth()
class SCIMBearerTokenValidator(BearerTokenValidator):
def authenticate_token(self, token_string: str):
token = Token()
token.token_id = "bidulos"
token.access_token = "NXeEceY820rnlzoh0FUxc4TFVKO5aAqikopiPEQacvL81ukk"
token.issue_date = datetime.datetime.now(datetime.timezone.utc)
token.lifetime = 600
token.revokation_date = None
token.scope = "openid profile email groups address phone"
backend.add_token(token)
return backend.get_token(token_string)
require_oauth = ResourceProtector()
require_oauth.register_token_validator(SCIMBearerTokenValidator())
class ClientBackend:
users: list[User[EnterpriseUser]] = []
groups: list[Group] = []
tokens: list = []
def get_user_by_username(self, username):
for user in self.users:
if user.user_name == username:
return user
return None
def add_user(self, user):
user.id = str(uuid.uuid4())
self.users.append(user)
def replace_user(self, user):
for saved_user in self.users:
if saved_user.userName == user.userName:
saved_user = user
break
def delete_user(self, user):
for saved_user in self.users:
if saved_user.user_name == user.userName:
self.users.remove(saved_user)
break
def add_group(self, group):
self.groups.append(group)
def replace_group(self, group):
for saved_group in self.groups:
if saved_group.groupName == group.groupName:
saved_group = group
break
def delete_group(self, group):
for saved_group in self.groups:
if saved_group.group_name == group.groupName:
self.groups.remove(saved_group)
break
def add_token(self, token):
self.tokens.append(token)
def get_token(self, access_token):
for token in self.tokens:
if token.access_token == access_token:
return token
return None
backend = ClientBackend()
def setup_routes(app): def setup_routes(app):
@app.route("/") @app.route("/")
@app.route("/tos") @app.route("/tos")
@ -84,6 +187,206 @@ def setup_routes(app):
flash("You have been successfully logged out", "success") flash("You have been successfully logged out", "success")
return redirect(url_for("index")) return redirect(url_for("index"))
@bp.route("/Users")
@csrf.exempt
@require_oauth()
def query_users():
req = parse_search_request(request)
users = backend.users[req.start_index_0 : req.stop_index_0]
total = len(users)
list_response = ListResponse[User[EnterpriseUser]](
start_index=req.start_index,
items_per_page=req.count,
total_results=total,
resources=users,
)
payload = list_response.model_dump(
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
)
return payload
@bp.route("/Users/<string:username>", methods=["GET"])
@csrf.exempt
@require_oauth()
def query_user(username):
user = backend.get_user_by_username(username)
if user:
return user.model_dump(
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
)
@bp.route("/Schemas", methods=["GET"])
@csrf.exempt
@require_oauth()
def query_schemas():
req = parse_search_request(request)
schemas = list(get_schemas().values())[req.start_index_0 : req.stop_index_0]
response = ListResponse[Schema](
total_results=len(schemas),
items_per_page=req.count or len(schemas),
start_index=req.start_index,
resources=schemas,
)
return response.model_dump(scim_ctx=Context.RESOURCE_QUERY_RESPONSE)
@bp.route("/Schemas/<string:schema_id>", methods=["GET"])
@csrf.exempt
@require_oauth()
def query_schema(schema_id):
schema = get_schemas().get(schema_id)
if not schema:
abort(404)
return schema.model_dump(scim_ctx=Context.RESOURCE_QUERY_RESPONSE)
@bp.route("/ResourceTypes", methods=["GET"])
@csrf.exempt
@require_oauth()
def query_resource_types():
req = parse_search_request(request)
resource_types = list(get_resource_types().values())[
req.start_index_0 : req.stop_index_0
]
response = ListResponse[ResourceType](
total_results=len(resource_types),
items_per_page=req.count or len(resource_types),
start_index=req.start_index,
resources=resource_types,
)
return response.model_dump(scim_ctx=Context.RESOURCE_QUERY_RESPONSE)
@bp.route("/ResourceTypes/<string:resource_type_name>", methods=["GET"])
@csrf.exempt
@require_oauth()
def query_resource_type(resource_type_name):
resource_type = get_resource_types().get(resource_type_name)
if not resource_type:
abort(404)
return resource_type.model_dump(scim_ctx=Context.RESOURCE_QUERY_RESPONSE)
@bp.route("/ServiceProviderConfig", methods=["GET"])
@csrf.exempt
@require_oauth()
def query_service_provider_config():
spc = get_service_provider_config()
return spc.model_dump(scim_ctx=Context.RESOURCE_QUERY_RESPONSE)
@bp.route("/Users", methods=["POST"])
@csrf.exempt
@require_oauth()
def create_user():
user = User[EnterpriseUser].model_validate(
request.json, scim_ctx=Context.RESOURCE_CREATION_REQUEST
)
backend.add_user(user)
payload = user.model_dump_json(scim_ctx=Context.RESOURCE_CREATION_RESPONSE)
return Response(payload, status=HTTPStatus.CREATED)
@bp.route("/Users/<string:username>", methods=["PUT"])
@csrf.exempt
@require_oauth()
def replace_user(username):
user = backend.get_user_by_username(username)
request_scim_user = User[EnterpriseUser].model_validate(
request.json,
scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST,
original=user,
)
backend.replace_user(request_scim_user)
payload = request_scim_user.model_dump(
scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE
)
return payload
@bp.route("/Users/<string:username>", methods=["DELETE"])
@csrf.exempt
@require_oauth()
def delete_user(username):
user = backend.get_user_by_username(username)
backend.delete_user(user)
return "", HTTPStatus.NO_CONTENT
@bp.route("/Groups", methods=["GET"])
@csrf.exempt
@require_oauth()
def query_groups():
req = parse_search_request(request)
groups = backend.groups[req.start_index_0 : req.stop_index_0]
total = len(groups)
list_response = ListResponse[Group](
start_index=req.start_index,
items_per_page=req.count,
total_results=total,
resources=groups,
)
payload = list_response.model_dump(
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
)
return payload
@bp.route("/Groups/<string:groupName>", methods=["GET"])
@csrf.exempt
@require_oauth()
def query_group(groupName):
return groupName.model_dump(
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
)
@bp.route("/Groups/<string:groupName>", methods=["PUT"])
@csrf.exempt
@require_oauth()
def replace_group(groupName):
group = Group.model_validate(
request.json,
scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST,
original=groupName,
)
backend.save_group(group)
payload = group.model_dump(scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE)
return payload
@bp.route("/Groups/<string:groupName>", methods=["DELETE"])
@csrf.exempt
@require_oauth()
def delete_group(groupName):
backend.delete_group(groupName)
return "", HTTPStatus.NO_CONTENT
@bp.after_request
def add_scim_content_type(response):
response.headers["Content-Type"] = "application/scim+json"
return response
@bp.errorhandler(HTTPException)
def http_error_handler(error):
obj = Error(detail=str(error), status=error.code)
return obj.model_dump(), obj.status
@bp.errorhandler(AuthlibHTTPException)
def oauth2_error(error):
body = json.loads(error.body)
obj = Error(
detail=f"{body['error']}: {body['error_description']}"
if "error_description" in body
else body["error"],
status=error.code,
)
return obj.model_dump(), error.code
@bp.errorhandler(ValidationError)
def scim_error_handler(error):
error_details = error.errors()[0]
obj = Error(status=400, detail=error_details["msg"])
# TODO: maybe the Pydantic <=> SCIM error code mapping could go in scim2_models
obj.scim_type = (
"invalidValue" if error_details["type"] == "required_error" else None
)
return obj.model_dump(), obj.status
app.register_blueprint(bp)
def setup_oauth(app): def setup_oauth(app):
oauth.init_app(app) oauth.init_app(app)
@ -119,3 +422,20 @@ def set_parameter_in_url_query(url, **kwargs):
split[3] = parameters split[3] = parameters
return urlunsplit(split) return urlunsplit(split)
def parse_search_request(request) -> SearchRequest:
"""Create a SearchRequest object from the request arguments."""
max_nb_items_per_page = 1000
count = (
min(request.args["count"], max_nb_items_per_page)
if request.args.get("count")
else None
)
req = SearchRequest(
attributes=request.args.get("attributes"),
excluded_attributes=request.args.get("excludedAttributes"),
start_index=request.args.get("startIndex"),
count=count,
)
return req