refactor: move BackendModel.reload to Backend.reload

This commit is contained in:
Éloi Rivard 2024-04-14 22:51:58 +02:00
parent 2ccdaeadf6
commit 473a262ea2
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184
35 changed files with 207 additions and 175 deletions

View file

@ -93,6 +93,21 @@ class BaseBackend:
"""Remove the current instance from the database."""
raise NotImplementedError()
def reload(self, instance):
"""Cancel the unsaved modifications.
>>> user = User.get(user_name="george")
>>> user.display_name
George
>>> user.display_name = "Jane"
>>> user.display_name
Jane
>>> BaseBackend.instance.reload(user)
>>> user.display_name
George
"""
raise NotImplementedError()
def check_user_password(self, user, password: str) -> bool:
"""Check if the password matches the user password in the database."""
raise NotImplementedError()

View file

@ -401,6 +401,20 @@ class Backend(BaseBackend):
# run the instance delete callback again if existing
next(save_callback, None)
def reload(self, instance):
# run the instance reload callback if existing
reload_callback = instance.reload() if hasattr(instance, "reload") else iter([])
next(reload_callback, None)
result = self.connection.search_s(
instance.dn, ldap.SCOPE_SUBTREE, None, ["+", "*"]
)
instance.changes = {}
instance.state = result[0][1]
# run the instance reload callback again if existing
next(reload_callback, None)
def setup_ldap_models(config):
from canaille.app import models

View file

@ -256,9 +256,3 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
@classmethod
def python_attribute_to_ldap(cls, name):
return cls.attribute_map.get(name, name) if cls.attribute_map else None
def reload(self):
conn = Backend.instance.connection
result = conn.search_s(self.dn, ldap.SCOPE_SUBTREE, None, ["+", "*"])
self.changes = {}
self.state = result[0][1]

View file

@ -118,3 +118,16 @@ class Backend(BaseBackend):
# run the instance delete callback again if existing
next(delete_callback, None)
def reload(self, instance):
# run the instance reload callback if existing
reload_callback = instance.reload() if hasattr(instance, "reload") else iter([])
next(reload_callback, None)
instance._state = BaseBackend.instance.get(
instance.__class__, id=instance.id
)._state
instance._cache = {}
# run the instance reload callback again if existing
next(reload_callback, None)

View file

@ -134,10 +134,6 @@ class MemoryModel(BackendModel):
# update the id index
del self.index()[self.id]
def reload(self):
self._state = BaseBackend.instance.get(self.__class__, id=self.id)._state
self._cache = {}
def __eq__(self, other):
if other is None:
return False

View file

@ -105,21 +105,6 @@ class BackendModel:
for attribute, value in kwargs.items():
setattr(self, attribute, value)
def reload(self):
"""Cancel the unsaved modifications.
>>> user = User.get(user_name="george")
>>> user.display_name
George
>>> user.display_name = "Jane"
>>> user.display_name
Jane
>>> user.reload()
>>> user.display_name
George
"""
raise NotImplementedError()
@classmethod
def get_model_annotations(cls, attribute):
annotations = cls.attributes[attribute]

View file

@ -129,3 +129,13 @@ class Backend(BaseBackend):
# run the instance delete callback again if existing
next(save_callback, None)
def reload(self, instance):
# run the instance reload callback if existing
reload_callback = instance.reload() if hasattr(instance, "reload") else iter([])
next(reload_callback, None)
Backend.instance.db_session.refresh(instance)
# run the instance reload callback again if existing
next(reload_callback, None)

View file

@ -22,7 +22,6 @@ import canaille.core.models
import canaille.oidc.models
from canaille.backends.models import BackendModel
from .backend import Backend
from .backend import Base
from .utils import TZDateTime
@ -47,9 +46,6 @@ class SqlAlchemyModel(BackendModel):
return getattr(cls, name) == value
def reload(self):
Backend.instance.db_session.refresh(self)
membership_association_table = Table(
"membership_association_table",

View file

@ -537,7 +537,7 @@ def profile_edition_main_form_validation(user, edited_user, profile_form):
edited_user.preferred_language = None
BaseBackend.instance.save(edited_user)
g.user.reload()
BaseBackend.instance.reload(g.user)
def profile_edition_emails_form(user, edited_user, has_smtp):

View file

@ -296,7 +296,7 @@ class User(Model):
self._readable = None
self._writable = None
self._permissions = None
super().reload()
yield
@property
def readable_fields(self):

View file

@ -187,10 +187,10 @@ def test_datetime_utc_field_invalid_timezone(testclient):
)
def test_fieldlist_add_readonly(testclient, logged_user):
def test_fieldlist_add_readonly(testclient, logged_user, backend):
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["WRITE"].remove("phone_numbers")
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["READ"].append("phone_numbers")
logged_user.reload()
backend.reload(logged_user)
res = testclient.get("/profile/user")
form = res.forms["baseform"]
@ -209,7 +209,7 @@ def test_fieldlist_add_readonly(testclient, logged_user):
def test_fieldlist_remove_readonly(testclient, logged_user, backend):
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["WRITE"].remove("phone_numbers")
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["READ"].append("phone_numbers")
logged_user.reload()
backend.reload(logged_user)
logged_user.phone_numbers = ["555-555-000", "555-555-111"]
backend.save(logged_user)

