Group methods refactoring

This commit is contained in:
Éloi Rivard 2023-04-08 00:31:22 +02:00
parent 52f7276527
commit e738faf52b
3 changed files with 21 additions and 30 deletions

View file

@ -406,7 +406,7 @@ def profile_create(current_app, form):
if "groups" in form:
groups = [Group.get(group_id) for group_id in form["groups"].data]
for group in groups:
group.add_member(user)
group.members = group.members + [user]
group.save()
if form["password1"].data:
@ -589,7 +589,7 @@ def profile_settings_edit(editor, edited_user):
else:
for attribute in form:
if attribute.name == "groups" and "groups" in editor.write:
edited_user.set_groups(attribute.data)
edited_user.groups = attribute.data
if (
"password1" in request.form

View file

@ -135,16 +135,17 @@ class User(LDAPObject):
self.load_groups()
return self._groups
def set_groups(self, values):
@groups.setter
def groups(self, values):
before = self._groups
after = [v if isinstance(v, Group) else Group.get(id=v) for v in values]
to_add = set(after) - set(before)
to_del = set(before) - set(after)
for group in to_add:
group.add_member(self)
group.members = group.members + [self]
group.save()
for group in to_del:
group.remove_member(self)
group.members = [member for member in group.members if member != self]
group.save()
self._groups = after
@ -211,12 +212,3 @@ class Group(LDAPObject):
"GROUP_NAME_ATTRIBUTE", Group.DEFAULT_NAME_ATTRIBUTE
)
return self[attribute][0]
def get_members(self):
return [member for member in self.members if member]
def add_member(self, user):
self.members = self.members + [user]
def remove_member(self, user):
self.members = [m for m in self.members if m != user]

View file

@ -87,23 +87,23 @@ def test_group_list_search(testclient, logged_admin, foo_group, bar_group):
def test_set_groups(app, user, foo_group, bar_group):
assert user in foo_group.get_members()
assert user in foo_group.members
assert user.groups[0] == foo_group
user.load_groups()
user.set_groups([foo_group, bar_group])
user.groups = [foo_group, bar_group]
bar_group = Group.get(bar_group.id)
assert user in bar_group.get_members()
assert user in bar_group.members
assert user.groups[1] == bar_group
user.load_groups()
user.set_groups([foo_group])
user.groups = [foo_group]
foo_group = Group.get(foo_group.id)
bar_group = Group.get(bar_group.id)
assert user in foo_group.get_members()
assert user not in bar_group.get_members()
assert user in foo_group.members
assert user not in bar_group.members
def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group):
@ -116,15 +116,15 @@ def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group):
user.save()
user.load_groups()
user.set_groups([foo_group])
user.groups = [foo_group]
assert user in foo_group.get_members()
assert user in foo_group.members
user.load_groups()
user.set_groups([])
user.groups = []
foo_group = Group.get(foo_group.id)
assert user.id not in foo_group.get_members()
assert user.id not in foo_group.members
user.delete()
@ -151,7 +151,7 @@ def test_moderator_can_create_edit_and_delete_group(
bar_group = Group.get("bar")
assert bar_group.display_name == "bar"
assert bar_group.description == ["yolo"]
assert bar_group.get_members() == [
assert bar_group.members == [
logged_moderator
] # Group cannot be empty so creator is added in it
res.mustcontain("bar")
@ -168,7 +168,7 @@ def test_moderator_can_create_edit_and_delete_group(
assert bar_group.display_name == "bar"
assert bar_group.description == ["yolo2"]
assert Group.get("bar2") is None
members = bar_group.get_members()
members = bar_group.members
for member in members:
res.mustcontain(member.formatted_name[0])
@ -208,7 +208,7 @@ def test_get_members_filters_non_existent_user(
foo_group.member = foo_group.member + [non_existent_user]
foo_group.save()
foo_group.get_members()
foo_group.members
assert foo_group.member == [user, non_existent_user]
@ -237,7 +237,7 @@ def test_user_list_pagination(testclient, logged_admin, foo_group):
users = fake_users(25)
for user in users:
foo_group.add_member(user)
foo_group.members = foo_group.members + [user]
foo_group.save()
res = testclient.get("/groups/foo")
@ -275,8 +275,7 @@ def test_user_list_bad_pages(testclient, logged_admin, foo_group):
def test_user_list_search(testclient, logged_admin, foo_group, user, moderator):
foo_group.add_member(logged_admin)
foo_group.add_member(moderator)
foo_group.members = foo_group.members + [logged_admin, moderator]
foo_group.save()
res = testclient.get("/groups/foo")