refactor: improve memory model serialization

do not systematically store every attributes as a list
This commit is contained in:
Éloi Rivard 2024-03-31 12:06:13 +02:00
parent 8834c65bea
commit 006bf08b3d
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184

View file

@ -36,8 +36,8 @@ class MemoryModel(Model):
ids = { ids = {
id id
for attribute, values in kwargs.items() for attribute, values in kwargs.items()
for value in cls.cardinalize(values) for value in cls.serialize(cls.listify(values))
for id in cls.attribute_index(attribute).get(cls.serialize(value), []) for id in cls.attribute_index(attribute).get(value, [])
} }
return [cls(**cls.index()[id]) for id in ids] return [cls(**cls.index()[id]) for id in ids]
@ -62,7 +62,7 @@ class MemoryModel(Model):
if any( if any(
query.lower() in value.lower() query.lower() in value.lower()
for attribute in attributes for attribute in attributes
for value in instance.state.get(attribute, []) for value in cls.listify(instance.state.get(attribute, []))
if isinstance(value, str) if isinstance(value, str)
) )
] ]
@ -76,16 +76,36 @@ class MemoryModel(Model):
return results[0] if results else None return results[0] if results else None
@classmethod @classmethod
def cardinalize(cls, value): def listify(cls, value):
return value if isinstance(value, list) else [value] return value if isinstance(value, list) else [value]
@classmethod @classmethod
def serialize(cls, value): def serialize(cls, value):
if isinstance(value, list):
values = [cls.serialize(item) for item in value]
return [item for item in values if item]
return value.id if isinstance(value, MemoryModel) else value return value.id if isinstance(value, MemoryModel) else value
@classmethod @classmethod
def deserialize(cls, value): def deserialize(cls, attribute_name, value):
return value if isinstance(value, MemoryModel) else cls.get(id=value) if isinstance(value, list):
values = [cls.deserialize(attribute_name, item) for item in value]
return [item for item in values if item]
if not value:
multiple_attribute = (
typing.get_origin(cls.attributes[attribute_name]) is list
)
return [] if multiple_attribute else None
if attribute_name in cls.model_attributes and not isinstance(
value, MemoryModel
):
model = getattr(models, cls.model_attributes[attribute_name][0])
return model.get(id=value)
return value
def save(self): def save(self):
self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace( self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
@ -101,7 +121,7 @@ class MemoryModel(Model):
# update the index for each attribute # update the index for each attribute
for attribute in self.attributes: for attribute in self.attributes:
attribute_values = self.cardinalize(getattr(self, attribute)) attribute_values = self.listify(getattr(self, attribute))
for value in attribute_values: for value in attribute_values:
self.attribute_index(attribute).setdefault(value, set()).add(self.id) self.attribute_index(attribute).setdefault(value, set()).add(self.id)
@ -132,7 +152,7 @@ class MemoryModel(Model):
# update the index for each attribute # update the index for each attribute
for attribute in self.model_attributes: for attribute in self.model_attributes:
attribute_values = self.cardinalize(old_state.get(attribute, [])) attribute_values = self.listify(old_state.get(attribute, []))
for value in attribute_values: for value in attribute_values:
if ( if (
value in self.attribute_index(attribute) value in self.attribute_index(attribute)
@ -160,7 +180,7 @@ class MemoryModel(Model):
# update the index for each attribute # update the index for each attribute
for attribute in self.attributes: for attribute in self.attributes:
attribute_values = self.cardinalize(old_state.get(attribute, [])) attribute_values = self.listify(old_state.get(attribute, []))
for value in attribute_values: for value in attribute_values:
if ( if (
value in self.attribute_index(attribute) value in self.attribute_index(attribute)
@ -189,28 +209,14 @@ class MemoryModel(Model):
def __getattr__(self, name): def __getattr__(self, name):
if name in self.attributes: if name in self.attributes:
values = self.cache.get(name, self.state.get(name, [])) return self.deserialize(name, self.cache.get(name, self.state.get(name)))
if name in self.model_attributes:
model = getattr(models, self.model_attributes[name][0])
values = [model.deserialize(value) for value in values]
values = [value for value in values if value]
unique_attribute = typing.get_origin(self.attributes[name]) is not list
if unique_attribute:
return values[0] if values else None
else:
return values or []
raise AttributeError() raise AttributeError()
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name in self.attributes: if name in self.attributes:
values = self.cardinalize(value) self.cache[name] = value
self.cache[name] = [value for value in values if value] self.state[name] = self.serialize(value)
values = [self.serialize(value) for value in values]
values = [value for value in values if value]
self.state[name] = values
else: else:
super().__setattr__(name, value) super().__setattr__(name, value)