import copy import typing import canaille.core.models import canaille.oidc.models from canaille.app import models from canaille.backends import BaseBackend from canaille.backends.models import BackendModel class MemoryModel(BackendModel): indexes = {} """Associates ids and states.""" attribute_indexes = {} """Associates attribute values and ids.""" def __init__(self, *args, **kwargs): self._state = {} self._cache = {} for attribute, value in kwargs.items(): setattr(self, attribute, value) def __repr__(self): return f"<{self.__class__.__name__} id={self.id}>" @classmethod def index(cls, class_name=None): return MemoryModel.indexes.setdefault(class_name or cls.__name__, {}) @classmethod def attribute_index(cls, attribute="id", class_name=None): return MemoryModel.attribute_indexes.setdefault( class_name or cls.__name__, {} ).setdefault(attribute, {}) @classmethod def listify(cls, value): return value if isinstance(value, list) else [value] @classmethod def serialize(cls, value): if isinstance(value, list): values = [cls.serialize(item) for item in value] return [item for item in values if item] return value.id if isinstance(value, MemoryModel) else value @classmethod def deserialize(cls, attribute_name, value): if isinstance(value, list): values = [cls.deserialize(attribute_name, item) for item in value] return [item for item in values if item] if not value: multiple_attribute = ( typing.get_origin(cls.attributes[attribute_name]) is list ) return [] if multiple_attribute else None model, _ = cls.get_model_annotations(attribute_name) if model and not isinstance(value, model): backend_model = getattr(models, model.__name__) return BaseBackend.instance.get(backend_model, id=value) return value def delete(self): self.index_delete() def index_save(self): # update the id index self.index()[self.id] = copy.deepcopy(self._state) # update the index for each attribute for attribute in self.attributes: attribute_values = self.listify(self._state.get(attribute, [])) for value in attribute_values: self.attribute_index(attribute).setdefault(value, set()).add(self.id) # update the mirror attributes of the submodel instances for attribute in self.attributes: model, mirror_attribute = self.get_model_annotations(attribute) if not model or not self.index(model.__name__) or not mirror_attribute: continue mirror_attribute_index = self.attribute_index( mirror_attribute, model.__name__ ).setdefault(self.id, set()) for subinstance_id in self.listify(self._state.get(attribute, [])): # add the current objet in the subinstance state subinstance_state = self.index(model.__name__)[subinstance_id] subinstance_state.setdefault(mirror_attribute, []) subinstance_state[mirror_attribute].append(self.id) # add the current objet in the subinstance index mirror_attribute_index.add(subinstance_id) def index_delete(self): if self.id not in self.index(): return old_state = self.index()[self.id] # update the mirror attributes of the submodel instances for attribute in self.attributes: attribute_values = self.listify(old_state.get(attribute, [])) for value in attribute_values: self.attribute_index(attribute)[value].remove(self.id) # update the mirror attributes of the submodel instances model, mirror_attribute = self.get_model_annotations(attribute) if not model or not self.index(model.__name__) or not mirror_attribute: continue mirror_attribute_index = self.attribute_index( mirror_attribute, model.__name__ ).setdefault(self.id, set()) for subinstance_id in self.index()[self.id].get(attribute, []): # remove the current objet from the subinstance state subinstance_state = self.index(model.__name__)[subinstance_id] subinstance_state[mirror_attribute].remove(self.id) # remove the current objet from the subinstance index mirror_attribute_index.remove(subinstance_id) # update the index for each attribute for attribute in self.attributes: attribute_values = self.listify(old_state.get(attribute, [])) for value in attribute_values: if ( value in self.attribute_index(attribute) and self.id in self.attribute_index(attribute)[value] ): self.attribute_index(attribute)[value].remove(self.id) # update the id index del self.index()[self.id] def reload(self): self._state = BaseBackend.instance.get(self.__class__, id=self.id)._state self._cache = {} def __eq__(self, other): if other is None: return False if not isinstance(other, MemoryModel): return self == BaseBackend.instance.get(self.__class__, id=other) return self._state == other._state def __hash__(self): return hash(self.id) def __getattribute__(self, name): if name != "attributes" and name in self.attributes: return self.deserialize(name, self._cache.get(name, self._state.get(name))) return super().__getattribute__(name) def __setattr__(self, name, value): if name in self.attributes: self._cache[name] = value self._state[name] = self.serialize(value) else: super().__setattr__(name, value) class User(canaille.core.models.User, MemoryModel): pass class Group(canaille.core.models.Group, MemoryModel): pass class Client(canaille.oidc.models.Client, MemoryModel): pass class AuthorizationCode(canaille.oidc.models.AuthorizationCode, MemoryModel): pass class Token(canaille.oidc.models.Token, MemoryModel): pass class Consent(canaille.oidc.models.Consent, MemoryModel): pass