refactor: model attributes are walked from the top to the bottom

This commit is contained in:
Éloi Rivard 2024-04-06 22:46:11 +02:00
parent fe809161ff
commit 75837fa207
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184
4 changed files with 15 additions and 16 deletions

View file

@ -100,8 +100,7 @@ class LDAPObjectQuery:
def guess_class(self, klass, object_classes): def guess_class(self, klass, object_classes):
if klass == LDAPObject: if klass == LDAPObject:
for oc in object_classes: for oc in object_classes:
if oc.decode() in LDAPObjectMetaclass.ldap_to_python_class: return LDAPObjectMetaclass.ldap_to_python_class[oc.decode()]
return LDAPObjectMetaclass.ldap_to_python_class[oc.decode()]
return klass return klass

View file

@ -41,6 +41,16 @@ class Model:
the value MUST be the same as the value of :attr:`~canaille.backends.models.Model.created`. the value MUST be the same as the value of :attr:`~canaille.backends.models.Model.created`.
""" """
@classproperty
def attributes(cls):
return ChainMap(
*(
klass.__annotations__
for klass in reversed(cls.__mro__)
if "__annotations__" in klass.__dict__ and issubclass(klass, Model)
)
)
class BackendModel: class BackendModel:
"""The backend model abstract class. """The backend model abstract class.
@ -49,12 +59,6 @@ class BackendModel:
implemented for every model and for every backend. implemented for every model and for every backend.
""" """
@classproperty
def attributes(cls):
return ChainMap(
*(c.__annotations__ for c in cls.__mro__ if "__annotations__" in c.__dict__)
)
@classmethod @classmethod
def query(cls, **kwargs): def query(cls, **kwargs):
""" """

View file

@ -61,6 +61,8 @@ class SqlAlchemyModel(BackendModel):
getattr(cls, attribute_name).ilike(f"%{query}%") getattr(cls, attribute_name).ilike(f"%{query}%")
for attribute_name in attributes for attribute_name in attributes
if "str" in str(cls.attributes[attribute_name]) if "str" in str(cls.attributes[attribute_name])
# erk, photo is an URL string according to SCIM, but bytes here
and attribute_name != "photo"
) )
return ( return (
@ -72,9 +74,7 @@ class SqlAlchemyModel(BackendModel):
if isinstance(value, list): if isinstance(value, list):
return or_(cls.attribute_filter(name, v) for v in value) return or_(cls.attribute_filter(name, v) for v in value)
# extract the sqlalchemy.orm.Mapped type multiple = typing.get_origin(cls.attributes[name]) is list
attribute_type = typing.get_args(cls.attributes[name])[0]
multiple = typing.get_origin(attribute_type) is list
if multiple: if multiple:
return getattr(cls, name).contains(value) return getattr(cls, name).contains(value)
@ -200,8 +200,7 @@ class User(canaille.core.models.User, Base, SqlAlchemyModel):
return all( return all(
self.normalize_filter_value(attribute, value) self.normalize_filter_value(attribute, value)
in getattr(self, attribute, []) in getattr(self, attribute, [])
if typing.get_origin(typing.get_args(self.attributes[attribute])[0]) if typing.get_origin(self.attributes[attribute]) is list
is list
else self.normalize_filter_value(attribute, value) else self.normalize_filter_value(attribute, value)
== getattr(self, attribute, None) == getattr(self, attribute, None)
for attribute, value in filter.items() for attribute, value in filter.items()

View file

@ -190,9 +190,6 @@ def test_guess_object_from_dn(backend, testclient, foo_group):
assert g == foo_group assert g == foo_group
assert g.display_name == foo_group.display_name assert g.display_name == foo_group.display_name
ou = LDAPObject.get(dn=f"{models.Group.base},{models.Group.root_dn}")
assert isinstance(ou, LDAPObject)
def test_object_class_update(backend, testclient): def test_object_class_update(backend, testclient):
testclient.app.config["CANAILLE_LDAP"]["USER_CLASS"] = ["inetOrgPerson"] testclient.app.config["CANAILLE_LDAP"]["USER_CLASS"] = ["inetOrgPerson"]