refactor: model attributes are walked from the top to the bottom

This commit is contained in:
Éloi Rivard 2024-04-06 22:46:11 +02:00
parent fe809161ff
commit 75837fa207
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184
4 changed files with 15 additions and 16 deletions

View file

@ -100,7 +100,6 @@ class LDAPObjectQuery:
def guess_class(self, klass, object_classes):
if klass == LDAPObject:
for oc in object_classes:
if oc.decode() in LDAPObjectMetaclass.ldap_to_python_class:
return LDAPObjectMetaclass.ldap_to_python_class[oc.decode()]
return klass

View file

@ -41,6 +41,16 @@ class Model:
the value MUST be the same as the value of :attr:`~canaille.backends.models.Model.created`.
"""
@classproperty
def attributes(cls):
return ChainMap(
*(
klass.__annotations__
for klass in reversed(cls.__mro__)
if "__annotations__" in klass.__dict__ and issubclass(klass, Model)
)
)
class BackendModel:
"""The backend model abstract class.
@ -49,12 +59,6 @@ class BackendModel:
implemented for every model and for every backend.
"""
@classproperty
def attributes(cls):
return ChainMap(
*(c.__annotations__ for c in cls.__mro__ if "__annotations__" in c.__dict__)
)
@classmethod
def query(cls, **kwargs):
"""

View file

@ -61,6 +61,8 @@ class SqlAlchemyModel(BackendModel):
getattr(cls, attribute_name).ilike(f"%{query}%")
for attribute_name in attributes
if "str" in str(cls.attributes[attribute_name])
# erk, photo is an URL string according to SCIM, but bytes here
and attribute_name != "photo"
)
return (
@ -72,9 +74,7 @@ class SqlAlchemyModel(BackendModel):
if isinstance(value, list):
return or_(cls.attribute_filter(name, v) for v in value)
# extract the sqlalchemy.orm.Mapped type
attribute_type = typing.get_args(cls.attributes[name])[0]
multiple = typing.get_origin(attribute_type) is list
multiple = typing.get_origin(cls.attributes[name]) is list
if multiple:
return getattr(cls, name).contains(value)
@ -200,8 +200,7 @@ class User(canaille.core.models.User, Base, SqlAlchemyModel):
return all(
self.normalize_filter_value(attribute, value)
in getattr(self, attribute, [])
if typing.get_origin(typing.get_args(self.attributes[attribute])[0])
is list
if typing.get_origin(self.attributes[attribute]) is list
else self.normalize_filter_value(attribute, value)
== getattr(self, attribute, None)
for attribute, value in filter.items()

View file

@ -190,9 +190,6 @@ def test_guess_object_from_dn(backend, testclient, foo_group):
assert g == foo_group
assert g.display_name == foo_group.display_name
ou = LDAPObject.get(dn=f"{models.Group.base},{models.Group.root_dn}")
assert isinstance(ou, LDAPObject)
def test_object_class_update(backend, testclient):
testclient.app.config["CANAILLE_LDAP"]["USER_CLASS"] = ["inetOrgPerson"]