View file

@ -17,7 +17,7 @@ def test_preferred_language(testclient, logged_user, backend):
assert res.flashes == [("success", "Le profil a été mis à jour avec succès.")]
res = res.follow()
form = res.forms["baseform"]
logged_user.reload()
backend.reload(logged_user)
assert logged_user.preferred_language == "fr"
assert form["preferred_language"].value == "fr"
assert res.pyquery("html")[0].attrib["lang"] == "fr"
@ -29,7 +29,7 @@ def test_preferred_language(testclient, logged_user, backend):
assert res.flashes == [("success", "Profile updated successfully.")]
res = res.follow()
form = res.forms["baseform"]
logged_user.reload()
backend.reload(logged_user)
assert logged_user.preferred_language == "en"
assert form["preferred_language"].value == "en"
assert res.pyquery("html")[0].attrib["lang"] == "en"
@ -41,7 +41,7 @@ def test_preferred_language(testclient, logged_user, backend):
assert res.flashes == [("success", "Profile updated successfully.")]
res = res.follow()
form = res.forms["baseform"]
logged_user.reload()
backend.reload(logged_user)
assert logged_user.preferred_language is None
assert form["preferred_language"].value == "auto"
assert res.pyquery("html")[0].attrib["lang"] == "en"

View file

@ -8,7 +8,7 @@ def test_model_references_set_unsaved_object(
exist."""
group = models.Group(members=[user], display_name="foo")
backend.save(group)
user.reload()
backend.reload(user)
non_existent_user = models.User(
formatted_name="foo", family_name="bar", user_name="baz"
@ -19,7 +19,7 @@ def test_model_references_set_unsaved_object(
backend.save(group)
assert group.members == [user, non_existent_user]
group.reload()
backend.reload(group)
assert group.members == [user]
testclient.get("/groups/foo", status=200)

View file

@ -92,7 +92,7 @@ homeDirectory: /home/foobar
process = slapd_server.ldapmodify(ldif)
assert process.returncode == 0
user.reload()
backend.reload(user)
# saving an object should not raise a ldap.OBJECT_CLASS_VIOLATION exception
backend.save(user)

View file

@ -1,20 +1,20 @@
def test_group_permissions_by_dn(testclient, user, foo_group):
def test_group_permissions_by_dn(testclient, user, foo_group, backend):
assert not user.can_manage_users
testclient.app.config["CANAILLE"]["ACL"]["ADMIN"]["FILTER"] = {
"groups": foo_group.dn
}
user.reload()
backend.reload(user)
assert user.can_manage_users
def test_group_permissions_str(testclient, user, foo_group):
def test_group_permissions_str(testclient, user, foo_group, backend):
assert not user.can_manage_users
testclient.app.config["CANAILLE"]["ACL"]["ADMIN"]["FILTER"] = (
f"memberOf={foo_group.dn}"
)
user.reload()
backend.reload(user)
assert user.can_manage_users

View file

@ -52,7 +52,7 @@ def test_model_lifecycle(testclient, backend):
assert user.family_name == "new_family_name"
user.reload()
backend.reload(user)
assert user.family_name == "family_name"
@ -178,14 +178,14 @@ def test_model_references(testclient, user, foo_group, admin, bar_group, backend
assert bar_group not in user.groups
user.groups = user.groups + [bar_group]
backend.save(user)
bar_group.reload()
backend.reload(bar_group)
assert user in bar_group.members
assert bar_group in user.groups
bar_group.members = [admin]
backend.save(bar_group)
user.reload()
backend.reload(user)
assert user not in bar_group.members
assert bar_group not in user.groups

View file

@ -252,7 +252,7 @@ def foo_group(app, user, backend):
display_name="foo",
)
backend.save(group)
user.reload()
backend.reload(user)
yield group
backend.delete(group)
@ -264,7 +264,7 @@ def bar_group(app, admin, backend):
display_name="bar",
)
backend.save(group)
admin.reload()
backend.reload(admin)
yield group
backend.delete(group)

View file

@ -5,7 +5,7 @@ from flask import g
from canaille.app import models
def test_index(testclient, user):
def test_index(testclient, user, backend):
res = testclient.get("/", status=302)
assert res.location == "/login"
@ -14,7 +14,7 @@ def test_index(testclient, user):
assert res.location == "/profile/user"
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = []
g.user.reload()
backend.reload(g.user)
res = testclient.get("/", status=302)
assert res.location == "/about"
@ -105,7 +105,7 @@ def test_user_self_deletion(testclient, backend):
"delete_account",
]
# Simulate an app restart
user.reload()
backend.reload(user)
res = testclient.get("/profile/temp/settings")
res.mustcontain("Delete my account")

View file

@ -25,7 +25,7 @@ def test_confirmation_disabled_email_editable(testclient, backend, logged_user):
assert res.flashes == [("success", "Profile updated successfully.")]
res = res.follow()
logged_user.reload()
backend.reload(logged_user)
assert logged_user.emails == ["email1@mydomain.tld", "email2@mydomain.tld"]
@ -51,7 +51,7 @@ def test_confirmation_unset_smtp_disabled_email_editable(
assert res.flashes == [("success", "Profile updated successfully.")]
res = res.follow()
user.reload()
backend.reload(user)
assert user.emails == ["email1@mydomain.tld", "email2@mydomain.tld"]
@ -91,7 +91,7 @@ def test_confirmation_unset_smtp_enabled_email_admin_editable(
assert res.flashes == [("success", "Profile updated successfully.")]
res = res.follow()
user.reload()
backend.reload(user)
assert user.emails == ["email1@mydomain.tld", "email2@mydomain.tld"]
@ -115,7 +115,7 @@ def test_confirmation_enabled_smtp_disabled_admin_editable(
assert res.flashes == [("success", "Profile updated successfully.")]
res = res.follow()
user.reload()
backend.reload(user)
assert user.emails == ["email1@mydomain.tld", "email2@mydomain.tld"]
@ -172,7 +172,7 @@ def test_confirmation_unset_smtp_enabled_email_user_validation(
res = testclient.get(email_confirmation_url)
assert ("success", "Your email address have been confirmed.") in res.flashes
user.reload()
backend.reload(user)
assert "new_email@mydomain.tld" in user.emails
@ -206,7 +206,7 @@ def test_confirmation_mail_form_failed(testclient, backend, user):
)
assert res.flashes == [("error", "Email addition failed.")]
user.reload()
backend.reload(user)
assert user.emails == ["john@doe.com"]
@ -233,7 +233,7 @@ def test_confirmation_mail_send_failed(SMTP, smtpd, testclient, backend, user):
)
assert res.flashes == [("error", "Could not send the verification email")]
user.reload()
backend.reload(user)
assert user.emails == ["john@doe.com"]
@ -258,7 +258,7 @@ def test_confirmation_expired_link(testclient, backend, user):
"error",
"The email confirmation link that brought you here has expired.",
) in res.flashes
user.reload()
backend.reload(user)
assert "new_email@mydomain.tld" not in user.emails
@ -283,7 +283,7 @@ def test_confirmation_invalid_hash_link(testclient, backend, user):
"error",
"The invitation link that brought you here was invalid.",
) in res.flashes
user.reload()
backend.reload(user)
assert "new_email@mydomain.tld" not in user.emails
@ -312,7 +312,7 @@ def test_confirmation_invalid_user_link(testclient, backend, user):
"error",
"The email confirmation link that brought you here is invalid.",
) in res.flashes
user.reload()
backend.reload(user)
assert "new_email@mydomain.tld" not in user.emails
@ -337,7 +337,7 @@ def test_confirmation_email_already_confirmed_link(testclient, backend, user, ad
"error",
"This address email have already been confirmed.",
) in res.flashes
user.reload()
backend.reload(user)
assert "new_email@mydomain.tld" not in user.emails
@ -367,7 +367,7 @@ def test_confirmation_email_already_used_link(testclient, backend, user, admin):
"error",
"This address email is already associated with another account.",
) in res.flashes
user.reload()
backend.reload(user)
assert "new_email@mydomain.tld" not in user.emails
@ -387,7 +387,7 @@ def test_delete_email(testclient, logged_user, backend):
)
assert res.flashes == [("success", "The email have been successfully deleted.")]
logged_user.reload()
backend.reload(logged_user)
assert logged_user.emails == ["john@doe.com"]
@ -408,7 +408,7 @@ def test_delete_wrong_email(testclient, logged_user, backend):
)
assert res2.flashes == [("error", "Email deletion failed.")]
logged_user.reload()
backend.reload(logged_user)
assert logged_user.emails == ["john@doe.com"]
@ -429,11 +429,11 @@ def test_delete_last_email(testclient, logged_user, backend):
)
assert res2.flashes == [("error", "Email deletion failed.")]
logged_user.reload()
backend.reload(logged_user)
assert logged_user.emails == ["john@doe.com"]
def test_edition_forced_mail(testclient, logged_user):
def test_edition_forced_mail(testclient, logged_user, backend):
"""Tests that users that must perform email verification cannot force the
profile form."""
res = testclient.get("/profile/user", status=200)
@ -447,7 +447,7 @@ def test_edition_forced_mail(testclient, logged_user):
},
)
logged_user.reload()
backend.reload(logged_user)
assert logged_user.emails == ["john@doe.com"]

View file

@ -83,9 +83,11 @@ def test_password_forgotten_invalid(smtpd, testclient, user):
assert len(smtpd.messages) == 0
def test_password_forgotten_invalid_when_user_cannot_self_edit(smtpd, testclient, user):
def test_password_forgotten_invalid_when_user_cannot_self_edit(
smtpd, testclient, user, backend
):
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = []
user.reload()
backend.reload(user)
testclient.app.config["CANAILLE"]["HIDE_INVALID_LOGINS"] = False
res = testclient.get("/reset", status=200)
@ -104,7 +106,7 @@ def test_password_forgotten_invalid_when_user_cannot_self_edit(smtpd, testclient
) in res.flashes
testclient.app.config["CANAILLE"]["HIDE_INVALID_LOGINS"] = True
user.reload()
backend.reload(user)
res = testclient.get("/reset", status=200)
res.form["login"] = "user"

View file

@ -61,11 +61,11 @@ def test_group_deletion(testclient, backend):
)
backend.save(group)
user.reload()
backend.reload(user)
assert user.groups == [group]
backend.delete(group)
user.reload()
backend.reload(user)
assert not user.groups
backend.delete(user)
@ -93,15 +93,15 @@ def test_set_groups(app, user, foo_group, bar_group, backend):
user.groups = [foo_group, bar_group]
backend.save(user)
bar_group.reload()
backend.reload(bar_group)
assert user in bar_group.members
assert bar_group in user.groups
user.groups = [foo_group]
backend.save(user)
foo_group.reload()
bar_group.reload()
backend.reload(foo_group)
backend.reload(bar_group)
assert user in foo_group.members
assert user not in bar_group.members
@ -118,13 +118,13 @@ def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group, back
user.groups = [foo_group]
backend.save(user)
foo_group.reload()
backend.reload(foo_group)
assert user in foo_group.members
user.groups = []
backend.save(user)
foo_group.reload()
backend.reload(foo_group)
assert user.id not in foo_group.members
backend.delete(user)
@ -149,7 +149,7 @@ def test_moderator_can_create_edit_and_delete_group(
# Group has been created
res = form.submit(status=302).follow(status=200)
logged_moderator.reload()
backend.reload(logged_moderator)
bar_group = backend.get(models.Group, display_name="bar")
assert bar_group.display_name == "bar"
assert bar_group.description == "yolo"
@ -221,13 +221,13 @@ def test_invalid_form_request(testclient, logged_moderator, foo_group):
res = form.submit(name="action", value="invalid-action", status=400)
def test_edition_failed(testclient, logged_moderator, foo_group):
def test_edition_failed(testclient, logged_moderator, foo_group, backend):
res = testclient.get("/groups/foo")
form = res.forms["editgroupform"]
form["display_name"] = ""
res = form.submit(name="action", value="edit")
res.mustcontain("Group edition failed.")
foo_group.reload()
backend.reload(foo_group)
assert foo_group.display_name == "foo"

View file

@ -47,7 +47,7 @@ def test_invitation(testclient, logged_admin, foo_group, smtpd, backend):
res = res.follow(status=200)
user = backend.get(models.User, user_name="someone")
foo_group.reload()
backend.reload(foo_group)
assert backend.check_user_password(user, "whatever")[0]
assert user.groups == [foo_group]
@ -103,7 +103,7 @@ def test_invitation_editable_user_name(
res = res.follow(status=200)
user = backend.get(models.User, user_name="djorje")
foo_group.reload()
backend.reload(foo_group)
assert backend.check_user_password(user, "whatever")[0]
assert user.groups == [foo_group]
@ -150,7 +150,7 @@ def test_generate_link(testclient, logged_admin, foo_group, smtpd, backend):
res = res.follow(status=200)
user = backend.get(models.User, user_name="sometwo")
foo_group.reload()
backend.reload(foo_group)
assert backend.check_user_password(user, "whatever")[0]
assert user.groups == [foo_group]
@ -332,6 +332,6 @@ def test_groups_are_saved_even_when_user_does_not_have_read_permission(
res = res.follow(status=200)
user = backend.get(models.User, user_name="someoneelse")
foo_group.reload()
backend.reload(foo_group)
assert user.groups == [foo_group]
backend.delete(user)

View file

@ -13,7 +13,7 @@ def test_password_reset(testclient, user, backend):
assert ("success", "Your password has been updated successfully") in res.flashes
assert res.location == "/profile/user"
user.reload()
backend.reload(user)
assert backend.check_user_password(user, "foobarbaz")[0]
res = testclient.get("/reset/user/" + hash)
@ -37,7 +37,7 @@ def test_password_reset_multiple_emails(testclient, user, backend):
res = res.form.submit()
assert ("success", "Your password has been updated successfully") in res.flashes
user.reload()
backend.reload(user)
assert backend.check_user_password(user, "foobarbaz")[0]
res = testclient.get("/reset/user/" + hash)

View file

@ -1,29 +1,29 @@
def test_group_permissions_by_id(testclient, user, foo_group):
def test_group_permissions_by_id(testclient, user, foo_group, backend):
assert not user.can_manage_users
testclient.app.config["CANAILLE"]["ACL"]["ADMIN"]["FILTER"] = {
"groups": foo_group.id
}
user.reload()
backend.reload(user)
assert user.can_manage_users
def test_group_permissions_by_display_name(testclient, user, foo_group):
def test_group_permissions_by_display_name(testclient, user, foo_group, backend):
assert not user.can_manage_users
testclient.app.config["CANAILLE"]["ACL"]["ADMIN"]["FILTER"] = {
"groups": foo_group.display_name
}
user.reload()
backend.reload(user)
assert user.can_manage_users
def test_invalid_group_permission(testclient, user, foo_group):
def test_invalid_group_permission(testclient, user, foo_group, backend):
assert not user.can_manage_users
testclient.app.config["CANAILLE"]["ACL"]["ADMIN"]["FILTER"] = {"groups": "invalid"}
user.reload()
backend.reload(user)
assert not user.can_manage_users

View file

@ -25,7 +25,7 @@ def test_user_creation_edition_and_deletion(
assert ("success", "User account creation succeed.") in res.flashes
res = res.follow(status=200)
george = backend.get(models.User, user_name="george")
foo_group.reload()
backend.reload(foo_group)
assert "George" == george.given_name
assert george.groups == [foo_group]
assert backend.check_user_password(george, "totoyolo")[0]
@ -49,8 +49,8 @@ def test_user_creation_edition_and_deletion(
assert "Georgio" == george.given_name
assert backend.check_user_password(george, "totoyolo")[0]
foo_group.reload()
bar_group.reload()
backend.reload(foo_group)
backend.reload(bar_group)
assert george in set(foo_group.members)
assert george in set(bar_group.members)
assert set(george.groups) == {foo_group, bar_group}

View file

@ -71,7 +71,7 @@ def test_user_list_search(testclient, logged_admin, user, moderator):
def test_user_list_search_only_allowed_fields(
testclient, logged_admin, user, moderator
testclient, logged_admin, user, moderator, backend
):
res = testclient.get("/users")
res.mustcontain("3 items")
@ -88,7 +88,7 @@ def test_user_list_search_only_allowed_fields(
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["READ"].remove("user_name")
testclient.app.config["CANAILLE"]["ACL"]["ADMIN"]["READ"].remove("user_name")
g.user.reload()
backend.reload(g.user)
form = res.forms["search"]
form["query"] = "user"
@ -103,13 +103,14 @@ def test_edition_permission(
testclient,
logged_user,
admin,
backend,
):
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = []
logged_user.reload()
backend.reload(logged_user)
testclient.get("/profile/user", status=404)
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = ["edit_self"]
g.user.reload()
backend.reload(g.user)
testclient.get("/profile/user", status=200)
@ -145,7 +146,7 @@ def test_edition(
], res.text
res = res.follow()
logged_user.reload()
backend.reload(logged_user)
assert logged_user.given_name == "given_name"
assert logged_user.family_name == "family_name"
@ -187,7 +188,7 @@ def test_edition_remove_fields(
assert res.flashes == [("success", "Profile updated successfully.")], res.text
res = res.follow()
logged_user.reload()
backend.reload(logged_user)
assert not logged_user.display_name
assert not logged_user.phone_numbers
@ -212,7 +213,7 @@ def test_field_permissions_none(testclient, logged_user, backend):
"FILTER": None,
}
g.user.reload()
backend.reload(g.user)
res = testclient.get("/profile/user", status=200)
form = res.forms["baseform"]
assert "phone_numbers-0" not in form.fields
@ -225,7 +226,7 @@ def test_field_permissions_none(testclient, logged_user, backend):
"csrf_token": form["csrf_token"].value,
},
)
logged_user.reload()
backend.reload(logged_user)
assert logged_user.phone_numbers == ["555-666-777"]
@ -240,7 +241,7 @@ def test_field_permissions_read(testclient, logged_user, backend):
"PERMISSIONS": ["edit_self"],
"FILTER": None,
}
g.user.reload()
backend.reload(g.user)
res = testclient.get("/profile/user", status=200)
form = res.forms["baseform"]
assert "phone_numbers-0" in form.fields
@ -253,7 +254,7 @@ def test_field_permissions_read(testclient, logged_user, backend):
"csrf_token": form["csrf_token"].value,
},
)
logged_user.reload()
backend.reload(logged_user)
assert logged_user.phone_numbers == ["555-666-777"]
@ -268,7 +269,7 @@ def test_field_permissions_write(testclient, logged_user, backend):
"PERMISSIONS": ["edit_self"],
"FILTER": None,
}
g.user.reload()
backend.reload(g.user)
res = testclient.get("/profile/user", status=200)
form = res.forms["baseform"]
assert "phone_numbers-0" in form.fields
@ -281,7 +282,7 @@ def test_field_permissions_write(testclient, logged_user, backend):
"csrf_token": form["csrf_token"].value,
},
)
logged_user.reload()
backend.reload(logged_user)
assert logged_user.phone_numbers == ["000-000-000"]
@ -306,7 +307,7 @@ def test_admin_bad_request(testclient, logged_moderator):
testclient.get("/profile/foobar", status=404)
def test_bad_email(testclient, logged_user):
def test_bad_email(testclient, logged_user, backend):
res = testclient.get("/profile/user", status=200)
form = res.forms["baseform"]
@ -323,12 +324,12 @@ def test_bad_email(testclient, logged_user):
res = form.submit(name="action", value="edit-profile", status=200)
logged_user.reload()
backend.reload(logged_user)
assert ["john@doe.com"] == logged_user.emails
def test_surname_is_mandatory(testclient, logged_user):
def test_surname_is_mandatory(testclient, logged_user, backend):
res = testclient.get("/profile/user", status=200)
form = res.forms["baseform"]
logged_user.family_name = "Doe"
@ -337,7 +338,7 @@ def test_surname_is_mandatory(testclient, logged_user):
res = form.submit(name="action", value="edit-profile", status=200)
logged_user.reload()
backend.reload(logged_user)
assert "Doe" == logged_user.family_name
@ -391,12 +392,12 @@ def test_inline_validation(testclient, logged_admin, user):
res.mustcontain("The email 'john@doe.com' is already used")
def test_inline_validation_keep_indicators(testclient, logged_admin, user):
def test_inline_validation_keep_indicators(testclient, logged_admin, user, backend):
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["WRITE"].remove("display_name")
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["READ"].append("display_name")
testclient.app.config["CANAILLE"]["ACL"]["ADMIN"]["WRITE"].append("display_name")
logged_admin.reload()
user.reload()
backend.reload(logged_admin)
backend.reload(user)
res = testclient.get("/profile/admin")
form = res.forms["baseform"]

View file

@ -8,7 +8,7 @@ from canaille.app import models
def test_photo(testclient, user, jpeg_photo, backend):
user.photo = jpeg_photo
backend.save(user)
user.reload()
backend.reload(user)
res = testclient.get("/profile/user/photo")
assert res.body == jpeg_photo
@ -52,6 +52,7 @@ def test_photo_on_profile_edition(
testclient,
logged_user,
jpeg_photo,
backend,
):
# Add a photo
res = testclient.get("/profile/user", status=200)
@ -62,7 +63,7 @@ def test_photo_on_profile_edition(
assert ("success", "Profile updated successfully.") in res.flashes
res = res.follow()
logged_user.reload()
backend.reload(logged_user)
assert logged_user.photo == jpeg_photo
@ -74,7 +75,7 @@ def test_photo_on_profile_edition(
assert ("success", "Profile updated successfully.") in res.flashes
res = res.follow()
logged_user.reload()
backend.reload(logged_user)
assert logged_user.photo == jpeg_photo
@ -86,7 +87,7 @@ def test_photo_on_profile_edition(
assert ("success", "Profile updated successfully.") in res.flashes
res = res.follow()
logged_user.reload()
backend.reload(logged_user)
assert logged_user.photo is None
@ -99,7 +100,7 @@ def test_photo_on_profile_edition(
assert ("success", "Profile updated successfully.") in res.flashes
res = res.follow()
logged_user.reload()
backend.reload(logged_user)
assert logged_user.photo is None

View file

@ -21,12 +21,12 @@ def test_edition(testclient, logged_user, admin, foo_group, bar_group, backend):
res.form["user_name"] = "toto"
res = res.form.submit(name="action", value="edit-settings")
assert res.flashes == [("error", "Profile edition failed.")]
logged_user.reload()
backend.reload(logged_user)
assert logged_user.user_name == "user"
foo_group.reload()
bar_group.reload()
backend.reload(foo_group)
backend.reload(bar_group)
assert logged_user.groups == [foo_group]
assert foo_group.members == [logged_user]
assert bar_group.members == [admin]
@ -41,7 +41,7 @@ def test_group_removal(testclient, logged_admin, user, foo_group, backend):
"""Tests that one can remove a group from a user."""
foo_group.members = [user, logged_admin]
backend.save(foo_group)
user.reload()
backend.reload(user)
assert foo_group in user.groups
res = testclient.get("/profile/user/settings", status=200)
@ -49,11 +49,11 @@ def test_group_removal(testclient, logged_admin, user, foo_group, backend):
res = res.form.submit(name="action", value="edit-settings")
assert res.flashes == [("success", "Profile updated successfully.")]
user.reload()
backend.reload(user)
assert foo_group not in user.groups
foo_group.reload()
logged_admin.reload()
backend.reload(foo_group)
backend.reload(logged_admin)
assert foo_group.members == [logged_admin]
@ -76,7 +76,7 @@ def test_empty_group_removal(testclient, logged_admin, user, foo_group, backend)
"The group 'foo' cannot be removed, because it must have at least one user left."
)
user.reload()
backend.reload(user)
assert foo_group in user.groups
@ -109,7 +109,7 @@ def test_edition_without_groups(
assert res.flashes == [("success", "Profile updated successfully.")]
res = res.follow()
logged_user.reload()
backend.reload(logged_user)
assert logged_user.user_name == "user"
assert backend.check_user_password(logged_user, "correct horse battery staple")[0]
@ -126,7 +126,7 @@ def test_password_change(testclient, logged_user, backend):
res = res.form.submit(name="action", value="edit-settings").follow()
logged_user.reload()
backend.reload(logged_user)
assert backend.check_user_password(logged_user, "new_password")[0]
res = testclient.get("/profile/user/settings", status=200)
@ -138,7 +138,7 @@ def test_password_change(testclient, logged_user, backend):
assert ("success", "Profile updated successfully.") in res.flashes
res = res.follow()
logged_user.reload()
backend.reload(logged_user)
assert backend.check_user_password(logged_user, "correct horse battery staple")[0]
@ -150,7 +150,7 @@ def test_password_change_fail(testclient, logged_user, backend):
res = res.form.submit(name="action", value="edit-settings", status=200)
logged_user.reload()
backend.reload(logged_user)
assert backend.check_user_password(logged_user, "correct horse battery staple")[0]
res = testclient.get("/profile/user/settings", status=200)
@ -160,7 +160,7 @@ def test_password_change_fail(testclient, logged_user, backend):
res = res.form.submit(name="action", value="edit-settings", status=200)
logged_user.reload()
backend.reload(logged_user)
assert backend.check_user_password(logged_user, "correct horse battery staple")[0]
@ -186,7 +186,7 @@ def test_password_initialization_mail(smtpd, testclient, backend, logged_admin):
assert len(smtpd.messages) == 1
assert smtpd.messages[0]["X-RcptTo"] == "john@doe.com"
u.reload()
backend.reload(u)
u.password = "correct horse battery staple"
backend.save(u)
@ -357,13 +357,14 @@ def test_edition_permission(
testclient,
logged_user,
admin,
backend,
):
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = []
logged_user.reload()
backend.reload(logged_user)
testclient.get("/profile/user/settings", status=404)
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = ["edit_self"]
g.user.reload()
backend.reload(g.user)
testclient.get("/profile/user/settings", status=200)
@ -462,7 +463,7 @@ def test_empty_lock_date(
assert res.flashes == [("success", "Profile updated successfully.")]
res = res.follow()
user.reload()
backend.reload(user)
assert not user.lock_date

View file

@ -1,7 +1,7 @@
from flask import g
def test_index(testclient, user):
def test_index(testclient, user, backend):
res = testclient.get("/", status=302)
assert res.location == "/login"
@ -10,11 +10,11 @@ def test_index(testclient, user):
assert res.location == "/profile/user"
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = ["use_oidc"]
g.user.reload()
backend.reload(g.user)
res = testclient.get("/", status=302)
assert res.location == "/consent/"
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = []
g.user.reload()
backend.reload(g.user)
res = testclient.get("/", status=302)
assert res.location == "/about"

View file

@ -503,9 +503,11 @@ def test_when_consent_already_given_but_for_a_smaller_scope(
backend.delete(consent)
def test_user_cannot_use_oidc(testclient, user, client, keypair, trusted_client):
def test_user_cannot_use_oidc(
testclient, user, client, keypair, trusted_client, backend
):
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["PERMISSIONS"] = []
user.reload()
backend.reload(user)
res = testclient.get(
"/oauth/authorize",

View file

@ -147,7 +147,7 @@ def test_add_missing_fields(testclient, logged_admin):
) in res.flashes
def test_client_edit(testclient, client, logged_admin, trusted_client):
def test_client_edit(testclient, client, logged_admin, trusted_client, backend):
res = testclient.get("/admin/client/edit/" + client.client_id)
data = {
"client_name": "foobar",
@ -179,7 +179,7 @@ def test_client_edit(testclient, client, logged_admin, trusted_client):
) not in res.flashes
assert ("success", "The client has been edited.") in res.flashes
client.reload()
backend.reload(client)
assert client.client_name == "foobar"
assert client.contacts == ["foo@bar.com"]
@ -204,7 +204,9 @@ def test_client_edit(testclient, client, logged_admin, trusted_client):
assert client.post_logout_redirect_uris == ["https://foo.bar/disconnected"]
def test_client_edit_missing_fields(testclient, client, logged_admin, trusted_client):
def test_client_edit_missing_fields(
testclient, client, logged_admin, trusted_client, backend
):
res = testclient.get("/admin/client/edit/" + client.client_id)
res.forms["clientaddform"]["client_name"] = ""
res = res.forms["clientaddform"].submit(name="action", value="edit")
@ -212,7 +214,7 @@ def test_client_edit_missing_fields(testclient, client, logged_admin, trusted_cl
"error",
"The client has not been edited. Please check your information.",
) in res.flashes
client.reload()
backend.reload(client)
assert client.client_name
@ -258,7 +260,7 @@ def test_client_delete_invalid_client(testclient, logged_admin, client):
)
def test_client_edit_preauth(testclient, client, logged_admin, trusted_client):
def test_client_edit_preauth(testclient, client, logged_admin, trusted_client, backend):
assert not client.preconsent
res = testclient.get("/admin/client/edit/" + client.client_id)
@ -266,7 +268,7 @@ def test_client_edit_preauth(testclient, client, logged_admin, trusted_client):
res = res.forms["clientaddform"].submit(name="action", value="edit")
assert ("success", "The client has been edited.") in res.flashes
client.reload()
backend.reload(client)
assert client.preconsent
res = testclient.get("/admin/client/edit/" + client.client_id)
@ -274,7 +276,7 @@ def test_client_edit_preauth(testclient, client, logged_admin, trusted_client):
res = res.forms["clientaddform"].submit(name="action", value="edit")
assert ("success", "The client has been edited.") in res.flashes
client.reload()
backend.reload(client)
assert not client.preconsent

View file

@ -10,7 +10,7 @@ def test_no_logged_no_access(testclient):
testclient.get("/consent", status=403)
def test_revokation(testclient, client, consent, logged_user, token):
def test_revokation(testclient, client, consent, logged_user, token, backend):
res = testclient.get("/consent", status=200)
res.mustcontain(client.client_name)
res.mustcontain("Revoke access")
@ -24,13 +24,13 @@ def test_revokation(testclient, client, consent, logged_user, token):
res.mustcontain(no="Revoke access")
res.mustcontain("Restore access")
consent.reload()
backend.reload(consent)
assert consent.revoked
token.reload()
backend.reload(token)
assert token.revoked
def test_revokation_already_revoked(testclient, client, consent, logged_user):
def test_revokation_already_revoked(testclient, client, consent, logged_user, backend):
consent.revoke()
assert consent.revoked
@ -39,24 +39,24 @@ def test_revokation_already_revoked(testclient, client, consent, logged_user):
assert ("error", "The access is already revoked") in res.flashes
res = res.follow(status=200)
consent.reload()
backend.reload(consent)
assert consent.revoked
def test_restoration(testclient, client, consent, logged_user, token):
def test_restoration(testclient, client, consent, logged_user, token, backend):
consent.revoke()
assert consent.revoked
token.reload()
backend.reload(token)
assert token.revoked
res = testclient.get(f"/consent/restore/{consent.consent_id}", status=302)
assert ("success", "The access has been restored") in res.flashes
res = res.follow(status=200)
consent.reload()
backend.reload(consent)
assert not consent.revoked
token.reload()
backend.reload(token)
assert token.revoked
@ -115,7 +115,7 @@ def test_oidc_authorization_after_revokation(
res = res.form.submit(name="answer", value="accept", status=302)
consents = backend.query(models.Consent, client=client, subject=logged_user)
consent.reload()
backend.reload(consent)
assert consents[0] == consent
assert not consent.revoked
@ -167,21 +167,21 @@ def test_revoke_preconsented_client(testclient, client, logged_user, token, back
assert consent.subject == logged_user
assert consent.scope == ["openid", "email", "profile", "groups", "address", "phone"]
assert not consent.issue_date
token.reload()
backend.reload(token)
assert token.revoked
res = testclient.get(f"/consent/restore/{consent.consent_id}", status=302)
assert ("success", "The access has been restored") in res.flashes
consent.reload()
backend.reload(consent)
assert not consent.revoked
assert consent.issue_date
token.reload()
backend.reload(token)
assert token.revoked
res = testclient.get(f"/consent/revoke/{consent.consent_id}", status=302)
assert ("success", "The access has been revoked") in res.flashes
consent.reload()
backend.reload(consent)
assert consent.revoked
assert consent.issue_date

View file

@ -60,7 +60,7 @@ def test_refresh_token(testclient, logged_user, client, backend):
assert new_token is not None
assert old_token.access_token != new_token.access_token
old_token.reload()
backend.reload(old_token)
assert old_token.revokation_date
res = testclient.get(

View file

@ -134,7 +134,7 @@ def test_revoke_bad_request(testclient, token, logged_admin):
res = res.form.submit(name="action", value="invalid", status=400)
def test_revoke_token(testclient, token, logged_admin):
def test_revoke_token(testclient, token, logged_admin, backend):
assert not token.revoked
res = testclient.get(f"/admin/token/{token.token_id}")
@ -142,7 +142,7 @@ def test_revoke_token(testclient, token, logged_admin):
res = res.form.submit(name="action", value="revoke")
assert ("success", "The token has successfully been revoked.") in res.flashes
token.reload()
backend.reload(token)
assert token.revoked

View file

@ -3,7 +3,7 @@ import datetime
from . import client_credentials
def test_revoke_access_token(testclient, user, client, token):
def test_revoke_access_token(testclient, user, client, token, backend):
assert not token.revokation_date
res = testclient.post(
@ -14,11 +14,11 @@ def test_revoke_access_token(testclient, user, client, token):
)
assert {} == res.json
token.reload()
backend.reload(token)
assert token.revokation_date
def test_revoke_access_token_with_hint(testclient, user, client, token):
def test_revoke_access_token_with_hint(testclient, user, client, token, backend):
assert not token.revokation_date
res = testclient.post(
@ -29,11 +29,11 @@ def test_revoke_access_token_with_hint(testclient, user, client, token):
)
assert {} == res.json
token.reload()
backend.reload(token)
assert token.revokation_date
def test_revoke_refresh_token(testclient, user, client, token):
def test_revoke_refresh_token(testclient, user, client, token, backend):
assert not token.revokation_date
res = testclient.post(
@ -44,11 +44,11 @@ def test_revoke_refresh_token(testclient, user, client, token):
)
assert {} == res.json
token.reload()
backend.reload(token)
assert token.revokation_date
def test_revoke_refresh_token_with_hint(testclient, user, client, token):
def test_revoke_refresh_token_with_hint(testclient, user, client, token, backend):
assert not token.revokation_date
res = testclient.post(
@ -59,7 +59,7 @@ def test_revoke_refresh_token_with_hint(testclient, user, client, token):
)
assert {} == res.json
token.reload()
backend.reload(token)
assert token.revokation_date