forked from Github-Mirrors/canaille
feat: Added proper SCIM access token verification by client
This commit is contained in:
parent
6e64f51ad4
commit
6c1557cf27
6 changed files with 78 additions and 49 deletions
|
@ -111,7 +111,6 @@ class User(canaille.core.models.User, Base, SqlAlchemyModel):
|
||||||
one_time_password_emission_date: Mapped[datetime.datetime] = mapped_column(
|
one_time_password_emission_date: Mapped[datetime.datetime] = mapped_column(
|
||||||
TZDateTime(timezone=True), nullable=True
|
TZDateTime(timezone=True), nullable=True
|
||||||
)
|
)
|
||||||
scim_id: Mapped[str] = mapped_column(String, nullable=True, unique=True)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def password_failure_timestamps(self):
|
def password_failure_timestamps(self):
|
||||||
|
|
|
@ -7,6 +7,7 @@ from flask import current_app
|
||||||
from httpx import Client as httpx_client
|
from httpx import Client as httpx_client
|
||||||
from scim2_client.engines.httpx import SyncSCIMClient
|
from scim2_client.engines.httpx import SyncSCIMClient
|
||||||
from scim2_models import SearchRequest
|
from scim2_models import SearchRequest
|
||||||
|
from werkzeug.security import gen_salt
|
||||||
|
|
||||||
from canaille.app import models
|
from canaille.app import models
|
||||||
from canaille.backends import Backend
|
from canaille.backends import Backend
|
||||||
|
@ -284,8 +285,6 @@ class User(Model):
|
||||||
one_time_password_emission_date: datetime.datetime | None = None
|
one_time_password_emission_date: datetime.datetime | None = None
|
||||||
"""A DateTime indicating when the user last emitted an email or sms one-time password."""
|
"""A DateTime indicating when the user last emitted an email or sms one-time password."""
|
||||||
|
|
||||||
scim_id: str | None = None
|
|
||||||
|
|
||||||
_readable_fields = None
|
_readable_fields = None
|
||||||
_writable_fields = None
|
_writable_fields = None
|
||||||
_permissions = None
|
_permissions = None
|
||||||
|
@ -504,11 +503,33 @@ class User(Model):
|
||||||
|
|
||||||
def propagate_scim_changes(self):
|
def propagate_scim_changes(self):
|
||||||
for client in self.get_clients():
|
for client in self.get_clients():
|
||||||
|
scim_tokens = Backend.instance.query(
|
||||||
|
models.Token, client=client, subject=None
|
||||||
|
)
|
||||||
|
valid_scim_tokens = [
|
||||||
|
token
|
||||||
|
for token in scim_tokens
|
||||||
|
if not token.is_expired() and not token.is_revoked()
|
||||||
|
]
|
||||||
|
if valid_scim_tokens:
|
||||||
|
scim_token = valid_scim_tokens[0]
|
||||||
|
else:
|
||||||
|
scim_token = models.Token(
|
||||||
|
token_id=gen_salt(48),
|
||||||
|
access_token=gen_salt(48),
|
||||||
|
subject=None,
|
||||||
|
audience=[client],
|
||||||
|
client=client,
|
||||||
|
refresh_token=gen_salt(48),
|
||||||
|
scope=["openid", "profile"],
|
||||||
|
issue_date=datetime.datetime.now(datetime.timezone.utc),
|
||||||
|
lifetime=3600,
|
||||||
|
)
|
||||||
|
Backend.instance.save(scim_token)
|
||||||
|
|
||||||
client_httpx = httpx_client(
|
client_httpx = httpx_client(
|
||||||
base_url=client.client_uri,
|
base_url=client.client_uri,
|
||||||
headers={
|
headers={"Authorization": f"Bearer {scim_token.access_token}"},
|
||||||
"Authorization": "Bearer NXeEceY820rnlzoh0FUxc4TFVKO5aAqikopiPEQacvL81ukk"
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
scim = SyncSCIMClient(client_httpx)
|
scim = SyncSCIMClient(client_httpx)
|
||||||
scim.discover()
|
scim.discover()
|
||||||
|
@ -518,7 +539,6 @@ class User(Model):
|
||||||
|
|
||||||
req = SearchRequest(filter=f'userName eq "{self.user_name}"')
|
req = SearchRequest(filter=f'userName eq "{self.user_name}"')
|
||||||
response = scim.query(User, search_request=req)
|
response = scim.query(User, search_request=req)
|
||||||
|
|
||||||
if not response.resources:
|
if not response.resources:
|
||||||
try:
|
try:
|
||||||
scim.create(user)
|
scim.create(user)
|
||||||
|
@ -536,7 +556,6 @@ class User(Model):
|
||||||
)
|
)
|
||||||
req = SearchRequest(filter=f'userName eq "{self.user_name}"')
|
req = SearchRequest(filter=f'userName eq "{self.user_name}"')
|
||||||
response = scim.query(User, search_request=req)
|
response = scim.query(User, search_request=req)
|
||||||
print("response:", response)
|
|
||||||
|
|
||||||
def propagate_scim_delete(self):
|
def propagate_scim_delete(self):
|
||||||
client = httpx_client(
|
client = httpx_client(
|
||||||
|
@ -554,7 +573,13 @@ class User(Model):
|
||||||
def get_clients(self):
|
def get_clients(self):
|
||||||
if self.id:
|
if self.id:
|
||||||
consents = Backend.instance.query(models.Consent, subject=self)
|
consents = Backend.instance.query(models.Consent, subject=self)
|
||||||
return {t.client for t in consents}
|
consented_clients = {t.client for t in consents}
|
||||||
|
preconsented_clients = [
|
||||||
|
client
|
||||||
|
for client in Backend.instance.query(models.Client)
|
||||||
|
if client.preconsent and client not in consented_clients
|
||||||
|
]
|
||||||
|
return list(consented_clients) + list(preconsented_clients)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -394,18 +394,20 @@ class IntrospectionEndpoint(_IntrospectionEndpoint):
|
||||||
|
|
||||||
def introspect_token(self, token):
|
def introspect_token(self, token):
|
||||||
audience = [aud.client_id for aud in token.audience]
|
audience = [aud.client_id for aud in token.audience]
|
||||||
return {
|
response = {
|
||||||
"active": True,
|
"active": True,
|
||||||
"client_id": token.client.client_id,
|
"client_id": token.client.client_id,
|
||||||
"token_type": token.type,
|
"token_type": token.type,
|
||||||
"username": token.subject.formatted_name,
|
|
||||||
"scope": token.get_scope(),
|
"scope": token.get_scope(),
|
||||||
"sub": token.subject.user_name,
|
|
||||||
"aud": audience,
|
"aud": audience,
|
||||||
"iss": get_issuer(),
|
"iss": get_issuer(),
|
||||||
"exp": token.get_expires_at(),
|
"exp": token.get_expires_at(),
|
||||||
"iat": token.get_issued_at(),
|
"iat": token.get_issued_at(),
|
||||||
}
|
}
|
||||||
|
if token.subject:
|
||||||
|
response["username"] = token.subject.formatted_name
|
||||||
|
response["sub"] = token.subject.user_name
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
class ClientManagementMixin:
|
class ClientManagementMixin:
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
import datetime
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from urllib.parse import urlsplit
|
from urllib.parse import urlsplit
|
||||||
from urllib.parse import urlunsplit
|
from urllib.parse import urlunsplit
|
||||||
|
|
||||||
|
import requests
|
||||||
from authlib.common.errors import AuthlibBaseError
|
from authlib.common.errors import AuthlibBaseError
|
||||||
from authlib.integrations.flask_client import OAuth
|
from authlib.integrations.flask_client import OAuth
|
||||||
from authlib.integrations.flask_oauth2 import ResourceProtector
|
from authlib.integrations.flask_oauth2 import ResourceProtector
|
||||||
from authlib.integrations.flask_oauth2.errors import (
|
from authlib.integrations.flask_oauth2.errors import (
|
||||||
_HTTPException as AuthlibHTTPException,
|
_HTTPException as AuthlibHTTPException,
|
||||||
)
|
)
|
||||||
from authlib.oauth2.rfc6750 import BearerTokenValidator
|
from authlib.oauth2.rfc7662 import IntrospectTokenValidator
|
||||||
from authlib.oidc.discovery import get_well_known_url
|
from authlib.oidc.discovery import get_well_known_url
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
|
@ -37,7 +37,6 @@ from scim2_models import User
|
||||||
from werkzeug.exceptions import HTTPException
|
from werkzeug.exceptions import HTTPException
|
||||||
|
|
||||||
from canaille import csrf
|
from canaille import csrf
|
||||||
from canaille.oidc.models import Token
|
|
||||||
from canaille.scim.models import get_resource_types
|
from canaille.scim.models import get_resource_types
|
||||||
from canaille.scim.models import get_schemas
|
from canaille.scim.models import get_schemas
|
||||||
from canaille.scim.models import get_service_provider_config
|
from canaille.scim.models import get_service_provider_config
|
||||||
|
@ -46,17 +45,17 @@ bp = Blueprint("scim", __name__)
|
||||||
oauth = OAuth()
|
oauth = OAuth()
|
||||||
|
|
||||||
|
|
||||||
class SCIMBearerTokenValidator(BearerTokenValidator):
|
class SCIMBearerTokenValidator(IntrospectTokenValidator):
|
||||||
def authenticate_token(self, token_string: str):
|
def introspect_token(self, token_string: str):
|
||||||
token = Token()
|
url = current_app.config["OAUTH_AUTH_SERVER"] + "/oauth/introspect"
|
||||||
token.token_id = "bidulos"
|
data = {"token": token_string, "token_type_hint": "access_token"}
|
||||||
token.access_token = "NXeEceY820rnlzoh0FUxc4TFVKO5aAqikopiPEQacvL81ukk"
|
auth = (
|
||||||
token.issue_date = datetime.datetime.now(datetime.timezone.utc)
|
current_app.config["OAUTH_CLIENT_ID"],
|
||||||
token.lifetime = 600
|
current_app.config["OAUTH_CLIENT_SECRET"],
|
||||||
token.revokation_date = None
|
)
|
||||||
token.scope = "openid profile email groups address phone"
|
resp = requests.post(url, data=data, auth=auth)
|
||||||
backend.add_token(token)
|
resp.raise_for_status()
|
||||||
return backend.get_token(token_string)
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
require_oauth = ResourceProtector()
|
require_oauth = ResourceProtector()
|
||||||
|
@ -68,9 +67,9 @@ class ClientBackend:
|
||||||
groups: list[Group] = []
|
groups: list[Group] = []
|
||||||
tokens: list = []
|
tokens: list = []
|
||||||
|
|
||||||
def get_user_by_username(self, username):
|
def get_user_by_id(self, id):
|
||||||
for user in self.users:
|
for user in self.users:
|
||||||
if user.user_name == username:
|
if user.id == id:
|
||||||
return user
|
return user
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -79,9 +78,9 @@ class ClientBackend:
|
||||||
self.users.append(user)
|
self.users.append(user)
|
||||||
|
|
||||||
def replace_user(self, user):
|
def replace_user(self, user):
|
||||||
for saved_user in self.users:
|
for i, saved_user in enumerate(self.users):
|
||||||
if saved_user.userName == user.userName:
|
if saved_user.id == user.id:
|
||||||
saved_user = user
|
self.users[i] = user
|
||||||
break
|
break
|
||||||
|
|
||||||
def delete_user(self, user):
|
def delete_user(self, user):
|
||||||
|
@ -105,15 +104,6 @@ class ClientBackend:
|
||||||
self.groups.remove(saved_group)
|
self.groups.remove(saved_group)
|
||||||
break
|
break
|
||||||
|
|
||||||
def add_token(self, token):
|
|
||||||
self.tokens.append(token)
|
|
||||||
|
|
||||||
def get_token(self, access_token):
|
|
||||||
for token in self.tokens:
|
|
||||||
if token.access_token == access_token:
|
|
||||||
return token
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
backend = ClientBackend()
|
backend = ClientBackend()
|
||||||
|
|
||||||
|
@ -205,11 +195,11 @@ def setup_routes(app):
|
||||||
)
|
)
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
@bp.route("/Users/<string:username>", methods=["GET"])
|
@bp.route("/Users/<string:id>", methods=["GET"])
|
||||||
@csrf.exempt
|
@csrf.exempt
|
||||||
@require_oauth()
|
@require_oauth()
|
||||||
def query_user(username):
|
def query_user(id):
|
||||||
user = backend.get_user_by_username(username)
|
user = backend.get_user_by_id(id)
|
||||||
if user:
|
if user:
|
||||||
return user.model_dump(
|
return user.model_dump(
|
||||||
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
|
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
|
||||||
|
@ -283,27 +273,28 @@ def setup_routes(app):
|
||||||
payload = user.model_dump_json(scim_ctx=Context.RESOURCE_CREATION_RESPONSE)
|
payload = user.model_dump_json(scim_ctx=Context.RESOURCE_CREATION_RESPONSE)
|
||||||
return Response(payload, status=HTTPStatus.CREATED)
|
return Response(payload, status=HTTPStatus.CREATED)
|
||||||
|
|
||||||
@bp.route("/Users/<string:username>", methods=["PUT"])
|
@bp.route("/Users/<string:id>", methods=["PUT"])
|
||||||
@csrf.exempt
|
@csrf.exempt
|
||||||
@require_oauth()
|
@require_oauth()
|
||||||
def replace_user(username):
|
def replace_user(id):
|
||||||
user = backend.get_user_by_username(username)
|
user = backend.get_user_by_id(id)
|
||||||
request_scim_user = User[EnterpriseUser].model_validate(
|
request_scim_user = User[EnterpriseUser].model_validate(
|
||||||
request.json,
|
request.json,
|
||||||
scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST,
|
scim_ctx=Context.RESOURCE_REPLACEMENT_REQUEST,
|
||||||
original=user,
|
original=user,
|
||||||
)
|
)
|
||||||
|
request_scim_user.id = user.id
|
||||||
backend.replace_user(request_scim_user)
|
backend.replace_user(request_scim_user)
|
||||||
payload = request_scim_user.model_dump(
|
payload = request_scim_user.model_dump(
|
||||||
scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE
|
scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE
|
||||||
)
|
)
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
@bp.route("/Users/<string:username>", methods=["DELETE"])
|
@bp.route("/Users/<string:id>", methods=["DELETE"])
|
||||||
@csrf.exempt
|
@csrf.exempt
|
||||||
@require_oauth()
|
@require_oauth()
|
||||||
def delete_user(username):
|
def delete_user(id):
|
||||||
user = backend.get_user_by_username(username)
|
user = backend.get_user_by_id(id)
|
||||||
backend.delete_user(user)
|
backend.delete_user(user)
|
||||||
return "", HTTPStatus.NO_CONTENT
|
return "", HTTPStatus.NO_CONTENT
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,7 @@ dependencies = [
|
||||||
"flask-wtf >= 1.2.1",
|
"flask-wtf >= 1.2.1",
|
||||||
"httpx>=0.28.1",
|
"httpx>=0.28.1",
|
||||||
"pydantic-settings >= 2.0.3",
|
"pydantic-settings >= 2.0.3",
|
||||||
|
"q>=2.7",
|
||||||
"requests>=2.32.3",
|
"requests>=2.32.3",
|
||||||
"wtforms >= 3.1.1",
|
"wtforms >= 3.1.1",
|
||||||
]
|
]
|
||||||
|
|
11
uv.lock
11
uv.lock
|
@ -143,6 +143,7 @@ dependencies = [
|
||||||
{ name = "flask-wtf" },
|
{ name = "flask-wtf" },
|
||||||
{ name = "httpx" },
|
{ name = "httpx" },
|
||||||
{ name = "pydantic-settings" },
|
{ name = "pydantic-settings" },
|
||||||
|
{ name = "q" },
|
||||||
{ name = "requests" },
|
{ name = "requests" },
|
||||||
{ name = "wtforms" },
|
{ name = "wtforms" },
|
||||||
]
|
]
|
||||||
|
@ -255,6 +256,7 @@ requires-dist = [
|
||||||
{ name = "pydantic-settings", specifier = ">=2.0.3" },
|
{ name = "pydantic-settings", specifier = ">=2.0.3" },
|
||||||
{ name = "python-ldap", marker = "extra == 'ldap'", specifier = ">=3.4.0" },
|
{ name = "python-ldap", marker = "extra == 'ldap'", specifier = ">=3.4.0" },
|
||||||
{ name = "pytz", marker = "extra == 'front'", specifier = ">=2022.7" },
|
{ name = "pytz", marker = "extra == 'front'", specifier = ">=2022.7" },
|
||||||
|
{ name = "q", specifier = ">=2.7" },
|
||||||
{ name = "qrcode", marker = "extra == 'otp'", specifier = ">=8.0" },
|
{ name = "qrcode", marker = "extra == 'otp'", specifier = ">=8.0" },
|
||||||
{ name = "requests", specifier = ">=2.32.3" },
|
{ name = "requests", specifier = ">=2.32.3" },
|
||||||
{ name = "scim2-models", marker = "extra == 'scim'", specifier = ">=0.2.2" },
|
{ name = "scim2-models", marker = "extra == 'scim'", specifier = ">=0.2.2" },
|
||||||
|
@ -1659,6 +1661,15 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 },
|
{ url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "q"
|
||||||
|
version = "2.7"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/15/90/2649ecc3b4b335e62de4a0c3762c7cd7b2f77a023c5c00649f549cebb56c/q-2.7.tar.gz", hash = "sha256:8e0b792f6658ab9e1133b5ea17af1b530530e60124cf9743bc0fa051b8c64f4e", size = 7946 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/75/f0/ae942c0530d02092702211fd36d9a465e203f732789c84d0b96fbebe3039/q-2.7-py2.py3-none-any.whl", hash = "sha256:8388a3ef7e79b3b6224189e44ddba8dc1a6e9ed3212ce96f83f6056fa532459c", size = 10390 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "qrcode"
|
name = "qrcode"
|
||||||
version = "8.0"
|
version = "8.0"
|
||||||
|
|
Loading…
Reference in a new issue