feat: Added proper SCIM access token verification by client

This commit is contained in:
Félix Rohrlich 2024-12-30 22:47:48 +01:00
parent 6e64f51ad4
commit 6c1557cf27
6 changed files with 78 additions and 49 deletions

View file

@ -111,7 +111,6 @@ class User(canaille.core.models.User, Base, SqlAlchemyModel):
one_time_password_emission_date: Mapped[datetime.datetime] = mapped_column( one_time_password_emission_date: Mapped[datetime.datetime] = mapped_column(
TZDateTime(timezone=True), nullable=True TZDateTime(timezone=True), nullable=True
) )
scim_id: Mapped[str] = mapped_column(String, nullable=True, unique=True)
@property @property
def password_failure_timestamps(self): def password_failure_timestamps(self):

View file

@ -7,6 +7,7 @@ from flask import current_app
from httpx import Client as httpx_client from httpx import Client as httpx_client
from scim2_client.engines.httpx import SyncSCIMClient from scim2_client.engines.httpx import SyncSCIMClient
from scim2_models import SearchRequest from scim2_models import SearchRequest
from werkzeug.security import gen_salt
from canaille.app import models from canaille.app import models
from canaille.backends import Backend from canaille.backends import Backend
@ -284,8 +285,6 @@ class User(Model):
one_time_password_emission_date: datetime.datetime | None = None one_time_password_emission_date: datetime.datetime | None = None
"""A DateTime indicating when the user last emitted an email or sms one-time password.""" """A DateTime indicating when the user last emitted an email or sms one-time password."""
scim_id: str | None = None
_readable_fields = None _readable_fields = None
_writable_fields = None _writable_fields = None
_permissions = None _permissions = None
@ -504,11 +503,33 @@ class User(Model):
def propagate_scim_changes(self): def propagate_scim_changes(self):
for client in self.get_clients(): for client in self.get_clients():
scim_tokens = Backend.instance.query(
models.Token, client=client, subject=None
)
valid_scim_tokens = [
token
for token in scim_tokens
if not token.is_expired() and not token.is_revoked()
]
if valid_scim_tokens:
scim_token = valid_scim_tokens[0]
else:
scim_token = models.Token(
token_id=gen_salt(48),
access_token=gen_salt(48),
subject=None,
audience=[client],
client=client,
refresh_token=gen_salt(48),
scope=["openid", "profile"],
issue_date=datetime.datetime.now(datetime.timezone.utc),
lifetime=3600,
)
Backend.instance.save(scim_token)
client_httpx = httpx_client( client_httpx = httpx_client(
base_url=client.client_uri, base_url=client.client_uri,
headers={ headers={"Authorization": f"Bearer {scim_token.access_token}"},
"Authorization": "Bearer NXeEceY820rnlzoh0FUxc4TFVKO5aAqikopiPEQacvL81ukk"
},
) )
scim = SyncSCIMClient(client_httpx) scim = SyncSCIMClient(client_httpx)
scim.discover() scim.discover()
@ -518,7 +539,6 @@ class User(Model):
req = SearchRequest(filter=f'userName eq "{self.user_name}"') req = SearchRequest(filter=f'userName eq "{self.user_name}"')
response = scim.query(User, search_request=req) response = scim.query(User, search_request=req)
if not response.resources: if not response.resources:
try: try:
scim.create(user) scim.create(user)
@ -536,7 +556,6 @@ class User(Model):
) )
req = SearchRequest(filter=f'userName eq "{self.user_name}"') req = SearchRequest(filter=f'userName eq "{self.user_name}"')
response = scim.query(User, search_request=req) response = scim.query(User, search_request=req)
print("response:", response)
def propagate_scim_delete(self): def propagate_scim_delete(self):
client = httpx_client( client = httpx_client(
@ -554,7 +573,13 @@ class User(Model):
def get_clients(self): def get_clients(self):
if self.id: if self.id:
consents = Backend.instance.query(models.Consent, subject=self) consents = Backend.instance.query(models.Consent, subject=self)
return {t.client for t in consents} consented_clients = {t.client for t in consents}
preconsented_clients = [
client
for client in Backend.instance.query(models.Client)
if client.preconsent and client not in consented_clients
]
return list(consented_clients) + list(preconsented_clients)
return [] return []

View file

@ -394,18 +394,20 @@ class IntrospectionEndpoint(_IntrospectionEndpoint):
def introspect_token(self, token): def introspect_token(self, token):
audience = [aud.client_id for aud in token.audience] audience = [aud.client_id for aud in token.audience]
return { response = {
"active": True, "active": True,
"client_id": token.client.client_id, "client_id": token.client.client_id,
"token_type": token.type, "token_type": token.type,
"username": token.subject.formatted_name,
"scope": token.get_scope(), "scope": token.get_scope(),
"sub": token.subject.user_name,
"aud": audience, "aud": audience,
"iss": get_issuer(), "iss": get_issuer(),
"exp": token.get_expires_at(), "exp": token.get_expires_at(),
"iat": token.get_issued_at(), "iat": token.get_issued_at(),
} }
if token.subject:
response["username"] = token.subject.formatted_name
response["sub"] = token.subject.user_name
return response
class ClientManagementMixin: class ClientManagementMixin:

View file

@ -1,17 +1,17 @@
import datetime
import json import json
import uuid import uuid
from http import HTTPStatus from http import HTTPStatus
from urllib.parse import urlsplit from urllib.parse import urlsplit
from urllib.parse import urlunsplit from urllib.parse import urlunsplit
import requests
from authlib.common.errors import AuthlibBaseError from authlib.common.errors import AuthlibBaseError
from authlib.integrations.flask_client import OAuth from authlib.integrations.flask_client import OAuth
from authlib.integrations.flask_oauth2 import ResourceProtector from authlib.integrations.flask_oauth2 import ResourceProtector
from authlib.integrations.flask_oauth2.errors import ( from authlib.integrations.flask_oauth2.errors import (
_HTTPException as AuthlibHTTPException, _HTTPException as AuthlibHTTPException,
) )
from authlib.oauth2.rfc6750 import BearerTokenValidator from authlib.oauth2.rfc7662 import IntrospectTokenValidator
from authlib.oidc.discovery import get_well_known_url from authlib.oidc.discovery import get_well_known_url
from flask import Blueprint from flask import Blueprint
from flask import Flask from flask import Flask
@ -37,7 +37,6 @@ from scim2_models import User
from werkzeug.exceptions import HTTPException from werkzeug.exceptions import HTTPException
from canaille import csrf from canaille import csrf
from canaille.oidc.models import Token
from canaille.scim.models import get_resource_types from canaille.scim.models import get_resource_types
from canaille.scim.models import get_schemas from canaille.scim.models import get_schemas
from canaille.scim.models import get_service_provider_config from canaille.scim.models import get_service_provider_config
@ -46,17 +45,17 @@ bp = Blueprint("scim", __name__)
oauth = OAuth() oauth = OAuth()
class SCIMBearerTokenValidator(BearerTokenValidator): class SCIMBearerTokenValidator(IntrospectTokenValidator):
def authenticate_token(self, token_string: str): def introspect_token(self, token_string: str):
token = Token() url = current_app.config["OAUTH_AUTH_SERVER"] + "/oauth/introspect"
token.token_id = "bidulos" data = {"token": token_string, "token_type_hint": "access_token"}
token.access_token = "NXeEceY820rnlzoh0FUxc4TFVKO5aAqikopiPEQacvL81ukk" auth = (
token.issue_date = datetime.datetime.now(datetime.timezone.utc) current_app.config["OAUTH_CLIENT_ID"],
token.lifetime = 600 current_app.config["OAUTH_CLIENT_SECRET"],
token.revokation_date = None )
token.scope = "openid profile email groups address phone" resp = requests.post(url, data=data, auth=auth)
backend.add_token(token) resp.raise_for_status()
return backend.get_token(token_string) return resp.json()
require_oauth = ResourceProtector() require_oauth = ResourceProtector()
@ -68,9 +67,9 @@ class ClientBackend:
groups: list[Group] = [] groups: list[Group] = []
tokens: list = [] tokens: list = []
def get_user_by_username(self, username): def get_user_by_id(self, id):
for user in self.users: for user in self.users:
if user.user_name == username: if user.id == id:
return user return user
return None return None
@ -79,9 +78,9 @@ class ClientBackend:
self.users.append(user) self.users.append(user)
def replace_user(self, user): def replace_user(self, user):
for saved_user in self.users: for i, saved_user in enumerate(self.users):
if saved_user.userName == user.userName: if saved_user.id == user.id:
saved_user = user self.users[i] = user
break break
def delete_user(self, user): def delete_user(self, user):
@ -105,15 +104,6 @@ class ClientBackend:
self.groups.remove(saved_group) self.groups.remove(saved_group)
break break
def add_token(self, token):
self.tokens.append(token)
def get_token(self, access_token):
for token in self.tokens:
if token.access_token == access_token:
return token
return None
backend = ClientBackend() backend = ClientBackend()
@ -205,11 +195,11 @@ def setup_routes(app):
) )
return payload return payload
@bp.route("/Users/<string:username>", methods=["GET"]) @bp.route("/Users/<string:id>", methods=["GET"])
@csrf.exempt @csrf.exempt
@require_oauth() @require_oauth()
def query_user(username): def query_user(id):
user = backend.get_user_by_username(username) user = backend.get_user_by_id(id)
if user: if user:
return user.model_dump( return user.model_dump(
scim_ctx=Context.RESOURCE_QUERY_RESPONSE, scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
@ -283,27 +273,28 @@ def setup_routes(app):
payload = user.model_dump_json(scim_ctx=Context.RESOURCE_CREATION_RESPONSE) payload = user.model_dump_json(scim_ctx=Context.RESOURCE_CREATION_RESPONSE)
return Response(payload, status=HTTPStatus.CREATED) return Response(payload, status=HTTPStatus.CREATED)
@bp.route("/Users/<string:username>", methods=["PUT"]) @bp.route("/Users/<string:id>", methods=["PUT"])
@csrf.exempt @csrf.exempt
@require_oauth() @require_oauth()
def replace_user(username): def replace_user(id):
user = backend.get_user_by_username(username) user = backend.get_user_by_id(id)
request_scim_user = User[EnterpriseUser].model_validate( request_scim_user = User[EnterpriseUser].model_validate(
request.json, request.json,
scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST, scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST,
original=user, original=user,
) )
request_scim_user.id = user.id
backend.replace_user(request_scim_user) backend.replace_user(request_scim_user)
payload = request_scim_user.model_dump( payload = request_scim_user.model_dump(
scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE
) )
return payload return payload
@bp.route("/Users/<string:username>", methods=["DELETE"]) @bp.route("/Users/<string:id>", methods=["DELETE"])
@csrf.exempt @csrf.exempt
@require_oauth() @require_oauth()
def delete_user(username): def delete_user(id):
user = backend.get_user_by_username(username) user = backend.get_user_by_id(id)
backend.delete_user(user) backend.delete_user(user)
return "", HTTPStatus.NO_CONTENT return "", HTTPStatus.NO_CONTENT

View file

@ -34,6 +34,7 @@ dependencies = [
"flask-wtf >= 1.2.1", "flask-wtf >= 1.2.1",
"httpx>=0.28.1", "httpx>=0.28.1",
"pydantic-settings >= 2.0.3", "pydantic-settings >= 2.0.3",
"q>=2.7",
"requests>=2.32.3", "requests>=2.32.3",
"wtforms >= 3.1.1", "wtforms >= 3.1.1",
] ]

11
uv.lock
View file

@ -143,6 +143,7 @@ dependencies = [
{ name = "flask-wtf" }, { name = "flask-wtf" },
{ name = "httpx" }, { name = "httpx" },
{ name = "pydantic-settings" }, { name = "pydantic-settings" },
{ name = "q" },
{ name = "requests" }, { name = "requests" },
{ name = "wtforms" }, { name = "wtforms" },
] ]
@ -255,6 +256,7 @@ requires-dist = [
{ name = "pydantic-settings", specifier = ">=2.0.3" }, { name = "pydantic-settings", specifier = ">=2.0.3" },
{ name = "python-ldap", marker = "extra == 'ldap'", specifier = ">=3.4.0" }, { name = "python-ldap", marker = "extra == 'ldap'", specifier = ">=3.4.0" },
{ name = "pytz", marker = "extra == 'front'", specifier = ">=2022.7" }, { name = "pytz", marker = "extra == 'front'", specifier = ">=2022.7" },
{ name = "q", specifier = ">=2.7" },
{ name = "qrcode", marker = "extra == 'otp'", specifier = ">=8.0" }, { name = "qrcode", marker = "extra == 'otp'", specifier = ">=8.0" },
{ name = "requests", specifier = ">=2.32.3" }, { name = "requests", specifier = ">=2.32.3" },
{ name = "scim2-models", marker = "extra == 'scim'", specifier = ">=0.2.2" }, { name = "scim2-models", marker = "extra == 'scim'", specifier = ">=0.2.2" },
@ -1659,6 +1661,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 }, { url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 },
] ]
[[package]]
name = "q"
version = "2.7"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/15/90/2649ecc3b4b335e62de4a0c3762c7cd7b2f77a023c5c00649f549cebb56c/q-2.7.tar.gz", hash = "sha256:8e0b792f6658ab9e1133b5ea17af1b530530e60124cf9743bc0fa051b8c64f4e", size = 7946 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/75/f0/ae942c0530d02092702211fd36d9a465e203f732789c84d0b96fbebe3039/q-2.7-py2.py3-none-any.whl", hash = "sha256:8388a3ef7e79b3b6224189e44ddba8dc1a6e9ed3212ce96f83f6056fa532459c", size = 10390 },
]
[[package]] [[package]]
name = "qrcode" name = "qrcode"
version = "8.0" version = "8.0"