diff --git a/canaille/backends/ldap/ldapobject.py b/canaille/backends/ldap/ldapobject.py index 1787a589..79a3c0e7 100644 --- a/canaille/backends/ldap/ldapobject.py +++ b/canaille/backends/ldap/ldapobject.py @@ -174,9 +174,6 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass): ldap_name = self.python_attribute_to_ldap(name) - if ldap_name == "dn": - return self.dn_for(self.rdn_value) - python_single_value = typing.get_origin(self.attributes[name]) is not list ldap_value = self.get_ldap_attribute(ldap_name) return cardinalize_attribute(python_single_value, ldap_value) @@ -306,18 +303,18 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass): return cls._attribute_type_by_name @classmethod - def get(cls, id=None, filter=None, **kwargs): + def get(cls, dn=None, filter=None, **kwargs): try: - return cls.query(id, filter, **kwargs)[0] + return cls.query(dn, filter, **kwargs)[0] except (IndexError, ldap.NO_SUCH_OBJECT): return None @classmethod - def query(cls, id=None, filter=None, **kwargs): + def query(cls, dn=None, filter=None, **kwargs): conn = Backend.get().connection - base = id or kwargs.get("id") - if base is None: + base = dn + if dn is None: base = f"{cls.base},{cls.root_dn}" elif "=" not in base: base = ldap.dn.escape_dn_chars(base) @@ -330,15 +327,17 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass): ) if class_filter: class_filter = f"(|{class_filter})" + arg_filter = "" - kwargs = python_attrs_to_ldap( + ldap_args = python_attrs_to_ldap( { cls.python_attribute_to_ldap(name): values for name, values in kwargs.items() + if values is not None }, encode=False, ) - for key, value in kwargs.items(): + for key, value in ldap_args.items(): if len(value) == 1: escaped_value = ldap.filter.escape_filter_chars(value[0]) arg_filter += f"({key}={escaped_value})" @@ -420,7 +419,6 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass): attributes = ["objectClass"] + [ self.python_attribute_to_ldap(name) for name in self.attributes ] - attributes.remove("dn") read_post_control = PostReadControl(criticality=True, attrList=attributes) # Object already exists in the LDAP database diff --git a/canaille/backends/ldap/models.py b/canaille/backends/ldap/models.py index 12df8b7e..5bf8c57c 100644 --- a/canaille/backends/ldap/models.py +++ b/canaille/backends/ldap/models.py @@ -13,7 +13,7 @@ from .ldapobject import LDAPObject class User(canaille.core.models.User, LDAPObject): attribute_map = { - "id": "dn", + "id": "entryUUID", "created": "createTimestamp", "last_modified": "modifyTimestamp", "user_name": "uid", @@ -59,7 +59,10 @@ class User(canaille.core.models.User, LDAPObject): if isinstance(filter_, dict): # not super generic, but how can we improve this? ¯\_(ツ)_/¯ if "groups" in filter_ and "=" not in filter_.get("groups"): - filter_["groups"] = Group.dn_for(filter_["groups"]) + group_by_id = Group.get(id=filter_["groups"]) + filter_["groups"] = ( + group_by_id.dn if group_by_id else Group.dn_for(filter_["groups"]) + ) base = "".join( f"({cls.python_attribute_to_ldap(key)}={value})" @@ -150,7 +153,7 @@ class User(canaille.core.models.User, LDAPObject): old_groups = self.state.get(group_attr) or [] new_groups = [ - v if isinstance(v, Group) else Group.get(id=v) for v in new_groups + v if isinstance(v, Group) else Group.get(dn=v) for v in new_groups ] to_add = set(new_groups) - set(old_groups) to_del = set(old_groups) - set(new_groups) @@ -186,7 +189,7 @@ class User(canaille.core.models.User, LDAPObject): class Group(canaille.core.models.Group, LDAPObject): attribute_map = { - "id": "dn", + "id": "entryUUID", "created": "createTimestamp", "last_modified": "modifyTimestamp", "display_name": "cn", @@ -198,11 +201,6 @@ class Group(canaille.core.models.Group, LDAPObject): def identifier(self): return self.rdn_value - @property - def display_name(self): - attribute = current_app.config["CANAILLE_LDAP"]["GROUP_NAME_ATTRIBUTE"] - return getattr(self, attribute)[0] - class Client(canaille.oidc.models.Client, LDAPObject): ldap_object_class = ["oauthClient"] @@ -235,7 +233,7 @@ class Client(canaille.oidc.models.Client, LDAPObject): } attribute_map = { - "id": "dn", + "id": "entryUUID", "created": "createTimestamp", "last_modified": "modifyTimestamp", "preconsent": "oauthPreconsent", @@ -256,7 +254,7 @@ class AuthorizationCode(canaille.oidc.models.AuthorizationCode, LDAPObject): base = "ou=authorizations,ou=oauth" rdn_attribute = "oauthAuthorizationCodeID" attribute_map = { - "id": "dn", + "id": "entryUUID", "created": "createTimestamp", "last_modified": "modifyTimestamp", "authorization_code_id": "oauthAuthorizationCodeID", @@ -284,7 +282,7 @@ class Token(canaille.oidc.models.Token, LDAPObject): base = "ou=tokens,ou=oauth" rdn_attribute = "oauthTokenID" attribute_map = { - "id": "dn", + "id": "entryUUID", "created": "createTimestamp", "last_modified": "modifyTimestamp", "token_id": "oauthTokenID", @@ -310,7 +308,7 @@ class Consent(canaille.oidc.models.Consent, LDAPObject): base = "ou=consents,ou=oauth" rdn_attribute = "cn" attribute_map = { - "id": "dn", + "id": "entryUUID", "created": "createTimestamp", "last_modified": "modifyTimestamp", "consent_id": "cn", diff --git a/canaille/backends/ldap/utils.py b/canaille/backends/ldap/utils.py index 50eca69c..e0d7e8a6 100644 --- a/canaille/backends/ldap/utils.py +++ b/canaille/backends/ldap/utils.py @@ -50,7 +50,7 @@ def ldap_to_python(value, syntax): return value.decode("utf-8").upper() == "TRUE" if syntax == Syntax.DISTINGUISHED_NAME: - return LDAPObject.get(id=value.decode("utf-8")) + return LDAPObject.get(dn=value.decode("utf-8")) return value.decode("utf-8") @@ -75,7 +75,7 @@ def python_to_ldap(value, syntax, encode=True): value = "TRUE" if value else "FALSE" if syntax == Syntax.DISTINGUISHED_NAME: - value = value.id if value else None + value = value.dn if value else None if not value: return None diff --git a/canaille/backends/memory/models.py b/canaille/backends/memory/models.py index 347c5905..56528f0c 100644 --- a/canaille/backends/memory/models.py +++ b/canaille/backends/memory/models.py @@ -19,7 +19,6 @@ class MemoryModel(BackendModel): """Associates attribute values and ids.""" def __init__(self, *args, **kwargs): - kwargs.setdefault("id", str(uuid.uuid4())) self._state = {} self._cache = {} for attribute, value in kwargs.items(): @@ -116,6 +115,9 @@ class MemoryModel(BackendModel): return value def save(self): + if not self.id: + self.id = str(uuid.uuid4()) + self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace( microsecond=0 ) diff --git a/canaille/oidc/endpoints/clients.py b/canaille/oidc/endpoints/clients.py index 1836fc0a..2cc541af 100644 --- a/canaille/oidc/endpoints/clients.py +++ b/canaille/oidc/endpoints/clients.py @@ -73,6 +73,7 @@ def add(user): if form["token_endpoint_auth_method"].data == "none" else gen_salt(48), ) + client.save() client.audience = [client] client.save() flash( diff --git a/canaille/oidc/oauth.py b/canaille/oidc/oauth.py index 0b6b6919..2151e25b 100644 --- a/canaille/oidc/oauth.py +++ b/canaille/oidc/oauth.py @@ -456,6 +456,7 @@ class ClientRegistrationEndpoint(ClientManagementMixin, _ClientRegistrationEndpo post_logout_redirect_uris=request.data.get("post_logout_redirect_uris"), **self.client_convert_data(**client_info, **client_metadata), ) + client.save() client.audience = [client] client.save() return client diff --git a/demo/demoapp.py b/demo/demoapp.py index 17c78ad6..c325e850 100644 --- a/demo/demoapp.py +++ b/demo/demoapp.py @@ -119,6 +119,7 @@ def populate(app): response_types=["code", "id_token"], token_endpoint_auth_method="client_secret_basic", ) + client1.save() client1.audience = [client1] client1.save() @@ -142,6 +143,7 @@ def populate(app): token_endpoint_auth_method="client_secret_basic", preconsent=True, ) + client2.save() client2.audience = [client2] client2.save() diff --git a/tests/backends/ldap/test_permissions.py b/tests/backends/ldap/test_permissions.py new file mode 100644 index 00000000..f2e5d4ff --- /dev/null +++ b/tests/backends/ldap/test_permissions.py @@ -0,0 +1,9 @@ +def test_group_permissions_by_dn(testclient, user, foo_group): + assert not user.can_manage_users + + testclient.app.config["CANAILLE"]["ACL"]["ADMIN"]["FILTER"] = { + "groups": foo_group.dn + } + user.reload() + + assert user.can_manage_users diff --git a/tests/backends/ldap/test_utils.py b/tests/backends/ldap/test_utils.py index df11997a..872fc257 100644 --- a/tests/backends/ldap/test_utils.py +++ b/tests/backends/ldap/test_utils.py @@ -48,14 +48,14 @@ def test_dn_when_leading_space_in_id_attribute(testclient, backend): ) user.save() - dn = user.id + dn = user.dn assert dn == "uid=user,ou=users,dc=mydomain,dc=tld" assert ldap.dn.is_dn(dn) assert ldap.dn.dn2str(ldap.dn.str2dn(dn)) == dn assert user == models.User.get(user.identifier) assert user == models.User.get(user_name=user.identifier) - assert user == models.User.get(id=dn) + assert user == models.User.get(dn=dn) user.delete() @@ -69,14 +69,14 @@ def test_special_chars_in_rdn(testclient, backend): ) user.save() - dn = user.id + dn = user.dn assert ldap.dn.is_dn(dn) assert ldap.dn.dn2str(ldap.dn.str2dn(dn)) == dn assert dn == "uid=\\#user,ou=users,dc=mydomain,dc=tld" assert user == models.User.get(user.identifier) assert user == models.User.get(user_name=user.identifier) - assert user == models.User.get(id=dn) + assert user == models.User.get(dn=dn) user.delete() @@ -184,13 +184,13 @@ def test_operational_attribute_conversion(backend): def test_guess_object_from_dn(backend, testclient, foo_group): foo_group.members = [foo_group] foo_group.save() - dn = foo_group.id - g = LDAPObject.get(id=dn) + dn = foo_group.dn + g = LDAPObject.get(dn=dn) assert isinstance(g, models.Group) assert g == foo_group assert g.display_name == foo_group.display_name - ou = LDAPObject.get(id=f"{models.Group.base},{models.Group.root_dn}") + ou = LDAPObject.get(dn=f"{models.Group.base},{models.Group.root_dn}") assert isinstance(ou, LDAPObject) diff --git a/tests/backends/test_models.py b/tests/backends/test_models.py index ec84b87f..9d64adcf 100644 --- a/tests/backends/test_models.py +++ b/tests/backends/test_models.py @@ -35,6 +35,7 @@ def test_model_lifecycle(testclient, backend): formatted_name="formatted_name", ) + assert not user.id assert not models.User.query() assert not models.User.query(id=user.id) assert not models.User.query(id="invalid") diff --git a/tests/oidc/conftest.py b/tests/oidc/conftest.py index 5310fb11..f4cc958d 100644 --- a/tests/oidc/conftest.py +++ b/tests/oidc/conftest.py @@ -67,6 +67,7 @@ def client(testclient, trusted_client, backend): token_endpoint_auth_method="client_secret_basic", post_logout_redirect_uris=["https://mydomain.tld/disconnected"], ) + c.save() c.audience = [c, trusted_client] c.save() @@ -104,6 +105,7 @@ def trusted_client(testclient, backend): post_logout_redirect_uris=["https://myotherdomain.tld/disconnected"], preconsent=True, ) + c.save() c.audience = [c] c.save()