forked from Github-Mirrors/canaille
feat: Achieved communication with SCIM client
This commit is contained in:
parent
efe79505fd
commit
6e64f51ad4
2 changed files with 355 additions and 31 deletions
|
@ -504,14 +504,13 @@ class User(Model):
|
|||
|
||||
def propagate_scim_changes(self):
|
||||
for client in self.get_clients():
|
||||
tokens = Backend.instance.query(models.Token, subject=self, client=client)
|
||||
if tokens:
|
||||
token = tokens[0]
|
||||
client = httpx_client(
|
||||
client_httpx = httpx_client(
|
||||
base_url=client.client_uri,
|
||||
headers={"Authorization": f"Bearer {token.access_token}"},
|
||||
headers={
|
||||
"Authorization": "Bearer NXeEceY820rnlzoh0FUxc4TFVKO5aAqikopiPEQacvL81ukk"
|
||||
},
|
||||
)
|
||||
scim = SyncSCIMClient(client)
|
||||
scim = SyncSCIMClient(client_httpx)
|
||||
scim.discover()
|
||||
User = scim.get_resource_model("User")
|
||||
EnterpriseUser = User.get_extension_model("EnterpriseUser")
|
||||
|
@ -520,21 +519,24 @@ class User(Model):
|
|||
req = SearchRequest(filter=f'userName eq "{self.user_name}"')
|
||||
response = scim.query(User, search_request=req)
|
||||
|
||||
if not response:
|
||||
if not response.resources:
|
||||
try:
|
||||
scim.create(user)
|
||||
except:
|
||||
except Exception:
|
||||
current_app.logger.warning(
|
||||
f"SCIM User {self.user_name} creation for client {client.client_name} failed"
|
||||
)
|
||||
else:
|
||||
user.id = response.id
|
||||
user.id = response.resources[0].id
|
||||
try:
|
||||
scim.replace(user)
|
||||
except:
|
||||
current_app.logger.warning(
|
||||
f"SCIM User {self.user_name} update for client {client.client_name} failed"
|
||||
)
|
||||
req = SearchRequest(filter=f'userName eq "{self.user_name}"')
|
||||
response = scim.query(User, search_request=req)
|
||||
print("response:", response)
|
||||
|
||||
def propagate_scim_delete(self):
|
||||
client = httpx_client(
|
||||
|
@ -550,8 +552,10 @@ class User(Model):
|
|||
current_app.logger.warning(f"SCIM User {self.user_name} delete failed")
|
||||
|
||||
def get_clients(self):
|
||||
if self.id:
|
||||
consents = Backend.instance.query(models.Consent, subject=self)
|
||||
return {t.client for t in consents}
|
||||
return []
|
||||
|
||||
|
||||
class Group(Model):
|
||||
|
|
|
@ -1,20 +1,123 @@
|
|||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
from http import HTTPStatus
|
||||
from urllib.parse import urlsplit
|
||||
from urllib.parse import urlunsplit
|
||||
|
||||
from authlib.common.errors import AuthlibBaseError
|
||||
from authlib.integrations.flask_client import OAuth
|
||||
from authlib.integrations.flask_oauth2 import ResourceProtector
|
||||
from authlib.integrations.flask_oauth2.errors import (
|
||||
_HTTPException as AuthlibHTTPException,
|
||||
)
|
||||
from authlib.oauth2.rfc6750 import BearerTokenValidator
|
||||
from authlib.oidc.discovery import get_well_known_url
|
||||
from flask import Blueprint
|
||||
from flask import Flask
|
||||
from flask import Response
|
||||
from flask import abort
|
||||
from flask import current_app
|
||||
from flask import flash
|
||||
from flask import redirect
|
||||
from flask import render_template
|
||||
from flask import request
|
||||
from flask import session
|
||||
from flask import url_for
|
||||
from pydantic import ValidationError
|
||||
from scim2_models import Context
|
||||
from scim2_models import EnterpriseUser
|
||||
from scim2_models import Error
|
||||
from scim2_models import Group
|
||||
from scim2_models import ListResponse
|
||||
from scim2_models import ResourceType
|
||||
from scim2_models import Schema
|
||||
from scim2_models import SearchRequest
|
||||
from scim2_models import User
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
from canaille import csrf
|
||||
from canaille.oidc.models import Token
|
||||
from canaille.scim.models import get_resource_types
|
||||
from canaille.scim.models import get_schemas
|
||||
from canaille.scim.models import get_service_provider_config
|
||||
|
||||
bp = Blueprint("scim", __name__)
|
||||
oauth = OAuth()
|
||||
|
||||
|
||||
class SCIMBearerTokenValidator(BearerTokenValidator):
|
||||
def authenticate_token(self, token_string: str):
|
||||
token = Token()
|
||||
token.token_id = "bidulos"
|
||||
token.access_token = "NXeEceY820rnlzoh0FUxc4TFVKO5aAqikopiPEQacvL81ukk"
|
||||
token.issue_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
token.lifetime = 600
|
||||
token.revokation_date = None
|
||||
token.scope = "openid profile email groups address phone"
|
||||
backend.add_token(token)
|
||||
return backend.get_token(token_string)
|
||||
|
||||
|
||||
require_oauth = ResourceProtector()
|
||||
require_oauth.register_token_validator(SCIMBearerTokenValidator())
|
||||
|
||||
|
||||
class ClientBackend:
|
||||
users: list[User[EnterpriseUser]] = []
|
||||
groups: list[Group] = []
|
||||
tokens: list = []
|
||||
|
||||
def get_user_by_username(self, username):
|
||||
for user in self.users:
|
||||
if user.user_name == username:
|
||||
return user
|
||||
return None
|
||||
|
||||
def add_user(self, user):
|
||||
user.id = str(uuid.uuid4())
|
||||
self.users.append(user)
|
||||
|
||||
def replace_user(self, user):
|
||||
for saved_user in self.users:
|
||||
if saved_user.userName == user.userName:
|
||||
saved_user = user
|
||||
break
|
||||
|
||||
def delete_user(self, user):
|
||||
for saved_user in self.users:
|
||||
if saved_user.user_name == user.userName:
|
||||
self.users.remove(saved_user)
|
||||
break
|
||||
|
||||
def add_group(self, group):
|
||||
self.groups.append(group)
|
||||
|
||||
def replace_group(self, group):
|
||||
for saved_group in self.groups:
|
||||
if saved_group.groupName == group.groupName:
|
||||
saved_group = group
|
||||
break
|
||||
|
||||
def delete_group(self, group):
|
||||
for saved_group in self.groups:
|
||||
if saved_group.group_name == group.groupName:
|
||||
self.groups.remove(saved_group)
|
||||
break
|
||||
|
||||
def add_token(self, token):
|
||||
self.tokens.append(token)
|
||||
|
||||
def get_token(self, access_token):
|
||||
for token in self.tokens:
|
||||
if token.access_token == access_token:
|
||||
return token
|
||||
return None
|
||||
|
||||
|
||||
backend = ClientBackend()
|
||||
|
||||
|
||||
def setup_routes(app):
|
||||
@app.route("/")
|
||||
@app.route("/tos")
|
||||
|
@ -84,6 +187,206 @@ def setup_routes(app):
|
|||
flash("You have been successfully logged out", "success")
|
||||
return redirect(url_for("index"))
|
||||
|
||||
@bp.route("/Users")
|
||||
@csrf.exempt
|
||||
@require_oauth()
|
||||
def query_users():
|
||||
req = parse_search_request(request)
|
||||
users = backend.users[req.start_index_0 : req.stop_index_0]
|
||||
total = len(users)
|
||||
list_response = ListResponse[User[EnterpriseUser]](
|
||||
start_index=req.start_index,
|
||||
items_per_page=req.count,
|
||||
total_results=total,
|
||||
resources=users,
|
||||
)
|
||||
payload = list_response.model_dump(
|
||||
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
|
||||
)
|
||||
return payload
|
||||
|
||||
@bp.route("/Users/<string:username>", methods=["GET"])
|
||||
@csrf.exempt
|
||||
@require_oauth()
|
||||
def query_user(username):
|
||||
user = backend.get_user_by_username(username)
|
||||
if user:
|
||||
return user.model_dump(
|
||||
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
|
||||
)
|
||||
|
||||
@bp.route("/Schemas", methods=["GET"])
|
||||
@csrf.exempt
|
||||
@require_oauth()
|
||||
def query_schemas():
|
||||
req = parse_search_request(request)
|
||||
schemas = list(get_schemas().values())[req.start_index_0 : req.stop_index_0]
|
||||
response = ListResponse[Schema](
|
||||
total_results=len(schemas),
|
||||
items_per_page=req.count or len(schemas),
|
||||
start_index=req.start_index,
|
||||
resources=schemas,
|
||||
)
|
||||
return response.model_dump(scim_ctx=Context.RESOURCE_QUERY_RESPONSE)
|
||||
|
||||
@bp.route("/Schemas/<string:schema_id>", methods=["GET"])
|
||||
@csrf.exempt
|
||||
@require_oauth()
|
||||
def query_schema(schema_id):
|
||||
schema = get_schemas().get(schema_id)
|
||||
if not schema:
|
||||
abort(404)
|
||||
|
||||
return schema.model_dump(scim_ctx=Context.RESOURCE_QUERY_RESPONSE)
|
||||
|
||||
@bp.route("/ResourceTypes", methods=["GET"])
|
||||
@csrf.exempt
|
||||
@require_oauth()
|
||||
def query_resource_types():
|
||||
req = parse_search_request(request)
|
||||
resource_types = list(get_resource_types().values())[
|
||||
req.start_index_0 : req.stop_index_0
|
||||
]
|
||||
response = ListResponse[ResourceType](
|
||||
total_results=len(resource_types),
|
||||
items_per_page=req.count or len(resource_types),
|
||||
start_index=req.start_index,
|
||||
resources=resource_types,
|
||||
)
|
||||
return response.model_dump(scim_ctx=Context.RESOURCE_QUERY_RESPONSE)
|
||||
|
||||
@bp.route("/ResourceTypes/<string:resource_type_name>", methods=["GET"])
|
||||
@csrf.exempt
|
||||
@require_oauth()
|
||||
def query_resource_type(resource_type_name):
|
||||
resource_type = get_resource_types().get(resource_type_name)
|
||||
if not resource_type:
|
||||
abort(404)
|
||||
|
||||
return resource_type.model_dump(scim_ctx=Context.RESOURCE_QUERY_RESPONSE)
|
||||
|
||||
@bp.route("/ServiceProviderConfig", methods=["GET"])
|
||||
@csrf.exempt
|
||||
@require_oauth()
|
||||
def query_service_provider_config():
|
||||
spc = get_service_provider_config()
|
||||
return spc.model_dump(scim_ctx=Context.RESOURCE_QUERY_RESPONSE)
|
||||
|
||||
@bp.route("/Users", methods=["POST"])
|
||||
@csrf.exempt
|
||||
@require_oauth()
|
||||
def create_user():
|
||||
user = User[EnterpriseUser].model_validate(
|
||||
request.json, scim_ctx=Context.RESOURCE_CREATION_REQUEST
|
||||
)
|
||||
backend.add_user(user)
|
||||
payload = user.model_dump_json(scim_ctx=Context.RESOURCE_CREATION_RESPONSE)
|
||||
return Response(payload, status=HTTPStatus.CREATED)
|
||||
|
||||
@bp.route("/Users/<string:username>", methods=["PUT"])
|
||||
@csrf.exempt
|
||||
@require_oauth()
|
||||
def replace_user(username):
|
||||
user = backend.get_user_by_username(username)
|
||||
request_scim_user = User[EnterpriseUser].model_validate(
|
||||
request.json,
|
||||
scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST,
|
||||
original=user,
|
||||
)
|
||||
backend.replace_user(request_scim_user)
|
||||
payload = request_scim_user.model_dump(
|
||||
scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE
|
||||
)
|
||||
return payload
|
||||
|
||||
@bp.route("/Users/<string:username>", methods=["DELETE"])
|
||||
@csrf.exempt
|
||||
@require_oauth()
|
||||
def delete_user(username):
|
||||
user = backend.get_user_by_username(username)
|
||||
backend.delete_user(user)
|
||||
return "", HTTPStatus.NO_CONTENT
|
||||
|
||||
@bp.route("/Groups", methods=["GET"])
|
||||
@csrf.exempt
|
||||
@require_oauth()
|
||||
def query_groups():
|
||||
req = parse_search_request(request)
|
||||
groups = backend.groups[req.start_index_0 : req.stop_index_0]
|
||||
total = len(groups)
|
||||
list_response = ListResponse[Group](
|
||||
start_index=req.start_index,
|
||||
items_per_page=req.count,
|
||||
total_results=total,
|
||||
resources=groups,
|
||||
)
|
||||
payload = list_response.model_dump(
|
||||
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
|
||||
)
|
||||
return payload
|
||||
|
||||
@bp.route("/Groups/<string:groupName>", methods=["GET"])
|
||||
@csrf.exempt
|
||||
@require_oauth()
|
||||
def query_group(groupName):
|
||||
return groupName.model_dump(
|
||||
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
|
||||
)
|
||||
|
||||
@bp.route("/Groups/<string:groupName>", methods=["PUT"])
|
||||
@csrf.exempt
|
||||
@require_oauth()
|
||||
def replace_group(groupName):
|
||||
group = Group.model_validate(
|
||||
request.json,
|
||||
scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST,
|
||||
original=groupName,
|
||||
)
|
||||
backend.save_group(group)
|
||||
payload = group.model_dump(scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE)
|
||||
return payload
|
||||
|
||||
@bp.route("/Groups/<string:groupName>", methods=["DELETE"])
|
||||
@csrf.exempt
|
||||
@require_oauth()
|
||||
def delete_group(groupName):
|
||||
backend.delete_group(groupName)
|
||||
return "", HTTPStatus.NO_CONTENT
|
||||
|
||||
@bp.after_request
|
||||
def add_scim_content_type(response):
|
||||
response.headers["Content-Type"] = "application/scim+json"
|
||||
return response
|
||||
|
||||
@bp.errorhandler(HTTPException)
|
||||
def http_error_handler(error):
|
||||
obj = Error(detail=str(error), status=error.code)
|
||||
return obj.model_dump(), obj.status
|
||||
|
||||
@bp.errorhandler(AuthlibHTTPException)
|
||||
def oauth2_error(error):
|
||||
body = json.loads(error.body)
|
||||
obj = Error(
|
||||
detail=f"{body['error']}: {body['error_description']}"
|
||||
if "error_description" in body
|
||||
else body["error"],
|
||||
status=error.code,
|
||||
)
|
||||
return obj.model_dump(), error.code
|
||||
|
||||
@bp.errorhandler(ValidationError)
|
||||
def scim_error_handler(error):
|
||||
error_details = error.errors()[0]
|
||||
obj = Error(status=400, detail=error_details["msg"])
|
||||
# TODO: maybe the Pydantic <=> SCIM error code mapping could go in scim2_models
|
||||
obj.scim_type = (
|
||||
"invalidValue" if error_details["type"] == "required_error" else None
|
||||
)
|
||||
|
||||
return obj.model_dump(), obj.status
|
||||
|
||||
app.register_blueprint(bp)
|
||||
|
||||
|
||||
def setup_oauth(app):
|
||||
oauth.init_app(app)
|
||||
|
@ -119,3 +422,20 @@ def set_parameter_in_url_query(url, **kwargs):
|
|||
split[3] = parameters
|
||||
|
||||
return urlunsplit(split)
|
||||
|
||||
|
||||
def parse_search_request(request) -> SearchRequest:
|
||||
"""Create a SearchRequest object from the request arguments."""
|
||||
max_nb_items_per_page = 1000
|
||||
count = (
|
||||
min(request.args["count"], max_nb_items_per_page)
|
||||
if request.args.get("count")
|
||||
else None
|
||||
)
|
||||
req = SearchRequest(
|
||||
attributes=request.args.get("attributes"),
|
||||
excluded_attributes=request.args.get("excludedAttributes"),
|
||||
start_index=request.args.get("startIndex"),
|
||||
count=count,
|
||||
)
|
||||
return req
|
||||
|
|
Loading…
Reference in a new issue