diff --git a/canaille/account.py b/canaille/account.py index b41fc355..c603ec71 100644 --- a/canaille/account.py +++ b/canaille/account.py @@ -406,7 +406,7 @@ def profile_create(current_app, form): if "groups" in form: groups = [Group.get(group_id) for group_id in form["groups"].data] for group in groups: - group.add_member(user) + group.members = group.members + [user] group.save() if form["password1"].data: @@ -589,7 +589,7 @@ def profile_settings_edit(editor, edited_user): else: for attribute in form: if attribute.name == "groups" and "groups" in editor.write: - edited_user.set_groups(attribute.data) + edited_user.groups = attribute.data if ( "password1" in request.form diff --git a/canaille/models.py b/canaille/models.py index 5eddf92a..201082d6 100644 --- a/canaille/models.py +++ b/canaille/models.py @@ -135,16 +135,17 @@ class User(LDAPObject): self.load_groups() return self._groups - def set_groups(self, values): + @groups.setter + def groups(self, values): before = self._groups after = [v if isinstance(v, Group) else Group.get(id=v) for v in values] to_add = set(after) - set(before) to_del = set(before) - set(after) for group in to_add: - group.add_member(self) + group.members = group.members + [self] group.save() for group in to_del: - group.remove_member(self) + group.members = [member for member in group.members if member != self] group.save() self._groups = after @@ -211,12 +212,3 @@ class Group(LDAPObject): "GROUP_NAME_ATTRIBUTE", Group.DEFAULT_NAME_ATTRIBUTE ) return self[attribute][0] - - def get_members(self): - return [member for member in self.members if member] - - def add_member(self, user): - self.members = self.members + [user] - - def remove_member(self, user): - self.members = [m for m in self.members if m != user] diff --git a/tests/test_groups.py b/tests/test_groups.py index ebcce6df..f2d29b84 100644 --- a/tests/test_groups.py +++ b/tests/test_groups.py @@ -87,23 +87,23 @@ def test_group_list_search(testclient, logged_admin, foo_group, bar_group): def test_set_groups(app, user, foo_group, bar_group): - assert user in foo_group.get_members() + assert user in foo_group.members assert user.groups[0] == foo_group user.load_groups() - user.set_groups([foo_group, bar_group]) + user.groups = [foo_group, bar_group] bar_group = Group.get(bar_group.id) - assert user in bar_group.get_members() + assert user in bar_group.members assert user.groups[1] == bar_group user.load_groups() - user.set_groups([foo_group]) + user.groups = [foo_group] foo_group = Group.get(foo_group.id) bar_group = Group.get(bar_group.id) - assert user in foo_group.get_members() - assert user not in bar_group.get_members() + assert user in foo_group.members + assert user not in bar_group.members def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group): @@ -116,15 +116,15 @@ def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group): user.save() user.load_groups() - user.set_groups([foo_group]) + user.groups = [foo_group] - assert user in foo_group.get_members() + assert user in foo_group.members user.load_groups() - user.set_groups([]) + user.groups = [] foo_group = Group.get(foo_group.id) - assert user.id not in foo_group.get_members() + assert user.id not in foo_group.members user.delete() @@ -151,7 +151,7 @@ def test_moderator_can_create_edit_and_delete_group( bar_group = Group.get("bar") assert bar_group.display_name == "bar" assert bar_group.description == ["yolo"] - assert bar_group.get_members() == [ + assert bar_group.members == [ logged_moderator ] # Group cannot be empty so creator is added in it res.mustcontain("bar") @@ -168,7 +168,7 @@ def test_moderator_can_create_edit_and_delete_group( assert bar_group.display_name == "bar" assert bar_group.description == ["yolo2"] assert Group.get("bar2") is None - members = bar_group.get_members() + members = bar_group.members for member in members: res.mustcontain(member.formatted_name[0]) @@ -208,7 +208,7 @@ def test_get_members_filters_non_existent_user( foo_group.member = foo_group.member + [non_existent_user] foo_group.save() - foo_group.get_members() + foo_group.members assert foo_group.member == [user, non_existent_user] @@ -237,7 +237,7 @@ def test_user_list_pagination(testclient, logged_admin, foo_group): users = fake_users(25) for user in users: - foo_group.add_member(user) + foo_group.members = foo_group.members + [user] foo_group.save() res = testclient.get("/groups/foo") @@ -275,8 +275,7 @@ def test_user_list_bad_pages(testclient, logged_admin, foo_group): def test_user_list_search(testclient, logged_admin, foo_group, user, moderator): - foo_group.add_member(logged_admin) - foo_group.add_member(moderator) + foo_group.members = foo_group.members + [logged_admin, moderator] foo_group.save() res = testclient.get("/groups/foo")