diff --git a/canaille/backends/__init__.py b/canaille/backends/__init__.py index 6b5a4835..e2d2a800 100644 --- a/canaille/backends/__init__.py +++ b/canaille/backends/__init__.py @@ -93,6 +93,21 @@ class BaseBackend: """Remove the current instance from the database.""" raise NotImplementedError() + def reload(self, instance): + """Cancel the unsaved modifications. + + >>> user = User.get(user_name="george") + >>> user.display_name + George + >>> user.display_name = "Jane" + >>> user.display_name + Jane + >>> BaseBackend.instance.reload(user) + >>> user.display_name + George + """ + raise NotImplementedError() + def check_user_password(self, user, password: str) -> bool: """Check if the password matches the user password in the database.""" raise NotImplementedError() diff --git a/canaille/backends/ldap/backend.py b/canaille/backends/ldap/backend.py index 132a7142..cd01eb84 100644 --- a/canaille/backends/ldap/backend.py +++ b/canaille/backends/ldap/backend.py @@ -401,6 +401,20 @@ class Backend(BaseBackend): # run the instance delete callback again if existing next(save_callback, None) + def reload(self, instance): + # run the instance reload callback if existing + reload_callback = instance.reload() if hasattr(instance, "reload") else iter([]) + next(reload_callback, None) + + result = self.connection.search_s( + instance.dn, ldap.SCOPE_SUBTREE, None, ["+", "*"] + ) + instance.changes = {} + instance.state = result[0][1] + + # run the instance reload callback again if existing + next(reload_callback, None) + def setup_ldap_models(config): from canaille.app import models diff --git a/canaille/backends/ldap/ldapobject.py b/canaille/backends/ldap/ldapobject.py index 5145d4ac..44527389 100644 --- a/canaille/backends/ldap/ldapobject.py +++ b/canaille/backends/ldap/ldapobject.py @@ -256,9 +256,3 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass): @classmethod def python_attribute_to_ldap(cls, name): return cls.attribute_map.get(name, name) if cls.attribute_map else None - - def reload(self): - conn = Backend.instance.connection - result = conn.search_s(self.dn, ldap.SCOPE_SUBTREE, None, ["+", "*"]) - self.changes = {} - self.state = result[0][1] diff --git a/canaille/backends/memory/backend.py b/canaille/backends/memory/backend.py index 088c62a4..fde4e7d3 100644 --- a/canaille/backends/memory/backend.py +++ b/canaille/backends/memory/backend.py @@ -118,3 +118,16 @@ class Backend(BaseBackend): # run the instance delete callback again if existing next(delete_callback, None) + + def reload(self, instance): + # run the instance reload callback if existing + reload_callback = instance.reload() if hasattr(instance, "reload") else iter([]) + next(reload_callback, None) + + instance._state = BaseBackend.instance.get( + instance.__class__, id=instance.id + )._state + instance._cache = {} + + # run the instance reload callback again if existing + next(reload_callback, None) diff --git a/canaille/backends/memory/models.py b/canaille/backends/memory/models.py index 44ec87a8..8b8ceb89 100644 --- a/canaille/backends/memory/models.py +++ b/canaille/backends/memory/models.py @@ -134,10 +134,6 @@ class MemoryModel(BackendModel): # update the id index del self.index()[self.id] - def reload(self): - self._state = BaseBackend.instance.get(self.__class__, id=self.id)._state - self._cache = {} - def __eq__(self, other): if other is None: return False diff --git a/canaille/backends/models.py b/canaille/backends/models.py index be361684..3354f50e 100644 --- a/canaille/backends/models.py +++ b/canaille/backends/models.py @@ -105,21 +105,6 @@ class BackendModel: for attribute, value in kwargs.items(): setattr(self, attribute, value) - def reload(self): - """Cancel the unsaved modifications. - - >>> user = User.get(user_name="george") - >>> user.display_name - George - >>> user.display_name = "Jane" - >>> user.display_name - Jane - >>> user.reload() - >>> user.display_name - George - """ - raise NotImplementedError() - @classmethod def get_model_annotations(cls, attribute): annotations = cls.attributes[attribute] diff --git a/canaille/backends/sql/backend.py b/canaille/backends/sql/backend.py index 2509e0f5..f1acaaab 100644 --- a/canaille/backends/sql/backend.py +++ b/canaille/backends/sql/backend.py @@ -129,3 +129,13 @@ class Backend(BaseBackend): # run the instance delete callback again if existing next(save_callback, None) + + def reload(self, instance): + # run the instance reload callback if existing + reload_callback = instance.reload() if hasattr(instance, "reload") else iter([]) + next(reload_callback, None) + + Backend.instance.db_session.refresh(instance) + + # run the instance reload callback again if existing + next(reload_callback, None) diff --git a/canaille/backends/sql/models.py b/canaille/backends/sql/models.py index 7cf140ae..d66ded4f 100644 --- a/canaille/backends/sql/models.py +++ b/canaille/backends/sql/models.py @@ -22,7 +22,6 @@ import canaille.core.models import canaille.oidc.models from canaille.backends.models import BackendModel -from .backend import Backend from .backend import Base from .utils import TZDateTime @@ -47,9 +46,6 @@ class SqlAlchemyModel(BackendModel): return getattr(cls, name) == value - def reload(self): - Backend.instance.db_session.refresh(self) - membership_association_table = Table( "membership_association_table", diff --git a/canaille/core/endpoints/account.py b/canaille/core/endpoints/account.py index b33afb4c..6d83281f 100644 --- a/canaille/core/endpoints/account.py +++ b/canaille/core/endpoints/account.py @@ -537,7 +537,7 @@ def profile_edition_main_form_validation(user, edited_user, profile_form): edited_user.preferred_language = None BaseBackend.instance.save(edited_user) - g.user.reload() + BaseBackend.instance.reload(g.user) def profile_edition_emails_form(user, edited_user, has_smtp): diff --git a/canaille/core/models.py b/canaille/core/models.py index 1e2c0652..ce1fb4ed 100644 --- a/canaille/core/models.py +++ b/canaille/core/models.py @@ -296,7 +296,7 @@ class User(Model): self._readable = None self._writable = None self._permissions = None - super().reload() + yield @property def readable_fields(self): diff --git a/tests/app/test_forms.py b/tests/app/test_forms.py index c9b8f97a..e3d556eb 100644 --- a/tests/app/test_forms.py +++ b/tests/app/test_forms.py @@ -187,10 +187,10 @@ def test_datetime_utc_field_invalid_timezone(testclient): ) -def test_fieldlist_add_readonly(testclient, logged_user): +def test_fieldlist_add_readonly(testclient, logged_user, backend): testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["WRITE"].remove("phone_numbers") testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["READ"].append("phone_numbers") - logged_user.reload() + backend.reload(logged_user) res = testclient.get("/profile/user") form = res.forms["baseform"] @@ -209,7 +209,7 @@ def test_fieldlist_add_readonly(testclient, logged_user): def test_fieldlist_remove_readonly(testclient, logged_user, backend): testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["WRITE"].remove("phone_numbers") testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["READ"].append("phone_numbers") - logged_user.reload() + backend.reload(logged_user) logged_user.phone_numbers = ["555-555-000", "555-555-111"] backend.save(logged_user) diff --git a/tests/app/test_i18n.py b/tests/app/test_i18n.py index ac9612c2..57651c47 100644 --- a/tests/app/test_i18n.py +++ b/tests/app/test_i18n.py @@ -17,7 +17,7 @@ def test_preferred_language(testclient, logged_user, backend): assert res.flashes == [("success", "Le profil a été mis à jour avec succès.")] res = res.follow() form = res.forms["baseform"] - logged_user.reload() + backend.reload(logged_user) assert logged_user.preferred_language == "fr" assert form["preferred_language"].value == "fr" assert res.pyquery("html")[0].attrib["lang"] == "fr" @@ -29,7 +29,7 @@ def test_preferred_language(testclient, logged_user, backend): assert res.flashes == [("success", "Profile updated successfully.")] res = res.follow() form = res.forms["baseform"] - logged_user.reload() + backend.reload(logged_user) assert logged_user.preferred_language == "en" assert form["preferred_language"].value == "en" assert res.pyquery("html")[0].attrib["lang"] == "en" @@ -41,7 +41,7 @@ def test_preferred_language(testclient, logged_user, backend): assert res.flashes == [("success", "Profile updated successfully.")] res = res.follow() form = res.forms["baseform"] - logged_user.reload() + backend.reload(logged_user) assert logged_user.preferred_language is None assert form["preferred_language"].value == "auto" assert res.pyquery("html")[0].attrib["lang"] == "en" diff --git a/tests/backends/ldap/test_models.py b/tests/backends/ldap/test_models.py index 8129d1f8..9be9134d 100644 --- a/tests/backends/ldap/test_models.py +++ b/tests/backends/ldap/test_models.py @@ -8,7 +8,7 @@ def test_model_references_set_unsaved_object( exist.""" group = models.Group(members=[user], display_name="foo") backend.save(group) - user.reload() + backend.reload(user) non_existent_user = models.User( formatted_name="foo", family_name="bar", user_name="baz" @@ -19,7 +19,7 @@ def test_model_references_set_unsaved_object( backend.save(group) assert group.members == [user, non_existent_user] - group.reload() + backend.reload(group) assert group.members == [user] testclient.get("/groups/foo", status=200) diff --git a/tests/backends/ldap/test_object_class.py b/tests/backends/ldap/test_object_class.py index 400212b3..b429a828 100644 --- a/tests/backends/ldap/test_object_class.py +++ b/tests/backends/ldap/test_object_class.py @@ -92,7 +92,7 @@ homeDirectory: /home/foobar process = slapd_server.ldapmodify(ldif) assert process.returncode == 0 - user.reload() + backend.reload(user) # saving an object should not raise a ldap.OBJECT_CLASS_VIOLATION exception backend.save(user) diff --git a/tests/backends/ldap/test_permissions.py b/tests/backends/ldap/test_permissions.py index 994c9585..97fca6be 100644 --- a/tests/backends/ldap/test_permissions.py +++ b/tests/backends/ldap/test_permissions.py @@ -1,20 +1,20 @@ -def test_group_permissions_by_dn(testclient, user, foo_group): +def test_group_permissions_by_dn(testclient, user, foo_group, backend): assert not user.can_manage_users testclient.app.config["CANAILLE"]["ACL"]["ADMIN"]["FILTER"] = { "groups": foo_group.dn } - user.reload() + backend.reload(user) assert user.can_manage_users -def test_group_permissions_str(testclient, user, foo_group): +def test_group_permissions_str(testclient, user, foo_group, backend): assert not user.can_manage_users testclient.app.config["CANAILLE"]["ACL"]["ADMIN"]["FILTER"] = ( f"memberOf={foo_group.dn}" ) - user.reload() + backend.reload(user) assert user.can_manage_users diff --git a/tests/backends/test_models.py b/tests/backends/test_models.py index cd145f83..14e83b27 100644 --- a/tests/backends/test_models.py +++ b/tests/backends/test_models.py @@ -52,7 +52,7 @@ def test_model_lifecycle(testclient, backend): assert user.family_name == "new_family_name" - user.reload() + backend.reload(user) assert user.family_name == "family_name" @@ -178,14 +178,14 @@ def test_model_references(testclient, user, foo_group, admin, bar_group, backend assert bar_group not in user.groups user.groups = user.groups + [bar_group] backend.save(user) - bar_group.reload() + backend.reload(bar_group) assert user in bar_group.members assert bar_group in user.groups bar_group.members = [admin] backend.save(bar_group) - user.reload() + backend.reload(user) assert user not in bar_group.members assert bar_group not in user.groups diff --git a/tests/conftest.py b/tests/conftest.py index cedca461..c3a5f8f8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -252,7 +252,7 @@ def foo_group(app, user, backend): display_name="foo", ) backend.save(group) - user.reload() + backend.reload(user) yield group backend.delete(group) @@ -264,7 +264,7 @@ def bar_group(app, admin, backend): display_name="bar", ) backend.save(group) - admin.reload() + backend.reload(admin) yield group backend.delete(group) diff --git a/tests/core/test_account.py b/tests/core/test_account.py index cc0f353d..c54a1a76 100644 --- a/tests/core/test_account.py +++ b/tests/core/test_account.py @@ -5,7 +5,7 @@ from flask import g from canaille.app import models -def test_index(testclient, user): +def test_index(testclient, user, backend): res = testclient.get("/", status=302) assert res.location == "/login" @@ -14,7 +14,7 @@ def test_index(testclient, user): assert res.location == "/profile/user" testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = [] - g.user.reload() + backend.reload(g.user) res = testclient.get("/", status=302) assert res.location == "/about" @@ -105,7 +105,7 @@ def test_user_self_deletion(testclient, backend): "delete_account", ] # Simulate an app restart - user.reload() + backend.reload(user) res = testclient.get("/profile/temp/settings") res.mustcontain("Delete my account") diff --git a/tests/core/test_email_confirmation.py b/tests/core/test_email_confirmation.py index cf676234..eca15642 100644 --- a/tests/core/test_email_confirmation.py +++ b/tests/core/test_email_confirmation.py @@ -25,7 +25,7 @@ def test_confirmation_disabled_email_editable(testclient, backend, logged_user): assert res.flashes == [("success", "Profile updated successfully.")] res = res.follow() - logged_user.reload() + backend.reload(logged_user) assert logged_user.emails == ["email1@mydomain.tld", "email2@mydomain.tld"] @@ -51,7 +51,7 @@ def test_confirmation_unset_smtp_disabled_email_editable( assert res.flashes == [("success", "Profile updated successfully.")] res = res.follow() - user.reload() + backend.reload(user) assert user.emails == ["email1@mydomain.tld", "email2@mydomain.tld"] @@ -91,7 +91,7 @@ def test_confirmation_unset_smtp_enabled_email_admin_editable( assert res.flashes == [("success", "Profile updated successfully.")] res = res.follow() - user.reload() + backend.reload(user) assert user.emails == ["email1@mydomain.tld", "email2@mydomain.tld"] @@ -115,7 +115,7 @@ def test_confirmation_enabled_smtp_disabled_admin_editable( assert res.flashes == [("success", "Profile updated successfully.")] res = res.follow() - user.reload() + backend.reload(user) assert user.emails == ["email1@mydomain.tld", "email2@mydomain.tld"] @@ -172,7 +172,7 @@ def test_confirmation_unset_smtp_enabled_email_user_validation( res = testclient.get(email_confirmation_url) assert ("success", "Your email address have been confirmed.") in res.flashes - user.reload() + backend.reload(user) assert "new_email@mydomain.tld" in user.emails @@ -206,7 +206,7 @@ def test_confirmation_mail_form_failed(testclient, backend, user): ) assert res.flashes == [("error", "Email addition failed.")] - user.reload() + backend.reload(user) assert user.emails == ["john@doe.com"] @@ -233,7 +233,7 @@ def test_confirmation_mail_send_failed(SMTP, smtpd, testclient, backend, user): ) assert res.flashes == [("error", "Could not send the verification email")] - user.reload() + backend.reload(user) assert user.emails == ["john@doe.com"] @@ -258,7 +258,7 @@ def test_confirmation_expired_link(testclient, backend, user): "error", "The email confirmation link that brought you here has expired.", ) in res.flashes - user.reload() + backend.reload(user) assert "new_email@mydomain.tld" not in user.emails @@ -283,7 +283,7 @@ def test_confirmation_invalid_hash_link(testclient, backend, user): "error", "The invitation link that brought you here was invalid.", ) in res.flashes - user.reload() + backend.reload(user) assert "new_email@mydomain.tld" not in user.emails @@ -312,7 +312,7 @@ def test_confirmation_invalid_user_link(testclient, backend, user): "error", "The email confirmation link that brought you here is invalid.", ) in res.flashes - user.reload() + backend.reload(user) assert "new_email@mydomain.tld" not in user.emails @@ -337,7 +337,7 @@ def test_confirmation_email_already_confirmed_link(testclient, backend, user, ad "error", "This address email have already been confirmed.", ) in res.flashes - user.reload() + backend.reload(user) assert "new_email@mydomain.tld" not in user.emails @@ -367,7 +367,7 @@ def test_confirmation_email_already_used_link(testclient, backend, user, admin): "error", "This address email is already associated with another account.", ) in res.flashes - user.reload() + backend.reload(user) assert "new_email@mydomain.tld" not in user.emails @@ -387,7 +387,7 @@ def test_delete_email(testclient, logged_user, backend): ) assert res.flashes == [("success", "The email have been successfully deleted.")] - logged_user.reload() + backend.reload(logged_user) assert logged_user.emails == ["john@doe.com"] @@ -408,7 +408,7 @@ def test_delete_wrong_email(testclient, logged_user, backend): ) assert res2.flashes == [("error", "Email deletion failed.")] - logged_user.reload() + backend.reload(logged_user) assert logged_user.emails == ["john@doe.com"] @@ -429,11 +429,11 @@ def test_delete_last_email(testclient, logged_user, backend): ) assert res2.flashes == [("error", "Email deletion failed.")] - logged_user.reload() + backend.reload(logged_user) assert logged_user.emails == ["john@doe.com"] -def test_edition_forced_mail(testclient, logged_user): +def test_edition_forced_mail(testclient, logged_user, backend): """Tests that users that must perform email verification cannot force the profile form.""" res = testclient.get("/profile/user", status=200) @@ -447,7 +447,7 @@ def test_edition_forced_mail(testclient, logged_user): }, ) - logged_user.reload() + backend.reload(logged_user) assert logged_user.emails == ["john@doe.com"] diff --git a/tests/core/test_forgotten_password.py b/tests/core/test_forgotten_password.py index ee4568d5..213e6f8d 100644 --- a/tests/core/test_forgotten_password.py +++ b/tests/core/test_forgotten_password.py @@ -83,9 +83,11 @@ def test_password_forgotten_invalid(smtpd, testclient, user): assert len(smtpd.messages) == 0 -def test_password_forgotten_invalid_when_user_cannot_self_edit(smtpd, testclient, user): +def test_password_forgotten_invalid_when_user_cannot_self_edit( + smtpd, testclient, user, backend +): testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = [] - user.reload() + backend.reload(user) testclient.app.config["CANAILLE"]["HIDE_INVALID_LOGINS"] = False res = testclient.get("/reset", status=200) @@ -104,7 +106,7 @@ def test_password_forgotten_invalid_when_user_cannot_self_edit(smtpd, testclient ) in res.flashes testclient.app.config["CANAILLE"]["HIDE_INVALID_LOGINS"] = True - user.reload() + backend.reload(user) res = testclient.get("/reset", status=200) res.form["login"] = "user" diff --git a/tests/core/test_groups.py b/tests/core/test_groups.py index c9ed3514..afab4adc 100644 --- a/tests/core/test_groups.py +++ b/tests/core/test_groups.py @@ -61,11 +61,11 @@ def test_group_deletion(testclient, backend): ) backend.save(group) - user.reload() + backend.reload(user) assert user.groups == [group] backend.delete(group) - user.reload() + backend.reload(user) assert not user.groups backend.delete(user) @@ -93,15 +93,15 @@ def test_set_groups(app, user, foo_group, bar_group, backend): user.groups = [foo_group, bar_group] backend.save(user) - bar_group.reload() + backend.reload(bar_group) assert user in bar_group.members assert bar_group in user.groups user.groups = [foo_group] backend.save(user) - foo_group.reload() - bar_group.reload() + backend.reload(foo_group) + backend.reload(bar_group) assert user in foo_group.members assert user not in bar_group.members @@ -118,13 +118,13 @@ def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group, back user.groups = [foo_group] backend.save(user) - foo_group.reload() + backend.reload(foo_group) assert user in foo_group.members user.groups = [] backend.save(user) - foo_group.reload() + backend.reload(foo_group) assert user.id not in foo_group.members backend.delete(user) @@ -149,7 +149,7 @@ def test_moderator_can_create_edit_and_delete_group( # Group has been created res = form.submit(status=302).follow(status=200) - logged_moderator.reload() + backend.reload(logged_moderator) bar_group = backend.get(models.Group, display_name="bar") assert bar_group.display_name == "bar" assert bar_group.description == "yolo" @@ -221,13 +221,13 @@ def test_invalid_form_request(testclient, logged_moderator, foo_group): res = form.submit(name="action", value="invalid-action", status=400) -def test_edition_failed(testclient, logged_moderator, foo_group): +def test_edition_failed(testclient, logged_moderator, foo_group, backend): res = testclient.get("/groups/foo") form = res.forms["editgroupform"] form["display_name"] = "" res = form.submit(name="action", value="edit") res.mustcontain("Group edition failed.") - foo_group.reload() + backend.reload(foo_group) assert foo_group.display_name == "foo" diff --git a/tests/core/test_invitation.py b/tests/core/test_invitation.py index c4a11258..a157b226 100644 --- a/tests/core/test_invitation.py +++ b/tests/core/test_invitation.py @@ -47,7 +47,7 @@ def test_invitation(testclient, logged_admin, foo_group, smtpd, backend): res = res.follow(status=200) user = backend.get(models.User, user_name="someone") - foo_group.reload() + backend.reload(foo_group) assert backend.check_user_password(user, "whatever")[0] assert user.groups == [foo_group] @@ -103,7 +103,7 @@ def test_invitation_editable_user_name( res = res.follow(status=200) user = backend.get(models.User, user_name="djorje") - foo_group.reload() + backend.reload(foo_group) assert backend.check_user_password(user, "whatever")[0] assert user.groups == [foo_group] @@ -150,7 +150,7 @@ def test_generate_link(testclient, logged_admin, foo_group, smtpd, backend): res = res.follow(status=200) user = backend.get(models.User, user_name="sometwo") - foo_group.reload() + backend.reload(foo_group) assert backend.check_user_password(user, "whatever")[0] assert user.groups == [foo_group] @@ -332,6 +332,6 @@ def test_groups_are_saved_even_when_user_does_not_have_read_permission( res = res.follow(status=200) user = backend.get(models.User, user_name="someoneelse") - foo_group.reload() + backend.reload(foo_group) assert user.groups == [foo_group] backend.delete(user) diff --git a/tests/core/test_password_reset.py b/tests/core/test_password_reset.py index dc7fec8f..49cf4d49 100644 --- a/tests/core/test_password_reset.py +++ b/tests/core/test_password_reset.py @@ -13,7 +13,7 @@ def test_password_reset(testclient, user, backend): assert ("success", "Your password has been updated successfully") in res.flashes assert res.location == "/profile/user" - user.reload() + backend.reload(user) assert backend.check_user_password(user, "foobarbaz")[0] res = testclient.get("/reset/user/" + hash) @@ -37,7 +37,7 @@ def test_password_reset_multiple_emails(testclient, user, backend): res = res.form.submit() assert ("success", "Your password has been updated successfully") in res.flashes - user.reload() + backend.reload(user) assert backend.check_user_password(user, "foobarbaz")[0] res = testclient.get("/reset/user/" + hash) diff --git a/tests/core/test_permissions.py b/tests/core/test_permissions.py index 1776081f..89a05e26 100644 --- a/tests/core/test_permissions.py +++ b/tests/core/test_permissions.py @@ -1,29 +1,29 @@ -def test_group_permissions_by_id(testclient, user, foo_group): +def test_group_permissions_by_id(testclient, user, foo_group, backend): assert not user.can_manage_users testclient.app.config["CANAILLE"]["ACL"]["ADMIN"]["FILTER"] = { "groups": foo_group.id } - user.reload() + backend.reload(user) assert user.can_manage_users -def test_group_permissions_by_display_name(testclient, user, foo_group): +def test_group_permissions_by_display_name(testclient, user, foo_group, backend): assert not user.can_manage_users testclient.app.config["CANAILLE"]["ACL"]["ADMIN"]["FILTER"] = { "groups": foo_group.display_name } - user.reload() + backend.reload(user) assert user.can_manage_users -def test_invalid_group_permission(testclient, user, foo_group): +def test_invalid_group_permission(testclient, user, foo_group, backend): assert not user.can_manage_users testclient.app.config["CANAILLE"]["ACL"]["ADMIN"]["FILTER"] = {"groups": "invalid"} - user.reload() + backend.reload(user) assert not user.can_manage_users diff --git a/tests/core/test_profile_creation.py b/tests/core/test_profile_creation.py index e5b48be2..6827df3c 100644 --- a/tests/core/test_profile_creation.py +++ b/tests/core/test_profile_creation.py @@ -25,7 +25,7 @@ def test_user_creation_edition_and_deletion( assert ("success", "User account creation succeed.") in res.flashes res = res.follow(status=200) george = backend.get(models.User, user_name="george") - foo_group.reload() + backend.reload(foo_group) assert "George" == george.given_name assert george.groups == [foo_group] assert backend.check_user_password(george, "totoyolo")[0] @@ -49,8 +49,8 @@ def test_user_creation_edition_and_deletion( assert "Georgio" == george.given_name assert backend.check_user_password(george, "totoyolo")[0] - foo_group.reload() - bar_group.reload() + backend.reload(foo_group) + backend.reload(bar_group) assert george in set(foo_group.members) assert george in set(bar_group.members) assert set(george.groups) == {foo_group, bar_group} diff --git a/tests/core/test_profile_edition.py b/tests/core/test_profile_edition.py index 5ab50373..ea11122c 100644 --- a/tests/core/test_profile_edition.py +++ b/tests/core/test_profile_edition.py @@ -71,7 +71,7 @@ def test_user_list_search(testclient, logged_admin, user, moderator): def test_user_list_search_only_allowed_fields( - testclient, logged_admin, user, moderator + testclient, logged_admin, user, moderator, backend ): res = testclient.get("/users") res.mustcontain("3 items") @@ -88,7 +88,7 @@ def test_user_list_search_only_allowed_fields( testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["READ"].remove("user_name") testclient.app.config["CANAILLE"]["ACL"]["ADMIN"]["READ"].remove("user_name") - g.user.reload() + backend.reload(g.user) form = res.forms["search"] form["query"] = "user" @@ -103,13 +103,14 @@ def test_edition_permission( testclient, logged_user, admin, + backend, ): testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = [] - logged_user.reload() + backend.reload(logged_user) testclient.get("/profile/user", status=404) testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = ["edit_self"] - g.user.reload() + backend.reload(g.user) testclient.get("/profile/user", status=200) @@ -145,7 +146,7 @@ def test_edition( ], res.text res = res.follow() - logged_user.reload() + backend.reload(logged_user) assert logged_user.given_name == "given_name" assert logged_user.family_name == "family_name" @@ -187,7 +188,7 @@ def test_edition_remove_fields( assert res.flashes == [("success", "Profile updated successfully.")], res.text res = res.follow() - logged_user.reload() + backend.reload(logged_user) assert not logged_user.display_name assert not logged_user.phone_numbers @@ -212,7 +213,7 @@ def test_field_permissions_none(testclient, logged_user, backend): "FILTER": None, } - g.user.reload() + backend.reload(g.user) res = testclient.get("/profile/user", status=200) form = res.forms["baseform"] assert "phone_numbers-0" not in form.fields @@ -225,7 +226,7 @@ def test_field_permissions_none(testclient, logged_user, backend): "csrf_token": form["csrf_token"].value, }, ) - logged_user.reload() + backend.reload(logged_user) assert logged_user.phone_numbers == ["555-666-777"] @@ -240,7 +241,7 @@ def test_field_permissions_read(testclient, logged_user, backend): "PERMISSIONS": ["edit_self"], "FILTER": None, } - g.user.reload() + backend.reload(g.user) res = testclient.get("/profile/user", status=200) form = res.forms["baseform"] assert "phone_numbers-0" in form.fields @@ -253,7 +254,7 @@ def test_field_permissions_read(testclient, logged_user, backend): "csrf_token": form["csrf_token"].value, }, ) - logged_user.reload() + backend.reload(logged_user) assert logged_user.phone_numbers == ["555-666-777"] @@ -268,7 +269,7 @@ def test_field_permissions_write(testclient, logged_user, backend): "PERMISSIONS": ["edit_self"], "FILTER": None, } - g.user.reload() + backend.reload(g.user) res = testclient.get("/profile/user", status=200) form = res.forms["baseform"] assert "phone_numbers-0" in form.fields @@ -281,7 +282,7 @@ def test_field_permissions_write(testclient, logged_user, backend): "csrf_token": form["csrf_token"].value, }, ) - logged_user.reload() + backend.reload(logged_user) assert logged_user.phone_numbers == ["000-000-000"] @@ -306,7 +307,7 @@ def test_admin_bad_request(testclient, logged_moderator): testclient.get("/profile/foobar", status=404) -def test_bad_email(testclient, logged_user): +def test_bad_email(testclient, logged_user, backend): res = testclient.get("/profile/user", status=200) form = res.forms["baseform"] @@ -323,12 +324,12 @@ def test_bad_email(testclient, logged_user): res = form.submit(name="action", value="edit-profile", status=200) - logged_user.reload() + backend.reload(logged_user) assert ["john@doe.com"] == logged_user.emails -def test_surname_is_mandatory(testclient, logged_user): +def test_surname_is_mandatory(testclient, logged_user, backend): res = testclient.get("/profile/user", status=200) form = res.forms["baseform"] logged_user.family_name = "Doe" @@ -337,7 +338,7 @@ def test_surname_is_mandatory(testclient, logged_user): res = form.submit(name="action", value="edit-profile", status=200) - logged_user.reload() + backend.reload(logged_user) assert "Doe" == logged_user.family_name @@ -391,12 +392,12 @@ def test_inline_validation(testclient, logged_admin, user): res.mustcontain("The email 'john@doe.com' is already used") -def test_inline_validation_keep_indicators(testclient, logged_admin, user): +def test_inline_validation_keep_indicators(testclient, logged_admin, user, backend): testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["WRITE"].remove("display_name") testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["READ"].append("display_name") testclient.app.config["CANAILLE"]["ACL"]["ADMIN"]["WRITE"].append("display_name") - logged_admin.reload() - user.reload() + backend.reload(logged_admin) + backend.reload(user) res = testclient.get("/profile/admin") form = res.forms["baseform"] diff --git a/tests/core/test_profile_photo.py b/tests/core/test_profile_photo.py index aab384ea..0b7ccc5c 100644 --- a/tests/core/test_profile_photo.py +++ b/tests/core/test_profile_photo.py @@ -8,7 +8,7 @@ from canaille.app import models def test_photo(testclient, user, jpeg_photo, backend): user.photo = jpeg_photo backend.save(user) - user.reload() + backend.reload(user) res = testclient.get("/profile/user/photo") assert res.body == jpeg_photo @@ -52,6 +52,7 @@ def test_photo_on_profile_edition( testclient, logged_user, jpeg_photo, + backend, ): # Add a photo res = testclient.get("/profile/user", status=200) @@ -62,7 +63,7 @@ def test_photo_on_profile_edition( assert ("success", "Profile updated successfully.") in res.flashes res = res.follow() - logged_user.reload() + backend.reload(logged_user) assert logged_user.photo == jpeg_photo @@ -74,7 +75,7 @@ def test_photo_on_profile_edition( assert ("success", "Profile updated successfully.") in res.flashes res = res.follow() - logged_user.reload() + backend.reload(logged_user) assert logged_user.photo == jpeg_photo @@ -86,7 +87,7 @@ def test_photo_on_profile_edition( assert ("success", "Profile updated successfully.") in res.flashes res = res.follow() - logged_user.reload() + backend.reload(logged_user) assert logged_user.photo is None @@ -99,7 +100,7 @@ def test_photo_on_profile_edition( assert ("success", "Profile updated successfully.") in res.flashes res = res.follow() - logged_user.reload() + backend.reload(logged_user) assert logged_user.photo is None diff --git a/tests/core/test_profile_settings.py b/tests/core/test_profile_settings.py index 60104a13..cb802149 100644 --- a/tests/core/test_profile_settings.py +++ b/tests/core/test_profile_settings.py @@ -21,12 +21,12 @@ def test_edition(testclient, logged_user, admin, foo_group, bar_group, backend): res.form["user_name"] = "toto" res = res.form.submit(name="action", value="edit-settings") assert res.flashes == [("error", "Profile edition failed.")] - logged_user.reload() + backend.reload(logged_user) assert logged_user.user_name == "user" - foo_group.reload() - bar_group.reload() + backend.reload(foo_group) + backend.reload(bar_group) assert logged_user.groups == [foo_group] assert foo_group.members == [logged_user] assert bar_group.members == [admin] @@ -41,7 +41,7 @@ def test_group_removal(testclient, logged_admin, user, foo_group, backend): """Tests that one can remove a group from a user.""" foo_group.members = [user, logged_admin] backend.save(foo_group) - user.reload() + backend.reload(user) assert foo_group in user.groups res = testclient.get("/profile/user/settings", status=200) @@ -49,11 +49,11 @@ def test_group_removal(testclient, logged_admin, user, foo_group, backend): res = res.form.submit(name="action", value="edit-settings") assert res.flashes == [("success", "Profile updated successfully.")] - user.reload() + backend.reload(user) assert foo_group not in user.groups - foo_group.reload() - logged_admin.reload() + backend.reload(foo_group) + backend.reload(logged_admin) assert foo_group.members == [logged_admin] @@ -76,7 +76,7 @@ def test_empty_group_removal(testclient, logged_admin, user, foo_group, backend) "The group 'foo' cannot be removed, because it must have at least one user left." ) - user.reload() + backend.reload(user) assert foo_group in user.groups @@ -109,7 +109,7 @@ def test_edition_without_groups( assert res.flashes == [("success", "Profile updated successfully.")] res = res.follow() - logged_user.reload() + backend.reload(logged_user) assert logged_user.user_name == "user" assert backend.check_user_password(logged_user, "correct horse battery staple")[0] @@ -126,7 +126,7 @@ def test_password_change(testclient, logged_user, backend): res = res.form.submit(name="action", value="edit-settings").follow() - logged_user.reload() + backend.reload(logged_user) assert backend.check_user_password(logged_user, "new_password")[0] res = testclient.get("/profile/user/settings", status=200) @@ -138,7 +138,7 @@ def test_password_change(testclient, logged_user, backend): assert ("success", "Profile updated successfully.") in res.flashes res = res.follow() - logged_user.reload() + backend.reload(logged_user) assert backend.check_user_password(logged_user, "correct horse battery staple")[0] @@ -150,7 +150,7 @@ def test_password_change_fail(testclient, logged_user, backend): res = res.form.submit(name="action", value="edit-settings", status=200) - logged_user.reload() + backend.reload(logged_user) assert backend.check_user_password(logged_user, "correct horse battery staple")[0] res = testclient.get("/profile/user/settings", status=200) @@ -160,7 +160,7 @@ def test_password_change_fail(testclient, logged_user, backend): res = res.form.submit(name="action", value="edit-settings", status=200) - logged_user.reload() + backend.reload(logged_user) assert backend.check_user_password(logged_user, "correct horse battery staple")[0] @@ -186,7 +186,7 @@ def test_password_initialization_mail(smtpd, testclient, backend, logged_admin): assert len(smtpd.messages) == 1 assert smtpd.messages[0]["X-RcptTo"] == "john@doe.com" - u.reload() + backend.reload(u) u.password = "correct horse battery staple" backend.save(u) @@ -357,13 +357,14 @@ def test_edition_permission( testclient, logged_user, admin, + backend, ): testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = [] - logged_user.reload() + backend.reload(logged_user) testclient.get("/profile/user/settings", status=404) testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = ["edit_self"] - g.user.reload() + backend.reload(g.user) testclient.get("/profile/user/settings", status=200) @@ -462,7 +463,7 @@ def test_empty_lock_date( assert res.flashes == [("success", "Profile updated successfully.")] res = res.follow() - user.reload() + backend.reload(user) assert not user.lock_date diff --git a/tests/oidc/test_account.py b/tests/oidc/test_account.py index 29685d05..dff28c66 100644 --- a/tests/oidc/test_account.py +++ b/tests/oidc/test_account.py @@ -1,7 +1,7 @@ from flask import g -def test_index(testclient, user): +def test_index(testclient, user, backend): res = testclient.get("/", status=302) assert res.location == "/login" @@ -10,11 +10,11 @@ def test_index(testclient, user): assert res.location == "/profile/user" testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = ["use_oidc"] - g.user.reload() + backend.reload(g.user) res = testclient.get("/", status=302) assert res.location == "/consent/" testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = [] - g.user.reload() + backend.reload(g.user) res = testclient.get("/", status=302) assert res.location == "/about" diff --git a/tests/oidc/test_authorization_code_flow.py b/tests/oidc/test_authorization_code_flow.py index 4a6e36be..0918a1ad 100644 --- a/tests/oidc/test_authorization_code_flow.py +++ b/tests/oidc/test_authorization_code_flow.py @@ -503,9 +503,11 @@ def test_when_consent_already_given_but_for_a_smaller_scope( backend.delete(consent) -def test_user_cannot_use_oidc(testclient, user, client, keypair, trusted_client): +def test_user_cannot_use_oidc( + testclient, user, client, keypair, trusted_client, backend +): testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = [] - user.reload() + backend.reload(user) res = testclient.get( "/oauth/authorize", diff --git a/tests/oidc/test_client_admin.py b/tests/oidc/test_client_admin.py index 7ae27c3e..f4b17500 100644 --- a/tests/oidc/test_client_admin.py +++ b/tests/oidc/test_client_admin.py @@ -147,7 +147,7 @@ def test_add_missing_fields(testclient, logged_admin): ) in res.flashes -def test_client_edit(testclient, client, logged_admin, trusted_client): +def test_client_edit(testclient, client, logged_admin, trusted_client, backend): res = testclient.get("/admin/client/edit/" + client.client_id) data = { "client_name": "foobar", @@ -179,7 +179,7 @@ def test_client_edit(testclient, client, logged_admin, trusted_client): ) not in res.flashes assert ("success", "The client has been edited.") in res.flashes - client.reload() + backend.reload(client) assert client.client_name == "foobar" assert client.contacts == ["foo@bar.com"] @@ -204,7 +204,9 @@ def test_client_edit(testclient, client, logged_admin, trusted_client): assert client.post_logout_redirect_uris == ["https://foo.bar/disconnected"] -def test_client_edit_missing_fields(testclient, client, logged_admin, trusted_client): +def test_client_edit_missing_fields( + testclient, client, logged_admin, trusted_client, backend +): res = testclient.get("/admin/client/edit/" + client.client_id) res.forms["clientaddform"]["client_name"] = "" res = res.forms["clientaddform"].submit(name="action", value="edit") @@ -212,7 +214,7 @@ def test_client_edit_missing_fields(testclient, client, logged_admin, trusted_cl "error", "The client has not been edited. Please check your information.", ) in res.flashes - client.reload() + backend.reload(client) assert client.client_name @@ -258,7 +260,7 @@ def test_client_delete_invalid_client(testclient, logged_admin, client): ) -def test_client_edit_preauth(testclient, client, logged_admin, trusted_client): +def test_client_edit_preauth(testclient, client, logged_admin, trusted_client, backend): assert not client.preconsent res = testclient.get("/admin/client/edit/" + client.client_id) @@ -266,7 +268,7 @@ def test_client_edit_preauth(testclient, client, logged_admin, trusted_client): res = res.forms["clientaddform"].submit(name="action", value="edit") assert ("success", "The client has been edited.") in res.flashes - client.reload() + backend.reload(client) assert client.preconsent res = testclient.get("/admin/client/edit/" + client.client_id) @@ -274,7 +276,7 @@ def test_client_edit_preauth(testclient, client, logged_admin, trusted_client): res = res.forms["clientaddform"].submit(name="action", value="edit") assert ("success", "The client has been edited.") in res.flashes - client.reload() + backend.reload(client) assert not client.preconsent diff --git a/tests/oidc/test_consent.py b/tests/oidc/test_consent.py index 8edc86ab..f3d1cce1 100644 --- a/tests/oidc/test_consent.py +++ b/tests/oidc/test_consent.py @@ -10,7 +10,7 @@ def test_no_logged_no_access(testclient): testclient.get("/consent", status=403) -def test_revokation(testclient, client, consent, logged_user, token): +def test_revokation(testclient, client, consent, logged_user, token, backend): res = testclient.get("/consent", status=200) res.mustcontain(client.client_name) res.mustcontain("Revoke access") @@ -24,13 +24,13 @@ def test_revokation(testclient, client, consent, logged_user, token): res.mustcontain(no="Revoke access") res.mustcontain("Restore access") - consent.reload() + backend.reload(consent) assert consent.revoked - token.reload() + backend.reload(token) assert token.revoked -def test_revokation_already_revoked(testclient, client, consent, logged_user): +def test_revokation_already_revoked(testclient, client, consent, logged_user, backend): consent.revoke() assert consent.revoked @@ -39,24 +39,24 @@ def test_revokation_already_revoked(testclient, client, consent, logged_user): assert ("error", "The access is already revoked") in res.flashes res = res.follow(status=200) - consent.reload() + backend.reload(consent) assert consent.revoked -def test_restoration(testclient, client, consent, logged_user, token): +def test_restoration(testclient, client, consent, logged_user, token, backend): consent.revoke() assert consent.revoked - token.reload() + backend.reload(token) assert token.revoked res = testclient.get(f"/consent/restore/{consent.consent_id}", status=302) assert ("success", "The access has been restored") in res.flashes res = res.follow(status=200) - consent.reload() + backend.reload(consent) assert not consent.revoked - token.reload() + backend.reload(token) assert token.revoked @@ -115,7 +115,7 @@ def test_oidc_authorization_after_revokation( res = res.form.submit(name="answer", value="accept", status=302) consents = backend.query(models.Consent, client=client, subject=logged_user) - consent.reload() + backend.reload(consent) assert consents[0] == consent assert not consent.revoked @@ -167,21 +167,21 @@ def test_revoke_preconsented_client(testclient, client, logged_user, token, back assert consent.subject == logged_user assert consent.scope == ["openid", "email", "profile", "groups", "address", "phone"] assert not consent.issue_date - token.reload() + backend.reload(token) assert token.revoked res = testclient.get(f"/consent/restore/{consent.consent_id}", status=302) assert ("success", "The access has been restored") in res.flashes - consent.reload() + backend.reload(consent) assert not consent.revoked assert consent.issue_date - token.reload() + backend.reload(token) assert token.revoked res = testclient.get(f"/consent/revoke/{consent.consent_id}", status=302) assert ("success", "The access has been revoked") in res.flashes - consent.reload() + backend.reload(consent) assert consent.revoked assert consent.issue_date diff --git a/tests/oidc/test_refresh_token.py b/tests/oidc/test_refresh_token.py index 88075631..6f693c0c 100644 --- a/tests/oidc/test_refresh_token.py +++ b/tests/oidc/test_refresh_token.py @@ -60,7 +60,7 @@ def test_refresh_token(testclient, logged_user, client, backend): assert new_token is not None assert old_token.access_token != new_token.access_token - old_token.reload() + backend.reload(old_token) assert old_token.revokation_date res = testclient.get( diff --git a/tests/oidc/test_token_admin.py b/tests/oidc/test_token_admin.py index ec005058..f25e00bf 100644 --- a/tests/oidc/test_token_admin.py +++ b/tests/oidc/test_token_admin.py @@ -134,7 +134,7 @@ def test_revoke_bad_request(testclient, token, logged_admin): res = res.form.submit(name="action", value="invalid", status=400) -def test_revoke_token(testclient, token, logged_admin): +def test_revoke_token(testclient, token, logged_admin, backend): assert not token.revoked res = testclient.get(f"/admin/token/{token.token_id}") @@ -142,7 +142,7 @@ def test_revoke_token(testclient, token, logged_admin): res = res.form.submit(name="action", value="revoke") assert ("success", "The token has successfully been revoked.") in res.flashes - token.reload() + backend.reload(token) assert token.revoked diff --git a/tests/oidc/test_token_revocation.py b/tests/oidc/test_token_revocation.py index 193613ae..70ea49cb 100644 --- a/tests/oidc/test_token_revocation.py +++ b/tests/oidc/test_token_revocation.py @@ -3,7 +3,7 @@ import datetime from . import client_credentials -def test_revoke_access_token(testclient, user, client, token): +def test_revoke_access_token(testclient, user, client, token, backend): assert not token.revokation_date res = testclient.post( @@ -14,11 +14,11 @@ def test_revoke_access_token(testclient, user, client, token): ) assert {} == res.json - token.reload() + backend.reload(token) assert token.revokation_date -def test_revoke_access_token_with_hint(testclient, user, client, token): +def test_revoke_access_token_with_hint(testclient, user, client, token, backend): assert not token.revokation_date res = testclient.post( @@ -29,11 +29,11 @@ def test_revoke_access_token_with_hint(testclient, user, client, token): ) assert {} == res.json - token.reload() + backend.reload(token) assert token.revokation_date -def test_revoke_refresh_token(testclient, user, client, token): +def test_revoke_refresh_token(testclient, user, client, token, backend): assert not token.revokation_date res = testclient.post( @@ -44,11 +44,11 @@ def test_revoke_refresh_token(testclient, user, client, token): ) assert {} == res.json - token.reload() + backend.reload(token) assert token.revokation_date -def test_revoke_refresh_token_with_hint(testclient, user, client, token): +def test_revoke_refresh_token_with_hint(testclient, user, client, token, backend): assert not token.revokation_date res = testclient.post( @@ -59,7 +59,7 @@ def test_revoke_refresh_token_with_hint(testclient, user, client, token): ) assert {} == res.json - token.reload() + backend.reload(token) assert token.revokation_date