refactor: move memory backend methods as classmethods

This commit is contained in:
Éloi Rivard 2024-03-31 01:05:48 +01:00
parent fa45ef6907
commit d2df12236d
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184

View file

@ -11,14 +11,6 @@ from canaille.app import models
from canaille.backends.models import Model from canaille.backends.models import Model
def listify(value):
return value if isinstance(value, list) else [value]
def serialize(value):
return value.id if isinstance(value, MemoryModel) else value
class MemoryModel(Model): class MemoryModel(Model):
indexes = {} indexes = {}
attribute_indexes = {} attribute_indexes = {}
@ -44,17 +36,14 @@ class MemoryModel(Model):
ids = { ids = {
id id
for attribute, values in kwargs.items() for attribute, values in kwargs.items()
for value in listify(values) for value in cls.cardinalize(values)
for id in cls.attribute_index(attribute).get(serialize(value), []) for id in cls.attribute_index(attribute).get(cls.serialize(value), [])
} }
return [cls(**cls.index()[id]) for id in ids] return [cls(**cls.index()[id]) for id in ids]
@classmethod @classmethod
def index(cls, class_name=None): def index(cls, class_name=None):
if not class_name: return MemoryModel.indexes.setdefault(class_name or cls.__name__, {})
class_name = cls.__name__
return MemoryModel.indexes.setdefault(class_name, {})
@classmethod @classmethod
def attribute_index(cls, attribute="id", class_name=None): def attribute_index(cls, attribute="id", class_name=None):
@ -82,9 +71,18 @@ class MemoryModel(Model):
def get(cls, identifier=None, **kwargs): def get(cls, identifier=None, **kwargs):
if identifier: if identifier:
kwargs[cls.identifier_attribute] = identifier kwargs[cls.identifier_attribute] = identifier
results = cls.query(**kwargs) results = cls.query(**kwargs)
return results[0] if results else None return results[0] if results else None
@classmethod
def cardinalize(cls, value):
return value if isinstance(value, list) else [value]
@classmethod
def serialize(cls, value):
return value.id if isinstance(value, MemoryModel) else value
def save(self): def save(self):
self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace( self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0 microsecond=0
@ -99,7 +97,7 @@ class MemoryModel(Model):
# update the index for each attribute # update the index for each attribute
for attribute in self.attributes: for attribute in self.attributes:
attribute_values = listify(getattr(self, attribute)) attribute_values = self.cardinalize(getattr(self, attribute))
for value in attribute_values: for value in attribute_values:
self.attribute_index(attribute).setdefault(value, set()).add(self.id) self.attribute_index(attribute).setdefault(value, set()).add(self.id)
@ -130,7 +128,7 @@ class MemoryModel(Model):
# update the index for each attribute # update the index for each attribute
for attribute in self.model_attributes: for attribute in self.model_attributes:
attribute_values = listify(old_state.get(attribute, [])) attribute_values = self.cardinalize(old_state.get(attribute, []))
for value in attribute_values: for value in attribute_values:
if ( if (
value in self.attribute_index(attribute) value in self.attribute_index(attribute)
@ -158,7 +156,7 @@ class MemoryModel(Model):
# update the index for each attribute # update the index for each attribute
for attribute in self.attributes: for attribute in self.attributes:
attribute_values = listify(old_state.get(attribute, [])) attribute_values = self.cardinalize(old_state.get(attribute, []))
for value in attribute_values: for value in attribute_values:
if ( if (
value in self.attribute_index(attribute) value in self.attribute_index(attribute)
@ -207,9 +205,9 @@ class MemoryModel(Model):
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name in self.attributes: if name in self.attributes:
values = listify(value) values = self.cardinalize(value)
self.cache[name] = [value for value in values if value] self.cache[name] = [value for value in values if value]
values = [serialize(value) for value in values] values = [self.serialize(value) for value in values]
values = [value for value in values if value] values = [value for value in values if value]
self.state[name] = values self.state[name] = values
else: else: