forked from Github-Mirrors/canaille
refactor: factorize match_filter in the main User class
This commit is contained in:
parent
76cd3dc169
commit
f113188368
7 changed files with 79 additions and 90 deletions
|
@ -6,6 +6,7 @@ import ldap.dn
|
|||
import ldap.filter
|
||||
from ldap.controls.readentry import PostReadControl
|
||||
|
||||
from canaille.app import classproperty
|
||||
from canaille.backends.models import BackendModel
|
||||
|
||||
from .backend import Backend
|
||||
|
@ -131,6 +132,10 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
|||
else "<LDAPOBject>"
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def identifier_attribute(cls):
|
||||
return cls.rdn_attribute
|
||||
|
||||
def __html__(self):
|
||||
return self.id
|
||||
|
||||
|
@ -302,10 +307,17 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
|||
return cls._attribute_type_by_name
|
||||
|
||||
@classmethod
|
||||
def get(cls, dn=None, **kwargs):
|
||||
def get(cls, identifier=None, /, **kwargs):
|
||||
try:
|
||||
return cls.query(dn, **kwargs)[0]
|
||||
return cls.query(identifier, **kwargs)[0]
|
||||
except (IndexError, ldap.NO_SUCH_OBJECT):
|
||||
if identifier and cls.base:
|
||||
return (
|
||||
cls.get(**{cls.identifier_attribute: identifier})
|
||||
or cls.get(id=identifier)
|
||||
or None
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -55,36 +55,11 @@ class User(canaille.core.models.User, LDAPObject):
|
|||
return cls.get(filter=filter, **kwargs)
|
||||
|
||||
def match_filter(self, filter):
|
||||
if isinstance(filter, str):
|
||||
conn = Backend.get().connection
|
||||
filter_ = self.acl_filter_to_ldap_filter(filter)
|
||||
return not filter_ or (
|
||||
self.dn and conn.search_s(self.dn, ldap.SCOPE_SUBTREE, filter_)
|
||||
)
|
||||
return self.dn and conn.search_s(self.dn, ldap.SCOPE_SUBTREE, filter)
|
||||
|
||||
@classmethod
|
||||
def acl_filter_to_ldap_filter(cls, filter_):
|
||||
if isinstance(filter_, dict):
|
||||
# not super generic, but how can we improve this? ¯\_(ツ)_/¯
|
||||
if "groups" in filter_ and "=" not in filter_.get("groups"):
|
||||
group_by_id = Group.get(id=filter_["groups"])
|
||||
filter_["groups"] = (
|
||||
group_by_id.dn if group_by_id else Group.dn_for(filter_["groups"])
|
||||
)
|
||||
|
||||
base = "".join(
|
||||
f"({cls.python_attribute_to_ldap(key)}={value})"
|
||||
for key, value in filter_.items()
|
||||
)
|
||||
return f"(&{base})" if len(filter_) > 1 else base
|
||||
|
||||
if isinstance(filter_, list):
|
||||
return (
|
||||
"(|"
|
||||
+ "".join(cls.acl_filter_to_ldap_filter(mapping) for mapping in filter_)
|
||||
+ ")"
|
||||
)
|
||||
|
||||
return filter_
|
||||
return super().match_filter(filter)
|
||||
|
||||
@classmethod
|
||||
def get(cls, *args, **kwargs):
|
||||
|
@ -159,9 +134,7 @@ class User(canaille.core.models.User, LDAPObject):
|
|||
return super().save(*args, **kwargs)
|
||||
|
||||
old_groups = self.state.get(group_attr) or []
|
||||
new_groups = [
|
||||
v if isinstance(v, Group) else Group.get(dn=v) for v in new_groups
|
||||
]
|
||||
new_groups = [v if isinstance(v, Group) else Group.get(v) for v in new_groups]
|
||||
to_add = set(new_groups) - set(old_groups)
|
||||
to_del = set(old_groups) - set(new_groups)
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ def ldap_to_python(value, syntax):
|
|||
return value.decode("utf-8").upper() == "TRUE"
|
||||
|
||||
if syntax == Syntax.DISTINGUISHED_NAME:
|
||||
return LDAPObject.get(dn=value.decode("utf-8"))
|
||||
return LDAPObject.get(value.decode("utf-8"))
|
||||
|
||||
return value.decode("utf-8")
|
||||
|
||||
|
|
|
@ -73,9 +73,13 @@ class MemoryModel(BackendModel):
|
|||
]
|
||||
|
||||
@classmethod
|
||||
def get(cls, identifier=None, **kwargs):
|
||||
def get(cls, identifier=None, /, **kwargs):
|
||||
if identifier:
|
||||
kwargs[cls.identifier_attribute] = identifier
|
||||
return (
|
||||
cls.get(**{cls.identifier_attribute: identifier})
|
||||
or cls.get(id=identifier)
|
||||
or None
|
||||
)
|
||||
|
||||
results = cls.query(**kwargs)
|
||||
return results[0] if results else None
|
||||
|
@ -233,26 +237,6 @@ class MemoryModel(BackendModel):
|
|||
def identifier(self):
|
||||
return getattr(self, self.identifier_attribute)
|
||||
|
||||
def match_filter(self, filter):
|
||||
if filter is None:
|
||||
return True
|
||||
|
||||
if isinstance(filter, dict):
|
||||
# not super generic, but we can improve this when we have
|
||||
# type checking and/or pydantic for the models
|
||||
if "groups" in filter:
|
||||
if models.Group.get(id=filter["groups"]):
|
||||
filter["groups"] = models.Group.get(id=filter["groups"])
|
||||
elif models.Group.get(display_name=filter["groups"]):
|
||||
filter["groups"] = models.Group.get(display_name=filter["groups"])
|
||||
|
||||
return all(
|
||||
getattr(self, attribute) and value in getattr(self, attribute)
|
||||
for attribute, value in filter.items()
|
||||
)
|
||||
|
||||
return any(self.match_filter(subfilter) for subfilter in filter)
|
||||
|
||||
|
||||
class User(canaille.core.models.User, MemoryModel):
|
||||
identifier_attribute = "user_name"
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import datetime
|
||||
import inspect
|
||||
import typing
|
||||
from collections import ChainMap
|
||||
from typing import Optional
|
||||
|
||||
from canaille.app import classproperty
|
||||
from canaille.app import models
|
||||
|
||||
|
||||
class Model:
|
||||
|
@ -141,3 +143,45 @@ class BackendModel:
|
|||
George
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def get_attribute_type(cls, attribute_name):
|
||||
"""Reads the attribute typing and extract the type, possibly burried
|
||||
under list or Optional."""
|
||||
attribute = cls.attributes[attribute_name]
|
||||
core_type = (
|
||||
typing.get_args(attribute)[0]
|
||||
if typing.get_origin(attribute) == list
|
||||
else attribute
|
||||
)
|
||||
return (
|
||||
typing._eval_type(core_type, globals(), locals())
|
||||
if isinstance(core_type, typing.ForwardRef)
|
||||
else core_type
|
||||
)
|
||||
|
||||
def match_filter(self, filter):
|
||||
if filter is None:
|
||||
return True
|
||||
|
||||
if isinstance(filter, list):
|
||||
return any(self.match_filter(subfilter) for subfilter in filter)
|
||||
|
||||
# If attribute are models, resolve the instance
|
||||
for attribute, value in filter.items():
|
||||
attribute_type = self.get_attribute_type(attribute)
|
||||
|
||||
if not inspect.isclass(attribute_type) or not issubclass(
|
||||
attribute_type, Model
|
||||
):
|
||||
continue
|
||||
|
||||
model = getattr(models, attribute_type.__name__)
|
||||
|
||||
if instance := model.get(value):
|
||||
filter[attribute] = instance
|
||||
|
||||
return all(
|
||||
getattr(self, attribute) and value in getattr(self, attribute)
|
||||
for attribute, value in filter.items()
|
||||
)
|
||||
|
|
|
@ -22,7 +22,6 @@ from sqlalchemy_utils import force_auto_coercion
|
|||
|
||||
import canaille.core.models
|
||||
import canaille.oidc.models
|
||||
from canaille.app import models
|
||||
from canaille.backends.models import BackendModel
|
||||
|
||||
from .backend import Backend
|
||||
|
@ -82,9 +81,13 @@ class SqlAlchemyModel(BackendModel):
|
|||
return getattr(cls, name) == value
|
||||
|
||||
@classmethod
|
||||
def get(cls, identifier=None, **kwargs):
|
||||
def get(cls, identifier=None, /, **kwargs):
|
||||
if identifier:
|
||||
kwargs[cls.identifier_attribute] = identifier
|
||||
return (
|
||||
cls.get(**{cls.identifier_attribute: identifier})
|
||||
or cls.get(id=identifier)
|
||||
or None
|
||||
)
|
||||
|
||||
filter = [
|
||||
cls.attribute_filter(attribute_name, expected_value)
|
||||
|
@ -181,33 +184,6 @@ class User(canaille.core.models.User, Base, SqlAlchemyModel):
|
|||
def load_permissions(self):
|
||||
super().load_permissions()
|
||||
|
||||
def normalize_filter_value(self, attribute, value):
|
||||
# not super generic, but we can improve this when we have
|
||||
# type checking and/or pydantic for the models
|
||||
if attribute == "groups":
|
||||
return (
|
||||
models.Group.get(id=value)
|
||||
or models.Group.get(display_name=value)
|
||||
or None
|
||||
)
|
||||
return value
|
||||
|
||||
def match_filter(self, filter):
|
||||
if filter is None:
|
||||
return True
|
||||
|
||||
if isinstance(filter, dict):
|
||||
return all(
|
||||
self.normalize_filter_value(attribute, value)
|
||||
in getattr(self, attribute, [])
|
||||
if typing.get_origin(self.attributes[attribute]) is list
|
||||
else self.normalize_filter_value(attribute, value)
|
||||
== getattr(self, attribute, None)
|
||||
for attribute, value in filter.items()
|
||||
)
|
||||
|
||||
return any(self.match_filter(subfilter) for subfilter in filter)
|
||||
|
||||
@classmethod
|
||||
def get_from_login(cls, login=None, **kwargs):
|
||||
return User.get(user_name=login)
|
||||
|
|
|
@ -55,7 +55,7 @@ def test_dn_when_leading_space_in_id_attribute(testclient, backend):
|
|||
|
||||
assert user == models.User.get(user.identifier)
|
||||
assert user == models.User.get(user_name=user.identifier)
|
||||
assert user == models.User.get(dn=dn)
|
||||
assert user == models.User.get(dn)
|
||||
|
||||
user.delete()
|
||||
|
||||
|
@ -76,7 +76,7 @@ def test_special_chars_in_rdn(testclient, backend):
|
|||
|
||||
assert user == models.User.get(user.identifier)
|
||||
assert user == models.User.get(user_name=user.identifier)
|
||||
assert user == models.User.get(dn=dn)
|
||||
assert user == models.User.get(dn)
|
||||
|
||||
user.delete()
|
||||
|
||||
|
@ -185,7 +185,7 @@ def test_guess_object_from_dn(backend, testclient, foo_group):
|
|||
foo_group.members = [foo_group]
|
||||
foo_group.save()
|
||||
dn = foo_group.dn
|
||||
g = LDAPObject.get(dn=dn)
|
||||
g = LDAPObject.get(dn)
|
||||
assert isinstance(g, models.Group)
|
||||
assert g == foo_group
|
||||
assert g.display_name == foo_group.display_name
|
||||
|
|
Loading…
Reference in a new issue