forked from Github-Mirrors/canaille
refactor: Use annotations to mark model attributes
This commit is contained in:
parent
5965cef133
commit
9c86f5e9af
3 changed files with 46 additions and 63 deletions
|
@ -3,8 +3,6 @@ import datetime
|
|||
import typing
|
||||
import uuid
|
||||
from typing import ClassVar
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
import canaille.core.models
|
||||
import canaille.oidc.models
|
||||
|
@ -115,11 +113,10 @@ class MemoryModel(BackendModel):
|
|||
)
|
||||
return [] if multiple_attribute else None
|
||||
|
||||
if attribute_name in cls.model_attributes and not isinstance(
|
||||
value, MemoryModel
|
||||
):
|
||||
model = getattr(models, cls.model_attributes[attribute_name][0])
|
||||
return model.get(id=value)
|
||||
model, _ = cls.get_model_annotations(attribute_name)
|
||||
if model and not isinstance(value, model):
|
||||
backend_model = getattr(models, model.__name__)
|
||||
return backend_model.get(id=value)
|
||||
|
||||
return value
|
||||
|
||||
|
@ -151,17 +148,17 @@ class MemoryModel(BackendModel):
|
|||
self.attribute_index(attribute).setdefault(value, set()).add(self.id)
|
||||
|
||||
# update the mirror attributes of the submodel instances
|
||||
for attribute in self.model_attributes:
|
||||
model, mirror_attribute = self.model_attributes[attribute]
|
||||
if not self.index(model) or not mirror_attribute:
|
||||
for attribute in self.attributes:
|
||||
model, mirror_attribute = self.get_model_annotations(attribute)
|
||||
if not model or not self.index(model.__name__) or not mirror_attribute:
|
||||
continue
|
||||
|
||||
mirror_attribute_index = self.attribute_index(
|
||||
mirror_attribute, model
|
||||
mirror_attribute, model.__name__
|
||||
).setdefault(self.id, set())
|
||||
for subinstance_id in self.listify(self._state.get(attribute, [])):
|
||||
# add the current objet in the subinstance state
|
||||
subinstance_state = self.index(model)[subinstance_id]
|
||||
subinstance_state = self.index(model.__name__)[subinstance_id]
|
||||
subinstance_state.setdefault(mirror_attribute, [])
|
||||
subinstance_state[mirror_attribute].append(self.id)
|
||||
|
||||
|
@ -175,22 +172,22 @@ class MemoryModel(BackendModel):
|
|||
old_state = self.index()[self.id]
|
||||
|
||||
# update the mirror attributes of the submodel instances
|
||||
for attribute in self.model_attributes:
|
||||
for attribute in self.attributes:
|
||||
attribute_values = self.listify(old_state.get(attribute, []))
|
||||
for value in attribute_values:
|
||||
self.attribute_index(attribute)[value].remove(self.id)
|
||||
|
||||
# update the mirror attributes of the submodel instances
|
||||
model, mirror_attribute = self.model_attributes[attribute]
|
||||
if not self.index(model) or not mirror_attribute:
|
||||
model, mirror_attribute = self.get_model_annotations(attribute)
|
||||
if not model or not self.index(model.__name__) or not mirror_attribute:
|
||||
continue
|
||||
|
||||
mirror_attribute_index = self.attribute_index(
|
||||
mirror_attribute, model
|
||||
mirror_attribute, model.__name__
|
||||
).setdefault(self.id, set())
|
||||
for subinstance_id in self.index()[self.id].get(attribute, []):
|
||||
# remove the current objet from the subinstance state
|
||||
subinstance_state = self.index(model)[subinstance_id]
|
||||
subinstance_state = self.index(model.__name__)[subinstance_id]
|
||||
subinstance_state[mirror_attribute].remove(self.id)
|
||||
|
||||
# remove the current objet from the subinstance index
|
||||
|
@ -245,45 +242,23 @@ class MemoryModel(BackendModel):
|
|||
|
||||
class User(canaille.core.models.User, MemoryModel):
|
||||
identifier_attribute: ClassVar[str] = "user_name"
|
||||
model_attributes: ClassVar[Dict[str, Optional[str]]] = {
|
||||
"groups": ("Group", "members"),
|
||||
}
|
||||
|
||||
|
||||
class Group(canaille.core.models.Group, MemoryModel):
|
||||
identifier_attribute: ClassVar[str] = "display_name"
|
||||
model_attributes: ClassVar[Dict[str, Optional[str]]] = {
|
||||
"members": ("User", "groups"),
|
||||
}
|
||||
|
||||
|
||||
class Client(canaille.oidc.models.Client, MemoryModel):
|
||||
identifier_attribute: ClassVar[str] = "client_id"
|
||||
model_attributes: ClassVar[Dict[str, Optional[str]]] = {
|
||||
"audience": ("Client", None),
|
||||
}
|
||||
|
||||
|
||||
class AuthorizationCode(canaille.oidc.models.AuthorizationCode, MemoryModel):
|
||||
identifier_attribute: ClassVar[str] = "authorization_code_id"
|
||||
model_attributes: ClassVar[Dict[str, Optional[str]]] = {
|
||||
"client": ("Client", None),
|
||||
"subject": ("User", None),
|
||||
}
|
||||
|
||||
|
||||
class Token(canaille.oidc.models.Token, MemoryModel):
|
||||
identifier_attribute: ClassVar[str] = "token_id"
|
||||
model_attributes: ClassVar[Dict[str, Optional[str]]] = {
|
||||
"client": ("Client", None),
|
||||
"subject": ("User", None),
|
||||
"audience": ("Client", None),
|
||||
}
|
||||
|
||||
|
||||
class Consent(canaille.oidc.models.Consent, MemoryModel):
|
||||
identifier_attribute: ClassVar[str] = "consent_id"
|
||||
model_attributes: ClassVar[Dict[str, Optional[str]]] = {
|
||||
"client": ("Client", None),
|
||||
"subject": ("User", None),
|
||||
}
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
import datetime
|
||||
import inspect
|
||||
import typing
|
||||
from collections import ChainMap
|
||||
from typing import Annotated
|
||||
from typing import ClassVar
|
||||
from typing import ForwardRef
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import _eval_type
|
||||
from typing import get_args
|
||||
from typing import get_origin
|
||||
from typing import get_type_hints
|
||||
|
||||
|
@ -57,7 +56,7 @@ class Model:
|
|||
if not cls._attributes:
|
||||
annotations = ChainMap(
|
||||
*(
|
||||
get_type_hints(klass)
|
||||
get_type_hints(klass, include_extras=True)
|
||||
for klass in reversed(cls.__mro__)
|
||||
if issubclass(klass, Model)
|
||||
)
|
||||
|
@ -161,19 +160,31 @@ class BackendModel:
|
|||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def get_attribute_type(cls, attribute_name):
|
||||
"""Reads the attribute typing and extract the type, possibly burried
|
||||
under list or Optional."""
|
||||
attribute = cls.attributes[attribute_name]
|
||||
core_type = (
|
||||
get_args(attribute)[0] if get_origin(attribute) == list else attribute
|
||||
def get_model_annotations(cls, attribute):
|
||||
annotations = cls.attributes[attribute]
|
||||
|
||||
# Extract the list type from list annotations
|
||||
attribute_type = (
|
||||
typing.get_args(annotations)[0]
|
||||
if typing.get_origin(annotations) is list
|
||||
else annotations
|
||||
)
|
||||
return (
|
||||
_eval_type(core_type, globals(), locals())
|
||||
if isinstance(core_type, ForwardRef)
|
||||
else core_type
|
||||
|
||||
# Extract the Annotated annotation
|
||||
attribute_type, metadata = (
|
||||
typing.get_args(attribute_type)
|
||||
if typing.get_origin(attribute_type) == Annotated
|
||||
else (attribute_type, None)
|
||||
)
|
||||
|
||||
if not inspect.isclass(attribute_type) or not issubclass(attribute_type, Model):
|
||||
return None, None
|
||||
|
||||
if not metadata:
|
||||
return attribute_type, None
|
||||
|
||||
return attribute_type, metadata.get("backref")
|
||||
|
||||
def match_filter(self, filter):
|
||||
if filter is None:
|
||||
return True
|
||||
|
@ -183,18 +194,14 @@ class BackendModel:
|
|||
|
||||
# If attribute are models, resolve the instance
|
||||
for attribute, value in filter.items():
|
||||
attribute_type = self.get_attribute_type(attribute)
|
||||
model, _ = self.get_model_annotations(attribute)
|
||||
|
||||
if (
|
||||
isinstance(value, Model)
|
||||
or not inspect.isclass(attribute_type)
|
||||
or not issubclass(attribute_type, Model)
|
||||
):
|
||||
if not model or isinstance(value, Model):
|
||||
continue
|
||||
|
||||
model = getattr(models, attribute_type.__name__)
|
||||
backend_model = getattr(models, model.__name__)
|
||||
|
||||
if instance := model.get(value):
|
||||
if instance := backend_model.get(value):
|
||||
filter[attribute] = instance
|
||||
|
||||
return all(
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import datetime
|
||||
from typing import Annotated
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
|
@ -211,7 +212,7 @@ class User(Model):
|
|||
department: Optional[str] = None
|
||||
"""Identifies the name of a department."""
|
||||
|
||||
groups: List["Group"] = []
|
||||
groups: List[Annotated["Group", {"backref": "members"}]] = []
|
||||
"""A list of groups to which the user belongs, either through direct
|
||||
membership, through nested groups, or dynamically calculated.
|
||||
|
||||
|
@ -330,7 +331,7 @@ class Group(Model):
|
|||
REQUIRED.
|
||||
"""
|
||||
|
||||
members: List["User"] = []
|
||||
members: List[Annotated["User", {"backref": "groups"}]] = []
|
||||
"""A list of members of the Group.
|
||||
|
||||
While values MAY be added or removed, sub-attributes of members are
|
||||
|
|
Loading…
Reference in a new issue