refactor: Use annotations to mark model attributes

This commit is contained in:
Éloi Rivard 2024-04-21 11:47:23 +02:00
parent 5965cef133
commit 9c86f5e9af
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184
3 changed files with 46 additions and 63 deletions

View file

@ -3,8 +3,6 @@ import datetime
import typing
import uuid
from typing import ClassVar
from typing import Dict
from typing import Optional
import canaille.core.models
import canaille.oidc.models
@ -115,11 +113,10 @@ class MemoryModel(BackendModel):
)
return [] if multiple_attribute else None
if attribute_name in cls.model_attributes and not isinstance(
value, MemoryModel
):
model = getattr(models, cls.model_attributes[attribute_name][0])
return model.get(id=value)
model, _ = cls.get_model_annotations(attribute_name)
if model and not isinstance(value, model):
backend_model = getattr(models, model.__name__)
return backend_model.get(id=value)
return value
@ -151,17 +148,17 @@ class MemoryModel(BackendModel):
self.attribute_index(attribute).setdefault(value, set()).add(self.id)
# update the mirror attributes of the submodel instances
for attribute in self.model_attributes:
model, mirror_attribute = self.model_attributes[attribute]
if not self.index(model) or not mirror_attribute:
for attribute in self.attributes:
model, mirror_attribute = self.get_model_annotations(attribute)
if not model or not self.index(model.__name__) or not mirror_attribute:
continue
mirror_attribute_index = self.attribute_index(
mirror_attribute, model
mirror_attribute, model.__name__
).setdefault(self.id, set())
for subinstance_id in self.listify(self._state.get(attribute, [])):
# add the current objet in the subinstance state
subinstance_state = self.index(model)[subinstance_id]
subinstance_state = self.index(model.__name__)[subinstance_id]
subinstance_state.setdefault(mirror_attribute, [])
subinstance_state[mirror_attribute].append(self.id)
@ -175,22 +172,22 @@ class MemoryModel(BackendModel):
old_state = self.index()[self.id]
# update the mirror attributes of the submodel instances
for attribute in self.model_attributes:
for attribute in self.attributes:
attribute_values = self.listify(old_state.get(attribute, []))
for value in attribute_values:
self.attribute_index(attribute)[value].remove(self.id)
# update the mirror attributes of the submodel instances
model, mirror_attribute = self.model_attributes[attribute]
if not self.index(model) or not mirror_attribute:
model, mirror_attribute = self.get_model_annotations(attribute)
if not model or not self.index(model.__name__) or not mirror_attribute:
continue
mirror_attribute_index = self.attribute_index(
mirror_attribute, model
mirror_attribute, model.__name__
).setdefault(self.id, set())
for subinstance_id in self.index()[self.id].get(attribute, []):
# remove the current objet from the subinstance state
subinstance_state = self.index(model)[subinstance_id]
subinstance_state = self.index(model.__name__)[subinstance_id]
subinstance_state[mirror_attribute].remove(self.id)
# remove the current objet from the subinstance index
@ -245,45 +242,23 @@ class MemoryModel(BackendModel):
class User(canaille.core.models.User, MemoryModel):
identifier_attribute: ClassVar[str] = "user_name"
model_attributes: ClassVar[Dict[str, Optional[str]]] = {
"groups": ("Group", "members"),
}
class Group(canaille.core.models.Group, MemoryModel):
identifier_attribute: ClassVar[str] = "display_name"
model_attributes: ClassVar[Dict[str, Optional[str]]] = {
"members": ("User", "groups"),
}
class Client(canaille.oidc.models.Client, MemoryModel):
identifier_attribute: ClassVar[str] = "client_id"
model_attributes: ClassVar[Dict[str, Optional[str]]] = {
"audience": ("Client", None),
}
class AuthorizationCode(canaille.oidc.models.AuthorizationCode, MemoryModel):
identifier_attribute: ClassVar[str] = "authorization_code_id"
model_attributes: ClassVar[Dict[str, Optional[str]]] = {
"client": ("Client", None),
"subject": ("User", None),
}
class Token(canaille.oidc.models.Token, MemoryModel):
identifier_attribute: ClassVar[str] = "token_id"
model_attributes: ClassVar[Dict[str, Optional[str]]] = {
"client": ("Client", None),
"subject": ("User", None),
"audience": ("Client", None),
}
class Consent(canaille.oidc.models.Consent, MemoryModel):
identifier_attribute: ClassVar[str] = "consent_id"
model_attributes: ClassVar[Dict[str, Optional[str]]] = {
"client": ("Client", None),
"subject": ("User", None),
}

View file

@ -1,12 +1,11 @@
import datetime
import inspect
import typing
from collections import ChainMap
from typing import Annotated
from typing import ClassVar
from typing import ForwardRef
from typing import List
from typing import Optional
from typing import _eval_type
from typing import get_args
from typing import get_origin
from typing import get_type_hints
@ -57,7 +56,7 @@ class Model:
if not cls._attributes:
annotations = ChainMap(
*(
get_type_hints(klass)
get_type_hints(klass, include_extras=True)
for klass in reversed(cls.__mro__)
if issubclass(klass, Model)
)
@ -161,19 +160,31 @@ class BackendModel:
raise NotImplementedError()
@classmethod
def get_attribute_type(cls, attribute_name):
"""Reads the attribute typing and extract the type, possibly burried
under list or Optional."""
attribute = cls.attributes[attribute_name]
core_type = (
get_args(attribute)[0] if get_origin(attribute) == list else attribute
def get_model_annotations(cls, attribute):
annotations = cls.attributes[attribute]
# Extract the list type from list annotations
attribute_type = (
typing.get_args(annotations)[0]
if typing.get_origin(annotations) is list
else annotations
)
return (
_eval_type(core_type, globals(), locals())
if isinstance(core_type, ForwardRef)
else core_type
# Extract the Annotated annotation
attribute_type, metadata = (
typing.get_args(attribute_type)
if typing.get_origin(attribute_type) == Annotated
else (attribute_type, None)
)
if not inspect.isclass(attribute_type) or not issubclass(attribute_type, Model):
return None, None
if not metadata:
return attribute_type, None
return attribute_type, metadata.get("backref")
def match_filter(self, filter):
if filter is None:
return True
@ -183,18 +194,14 @@ class BackendModel:
# If attribute are models, resolve the instance
for attribute, value in filter.items():
attribute_type = self.get_attribute_type(attribute)
model, _ = self.get_model_annotations(attribute)
if (
isinstance(value, Model)
or not inspect.isclass(attribute_type)
or not issubclass(attribute_type, Model)
):
if not model or isinstance(value, Model):
continue
model = getattr(models, attribute_type.__name__)
backend_model = getattr(models, model.__name__)
if instance := model.get(value):
if instance := backend_model.get(value):
filter[attribute] = instance
return all(

View file

@ -1,4 +1,5 @@
import datetime
from typing import Annotated
from typing import List
from typing import Optional
@ -211,7 +212,7 @@ class User(Model):
department: Optional[str] = None
"""Identifies the name of a department."""
groups: List["Group"] = []
groups: List[Annotated["Group", {"backref": "members"}]] = []
"""A list of groups to which the user belongs, either through direct
membership, through nested groups, or dynamically calculated.
@ -330,7 +331,7 @@ class Group(Model):
REQUIRED.
"""
members: List["User"] = []
members: List[Annotated["User", {"backref": "groups"}]] = []
"""A list of members of the Group.
While values MAY be added or removed, sub-attributes of members are