diff --git a/canaille/groups.py b/canaille/groups.py index 382abece..4a90d506 100644 --- a/canaille/groups.py +++ b/canaille/groups.py @@ -32,7 +32,7 @@ def create_group(user): flash(_("Group creation failed."), "error") else: group = Group() - group.member = [user.dn] + group.member = [user] group.cn = [form.name.data] group.description = [form.description.data] group.save() diff --git a/canaille/ldap_backend/backend.py b/canaille/ldap_backend/backend.py index 48e8fcc5..a65b23d6 100644 --- a/canaille/ldap_backend/backend.py +++ b/canaille/ldap_backend/backend.py @@ -155,7 +155,7 @@ def validate_configuration(config): group = Group( cn=f"canaille_{uuid.uuid4()}", - member=[user.dn], + member=[user], ) group.save(conn) group.delete(conn) diff --git a/canaille/ldap_backend/ldapobject.py b/canaille/ldap_backend/ldapobject.py index 60c99e19..97381176 100644 --- a/canaille/ldap_backend/ldapobject.py +++ b/canaille/ldap_backend/ldapobject.py @@ -6,7 +6,24 @@ from .utils import ldap_to_python from .utils import python_to_ldap -class LDAPObject: +class LDAPObjectMetaclass(type): + ldap_to_python_class = {} + + def __new__(cls, name, bases, attrs): + klass = super().__new__(cls, name, bases, attrs) + if attrs.get("object_class"): + for oc in attrs["object_class"]: + cls.ldap_to_python_class[oc] = klass + return klass + + def __setattr__(cls, name, value): + super().__setattr__(name, value) + if name == "object_class": + for oc in value: + cls.ldap_to_python_class[oc] = cls + + +class LDAPObject(metaclass=LDAPObjectMetaclass): _object_class_by_name = None _attribute_type_by_name = None _may = None @@ -22,23 +39,44 @@ class LDAPObject: self.changes = {} kwargs.setdefault("objectClass", self.object_class) + for name, value in kwargs.items(): setattr(self, name, value) def __repr__(self): - return f"<{self.__class__.__name__} {self.rdn_attribute}={self.rdn_value}>" + return ( + f"<{self.__class__.__name__} {self.rdn_attribute}={self.rdn_value}>" + if self.rdn_attribute + else "" + ) def __eq__(self, other): - return ( + if not ( isinstance(other, self.__class__) and self.may() == other.may() and self.must() == other.must() and all( - getattr(self, attr) == getattr(other, attr) + hasattr(self, attr) == hasattr(other, attr) for attr in self.may() + self.must() - if hasattr(self, attr) and hasattr(other, attr) ) + ): + return False + + self_attributes = self.python_attrs_to_ldap( + { + attr: getattr(self, attr) + for attr in self.may() + self.must() + if hasattr(self, attr) + } ) + other_attributes = other.python_attrs_to_ldap( + { + attr: getattr(other, attr) + for attr in self.may() + self.must() + if hasattr(self, attr) + } + ) + return self_attributes == other_attributes def __hash__(self): return hash(self.dn) @@ -59,8 +97,7 @@ class LDAPObject: # Lazy conversion from ldap format to python format if any(isinstance(value, bytes) for value in self.attrs[name]): - ldap_attrs = LDAPObject.ldap_object_attributes() - syntax = ldap_attrs[name].syntax if name in ldap_attrs else None + syntax = self.attribute_ldap_syntax(name) self.attrs[name] = [ ldap_to_python(value, syntax) for value in self.attrs[name] ] @@ -102,8 +139,6 @@ class LDAPObject: return self._may def must(self): - if not self._must: - self.update_ldap_attributes() return self._must @classmethod @@ -178,7 +213,6 @@ class LDAPObject: @classmethod def python_attrs_to_ldap(cls, attrs, encode=True): - ldap_attrs = LDAPObject.ldap_object_attributes() if cls.attribute_table: attrs = { cls.attribute_table.get(name, name): values @@ -186,16 +220,23 @@ class LDAPObject: } return { name: [ - python_to_ldap( - value, - ldap_attrs[name].syntax if name in ldap_attrs else None, - encode=encode, - ) + python_to_ldap(value, cls.attribute_ldap_syntax(name), encode=encode) for value in (values if isinstance(values, list) else [values]) ] for name, values in attrs.items() } + @classmethod + def attribute_ldap_syntax(cls, attribute_name): + ldap_attrs = LDAPObject.ldap_object_attributes() + if attribute_name not in ldap_attrs: + return None + + if ldap_attrs[attribute_name].syntax: + return ldap_attrs[attribute_name].syntax + + return cls.attribute_ldap_syntax(ldap_attrs[attribute_name].sup[0]) + @classmethod def get(cls, dn=None, filter=None, conn=None, **kwargs): try: @@ -241,11 +282,20 @@ class LDAPObject: objects = [] for _, args in result: + cls = cls.guess_class(args["objectClass"]) obj = cls() obj.attrs = args objects.append(obj) return objects + @classmethod + def guess_class(cls, object_classes): + if cls == LDAPObject: + for oc in object_classes: + if oc.decode() in LDAPObjectMetaclass.ldap_to_python_class: + return LDAPObjectMetaclass.ldap_to_python_class[oc.decode()] + return cls + @classmethod def update_ldap_attributes(cls): all_object_classes = cls.ldap_object_classes() diff --git a/canaille/ldap_backend/utils.py b/canaille/ldap_backend/utils.py index 45f33ee0..2d9d0cad 100644 --- a/canaille/ldap_backend/utils.py +++ b/canaille/ldap_backend/utils.py @@ -6,22 +6,26 @@ LDAP_NULL_DATE = "000001010000Z" class Syntax(str, Enum): # fmt: off - BOOLEAN = "1.3.6.1.4.1.1466.115.121.1.7" - DIRECTORY_STRING = "1.3.6.1.4.1.1466.115.121.1.15" - FAX_IMAGE = "1.3.6.1.4.1.1466.115.121.1.23" - GENERALIZED_TIME = "1.3.6.1.4.1.1466.115.121.1.24" - IA5_STRING = "1.3.6.1.4.1.1466.115.121.1.26" - INTEGER = "1.3.6.1.4.1.1466.115.121.1.27" - JPEG = "1.3.6.1.4.1.1466.115.121.1.28" - NUMERIC_STRING = "1.3.6.1.4.1.1466.115.121.1.36" - OCTET_STRING = "1.3.6.1.4.1.1466.115.121.1.40" - POSTAL_ADDRESS = "1.3.6.1.4.1.1466.115.121.1.41" - PRINTABLE_STRING = "1.3.6.1.4.1.1466.115.121.1.44" - TELEPHONE_NUMBER = "1.3.6.1.4.1.1466.115.121.1.50" + BINARY = "1.3.6.1.4.1.1466.115.121.1.5" + BOOLEAN = "1.3.6.1.4.1.1466.115.121.1.7" + DISTINGUISHED_NAME = "1.3.6.1.4.1.1466.115.121.1.12" + DIRECTORY_STRING = "1.3.6.1.4.1.1466.115.121.1.15" + FAX_IMAGE = "1.3.6.1.4.1.1466.115.121.1.23" + GENERALIZED_TIME = "1.3.6.1.4.1.1466.115.121.1.24" + IA5_STRING = "1.3.6.1.4.1.1466.115.121.1.26" + INTEGER = "1.3.6.1.4.1.1466.115.121.1.27" + JPEG = "1.3.6.1.4.1.1466.115.121.1.28" + NUMERIC_STRING = "1.3.6.1.4.1.1466.115.121.1.36" + OCTET_STRING = "1.3.6.1.4.1.1466.115.121.1.40" + POSTAL_ADDRESS = "1.3.6.1.4.1.1466.115.121.1.41" + PRINTABLE_STRING = "1.3.6.1.4.1.1466.115.121.1.44" + TELEPHONE_NUMBER = "1.3.6.1.4.1.1466.115.121.1.50" # fmt: on def ldap_to_python(value, syntax): + from .ldapobject import LDAPObject + if syntax == Syntax.GENERALIZED_TIME: value = value.decode("utf-8") if value == LDAP_NULL_DATE: @@ -39,6 +43,9 @@ def ldap_to_python(value, syntax): if syntax == Syntax.BOOLEAN: return value.decode("utf-8").upper() == "TRUE" + if syntax == Syntax.DISTINGUISHED_NAME: + return LDAPObject.get(dn=value.decode("utf-8")) + return value.decode("utf-8") @@ -59,6 +66,9 @@ def python_to_ldap(value, syntax, encode=True): if syntax == Syntax.BOOLEAN and isinstance(value, bool): value = "TRUE" if value else "FALSE" + if syntax == Syntax.DISTINGUISHED_NAME: + value = value.dn + if not value: return None diff --git a/canaille/models.py b/canaille/models.py index ae0d77b8..75b6d3d1 100644 --- a/canaille/models.py +++ b/canaille/models.py @@ -190,16 +190,12 @@ class Group(LDAPObject): return self[attribute][0] def get_members(self, conn=None): - return [ - User.get(dn=user_id, conn=conn) - for user_id in self.member - if User.get(dn=user_id, conn=conn) - ] + return [member for member in self.member if member] def add_member(self, user): - self.member = self.member + [user.dn] + self.member = self.member + [user] self.save() def remove_member(self, user): - self.member = [m for m in self.member if m != user.dn] + self.member = [m for m in self.member if m != user] self.save() diff --git a/canaille/oidc/clients.py b/canaille/oidc/clients.py index 75a36f96..b05b3948 100644 --- a/canaille/oidc/clients.py +++ b/canaille/oidc/clients.py @@ -71,7 +71,7 @@ def add(user): if form["token_endpoint_auth_method"].data == "none" else gen_salt(48), ) - client.audience = [client.dn] + client.audience = [client] client.save() flash( _("The client has been created."), @@ -142,7 +142,7 @@ def client_edit(client_id): software_version=form["software_version"].data, jwk=form["jwk"].data, jwks_uri=form["jwks_uri"].data, - audience=form["audience"].data, + audience=[Client.get(dn=dn) for dn in form["audience"].data], preconsent=form["preconsent"].data, ) client.save() diff --git a/canaille/oidc/consents.py b/canaille/oidc/consents.py index 6c726f7d..03902431 100644 --- a/canaille/oidc/consents.py +++ b/canaille/oidc/consents.py @@ -20,19 +20,17 @@ bp = Blueprint("consents", __name__, url_prefix="/consent") @bp.route("/") @user_needed() def consents(user): - consents = Consent.query(subject=user.dn) - client_dns = {t.client for t in consents} - clients = {dn: Client.get(dn) for dn in client_dns} + consents = Consent.query(subject=user) + clients = {t.client for t in consents} preconsented = [ client for client in Client.query() - if client.preconsent and client.dn not in clients + if client.preconsent and client not in clients ] return render_template( "oidc/user/consent_list.html", consents=consents, - clients=clients, menuitem="consents", scope_details=SCOPE_DETAILS, ignored_scopes=["openid"], @@ -45,7 +43,7 @@ def consents(user): def revoke(user, consent_id): consent = Consent.get(consent_id) - if not consent or consent.subject != user.dn: + if not consent or consent.subject != user: flash(_("Could not revoke this access"), "error") elif consent.revokation_date: @@ -63,7 +61,7 @@ def revoke(user, consent_id): def restore(user, consent_id): consent = Consent.get(consent_id) - if not consent or consent.subject != user.dn: + if not consent or consent.subject != user: flash(_("Could not restore this access"), "error") elif not consent.revokation_date: @@ -88,14 +86,14 @@ def revoke_preconsent(user, client_id): flash(_("Could not revoke this access"), "error") return redirect(url_for("oidc.consents.consents")) - consent = Consent.get(client=client.dn, subject=user.dn) + consent = Consent.get(client=client, subject=user) if consent: return redirect(url_for("oidc.consents.revoke", consent_id=consent.cn[0])) consent = Consent( cn=str(uuid.uuid4()), - client=client.dn, - subject=user.dn, + client=client, + subject=user, scope=client.scope, ) consent.revoke() diff --git a/canaille/oidc/endpoints.py b/canaille/oidc/endpoints.py index 2701a6ba..11b9e7cd 100644 --- a/canaille/oidc/endpoints.py +++ b/canaille/oidc/endpoints.py @@ -90,8 +90,8 @@ def authorize(): # CONSENT consents = Consent.query( - client=client.dn, - subject=user.dn, + client=client, + subject=user, ) consent = consents[0] if consents else None @@ -101,7 +101,7 @@ def authorize(): or (consent and all(scope in set(consent.scope) for scope in scopes)) and not consent.revoked ): - return authorization.create_authorization_response(grant_user=user.dn) + return authorization.create_authorization_response(grant_user=user) elif request.args.get("prompt") == "none": response = {"error": "consent_required"} @@ -135,7 +135,7 @@ def authorize(): grant_user = None if request.form["answer"] == "accept": - grant_user = user.dn + grant_user = user if consent: if consent.revoked: @@ -146,8 +146,8 @@ def authorize(): else: consent = Consent( cn=str(uuid.uuid4()), - client=client.dn, - subject=user.dn, + client=client, + subject=user, scope=scopes, issue_date=datetime.datetime.now(), ) diff --git a/canaille/oidc/models.py b/canaille/oidc/models.py index f75c7c5c..c3ce43a4 100644 --- a/canaille/oidc/models.py +++ b/canaille/oidc/models.py @@ -96,13 +96,13 @@ class Client(LDAPObject, ClientMixin): return metadata def delete(self): - for consent in Consent.query(client=self.dn): + for consent in Consent.query(client=self): consent.delete() - for code in AuthorizationCode.query(client=self.dn): + for code in AuthorizationCode.query(client=self): code.delete() - for token in Token.query(client=self.dn): + for token in Token.query(client=self): token.delete() super().delete() @@ -206,7 +206,7 @@ class Token(LDAPObject, TokenMixin): return bool(self.revokation_date) def check_client(self, client): - return client.client_id == Client.get(self.client).client_id + return client.client_id == self.client.client_id class Consent(LDAPObject): diff --git a/canaille/oidc/oauth.py b/canaille/oidc/oauth.py index 56898383..f9ac5d67 100644 --- a/canaille/oidc/oauth.py +++ b/canaille/oidc/oauth.py @@ -43,7 +43,8 @@ AUTHORIZATION_CODE_LIFETIME = 84400 def exists_nonce(nonce, req): - exists = AuthorizationCode.query(client=req.client_id, nonce=nonce) + client = Client.get(dn=req.client_id) + exists = AuthorizationCode.query(client=client, nonce=nonce) return bool(exists) @@ -98,7 +99,6 @@ def claims_from_scope(scope): def generate_user_info(user, scope): - user = User.get(dn=user) claims = claims_from_scope(scope) data = generate_user_claims(user, claims) return UserInfo(**data) @@ -131,7 +131,7 @@ def save_authorization_code(code, request): authorization_code_id=gen_salt(48), code=code, subject=request.user, - client=request.client.dn, + client=request.client, redirect_uri=request.redirect_uri or request.client.redirect_uris[0], scope=scope, nonce=nonce, @@ -151,7 +151,7 @@ class AuthorizationCodeGrant(_AuthorizationCodeGrant): return save_authorization_code(code, request) def query_authorization_code(self, code, client): - item = AuthorizationCode.query(code=code, client=client.dn) + item = AuthorizationCode.query(code=code, client=client) if item and not item[0].is_expired(): return item[0] @@ -159,9 +159,7 @@ class AuthorizationCodeGrant(_AuthorizationCodeGrant): authorization_code.delete() def authenticate_user(self, authorization_code): - user = User.get(dn=authorization_code.subject) - if user: - return user.dn + return authorization_code.subject class OpenIDCode(_OpenIDCode): @@ -176,16 +174,14 @@ class OpenIDCode(_OpenIDCode): def get_audiences(self, request): client = request.client - return [Client.get(aud).client_id for aud in client.audience] + return [aud.client_id for aud in client.audience] class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant): TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"] def authenticate_user(self, username, password): - user = User.authenticate(username, password) - if user: - return user.dn + return User.authenticate(username, password) class RefreshTokenGrant(_RefreshTokenGrant): @@ -195,9 +191,7 @@ class RefreshTokenGrant(_RefreshTokenGrant): return token[0] def authenticate_user(self, credential): - user = User.get(dn=credential.subject) - if user: - return user.dn + return credential.subject def revoke_old_credential(self, credential): credential.revokation_date = datetime.datetime.now() @@ -216,7 +210,7 @@ class OpenIDImplicitGrant(_OpenIDImplicitGrant): def get_audiences(self, request): client = request.client - return [Client.get(aud).client_id for aud in client.audience] + return [aud.client_id for aud in client.audience] class OpenIDHybridGrant(_OpenIDHybridGrant): @@ -234,7 +228,7 @@ class OpenIDHybridGrant(_OpenIDHybridGrant): def get_audiences(self, request): client = request.client - return [Client.get(aud).client_id for aud in client.audience] + return [aud.client_id for aud in client.audience] def query_client(client_id): @@ -250,7 +244,7 @@ def save_token(token, request): issue_date=now, lifetime=token["expires_in"], scope=token["scope"], - client=request.client.dn, + client=request.client, refresh_token=token.get("refresh_token"), subject=request.user, audience=request.client.audience, @@ -294,19 +288,17 @@ class IntrospectionEndpoint(_IntrospectionEndpoint): return query_token(token, token_type_hint) def check_permission(self, token, client, request): - return client.dn in token.audience + return client in token.audience def introspect_token(self, token): - client_id = Client.get(token.client).client_id - user = User.get(dn=token.subject) - audience = [Client.get(aud).client_id for aud in token.audience] + audience = [aud.client_id for aud in token.audience] return { "active": True, - "client_id": client_id, + "client_id": token.client.client_id, "token_type": token.type, - "username": user.name, + "username": token.subject.name, "scope": token.get_scope(), - "sub": user.uid[0], + "sub": token.subject.uid[0], "aud": audience, "iss": get_issuer(), "exp": token.get_expires_at(), @@ -402,7 +394,7 @@ require_oauth = ResourceProtector() def generate_access_token(client, grant_type, user, scope): - audience = [Client.get(dn).client_id for dn in client.audience] + audience = [client.client_id for client in client.audience] bearer_token_generator = authorization._token_generators["default"] kwargs = { "token": {}, diff --git a/canaille/oidc/tokens.py b/canaille/oidc/tokens.py index 46b14242..90f143d7 100644 --- a/canaille/oidc/tokens.py +++ b/canaille/oidc/tokens.py @@ -20,11 +20,9 @@ bp = Blueprint("tokens", __name__, url_prefix="/admin/token") @permissions_needed("manage_oidc") def index(user): tokens = Token.query() - items = ( - (token, Client.get(token.client), User.get(dn=token.subject)) - for token in tokens + return render_template( + "oidc/admin/token_list.html", tokens=tokens, menuitem="admin" ) - return render_template("oidc/admin/token_list.html", items=items, menuitem="admin") @bp.route("/", methods=["GET", "POST"]) @@ -35,15 +33,9 @@ def view(user, token_id): if not token: abort(404) - token_client = Client.get(token.client) - token_user = User.get(dn=token.subject) - token_audience = [Client.get(aud) for aud in token.audience] return render_template( "oidc/admin/token_view.html", token=token, - token_client=token_client, - token_user=token_user, - token_audience=token_audience, menuitem="admin", ) diff --git a/canaille/populate.py b/canaille/populate.py index b6ae09af..ab89defc 100644 --- a/canaille/populate.py +++ b/canaille/populate.py @@ -45,7 +45,7 @@ def fake_groups(nb=1, nb_users_max=1): description=fake.sentence(), ) nb_users = random.randrange(1, nb_users_max + 1) - group.member = list({random.choice(users).dn for _ in range(nb_users)}) + group.member = list({random.choice(users) for _ in range(nb_users)}) group.save() groups.append(group) return groups diff --git a/canaille/templates/oidc/admin/token_list.html b/canaille/templates/oidc/admin/token_list.html index f526bcb0..f3f85468 100644 --- a/canaille/templates/oidc/admin/token_list.html +++ b/canaille/templates/oidc/admin/token_list.html @@ -20,7 +20,7 @@ {% trans %}Subject{% endtrans %} {% trans %}Created{% endtrans %} - {% for token, client, user in items %} + {% for token in tokens %} @@ -28,13 +28,13 @@ - - {{ client.client_name }} + + {{ token.client.client_name }} - - {{ user.uid[0] }} + + {{ token.subject.uid[0] }} {{ token.issue_date }} diff --git a/canaille/templates/oidc/admin/token_view.html b/canaille/templates/oidc/admin/token_view.html index 0bd6da8e..67fac098 100644 --- a/canaille/templates/oidc/admin/token_view.html +++ b/canaille/templates/oidc/admin/token_view.html @@ -33,16 +33,16 @@ {{ _("Client") }} - - {{ token_client.client_name }} + + {{ token.client.client_name }} {{ _("Subject") }} - - {{ token_user.name }} - {{ token_user.uid[0] }} + + {{ token.subject.name }} - {{ token.subject.uid[0] }} @@ -60,7 +60,7 @@ {{ _("Audience") }}