diff --git a/canaille/ldap_backend/backend.py b/canaille/ldap_backend/backend.py index 41b7f178..48e8fcc5 100644 --- a/canaille/ldap_backend/backend.py +++ b/canaille/ldap_backend/backend.py @@ -18,7 +18,9 @@ def setup_ldap_models(config): user_base = config["LDAP"]["USER_BASE"].replace(f',{config["LDAP"]["ROOT_DN"]}', "") User.base = user_base - User.rdn = config["LDAP"].get("USER_ID_ATTRIBUTE", User.DEFAULT_ID_ATTRIBUTE) + User.rdn_attribute = config["LDAP"].get( + "USER_ID_ATTRIBUTE", User.DEFAULT_ID_ATTRIBUTE + ) User.object_class = [config["LDAP"].get("USER_CLASS", User.DEFAULT_OBJECT_CLASS)] group_base = ( @@ -27,7 +29,9 @@ def setup_ldap_models(config): .replace(f',{config["LDAP"]["ROOT_DN"]}', "") ) Group.base = group_base or None - Group.rdn = config["LDAP"].get("GROUP_ID_ATTRIBUTE", Group.DEFAULT_ID_ATTRIBUTE) + Group.rdn_attribute = config["LDAP"].get( + "GROUP_ID_ATTRIBUTE", Group.DEFAULT_ID_ATTRIBUTE + ) Group.object_class = [config["LDAP"].get("GROUP_CLASS", Group.DEFAULT_OBJECT_CLASS)] diff --git a/canaille/ldap_backend/ldapobject.py b/canaille/ldap_backend/ldapobject.py index 637d5fa2..d3670aa2 100644 --- a/canaille/ldap_backend/ldapobject.py +++ b/canaille/ldap_backend/ldapobject.py @@ -13,7 +13,7 @@ class LDAPObject: _must = None base = None root_dn = None - rdn = None + rdn_attribute = None attribute_table = None object_class = None @@ -26,8 +26,7 @@ class LDAPObject: setattr(self, name, value) def __repr__(self): - rdn = getattr(self, self.rdn, "?") - return f"<{self.__class__.__name__} {self.rdn}={rdn}>" + return f"<{self.__class__.__name__} {self.rdn_attribute}={self.rdn_value}>" def __eq__(self, other): return ( @@ -77,13 +76,14 @@ class LDAPObject: def __setitem__(self, item, value): return setattr(self, item, value) + @property + def rdn_value(self): + value = getattr(self, self.rdn_attribute) + return (value[0] if isinstance(value, list) else value).strip() + @property def dn(self): - if self.rdn in self.changes: - rdn = self.changes[self.rdn][0] - else: - rdn = self.attrs[self.rdn][0] - return f"{self.rdn}={ldap.dn.escape_dn_chars(rdn.strip())},{self.base},{self.root_dn}" + return f"{self.rdn_attribute}={ldap.dn.escape_dn_chars(self.rdn_value)},{self.base},{self.root_dn}" def may(self): if not self._may: @@ -205,7 +205,7 @@ class LDAPObject: if base is None: base = f"{cls.base},{cls.root_dn}" elif "=" not in base: - base = f"{cls.rdn}={base},{cls.base},{cls.root_dn}" + base = f"{cls.rdn_attribute}={base},{cls.base},{cls.root_dn}" class_filter = ( "".join([f"(objectClass={oc})" for oc in cls.object_class]) diff --git a/canaille/oidc/models.py b/canaille/oidc/models.py index 51ca6255..f75c7c5c 100644 --- a/canaille/oidc/models.py +++ b/canaille/oidc/models.py @@ -10,7 +10,7 @@ from canaille.ldap_backend.ldapobject import LDAPObject class Client(LDAPObject, ClientMixin): object_class = ["oauthClient"] base = "ou=clients,ou=oauth" - rdn = "oauthClientID" + rdn_attribute = "oauthClientID" client_info_attributes = { "client_id": "oauthClientID", @@ -111,7 +111,7 @@ class Client(LDAPObject, ClientMixin): class AuthorizationCode(LDAPObject, AuthorizationCodeMixin): object_class = ["oauthAuthorizationCode"] base = "ou=authorizations,ou=oauth" - rdn = "oauthAuthorizationCodeID" + rdn_attribute = "oauthAuthorizationCodeID" attribute_table = { "authorization_code_id": "oauthAuthorizationCodeID", "description": "description", @@ -151,7 +151,7 @@ class AuthorizationCode(LDAPObject, AuthorizationCodeMixin): class Token(LDAPObject, TokenMixin): object_class = ["oauthToken"] base = "ou=tokens,ou=oauth" - rdn = "oauthTokenID" + rdn_attribute = "oauthTokenID" attribute_table = { "token_id": "oauthTokenID", "access_token": "oauthAccessToken", @@ -212,7 +212,7 @@ class Token(LDAPObject, TokenMixin): class Consent(LDAPObject): object_class = ["oauthConsent"] base = "ou=consents,ou=oauth" - rdn = "cn" + rdn_attribute = "cn" attribute_table = { "cn": "cn", "subject": "oauthSubject", diff --git a/tests/ldap/test_utils.py b/tests/ldap/test_utils.py index 2897fc62..36c0ef53 100644 --- a/tests/ldap/test_utils.py +++ b/tests/ldap/test_utils.py @@ -10,8 +10,8 @@ from canaille.models import User def test_repr(slapd_connection, foo_group, user): - assert repr(foo_group) == "" - assert repr(user) == "" + assert repr(foo_group) == "" + assert repr(user) == "" def test_equality(slapd_connection, foo_group, bar_group):