refactor: improve memory model serialization

do not systematically store every attributes as a list
This commit is contained in:
Éloi Rivard 2024-03-31 12:06:13 +02:00
parent 8834c65bea
commit 006bf08b3d
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184

View file

@ -36,8 +36,8 @@ class MemoryModel(Model):
ids = {
id
for attribute, values in kwargs.items()
for value in cls.cardinalize(values)
for id in cls.attribute_index(attribute).get(cls.serialize(value), [])
for value in cls.serialize(cls.listify(values))
for id in cls.attribute_index(attribute).get(value, [])
}
return [cls(**cls.index()[id]) for id in ids]
@ -62,7 +62,7 @@ class MemoryModel(Model):
if any(
query.lower() in value.lower()
for attribute in attributes
for value in instance.state.get(attribute, [])
for value in cls.listify(instance.state.get(attribute, []))
if isinstance(value, str)
)
]
@ -76,16 +76,36 @@ class MemoryModel(Model):
return results[0] if results else None
@classmethod
def cardinalize(cls, value):
def listify(cls, value):
return value if isinstance(value, list) else [value]
@classmethod
def serialize(cls, value):
if isinstance(value, list):
values = [cls.serialize(item) for item in value]
return [item for item in values if item]
return value.id if isinstance(value, MemoryModel) else value
@classmethod
def deserialize(cls, value):
return value if isinstance(value, MemoryModel) else cls.get(id=value)
def deserialize(cls, attribute_name, value):
if isinstance(value, list):
values = [cls.deserialize(attribute_name, item) for item in value]
return [item for item in values if item]
if not value:
multiple_attribute = (
typing.get_origin(cls.attributes[attribute_name]) is list
)
return [] if multiple_attribute else None
if attribute_name in cls.model_attributes and not isinstance(
value, MemoryModel
):
model = getattr(models, cls.model_attributes[attribute_name][0])
return model.get(id=value)
return value
def save(self):
self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
@ -101,7 +121,7 @@ class MemoryModel(Model):
# update the index for each attribute
for attribute in self.attributes:
attribute_values = self.cardinalize(getattr(self, attribute))
attribute_values = self.listify(getattr(self, attribute))
for value in attribute_values:
self.attribute_index(attribute).setdefault(value, set()).add(self.id)
@ -132,7 +152,7 @@ class MemoryModel(Model):
# update the index for each attribute
for attribute in self.model_attributes:
attribute_values = self.cardinalize(old_state.get(attribute, []))
attribute_values = self.listify(old_state.get(attribute, []))
for value in attribute_values:
if (
value in self.attribute_index(attribute)
@ -160,7 +180,7 @@ class MemoryModel(Model):
# update the index for each attribute
for attribute in self.attributes:
attribute_values = self.cardinalize(old_state.get(attribute, []))
attribute_values = self.listify(old_state.get(attribute, []))
for value in attribute_values:
if (
value in self.attribute_index(attribute)
@ -189,28 +209,14 @@ class MemoryModel(Model):
def __getattr__(self, name):
if name in self.attributes:
values = self.cache.get(name, self.state.get(name, []))
if name in self.model_attributes:
model = getattr(models, self.model_attributes[name][0])
values = [model.deserialize(value) for value in values]
values = [value for value in values if value]
unique_attribute = typing.get_origin(self.attributes[name]) is not list
if unique_attribute:
return values[0] if values else None
else:
return values or []
return self.deserialize(name, self.cache.get(name, self.state.get(name)))
raise AttributeError()
def __setattr__(self, name, value):
if name in self.attributes:
values = self.cardinalize(value)
self.cache[name] = [value for value in values if value]
values = [self.serialize(value) for value in values]
values = [value for value in values if value]
self.state[name] = values
self.cache[name] = value
self.state[name] = self.serialize(value)
else:
super().__setattr__(name, value)