From 5d9a41f18b735901f6e76cf73a6eec3923befe5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89loi=20Rivard?= Date: Wed, 8 Mar 2023 00:53:27 +0100 Subject: [PATCH] Delayed LDAPObject may and must initialization --- canaille/account.py | 4 +- canaille/ldap_backend/ldapobject.py | 37 +++++++++++-------- .../oidc/admin/authorization_view.html | 4 +- tests/oidc/test_code_admin.py | 2 +- 4 files changed, 27 insertions(+), 20 deletions(-) diff --git a/canaille/account.py b/canaille/account.py index 24fa3115..beaf7bf8 100644 --- a/canaille/account.py +++ b/canaille/account.py @@ -366,7 +366,7 @@ def registration(data, hash): def profile_create(current_app, form): user = User() for attribute in form: - if attribute.name in user.may + user.must: + if attribute.name in user.may() + user.must(): if isinstance(attribute.data, FileStorage): data = attribute.data.stream.read() else: @@ -477,7 +477,7 @@ def profile_edit(editor, username): else: for attribute in form: if ( - attribute.name in user.may + user.must + attribute.name in user.may() + user.must() and attribute.name in editor.write ): if isinstance(attribute.data, FileStorage): diff --git a/canaille/ldap_backend/ldapobject.py b/canaille/ldap_backend/ldapobject.py index 146ce3a7..637d5fa2 100644 --- a/canaille/ldap_backend/ldapobject.py +++ b/canaille/ldap_backend/ldapobject.py @@ -9,8 +9,8 @@ from .utils import python_to_ldap class LDAPObject: _object_class_by_name = None _attribute_type_by_name = None - may = None - must = None + _may = None + _must = None base = None root_dn = None rdn = None @@ -25,9 +25,6 @@ class LDAPObject: for name, value in kwargs.items(): setattr(self, name, value) - if not self.may and not self.must: - self.update_ldap_attributes() - def __repr__(self): rdn = getattr(self, self.rdn, "?") return f"<{self.__class__.__name__} {self.rdn}={rdn}>" @@ -35,11 +32,11 @@ class LDAPObject: def __eq__(self, other): return ( isinstance(other, self.__class__) - and self.may == other.may - and self.must == other.must + and self.may() == other.may() + and self.must() == other.must() and all( getattr(self, attr) == getattr(other, attr) - for attr in self.may + self.must + for attr in self.may() + self.must() if hasattr(self, attr) and hasattr(other, attr) ) ) @@ -88,6 +85,16 @@ class LDAPObject: rdn = self.attrs[self.rdn][0] return f"{self.rdn}={ldap.dn.escape_dn_chars(rdn.strip())},{self.base},{self.root_dn}" + def may(self): + if not self._may: + self.update_ldap_attributes() + return self._may + + def must(self): + if not self._must: + self.update_ldap_attributes() + return self._must + @classmethod def ldap_connection(cls): return g.ldap_connection @@ -246,8 +253,8 @@ class LDAPObject: this_object_classes = {all_object_classes[name] for name in cls.object_class} done = set() - cls.may = [] - cls.must = [] + cls._may = [] + cls._must = [] while len(this_object_classes) > 0: object_class = this_object_classes.pop() done.add(object_class) @@ -256,11 +263,11 @@ class LDAPObject: for ocsup in object_class.sup if ocsup not in done } - cls.may.extend(object_class.may) - cls.must.extend(object_class.must) + cls._may.extend(object_class.may) + cls._must.extend(object_class.must) - cls.may = list(set(cls.may)) - cls.must = list(set(cls.must)) + cls._may = list(set(cls._may)) + cls._must = list(set(cls._must)) def reload(self, conn=None): conn = conn or self.ldap_connection() @@ -325,6 +332,6 @@ class LDAPObject: conn.delete_s(self.dn) def keys(self): - ldap_keys = self.must + self.may + ldap_keys = self.must() + self.may() inverted_table = {value: key for key, value in self.attribute_table.items()} return [inverted_table.get(key, key) for key in ldap_keys] diff --git a/canaille/templates/oidc/admin/authorization_view.html b/canaille/templates/oidc/admin/authorization_view.html index 048e6be0..b0fc4c91 100644 --- a/canaille/templates/oidc/admin/authorization_view.html +++ b/canaille/templates/oidc/admin/authorization_view.html @@ -9,10 +9,10 @@
diff --git a/tests/oidc/test_code_admin.py b/tests/oidc/test_code_admin.py index 1fc36cc0..204870db 100644 --- a/tests/oidc/test_code_admin.py +++ b/tests/oidc/test_code_admin.py @@ -13,5 +13,5 @@ def test_authorizaton_list(testclient, authorization, logged_admin): def test_authorizaton_view(testclient, authorization, logged_admin): res = testclient.get("/admin/authorization/" + authorization.authorization_code_id) - for attr in authorization.may + authorization.must: + for attr in authorization.may() + authorization.must(): assert attr in res.text