refactor: fix coverage

This commit is contained in:
Éloi Rivard 2024-04-05 15:21:35 +02:00
parent ec7a721336
commit 47ef573917
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184
2 changed files with 7 additions and 17 deletions

View file

@ -150,9 +150,6 @@ class MemoryModel(BackendModel):
mirror_attribute, model
).setdefault(self.id, set())
for subinstance_id in self.listify(self._state.get(attribute, [])):
if not subinstance_id or subinstance_id not in self.index(model):
continue
# add the current objet in the subinstance state
subinstance_state = self.index(model)[subinstance_id]
subinstance_state.setdefault(mirror_attribute, [])
@ -171,10 +168,6 @@ class MemoryModel(BackendModel):
for attribute in self.model_attributes:
attribute_values = self.listify(old_state.get(attribute, []))
for value in attribute_values:
if (
value in self.attribute_index(attribute)
and self.id in self.attribute_index(attribute)[value]
):
self.attribute_index(attribute)[value].remove(self.id)
# update the mirror attributes of the submodel instances
@ -186,15 +179,11 @@ class MemoryModel(BackendModel):
mirror_attribute, model
).setdefault(self.id, set())
for subinstance_id in self.index()[self.id].get(attribute, []):
if subinstance_id not in self.index(model):
continue
# remove the current objet from the subinstance state
subinstance_state = self.index(model)[subinstance_id]
subinstance_state[mirror_attribute].remove(self.id)
# remove the current objet from the subinstance index
if subinstance_id in mirror_attribute_index:
mirror_attribute_index.remove(subinstance_id)
# update the index for each attribute

View file

@ -193,10 +193,11 @@ class User(canaille.core.models.User, Base, SqlAlchemyModel):
# not super generic, but we can improve this when we have
# type checking and/or pydantic for the models
if attribute == "groups":
if models.Group.get(id=value):
return models.Group.get(id=value)
elif models.Group.get(display_name=value):
return models.Group.get(display_name=value)
return (
models.Group.get(id=value)
or models.Group.get(display_name=value)
or None
)
return value
def match_filter(self, filter):