refactor: factorize match_filter in the main User class

This commit is contained in:
Éloi Rivard 2024-04-07 01:21:55 +02:00
parent 76cd3dc169
commit f113188368
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184
7 changed files with 79 additions and 90 deletions

View file

@ -6,6 +6,7 @@ import ldap.dn
import ldap.filter
from ldap.controls.readentry import PostReadControl
from canaille.app import classproperty
from canaille.backends.models import BackendModel
from .backend import Backend
@ -131,6 +132,10 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
else "<LDAPOBject>"
)
@classproperty
def identifier_attribute(cls):
return cls.rdn_attribute
def __html__(self):
return self.id
@ -302,10 +307,17 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
return cls._attribute_type_by_name
@classmethod
def get(cls, dn=None, **kwargs):
def get(cls, identifier=None, /, **kwargs):
try:
return cls.query(dn, **kwargs)[0]
return cls.query(identifier, **kwargs)[0]
except (IndexError, ldap.NO_SUCH_OBJECT):
if identifier and cls.base:
return (
cls.get(**{cls.identifier_attribute: identifier})
or cls.get(id=identifier)
or None
)
return None
@classmethod

View file

@ -55,36 +55,11 @@ class User(canaille.core.models.User, LDAPObject):
return cls.get(filter=filter, **kwargs)
def match_filter(self, filter):
if isinstance(filter, str):
conn = Backend.get().connection
filter_ = self.acl_filter_to_ldap_filter(filter)
return not filter_ or (
self.dn and conn.search_s(self.dn, ldap.SCOPE_SUBTREE, filter_)
)
return self.dn and conn.search_s(self.dn, ldap.SCOPE_SUBTREE, filter)
@classmethod
def acl_filter_to_ldap_filter(cls, filter_):
if isinstance(filter_, dict):
# not super generic, but how can we improve this? ¯\_(ツ)_/¯
if "groups" in filter_ and "=" not in filter_.get("groups"):
group_by_id = Group.get(id=filter_["groups"])
filter_["groups"] = (
group_by_id.dn if group_by_id else Group.dn_for(filter_["groups"])
)
base = "".join(
f"({cls.python_attribute_to_ldap(key)}={value})"
for key, value in filter_.items()
)
return f"(&{base})" if len(filter_) > 1 else base
if isinstance(filter_, list):
return (
"(|"
+ "".join(cls.acl_filter_to_ldap_filter(mapping) for mapping in filter_)
+ ")"
)
return filter_
return super().match_filter(filter)
@classmethod
def get(cls, *args, **kwargs):
@ -159,9 +134,7 @@ class User(canaille.core.models.User, LDAPObject):
return super().save(*args, **kwargs)
old_groups = self.state.get(group_attr) or []
new_groups = [
v if isinstance(v, Group) else Group.get(dn=v) for v in new_groups
]
new_groups = [v if isinstance(v, Group) else Group.get(v) for v in new_groups]
to_add = set(new_groups) - set(old_groups)
to_del = set(old_groups) - set(new_groups)

View file

@ -50,7 +50,7 @@ def ldap_to_python(value, syntax):
return value.decode("utf-8").upper() == "TRUE"
if syntax == Syntax.DISTINGUISHED_NAME:
return LDAPObject.get(dn=value.decode("utf-8"))
return LDAPObject.get(value.decode("utf-8"))
return value.decode("utf-8")

View file

@ -73,9 +73,13 @@ class MemoryModel(BackendModel):
]
@classmethod
def get(cls, identifier=None, **kwargs):
def get(cls, identifier=None, /, **kwargs):
if identifier:
kwargs[cls.identifier_attribute] = identifier
return (
cls.get(**{cls.identifier_attribute: identifier})
or cls.get(id=identifier)
or None
)
results = cls.query(**kwargs)
return results[0] if results else None
@ -233,26 +237,6 @@ class MemoryModel(BackendModel):
def identifier(self):
return getattr(self, self.identifier_attribute)
def match_filter(self, filter):
if filter is None:
return True
if isinstance(filter, dict):
# not super generic, but we can improve this when we have
# type checking and/or pydantic for the models
if "groups" in filter:
if models.Group.get(id=filter["groups"]):
filter["groups"] = models.Group.get(id=filter["groups"])
elif models.Group.get(display_name=filter["groups"]):
filter["groups"] = models.Group.get(display_name=filter["groups"])
return all(
getattr(self, attribute) and value in getattr(self, attribute)
for attribute, value in filter.items()
)
return any(self.match_filter(subfilter) for subfilter in filter)
class User(canaille.core.models.User, MemoryModel):
identifier_attribute = "user_name"

