refactor: reliably detect the model attribute cardinality

This commit is contained in:
Éloi Rivard 2024-03-30 23:39:09 +01:00
parent 7418d10efb
commit 58b967a43e
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184
3 changed files with 10 additions and 4 deletions

View file

@ -1,4 +1,5 @@
import itertools
import typing
from collections.abc import Iterable
import ldap.dn
@ -174,7 +175,7 @@ class LDAPObject(Model, metaclass=LDAPObjectMetaclass):
if ldap_name == "dn":
return self.dn_for(self.rdn_value)
python_single_value = "List" not in str(self.attributes[name])
python_single_value = typing.get_origin(self.attributes[name]) is not list
ldap_value = self.get_ldap_attribute(ldap_name)
return cardinalize_attribute(python_single_value, ldap_value)

View file

@ -1,5 +1,6 @@
import copy
import datetime
import typing
import uuid
from flask import current_app
@ -196,7 +197,7 @@ class MemoryModel(Model):
]
values = [value for value in values if value]
unique_attribute = "List" not in str(self.attributes[name])
unique_attribute = typing.get_origin(self.attributes[name]) is not list
if unique_attribute:
return values[0] if values else None
else:

View file

@ -1,4 +1,5 @@
import datetime
import typing
import uuid
from typing import List
@ -72,7 +73,10 @@ class SqlAlchemyModel(Model):
if isinstance(value, list):
return or_(cls.attribute_filter(name, v) for v in value)
multiple = "List" in str(cls.attributes[name])
# extract the sqlalchemy.orm.Mapped type
attribute_type = typing.get_args(cls.attributes[name])[0]
multiple = typing.get_origin(attribute_type) is list
if multiple:
return getattr(cls, name).contains(value)
@ -203,7 +207,7 @@ class User(canaille.core.models.User, Base, SqlAlchemyModel):
return all(
self.normalize_filter_value(attribute, value)
in getattr(self, attribute, [])
if "List" in str(self.attributes[attribute])
if typing.get_origin(self.attributes[attribute]) is list
else self.normalize_filter_value(attribute, value)
== getattr(self, attribute, None)
for attribute, value in filter.items()