forked from Github-Mirrors/canaille
feat: Added proper SCIM access token verification by client
This commit is contained in:
parent
6e64f51ad4
commit
6c1557cf27
6 changed files with 78 additions and 49 deletions
|
@ -111,7 +111,6 @@ class User(canaille.core.models.User, Base, SqlAlchemyModel):
|
|||
one_time_password_emission_date: Mapped[datetime.datetime] = mapped_column(
|
||||
TZDateTime(timezone=True), nullable=True
|
||||
)
|
||||
scim_id: Mapped[str] = mapped_column(String, nullable=True, unique=True)
|
||||
|
||||
@property
|
||||
def password_failure_timestamps(self):
|
||||
|
|
|
@ -7,6 +7,7 @@ from flask import current_app
|
|||
from httpx import Client as httpx_client
|
||||
from scim2_client.engines.httpx import SyncSCIMClient
|
||||
from scim2_models import SearchRequest
|
||||
from werkzeug.security import gen_salt
|
||||
|
||||
from canaille.app import models
|
||||
from canaille.backends import Backend
|
||||
|
@ -284,8 +285,6 @@ class User(Model):
|
|||
one_time_password_emission_date: datetime.datetime | None = None
|
||||
"""A DateTime indicating when the user last emitted an email or sms one-time password."""
|
||||
|
||||
scim_id: str | None = None
|
||||
|
||||
_readable_fields = None
|
||||
_writable_fields = None
|
||||
_permissions = None
|
||||
|
@ -504,11 +503,33 @@ class User(Model):
|
|||
|
||||
def propagate_scim_changes(self):
|
||||
for client in self.get_clients():
|
||||
scim_tokens = Backend.instance.query(
|
||||
models.Token, client=client, subject=None
|
||||
)
|
||||
valid_scim_tokens = [
|
||||
token
|
||||
for token in scim_tokens
|
||||
if not token.is_expired() and not token.is_revoked()
|
||||
]
|
||||
if valid_scim_tokens:
|
||||
scim_token = valid_scim_tokens[0]
|
||||
else:
|
||||
scim_token = models.Token(
|
||||
token_id=gen_salt(48),
|
||||
access_token=gen_salt(48),
|
||||
subject=None,
|
||||
audience=[client],
|
||||
client=client,
|
||||
refresh_token=gen_salt(48),
|
||||
scope=["openid", "profile"],
|
||||
issue_date=datetime.datetime.now(datetime.timezone.utc),
|
||||
lifetime=3600,
|
||||
)
|
||||
Backend.instance.save(scim_token)
|
||||
|
||||
client_httpx = httpx_client(
|
||||
base_url=client.client_uri,
|
||||
headers={
|
||||
"Authorization": "Bearer NXeEceY820rnlzoh0FUxc4TFVKO5aAqikopiPEQacvL81ukk"
|
||||
},
|
||||
headers={"Authorization": f"Bearer {scim_token.access_token}"},
|
||||
)
|
||||
scim = SyncSCIMClient(client_httpx)
|
||||
scim.discover()
|
||||
|
@ -518,7 +539,6 @@ class User(Model):
|
|||
|
||||
req = SearchRequest(filter=f'userName eq "{self.user_name}"')
|
||||
response = scim.query(User, search_request=req)
|
||||
|
||||
if not response.resources:
|
||||
try:
|
||||
scim.create(user)
|
||||
|
@ -536,7 +556,6 @@ class User(Model):
|
|||
)
|
||||
req = SearchRequest(filter=f'userName eq "{self.user_name}"')
|
||||
response = scim.query(User, search_request=req)
|
||||
print("response:", response)
|
||||
|
||||
def propagate_scim_delete(self):
|
||||
client = httpx_client(
|
||||
|
@ -554,7 +573,13 @@ class User(Model):
|
|||
def get_clients(self):
|
||||
if self.id:
|
||||
consents = Backend.instance.query(models.Consent, subject=self)
|
||||
return {t.client for t in consents}
|
||||
consented_clients = {t.client for t in consents}
|
||||
preconsented_clients = [
|
||||
client
|
||||
for client in Backend.instance.query(models.Client)
|
||||
if client.preconsent and client not in consented_clients
|
||||
]
|
||||
return list(consented_clients) + list(preconsented_clients)
|
||||
return []
|
||||
|
||||
|
||||
|
|
|
@ -394,18 +394,20 @@ class IntrospectionEndpoint(_IntrospectionEndpoint):
|
|||
|
||||
def introspect_token(self, token):
|
||||
audience = [aud.client_id for aud in token.audience]
|
||||
return {
|
||||
response = {
|
||||
"active": True,
|
||||
"client_id": token.client.client_id,
|
||||
"token_type": token.type,
|
||||
"username": token.subject.formatted_name,
|
||||
"scope": token.get_scope(),
|
||||
"sub": token.subject.user_name,
|
||||
"aud": audience,
|
||||
"iss": get_issuer(),
|
||||
"exp": token.get_expires_at(),
|
||||
"iat": token.get_issued_at(),
|
||||
}
|
||||
if token.subject:
|
||||
response["username"] = token.subject.formatted_name
|
||||
response["sub"] = token.subject.user_name
|
||||
return response
|
||||
|
||||
|
||||
class ClientManagementMixin:
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
from http import HTTPStatus
|
||||
from urllib.parse import urlsplit
|
||||
from urllib.parse import urlunsplit
|
||||
|
||||
import requests
|
||||
from authlib.common.errors import AuthlibBaseError
|
||||
from authlib.integrations.flask_client import OAuth
|
||||
from authlib.integrations.flask_oauth2 import ResourceProtector
|
||||
from authlib.integrations.flask_oauth2.errors import (
|
||||
_HTTPException as AuthlibHTTPException,
|
||||
)
|
||||
from authlib.oauth2.rfc6750 import BearerTokenValidator
|
||||
from authlib.oauth2.rfc7662 import IntrospectTokenValidator
|
||||
from authlib.oidc.discovery import get_well_known_url
|
||||
from flask import Blueprint
|
||||
from flask import Flask
|
||||
|
@ -37,7 +37,6 @@ from scim2_models import User
|
|||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
from canaille import csrf
|
||||
from canaille.oidc.models import Token
|
||||
from canaille.scim.models import get_resource_types
|
||||
from canaille.scim.models import get_schemas
|
||||
from canaille.scim.models import get_service_provider_config
|
||||
|
@ -46,17 +45,17 @@ bp = Blueprint("scim", __name__)
|
|||
oauth = OAuth()
|
||||
|
||||
|
||||
class SCIMBearerTokenValidator(BearerTokenValidator):
|
||||
def authenticate_token(self, token_string: str):
|
||||
token = Token()
|
||||
token.token_id = "bidulos"
|
||||
token.access_token = "NXeEceY820rnlzoh0FUxc4TFVKO5aAqikopiPEQacvL81ukk"
|
||||
token.issue_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
token.lifetime = 600
|
||||
token.revokation_date = None
|
||||
token.scope = "openid profile email groups address phone"
|
||||
backend.add_token(token)
|
||||
return backend.get_token(token_string)
|
||||
class SCIMBearerTokenValidator(IntrospectTokenValidator):
|
||||
def introspect_token(self, token_string: str):
|
||||
url = current_app.config["OAUTH_AUTH_SERVER"] + "/oauth/introspect"
|
||||
data = {"token": token_string, "token_type_hint": "access_token"}
|
||||
auth = (
|
||||
current_app.config["OAUTH_CLIENT_ID"],
|
||||
current_app.config["OAUTH_CLIENT_SECRET"],
|
||||
)
|
||||
resp = requests.post(url, data=data, auth=auth)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
require_oauth = ResourceProtector()
|
||||
|
@ -68,9 +67,9 @@ class ClientBackend:
|
|||
groups: list[Group] = []
|
||||
tokens: list = []
|
||||
|
||||
def get_user_by_username(self, username):
|
||||
def get_user_by_id(self, id):
|
||||
for user in self.users:
|
||||
if user.user_name == username:
|
||||
if user.id == id:
|
||||
return user
|
||||
return None
|
||||
|
||||
|
@ -79,9 +78,9 @@ class ClientBackend:
|
|||
self.users.append(user)
|
||||
|
||||
def replace_user(self, user):
|
||||
for saved_user in self.users:
|
||||
if saved_user.userName == user.userName:
|
||||
saved_user = user
|
||||
for i, saved_user in enumerate(self.users):
|
||||
if saved_user.id == user.id:
|
||||
self.users[i] = user
|
||||
break
|
||||
|
||||
def delete_user(self, user):
|
||||
|
@ -105,15 +104,6 @@ class ClientBackend:
|
|||
self.groups.remove(saved_group)
|
||||
break
|
||||
|
||||
def add_token(self, token):
|
||||
self.tokens.append(token)
|
||||
|
||||
def get_token(self, access_token):
|
||||
for token in self.tokens:
|
||||
if token.access_token == access_token:
|
||||
return token
|
||||
return None
|
||||
|
||||
|
||||
backend = ClientBackend()
|
||||
|
||||
|
@ -205,11 +195,11 @@ def setup_routes(app):
|
|||
)
|
||||
return payload
|
||||
|
||||
@bp.route("/Users/<string:username>", methods=["GET"])
|
||||
@bp.route("/Users/<string:id>", methods=["GET"])
|
||||
@csrf.exempt
|
||||
@require_oauth()
|
||||
def query_user(username):
|
||||
user = backend.get_user_by_username(username)
|
||||
def query_user(id):
|
||||
user = backend.get_user_by_id(id)
|
||||
if user:
|
||||
return user.model_dump(
|
||||
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
|
||||
|
@ -283,27 +273,28 @@ def setup_routes(app):
|
|||
payload = user.model_dump_json(scim_ctx=Context.RESOURCE_CREATION_RESPONSE)
|
||||
return Response(payload, status=HTTPStatus.CREATED)
|
||||
|
||||
@bp.route("/Users/<string:username>", methods=["PUT"])
|
||||
@bp.route("/Users/<string:id>", methods=["PUT"])
|
||||
@csrf.exempt
|
||||
@require_oauth()
|
||||
def replace_user(username):
|
||||
user = backend.get_user_by_username(username)
|
||||
def replace_user(id):
|
||||
user = backend.get_user_by_id(id)
|
||||
request_scim_user = User[EnterpriseUser].model_validate(
|
||||
request.json,
|
||||
scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST,
|
||||
original=user,
|
||||
)
|
||||
request_scim_user.id = user.id
|
||||
backend.replace_user(request_scim_user)
|
||||
payload = request_scim_user.model_dump(
|
||||
scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE
|
||||
)
|
||||
return payload
|
||||
|
||||
@bp.route("/Users/<string:username>", methods=["DELETE"])
|
||||
@bp.route("/Users/<string:id>", methods=["DELETE"])
|
||||
@csrf.exempt
|
||||
@require_oauth()
|
||||
def delete_user(username):
|
||||
user = backend.get_user_by_username(username)
|
||||
def delete_user(id):
|
||||
user = backend.get_user_by_id(id)
|
||||
backend.delete_user(user)
|
||||
return "", HTTPStatus.NO_CONTENT
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ dependencies = [
|
|||
"flask-wtf >= 1.2.1",
|
||||
"httpx>=0.28.1",
|
||||
"pydantic-settings >= 2.0.3",
|
||||
"q>=2.7",
|
||||
"requests>=2.32.3",
|
||||
"wtforms >= 3.1.1",
|
||||
]
|
||||
|
|
11
uv.lock
11
uv.lock
|
@ -143,6 +143,7 @@ dependencies = [
|
|||
{ name = "flask-wtf" },
|
||||
{ name = "httpx" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "q" },
|
||||
{ name = "requests" },
|
||||
{ name = "wtforms" },
|
||||
]
|
||||
|
@ -255,6 +256,7 @@ requires-dist = [
|
|||
{ name = "pydantic-settings", specifier = ">=2.0.3" },
|
||||
{ name = "python-ldap", marker = "extra == 'ldap'", specifier = ">=3.4.0" },
|
||||
{ name = "pytz", marker = "extra == 'front'", specifier = ">=2022.7" },
|
||||
{ name = "q", specifier = ">=2.7" },
|
||||
{ name = "qrcode", marker = "extra == 'otp'", specifier = ">=8.0" },
|
||||
{ name = "requests", specifier = ">=2.32.3" },
|
||||
{ name = "scim2-models", marker = "extra == 'scim'", specifier = ">=0.2.2" },
|
||||
|
@ -1659,6 +1661,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "q"
|
||||
version = "2.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/15/90/2649ecc3b4b335e62de4a0c3762c7cd7b2f77a023c5c00649f549cebb56c/q-2.7.tar.gz", hash = "sha256:8e0b792f6658ab9e1133b5ea17af1b530530e60124cf9743bc0fa051b8c64f4e", size = 7946 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/75/f0/ae942c0530d02092702211fd36d9a465e203f732789c84d0b96fbebe3039/q-2.7-py2.py3-none-any.whl", hash = "sha256:8388a3ef7e79b3b6224189e44ddba8dc1a6e9ed3212ce96f83f6056fa532459c", size = 10390 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "qrcode"
|
||||
version = "8.0"
|
||||
|
|
Loading…
Reference in a new issue