forked from Github-Mirrors/canaille
refactor: reliably detect the model attribute cardinality
This commit is contained in:
parent
7418d10efb
commit
58b967a43e
3 changed files with 10 additions and 4 deletions
|
@ -1,4 +1,5 @@
|
|||
import itertools
|
||||
import typing
|
||||
from collections.abc import Iterable
|
||||
|
||||
import ldap.dn
|
||||
|
@ -174,7 +175,7 @@ class LDAPObject(Model, metaclass=LDAPObjectMetaclass):
|
|||
if ldap_name == "dn":
|
||||
return self.dn_for(self.rdn_value)
|
||||
|
||||
python_single_value = "List" not in str(self.attributes[name])
|
||||
python_single_value = typing.get_origin(self.attributes[name]) is not list
|
||||
ldap_value = self.get_ldap_attribute(ldap_name)
|
||||
return cardinalize_attribute(python_single_value, ldap_value)
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import copy
|
||||
import datetime
|
||||
import typing
|
||||
import uuid
|
||||
|
||||
from flask import current_app
|
||||
|
@ -196,7 +197,7 @@ class MemoryModel(Model):
|
|||
]
|
||||
values = [value for value in values if value]
|
||||
|
||||
unique_attribute = "List" not in str(self.attributes[name])
|
||||
unique_attribute = typing.get_origin(self.attributes[name]) is not list
|
||||
if unique_attribute:
|
||||
return values[0] if values else None
|
||||
else:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import datetime
|
||||
import typing
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
|
@ -72,7 +73,10 @@ class SqlAlchemyModel(Model):
|
|||
if isinstance(value, list):
|
||||
return or_(cls.attribute_filter(name, v) for v in value)
|
||||
|
||||
multiple = "List" in str(cls.attributes[name])
|
||||
# extract the sqlalchemy.orm.Mapped type
|
||||
attribute_type = typing.get_args(cls.attributes[name])[0]
|
||||
multiple = typing.get_origin(attribute_type) is list
|
||||
|
||||
if multiple:
|
||||
return getattr(cls, name).contains(value)
|
||||
|
||||
|
@ -203,7 +207,7 @@ class User(canaille.core.models.User, Base, SqlAlchemyModel):
|
|||
return all(
|
||||
self.normalize_filter_value(attribute, value)
|
||||
in getattr(self, attribute, [])
|
||||
if "List" in str(self.attributes[attribute])
|
||||
if typing.get_origin(self.attributes[attribute]) is list
|
||||
else self.normalize_filter_value(attribute, value)
|
||||
== getattr(self, attribute, None)
|
||||
for attribute, value in filter.items()
|
||||
|
|
Loading…
Reference in a new issue