View file

@ -1,9 +1,11 @@
import datetime
import inspect
import typing
from collections import ChainMap
from typing import Optional
from canaille.app import classproperty
from canaille.app import models
class Model:
@ -141,3 +143,45 @@ class BackendModel:
George
"""
raise NotImplementedError()
@classmethod
def get_attribute_type(cls, attribute_name):
"""Reads the attribute typing and extract the type, possibly burried
under list or Optional."""
attribute = cls.attributes[attribute_name]
core_type = (
typing.get_args(attribute)[0]
if typing.get_origin(attribute) == list
else attribute
)
return (
typing._eval_type(core_type, globals(), locals())
if isinstance(core_type, typing.ForwardRef)
else core_type
)
def match_filter(self, filter):
if filter is None:
return True
if isinstance(filter, list):
return any(self.match_filter(subfilter) for subfilter in filter)
# If attribute are models, resolve the instance
for attribute, value in filter.items():
attribute_type = self.get_attribute_type(attribute)
if not inspect.isclass(attribute_type) or not issubclass(
attribute_type, Model
):
continue
model = getattr(models, attribute_type.__name__)
if instance := model.get(value):
filter[attribute] = instance
return all(
getattr(self, attribute) and value in getattr(self, attribute)
for attribute, value in filter.items()
)

View file

@ -22,7 +22,6 @@ from sqlalchemy_utils import force_auto_coercion
import canaille.core.models
import canaille.oidc.models
from canaille.app import models
from canaille.backends.models import BackendModel
from .backend import Backend
@ -82,9 +81,13 @@ class SqlAlchemyModel(BackendModel):
return getattr(cls, name) == value
@classmethod
def get(cls, identifier=None, **kwargs):
def get(cls, identifier=None, /, **kwargs):
if identifier:
kwargs[cls.identifier_attribute] = identifier
return (
cls.get(**{cls.identifier_attribute: identifier})
or cls.get(id=identifier)
or None
)
filter = [
cls.attribute_filter(attribute_name, expected_value)
@ -181,33 +184,6 @@ class User(canaille.core.models.User, Base, SqlAlchemyModel):
def load_permissions(self):
super().load_permissions()
def normalize_filter_value(self, attribute, value):
# not super generic, but we can improve this when we have
# type checking and/or pydantic for the models
if attribute == "groups":
return (
models.Group.get(id=value)
or models.Group.get(display_name=value)
or None
)
return value
def match_filter(self, filter):
if filter is None:
return True
if isinstance(filter, dict):
return all(
self.normalize_filter_value(attribute, value)
in getattr(self, attribute, [])
if typing.get_origin(self.attributes[attribute]) is list
else self.normalize_filter_value(attribute, value)
== getattr(self, attribute, None)
for attribute, value in filter.items()
)
return any(self.match_filter(subfilter) for subfilter in filter)
@classmethod
def get_from_login(cls, login=None, **kwargs):
return User.get(user_name=login)

View file

@ -55,7 +55,7 @@ def test_dn_when_leading_space_in_id_attribute(testclient, backend):
assert user == models.User.get(user.identifier)
assert user == models.User.get(user_name=user.identifier)
assert user == models.User.get(dn=dn)
assert user == models.User.get(dn)
user.delete()
@ -76,7 +76,7 @@ def test_special_chars_in_rdn(testclient, backend):
assert user == models.User.get(user.identifier)
assert user == models.User.get(user_name=user.identifier)
assert user == models.User.get(dn=dn)
assert user == models.User.get(dn)
user.delete()
@ -185,7 +185,7 @@ def test_guess_object_from_dn(backend, testclient, foo_group):
foo_group.members = [foo_group]
foo_group.save()
dn = foo_group.dn
g = LDAPObject.get(dn=dn)
g = LDAPObject.get(dn)
assert isinstance(g, models.Group)
assert g == foo_group
assert g.display_name == foo_group.display_name