diff --git a/canaille/backends/__init__.py b/canaille/backends/__init__.py index a7bfcad1..dba168c2 100644 --- a/canaille/backends/__init__.py +++ b/canaille/backends/__init__.py @@ -85,6 +85,10 @@ class BaseBackend: only one element or :py:data:`None` if no item is matching.""" raise NotImplementedError() + def save(self, instance): + """Validate the current modifications in the database.""" + raise NotImplementedError() + def check_user_password(self, user, password: str) -> bool: """Check if the password matches the user password in the database.""" raise NotImplementedError() diff --git a/canaille/backends/ldap/backend.py b/canaille/backends/ldap/backend.py index e48f51e1..44c893a4 100644 --- a/canaille/backends/ldap/backend.py +++ b/canaille/backends/ldap/backend.py @@ -9,6 +9,7 @@ from flask import current_app from ldap.controls import DecodeControlTuples from ldap.controls.ppolicy import PasswordPolicyControl from ldap.controls.ppolicy import PasswordPolicyError +from ldap.controls.readentry import PostReadControl from canaille.app import models from canaille.app.configuration import ConfigurationException @@ -128,7 +129,7 @@ class Backend(BaseBackend): emails=f"canaille_{uuid.uuid4()}@mydomain.tld", password="correct horse battery staple", ) - user.save() + BaseBackend.instance.save(user) user.delete() except ldap.INSUFFICIENT_ACCESS as exc: @@ -147,13 +148,13 @@ class Backend(BaseBackend): emails=f"canaille_{uuid.uuid4()}@mydomain.tld", password="correct horse battery staple", ) - user.save() + BaseBackend.instance.save(user) group = models.Group( display_name=f"canaille_{uuid.uuid4()}", members=[user], ) - group.save() + BaseBackend.instance.save(group) group.delete() except ldap.INSUFFICIENT_ACCESS as exc: @@ -324,6 +325,69 @@ class Backend(BaseBackend): return None + def save(self, instance): + # run the instance save callback if existing + save_callback = instance.save() if hasattr(instance, "save") else iter([]) + next(save_callback, None) + + current_object_classes = instance.get_ldap_attribute("objectClass") or [] + instance.set_ldap_attribute( + "objectClass", + list(set(instance.ldap_object_class) | set(current_object_classes)), + ) + + # PostReadControl allows to read the updated object attributes on creation/edition + attributes = ["objectClass"] + [ + instance.python_attribute_to_ldap(name) for name in instance.attributes + ] + read_post_control = PostReadControl(criticality=True, attrList=attributes) + + # Object already exists in the LDAP database + if instance.exists: + deletions = [ + name + for name, value in instance.changes.items() + if ( + value is None + or value == [] + or (isinstance(value, list) and len(value) == 1 and not value[0]) + ) + and name in instance.state + ] + changes = { + name: value + for name, value in instance.changes.items() + if name not in deletions and instance.state.get(name) != value + } + formatted_changes = python_attrs_to_ldap(changes, null_allowed=False) + modlist = [(ldap.MOD_DELETE, name, None) for name in deletions] + [ + (ldap.MOD_REPLACE, name, values) + for name, values in formatted_changes.items() + ] + _, _, _, [result] = self.connection.modify_ext_s( + instance.dn, modlist, serverctrls=[read_post_control] + ) + + # Object does not exist yet in the LDAP database + else: + changes = { + name: value + for name, value in {**instance.state, **instance.changes}.items() + if value and value[0] + } + formatted_changes = python_attrs_to_ldap(changes, null_allowed=False) + modlist = [(name, values) for name, values in formatted_changes.items()] + _, _, _, [result] = self.connection.add_ext_s( + instance.dn, modlist, serverctrls=[read_post_control] + ) + + instance.exists = True + instance.state = {**result.entry, **instance.changes} + instance.changes = {} + + # run the instance save callback again if existing + next(save_callback, None) + def setup_ldap_models(config): from canaille.app import models diff --git a/canaille/backends/ldap/ldapobject.py b/canaille/backends/ldap/ldapobject.py index 2eacee7b..2955a0da 100644 --- a/canaille/backends/ldap/ldapobject.py +++ b/canaille/backends/ldap/ldapobject.py @@ -2,7 +2,6 @@ import typing import ldap.dn import ldap.filter -from ldap.controls.readentry import PostReadControl from canaille.backends.models import BackendModel @@ -264,64 +263,6 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass): self.changes = {} self.state = result[0][1] - def save(self): - conn = Backend.instance.connection - - current_object_classes = self.get_ldap_attribute("objectClass") or [] - self.set_ldap_attribute( - "objectClass", - list(set(self.ldap_object_class) | set(current_object_classes)), - ) - - # PostReadControl allows to read the updated object attributes on creation/edition - attributes = ["objectClass"] + [ - self.python_attribute_to_ldap(name) for name in self.attributes - ] - read_post_control = PostReadControl(criticality=True, attrList=attributes) - - # Object already exists in the LDAP database - if self.exists: - deletions = [ - name - for name, value in self.changes.items() - if ( - value is None - or value == [] - or (isinstance(value, list) and len(value) == 1 and not value[0]) - ) - and name in self.state - ] - changes = { - name: value - for name, value in self.changes.items() - if name not in deletions and self.state.get(name) != value - } - formatted_changes = python_attrs_to_ldap(changes, null_allowed=False) - modlist = [(ldap.MOD_DELETE, name, None) for name in deletions] + [ - (ldap.MOD_REPLACE, name, values) - for name, values in formatted_changes.items() - ] - _, _, _, [result] = conn.modify_ext_s( - self.dn, modlist, serverctrls=[read_post_control] - ) - - # Object does not exist yet in the LDAP database - else: - changes = { - name: value - for name, value in {**self.state, **self.changes}.items() - if value and value[0] - } - formatted_changes = python_attrs_to_ldap(changes, null_allowed=False) - modlist = [(name, values) for name, values in formatted_changes.items()] - _, _, _, [result] = conn.add_ext_s( - self.dn, modlist, serverctrls=[read_post_control] - ) - - self.exists = True - self.state = {**result.entry, **self.changes} - self.changes = {} - def delete(self): conn = Backend.instance.connection try: diff --git a/canaille/backends/ldap/models.py b/canaille/backends/ldap/models.py index c931ffcc..c4e99a3d 100644 --- a/canaille/backends/ldap/models.py +++ b/canaille/backends/ldap/models.py @@ -43,10 +43,10 @@ class User(canaille.core.models.User, LDAPObject): return super().match_filter(filter) - def save(self, *args, **kwargs): + def save(self): group_attr = self.python_attribute_to_ldap("groups") if group_attr not in self.changes: - return super().save(*args, **kwargs) + return # The LDAP attribute memberOf cannot directly be edited, # so this is needed to update the Group.member attribute @@ -60,11 +60,11 @@ class User(canaille.core.models.User, LDAPObject): to_del = set(old_groups) - set(new_groups) del self.changes[group_attr] - super().save(*args, **kwargs) + yield for group in to_add: group.members = group.members + [self] - group.save() + Backend.instance.save(group) for group in to_del: # LDAP groups cannot be empty because groupOfNames.member @@ -73,7 +73,7 @@ class User(canaille.core.models.User, LDAPObject): # TODO: properly manage the situation where one wants to # remove the last member of a group group.members = [member for member in group.members if member != self] - group.save() + Backend.instance.save(group) self.state[group_attr] = new_groups diff --git a/canaille/backends/memory/backend.py b/canaille/backends/memory/backend.py index fbeb3020..038bb0e2 100644 --- a/canaille/backends/memory/backend.py +++ b/canaille/backends/memory/backend.py @@ -1,3 +1,6 @@ +import datetime +import uuid + from canaille.backends import BaseBackend @@ -39,7 +42,7 @@ class Backend(BaseBackend): def set_user_password(self, user, password): user.password = password - user.save() + self.save(user) def query(self, model, **kwargs): # if there is no filter, return all models @@ -91,3 +94,17 @@ class Backend(BaseBackend): results = self.query(model, **kwargs) return results[0] if results else None + + def save(self, instance): + if not instance.id: + instance.id = str(uuid.uuid4()) + + instance.last_modified = datetime.datetime.now(datetime.timezone.utc).replace( + microsecond=0 + ) + if not instance.created: + instance.created = instance.last_modified + + instance.index_delete() + instance.index_save() + instance._cache = {} diff --git a/canaille/backends/memory/models.py b/canaille/backends/memory/models.py index 41ebb8ce..cc2cc33e 100644 --- a/canaille/backends/memory/models.py +++ b/canaille/backends/memory/models.py @@ -1,7 +1,5 @@ import copy -import datetime import typing -import uuid import canaille.core.models import canaille.oidc.models @@ -67,20 +65,6 @@ class MemoryModel(BackendModel): return value - def save(self): - if not self.id: - self.id = str(uuid.uuid4()) - - self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace( - microsecond=0 - ) - if not self.created: - self.created = self.last_modified - - self.index_delete() - self.index_save() - self._cache = {} - def delete(self): self.index_delete() diff --git a/canaille/backends/models.py b/canaille/backends/models.py index 3e8998b6..b07da30e 100644 --- a/canaille/backends/models.py +++ b/canaille/backends/models.py @@ -88,10 +88,6 @@ class BackendModel: implemented for every model and for every backend. """ - def save(self): - """Validate the current modifications in the database.""" - raise NotImplementedError() - def delete(self): """Remove the current instance from the database.""" raise NotImplementedError() diff --git a/canaille/backends/sql/backend.py b/canaille/backends/sql/backend.py index 39866386..4f8ecaf0 100644 --- a/canaille/backends/sql/backend.py +++ b/canaille/backends/sql/backend.py @@ -1,3 +1,5 @@ +import datetime + from sqlalchemy import create_engine from sqlalchemy import or_ from sqlalchemy import select @@ -66,7 +68,7 @@ class Backend(BaseBackend): def set_user_password(self, user, password): user.password = password - user.save() + self.save(user) def query(self, model, **kwargs): filter = [ @@ -106,3 +108,13 @@ class Backend(BaseBackend): return Backend.instance.db_session.execute( select(model).filter(*filter) ).scalar_one_or_none() + + def save(self, instance): + instance.last_modified = datetime.datetime.now(datetime.timezone.utc).replace( + microsecond=0 + ) + if not instance.created: + instance.created = instance.last_modified + + Backend.instance.db_session.add(instance) + Backend.instance.db_session.commit() diff --git a/canaille/backends/sql/models.py b/canaille/backends/sql/models.py index bd14f85a..302f4d37 100644 --- a/canaille/backends/sql/models.py +++ b/canaille/backends/sql/models.py @@ -47,16 +47,6 @@ class SqlAlchemyModel(BackendModel): return getattr(cls, name) == value - def save(self): - self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace( - microsecond=0 - ) - if not self.created: - self.created = self.last_modified - - Backend.instance.db_session.add(self) - Backend.instance.db_session.commit() - def delete(self): Backend.instance.db_session.delete(self) Backend.instance.db_session.commit() diff --git a/canaille/core/endpoints/account.py b/canaille/core/endpoints/account.py index a34af0da..fd74436e 100644 --- a/canaille/core/endpoints/account.py +++ b/canaille/core/endpoints/account.py @@ -404,7 +404,7 @@ def email_confirmation(data, hash): return redirect(url_for("core.account.index")) user.emails = user.emails + [confirmation_obj.email] - user.save() + BaseBackend.instance.save(user) flash(_("Your email address have been confirmed."), "success") return redirect(url_for("core.account.index")) @@ -460,11 +460,11 @@ def profile_create(current_app, form): given_name = user.given_name if user.given_name else "" family_name = user.family_name if user.family_name else "" user.formatted_name = f"{given_name} {family_name}".strip() - user.save() + BaseBackend.instance.save(user) if form["password1"].data: BaseBackend.instance.set_user_password(user, form["password1"].data) - user.save() + BaseBackend.instance.save(user) return user @@ -536,7 +536,7 @@ def profile_edition_main_form_validation(user, edited_user, profile_form): if profile_form["preferred_language"].data == "auto": edited_user.preferred_language = None - edited_user.save() + BaseBackend.instance.save(edited_user) g.user.reload() @@ -574,7 +574,7 @@ def profile_edition_remove_email(user, edited_user, email): return False edited_user.emails = [m for m in edited_user.emails if m != email] - edited_user.save() + BaseBackend.instance.save(edited_user) return True @@ -730,7 +730,7 @@ def profile_settings(user, edited_user): ): flash(_("The account has been locked"), "success") edited_user.lock_date = datetime.datetime.now(datetime.timezone.utc) - edited_user.save() + BaseBackend.instance.save(edited_user) return profile_settings_edit(user, edited_user) @@ -741,7 +741,7 @@ def profile_settings(user, edited_user): ): flash(_("The account has been unlocked"), "success") edited_user.lock_date = None - edited_user.save() + BaseBackend.instance.save(edited_user) return profile_settings_edit(user, edited_user) @@ -791,7 +791,7 @@ def profile_settings_edit(editor, edited_user): edited_user, form["password1"].data ) - edited_user.save() + BaseBackend.instance.save(edited_user) flash(_("Profile updated successfully."), "success") return redirect( url_for("core.account.profile_settings", edited_user=edited_user) diff --git a/canaille/core/endpoints/groups.py b/canaille/core/endpoints/groups.py index 3aeeb3a6..3edad29c 100644 --- a/canaille/core/endpoints/groups.py +++ b/canaille/core/endpoints/groups.py @@ -11,6 +11,7 @@ from canaille.app.flask import render_htmx_template from canaille.app.forms import TableForm from canaille.app.i18n import gettext as _ from canaille.app.themes import render_template +from canaille.backends import BaseBackend from .forms import CreateGroupForm from .forms import DeleteGroupMemberForm @@ -42,7 +43,7 @@ def create_group(user): group.members = [user] group.display_name = form.display_name.data group.description = form.description.data - group.save() + BaseBackend.instance.save(group) flash( _( "The group %(group)s has been sucessfully created", @@ -102,7 +103,7 @@ def edit_group(group): ): if form.validate(): group.description = form.description.data - group.save() + BaseBackend.instance.save(group) flash( _( "The group %(group)s has been sucessfully edited.", @@ -151,7 +152,7 @@ def delete_member(group): group.members = [ member for member in group.members if member != form.member.data ] - group.save() + BaseBackend.instance.save(group) return edit_group(group) diff --git a/canaille/core/populate.py b/canaille/core/populate.py index dbba5195..7dd0e9ec 100644 --- a/canaille/core/populate.py +++ b/canaille/core/populate.py @@ -40,7 +40,7 @@ def fake_users(nb=1): password=fake.password(), preferred_language=fake._locales[0], ) - user.save() + BaseBackend.instance.save(user) users.append(user) except Exception: # pragma: no cover pass @@ -59,7 +59,7 @@ def fake_groups(nb=1, nb_users_max=1): ) nb_users = random.randrange(1, nb_users_max + 1) group.members = list({random.choice(users) for _ in range(nb_users)}) - group.save() + BaseBackend.instance.save(group) groups.append(group) except Exception: # pragma: no cover pass diff --git a/canaille/oidc/endpoints/clients.py b/canaille/oidc/endpoints/clients.py index 2cc541af..3f223a52 100644 --- a/canaille/oidc/endpoints/clients.py +++ b/canaille/oidc/endpoints/clients.py @@ -14,6 +14,7 @@ from canaille.app.flask import render_htmx_template from canaille.app.forms import TableForm from canaille.app.i18n import gettext as _ from canaille.app.themes import render_template +from canaille.backends import BaseBackend from .forms import ClientAddForm @@ -73,9 +74,9 @@ def add(user): if form["token_endpoint_auth_method"].data == "none" else gen_salt(48), ) - client.save() + BaseBackend.instance.save(client) client.audience = [client] - client.save() + BaseBackend.instance.save(client) flash( _("The client has been created."), "success", @@ -137,7 +138,7 @@ def client_edit(client): audience=form["audience"].data, preconsent=form["preconsent"].data, ) - client.save() + BaseBackend.instance.save(client) flash( _("The client has been edited."), "success", diff --git a/canaille/oidc/endpoints/consents.py b/canaille/oidc/endpoints/consents.py index 1dcff0be..bb2f9eb9 100644 --- a/canaille/oidc/endpoints/consents.py +++ b/canaille/oidc/endpoints/consents.py @@ -95,7 +95,7 @@ def restore(user, consent): consent.restore() if not consent.issue_date: consent.issue_date = datetime.datetime.now(datetime.timezone.utc) - consent.save() + BaseBackend.instance.save(consent) flash(_("The access has been restored"), "success") return redirect(url_for("oidc.consents.consents")) @@ -119,7 +119,7 @@ def revoke_preconsent(user, client): scope=client.scope, ) consent.revoke() - consent.save() + BaseBackend.instance.save(consent) flash(_("The access has been revoked"), "success") return redirect(url_for("oidc.consents.consents")) diff --git a/canaille/oidc/endpoints/oauth.py b/canaille/oidc/endpoints/oauth.py index 9810993c..c116ac81 100644 --- a/canaille/oidc/endpoints/oauth.py +++ b/canaille/oidc/endpoints/oauth.py @@ -177,7 +177,7 @@ def authorize_consent(client, user): scope=allowed_scopes, issue_date=datetime.datetime.now(datetime.timezone.utc), ) - consent.save() + BaseBackend.instance.save(consent) response = authorization.create_authorization_response(grant_user=grant_user) current_app.logger.debug("authorization endpoint response: %s", response.location) diff --git a/canaille/oidc/endpoints/tokens.py b/canaille/oidc/endpoints/tokens.py index 234a255d..80587fcd 100644 --- a/canaille/oidc/endpoints/tokens.py +++ b/canaille/oidc/endpoints/tokens.py @@ -11,6 +11,7 @@ from canaille.app.flask import render_htmx_template from canaille.app.forms import TableForm from canaille.app.i18n import gettext as _ from canaille.app.themes import render_template +from canaille.backends import BaseBackend from .forms import TokenRevokationForm @@ -40,7 +41,7 @@ def view(user, token): elif request.form.get("action") == "revoke": token.revokation_date = datetime.datetime.now(datetime.timezone.utc) - token.save() + BaseBackend.instance.save(token) flash(_("The token has successfully been revoked."), "success") else: diff --git a/canaille/oidc/models.py b/canaille/oidc/models.py index b4edbd8f..7e4da07b 100644 --- a/canaille/oidc/models.py +++ b/canaille/oidc/models.py @@ -184,7 +184,7 @@ class Consent(BaseConsent): def revoke(self): self.revokation_date = datetime.datetime.now(datetime.timezone.utc) - self.save() + BaseBackend.instance.save(self) tokens = BaseBackend.instance.query( models.Token, @@ -194,8 +194,8 @@ class Consent(BaseConsent): tokens = [token for token in tokens if not token.revoked] for t in tokens: t.revokation_date = self.revokation_date - t.save() + BaseBackend.instance.save(t) def restore(self): self.revokation_date = None - self.save() + BaseBackend.instance.save(self) diff --git a/canaille/oidc/oauth.py b/canaille/oidc/oauth.py index a07efb6e..848140ec 100644 --- a/canaille/oidc/oauth.py +++ b/canaille/oidc/oauth.py @@ -228,7 +228,7 @@ def save_authorization_code(code, request): challenge=request.data.get("code_challenge"), challenge_method=request.data.get("code_challenge_method"), ) - code.save() + BaseBackend.instance.save(code) return code.code @@ -297,7 +297,7 @@ class RefreshTokenGrant(_RefreshTokenGrant): def revoke_old_credential(self, credential): credential.revokation_date = datetime.datetime.now(datetime.timezone.utc) - credential.save() + BaseBackend.instance.save(credential) class OpenIDImplicitGrant(_OpenIDImplicitGrant): @@ -351,7 +351,7 @@ def save_token(token, request): subject=request.user, audience=request.client.audience, ) - t.save() + BaseBackend.instance.save(t) class BearerTokenValidator(_BearerTokenValidator): @@ -382,7 +382,7 @@ class RevocationEndpoint(_RevocationEndpoint): def revoke_token(self, token, request): token.revokation_date = datetime.datetime.now(datetime.timezone.utc) - token.save() + BaseBackend.instance.save(token) class IntrospectionEndpoint(_IntrospectionEndpoint): @@ -463,9 +463,9 @@ class ClientRegistrationEndpoint(ClientManagementMixin, _ClientRegistrationEndpo post_logout_redirect_uris=request.data.get("post_logout_redirect_uris"), **self.client_convert_data(**client_info, **client_metadata), ) - client.save() + BaseBackend.instance.save(client) client.audience = [client] - client.save() + BaseBackend.instance.save(client) return client @@ -485,7 +485,7 @@ class ClientConfigurationEndpoint(ClientManagementMixin, _ClientConfigurationEnd def update_client(self, client, client_metadata, request): client.update(**self.client_convert_data(**client_metadata)) - client.save() + BaseBackend.instance.save(client) return client def generate_client_registration_info(self, client, request): diff --git a/tests/app/test_forms.py b/tests/app/test_forms.py index e4afd25f..c9b8f97a 100644 --- a/tests/app/test_forms.py +++ b/tests/app/test_forms.py @@ -206,13 +206,13 @@ def test_fieldlist_add_readonly(testclient, logged_user): testclient.post("/profile/user", data, status=403) -def test_fieldlist_remove_readonly(testclient, logged_user): +def test_fieldlist_remove_readonly(testclient, logged_user, backend): testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["WRITE"].remove("phone_numbers") testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["READ"].append("phone_numbers") logged_user.reload() logged_user.phone_numbers = ["555-555-000", "555-555-111"] - logged_user.save() + backend.save(logged_user) res = testclient.get("/profile/user") form = res.forms["baseform"] diff --git a/tests/app/test_i18n.py b/tests/app/test_i18n.py index 7185d705..ac9612c2 100644 --- a/tests/app/test_i18n.py +++ b/tests/app/test_i18n.py @@ -1,9 +1,9 @@ from flask_babel import refresh -def test_preferred_language(testclient, logged_user): +def test_preferred_language(testclient, logged_user, backend): logged_user.preferred_language = None - logged_user.save() + backend.save(logged_user) res = testclient.get("/profile/user", status=200) form = res.forms["baseform"] @@ -49,9 +49,9 @@ def test_preferred_language(testclient, logged_user): res.mustcontain(no="Mon profil") -def test_form_translations(testclient, logged_user): +def test_form_translations(testclient, logged_user, backend): logged_user.preferred_language = "fr" - logged_user.save() + backend.save(logged_user) res = testclient.get("/profile/user", status=200) form = res.forms["baseform"] @@ -62,9 +62,9 @@ def test_form_translations(testclient, logged_user): res.mustcontain("N’est pas un numéro de téléphone valide") -def test_language_config(testclient, logged_user): +def test_language_config(testclient, logged_user, backend): logged_user.preferred_language = None - logged_user.save() + backend.save(logged_user) res = testclient.get("/profile/user", status=200) assert res.pyquery("html")[0].attrib["lang"] == "en" diff --git a/tests/backends/ldap/test_models.py b/tests/backends/ldap/test_models.py index baea5416..8ee5cb91 100644 --- a/tests/backends/ldap/test_models.py +++ b/tests/backends/ldap/test_models.py @@ -7,7 +7,7 @@ def test_model_references_set_unsaved_object( """LDAP groups can be inconsistent by containing members which doesn't exist.""" group = models.Group(members=[user], display_name="foo") - group.save() + backend.save(group) user.reload() non_existent_user = models.User( @@ -16,7 +16,7 @@ def test_model_references_set_unsaved_object( group.members = group.members + [non_existent_user] assert group.members == [user, non_existent_user] - group.save() + backend.save(group) assert group.members == [user, non_existent_user] group.reload() diff --git a/tests/backends/ldap/test_object_class.py b/tests/backends/ldap/test_object_class.py index b7863459..8152cc2e 100644 --- a/tests/backends/ldap/test_object_class.py +++ b/tests/backends/ldap/test_object_class.py @@ -5,7 +5,7 @@ from canaille.backends.ldap.ldapobject import LDAPObject def test_guess_object_from_dn(backend, testclient, foo_group): foo_group.members = [foo_group] - foo_group.save() + backend.save(foo_group) dn = foo_group.dn g = backend.get(LDAPObject, dn) assert isinstance(g, models.Group) @@ -18,7 +18,7 @@ def test_object_class_update(backend, testclient): setup_ldap_models(testclient.app.config) user1 = models.User(cn="foo1", sn="bar1", user_name="baz1") - user1.save() + backend.save(user1) assert set(user1.get_ldap_attribute("objectClass")) == {"inetOrgPerson"} assert set( @@ -32,7 +32,7 @@ def test_object_class_update(backend, testclient): setup_ldap_models(testclient.app.config) user2 = models.User(cn="foo2", sn="bar2", user_name="baz2") - user2.save() + backend.save(user2) assert set(user2.get_ldap_attribute("objectClass")) == { "inetOrgPerson", @@ -48,7 +48,7 @@ def test_object_class_update(backend, testclient): user1 = backend.get(models.User, id=user1.id) assert user1.get_ldap_attribute("objectClass") == ["inetOrgPerson"] - user1.save() + backend.save(user1) assert set(user1.get_ldap_attribute("objectClass")) == { "inetOrgPerson", "extensibleObject", @@ -72,7 +72,7 @@ def test_keep_old_object_classes(backend, testclient, slapd_server): attributes. """ user = models.User(cn="foo", sn="bar", user_name="baz") - user.save() + backend.save(user) ldif = f"""dn: {user.dn} changetype: modify @@ -95,6 +95,6 @@ homeDirectory: /home/foobar user.reload() # saving an object should not raise a ldap.OBJECT_CLASS_VIOLATION exception - user.save() + backend.save(user) user.delete() diff --git a/tests/backends/ldap/test_utils.py b/tests/backends/ldap/test_utils.py index 36b87084..f14bfc0a 100644 --- a/tests/backends/ldap/test_utils.py +++ b/tests/backends/ldap/test_utils.py @@ -24,7 +24,7 @@ def test_object_creation(app, backend): emails=["john@doe.com"], ) assert not user.exists - user.save() + backend.save(user) assert user.exists user = backend.get(models.User, id=user.id) @@ -45,7 +45,7 @@ def test_dn_when_leading_space_in_id_attribute(testclient, backend): user_name=" user", emails=["john@doe.com"], ) - user.save() + backend.save(user) dn = user.dn assert dn == "uid=user,ou=users,dc=mydomain,dc=tld" @@ -66,7 +66,7 @@ def test_special_chars_in_rdn(testclient, backend): user_name="#user", # special char emails=["john@doe.com"], ) - user.save() + backend.save(user) dn = user.dn assert ldap.dn.is_dn(dn) diff --git a/tests/backends/test_models.py b/tests/backends/test_models.py index 57086a9f..9fe606e8 100644 --- a/tests/backends/test_models.py +++ b/tests/backends/test_models.py @@ -12,13 +12,13 @@ def test_model_comparison(testclient, backend): family_name="foo", formatted_name="foo", ) - foo1.save() + backend.save(foo1) bar = models.User( user_name="bar", family_name="bar", formatted_name="bar", ) - bar.save() + backend.save(bar) foo2 = backend.get(models.User, id=foo1.id) assert foo1 == foo2 @@ -41,7 +41,7 @@ def test_model_lifecycle(testclient, backend): assert not backend.query(models.User, id="invalid") assert not backend.get(models.User, id=user.id) - user.save() + backend.save(user) assert backend.query(models.User) == [user] assert backend.query(models.User, id=user.id) == [user] @@ -72,7 +72,7 @@ def test_model_attribute_edition(testclient, backend): display_name="display_name", emails=["email1@user.com", "email2@user.com"], ) - user.save() + backend.save(user) assert user.user_name == "user_name" assert user.family_name == "family_name" @@ -85,7 +85,7 @@ def test_model_attribute_edition(testclient, backend): user.family_name = "new_family_name" user.emails = ["email1@user.com"] - user.save() + backend.save(user) assert user.family_name == "new_family_name" assert user.emails == ["email1@user.com"] @@ -97,7 +97,7 @@ def test_model_attribute_edition(testclient, backend): user.display_name = "" assert not user.display_name - user.save() + backend.save(user) assert not user.display_name user.delete() @@ -110,7 +110,7 @@ def test_model_indexation(testclient, backend): formatted_name="formatted_name", emails=["email1@user.com", "email2@user.com"], ) - user.save() + backend.save(user) assert backend.get(models.User, family_name="family_name") == user assert not backend.get(models.User, family_name="new_family_name") @@ -125,7 +125,7 @@ def test_model_indexation(testclient, backend): assert backend.get(models.User, emails=["email1@user.com"]) != user assert not backend.get(models.User, emails=["email3@user.com"]) - user.save() + backend.save(user) assert not backend.get(models.User, family_name="family_name") assert backend.get(models.User, family_name="new_family_name") == user @@ -177,14 +177,14 @@ def test_model_references(testclient, user, foo_group, admin, bar_group, backend assert user not in bar_group.members assert bar_group not in user.groups user.groups = user.groups + [bar_group] - user.save() + backend.save(user) bar_group.reload() assert user in bar_group.members assert bar_group in user.groups bar_group.members = [admin] - bar_group.save() + backend.save(bar_group) user.reload() assert user not in bar_group.members @@ -201,7 +201,7 @@ def test_model_creation_edition_datetime(testclient, backend): family_name="foo", formatted_name="foo", ) - user.save() + backend.save(user) assert user.created == datetime.datetime( 2020, 1, 1, 2, tzinfo=datetime.timezone.utc ) @@ -211,7 +211,7 @@ def test_model_creation_edition_datetime(testclient, backend): with time_machine.travel("2021-01-01 02:00:00+00:00", tick=False): user.family_name = "bar" - user.save() + backend.save(user) assert user.created == datetime.datetime( 2020, 1, 1, 2, tzinfo=datetime.timezone.utc ) diff --git a/tests/conftest.py b/tests/conftest.py index 8143126f..fa0a5319 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -191,7 +191,7 @@ def user(app, backend): profile_url="https://john.example", formatted_address="1235, somewhere", ) - u.save() + backend.save(u) yield u u.delete() @@ -205,7 +205,7 @@ def admin(app, backend): emails=["jane@doe.com"], password="admin", ) - u.save() + backend.save(u) yield u u.delete() @@ -219,7 +219,7 @@ def moderator(app, backend): emails=["jack@doe.com"], password="moderator", ) - u.save() + backend.save(u) yield u u.delete() @@ -251,7 +251,7 @@ def foo_group(app, user, backend): members=[user], display_name="foo", ) - group.save() + backend.save(group) user.reload() yield group group.delete() @@ -263,7 +263,7 @@ def bar_group(app, admin, backend): members=[admin], display_name="bar", ) - group.save() + backend.save(group) admin.reload() yield group group.delete() diff --git a/tests/core/test_account.py b/tests/core/test_account.py index 133d2374..a7fb293a 100644 --- a/tests/core/test_account.py +++ b/tests/core/test_account.py @@ -27,7 +27,7 @@ def test_user_deleted_in_session(testclient, backend): emails=["jake@doe.com"], password="correct horse battery staple", ) - u.save() + backend.save(u) testclient.get("/profile/jake", status=403) with testclient.session_transaction() as session: @@ -66,7 +66,7 @@ def test_admin_self_deletion(testclient, backend): emails=["temp@temp.com"], password="admin", ) - admin.save() + backend.save(admin) with testclient.session_transaction() as sess: sess["user_id"] = [admin.id] @@ -92,7 +92,7 @@ def test_user_self_deletion(testclient, backend): emails=["temp@temp.com"], password="correct horse battery staple", ) - user.save() + backend.save(user) with testclient.session_transaction() as sess: sess["user_id"] = [user.id] @@ -134,7 +134,7 @@ def test_account_locking(user, backend): user.lock_date = datetime.datetime.now(datetime.timezone.utc) assert user.locked - user.save() + backend.save(user) assert user.locked assert backend.get(models.User, id=user.id).locked assert backend.check_user_password(user, "correct horse battery staple") == ( @@ -143,7 +143,7 @@ def test_account_locking(user, backend): ) user.lock_date = None - user.save() + backend.save(user) assert not user.locked assert not backend.get(models.User, id=user.id).locked assert backend.check_user_password(user, "correct horse battery staple") == ( @@ -163,7 +163,7 @@ def test_account_locking_past_date(user, backend): user.lock_date = datetime.datetime.now(datetime.timezone.utc).replace( microsecond=0 ) - datetime.timedelta(days=30) - user.save() + backend.save(user) assert user.locked assert backend.get(models.User, id=user.id).locked assert backend.check_user_password(user, "correct horse battery staple") == ( @@ -183,7 +183,7 @@ def test_account_locking_future_date(user, backend): user.lock_date = datetime.datetime.now(datetime.timezone.utc).replace( microsecond=0 ) + datetime.timedelta(days=365 * 4) - user.save() + backend.save(user) assert not user.locked assert not backend.get(models.User, id=user.id).locked assert backend.check_user_password(user, "correct horse battery staple") == ( @@ -192,7 +192,7 @@ def test_account_locking_future_date(user, backend): ) -def test_account_locked_during_session(testclient, logged_user): +def test_account_locked_during_session(testclient, logged_user, backend): logged_user.lock_date = datetime.datetime.now(datetime.timezone.utc) - logged_user.save() + backend.save(logged_user) testclient.get("/profile/user/settings", status=403) diff --git a/tests/core/test_auth.py b/tests/core/test_auth.py index 2b873616..b6e5718e 100644 --- a/tests/core/test_auth.py +++ b/tests/core/test_auth.py @@ -154,7 +154,7 @@ def test_user_without_password_first_login(testclient, backend, smtpd): user_name="temp", emails=["john@doe.com", "johhny@doe.com"], ) - u.save() + backend.save(u) res = testclient.get("/login", status=200) res.form["login"] = "temp" @@ -189,7 +189,7 @@ def test_first_login_account_initialization_mail_sending_failed( user_name="temp", emails=["john@doe.com"], ) - u.save() + backend.save(u) res = testclient.get("/firstlogin/temp") res = res.form.submit(name="action", value="sendmail", expect_errors=True) @@ -211,7 +211,7 @@ def test_first_login_form_error(testclient, backend, smtpd): user_name="temp", emails=["john@doe.com"], ) - u.save() + backend.save(u) res = testclient.get("/firstlogin/temp", status=200) res.form["csrf_token"] = "invalid" @@ -236,7 +236,7 @@ def test_user_password_deleted_during_login(testclient, backend): emails=["john@doe.com"], password="correct horse battery staple", ) - u.save() + backend.save(u) res = testclient.get("/login") res.form["login"] = "temp" @@ -244,7 +244,7 @@ def test_user_password_deleted_during_login(testclient, backend): res.form["password"] = "correct horse battery staple" u.password = None - u.save() + backend.save(u) res = res.form.submit(status=302) assert res.location == "/firstlogin/temp" @@ -272,12 +272,12 @@ def test_wrong_login(testclient, user): res.mustcontain("The login 'invalid' does not exist") -def test_signin_locked_account(testclient, user): +def test_signin_locked_account(testclient, user, backend): with testclient.session_transaction() as session: assert not session.get("user_id") user.lock_date = datetime.datetime.now(datetime.timezone.utc) - user.save() + backend.save(user) res = testclient.get("/login", status=200) res.form["login"] = "user" @@ -289,4 +289,4 @@ def test_signin_locked_account(testclient, user): res.mustcontain("Your account has been locked.") user.lock_date = None - user.save() + backend.save(user) diff --git a/tests/core/test_email_confirmation.py b/tests/core/test_email_confirmation.py index c8ff1730..cf676234 100644 --- a/tests/core/test_email_confirmation.py +++ b/tests/core/test_email_confirmation.py @@ -371,14 +371,14 @@ def test_confirmation_email_already_used_link(testclient, backend, user, admin): assert "new_email@mydomain.tld" not in user.emails -def test_delete_email(testclient, logged_user): +def test_delete_email(testclient, logged_user, backend): """Tests that user can deletes its emails unless they have only one left.""" res = testclient.get("/profile/user") assert "email_remove" not in res.forms["emailconfirmationform"].fields logged_user.emails = logged_user.emails + ["new@email.com"] - logged_user.save() + backend.save(logged_user) res = testclient.get("/profile/user") assert "email_remove" in res.forms["emailconfirmationform"].fields @@ -391,10 +391,10 @@ def test_delete_email(testclient, logged_user): assert logged_user.emails == ["john@doe.com"] -def test_delete_wrong_email(testclient, logged_user): +def test_delete_wrong_email(testclient, logged_user, backend): """Tests that removing an already removed email do not produce anything.""" logged_user.emails = logged_user.emails + ["new@email.com"] - logged_user.save() + backend.save(logged_user) res = testclient.get("/profile/user") @@ -412,10 +412,10 @@ def test_delete_wrong_email(testclient, logged_user): assert logged_user.emails == ["john@doe.com"] -def test_delete_last_email(testclient, logged_user): +def test_delete_last_email(testclient, logged_user, backend): """Tests that users cannot remove their last email address.""" logged_user.emails = logged_user.emails + ["new@email.com"] - logged_user.save() + backend.save(logged_user) res = testclient.get("/profile/user") diff --git a/tests/core/test_forgotten_password.py b/tests/core/test_forgotten_password.py index 5c2342d1..ee4568d5 100644 --- a/tests/core/test_forgotten_password.py +++ b/tests/core/test_forgotten_password.py @@ -26,9 +26,9 @@ def test_password_forgotten(smtpd, testclient, user): assert len(smtpd.messages) == 1 -def test_password_forgotten_multiple_mails(smtpd, testclient, user): +def test_password_forgotten_multiple_mails(smtpd, testclient, user, backend): user.emails = ["foo@bar.com", "foo@baz.com", "foo@foo.com"] - user.save() + backend.save(user) res = testclient.get("/reset", status=200) diff --git a/tests/core/test_groups.py b/tests/core/test_groups.py index bd9d6bb3..01605086 100644 --- a/tests/core/test_groups.py +++ b/tests/core/test_groups.py @@ -53,13 +53,13 @@ def test_group_deletion(testclient, backend): user_name="foobar", emails=["foo@bar.com"], ) - user.save() + backend.save(user) group = models.Group( members=[user], display_name="foobar", ) - group.save() + backend.save(group) user.reload() assert user.groups == [group] @@ -86,19 +86,19 @@ def test_group_list_search(testclient, logged_admin, foo_group, bar_group): res.mustcontain(no=bar_group.display_name) -def test_set_groups(app, user, foo_group, bar_group): +def test_set_groups(app, user, foo_group, bar_group, backend): assert user in foo_group.members assert user.groups == [foo_group] user.groups = [foo_group, bar_group] - user.save() + backend.save(user) bar_group.reload() assert user in bar_group.members assert bar_group in user.groups user.groups = [foo_group] - user.save() + backend.save(user) foo_group.reload() bar_group.reload() @@ -106,23 +106,23 @@ def test_set_groups(app, user, foo_group, bar_group): assert user not in bar_group.members -def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group): +def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group, backend): user = models.User( formatted_name=" Doe", # leading space in id attribute family_name="Doe", user_name="user2", emails=["john@doe.com"], ) - user.save() + backend.save(user) user.groups = [foo_group] - user.save() + backend.save(user) foo_group.reload() assert user in foo_group.members user.groups = [] - user.save() + backend.save(user) foo_group.reload() assert user.id not in foo_group.members @@ -231,14 +231,14 @@ def test_edition_failed(testclient, logged_moderator, foo_group): assert foo_group.display_name == "foo" -def test_user_list_pagination(testclient, logged_admin, foo_group): +def test_user_list_pagination(testclient, logged_admin, foo_group, backend): res = testclient.get("/groups/foo") res.mustcontain("1 item") users = fake_users(25) for user in users: foo_group.members = foo_group.members + [user] - foo_group.save() + backend.save(foo_group) assert len(foo_group.members) == 26 res = testclient.get("/groups/foo") @@ -274,9 +274,11 @@ def test_user_list_bad_pages(testclient, logged_admin, foo_group): ) -def test_user_list_search(testclient, logged_admin, foo_group, user, moderator): +def test_user_list_search( + testclient, logged_admin, foo_group, user, moderator, backend +): foo_group.members = foo_group.members + [logged_admin, moderator] - foo_group.save() + backend.save(foo_group) res = testclient.get("/groups/foo") res.mustcontain("3 items") @@ -294,9 +296,9 @@ def test_user_list_search(testclient, logged_admin, foo_group, user, moderator): res.mustcontain(no=moderator.formatted_name) -def test_remove_member(testclient, logged_admin, foo_group, user, moderator): +def test_remove_member(testclient, logged_admin, foo_group, user, moderator, backend): foo_group.members = [user, moderator] - foo_group.save() + backend.save(foo_group) res = testclient.get("/groups/foo") form = res.forms[f"deletegroupmemberform-{user.id}"] @@ -313,15 +315,15 @@ def test_remove_member(testclient, logged_admin, foo_group, user, moderator): def test_remove_member_already_remove_from_group( - testclient, logged_admin, foo_group, user, moderator + testclient, logged_admin, foo_group, user, moderator, backend ): foo_group.members = [user, moderator] - foo_group.save() + backend.save(foo_group) res = testclient.get("/groups/foo") form = res.forms[f"deletegroupmemberform-{user.id}"] foo_group.members = [moderator] - foo_group.save() + backend.save(foo_group) res = form.submit(name="action", value="confirm-remove-member") assert ( @@ -331,17 +333,17 @@ def test_remove_member_already_remove_from_group( def test_confirm_remove_member_already_removed_from_group( - testclient, logged_admin, foo_group, user, moderator + testclient, logged_admin, foo_group, user, moderator, backend ): foo_group.members = [user, moderator] - foo_group.save() + backend.save(foo_group) res = testclient.get("/groups/foo") form = res.forms[f"deletegroupmemberform-{user.id}"] res = form.submit(name="action", value="confirm-remove-member") foo_group.members = [moderator] - foo_group.save() + backend.save(foo_group) res = res.form.submit(name="action", value="remove-member") assert ( @@ -351,10 +353,10 @@ def test_confirm_remove_member_already_removed_from_group( def test_remove_member_already_deleted( - testclient, logged_admin, foo_group, user, moderator + testclient, logged_admin, foo_group, user, moderator, backend ): foo_group.members = [user, moderator] - foo_group.save() + backend.save(foo_group) res = testclient.get("/groups/foo") form = res.forms[f"deletegroupmemberform-{user.id}"] @@ -368,10 +370,10 @@ def test_remove_member_already_deleted( def test_confirm_remove_member_already_deleted( - testclient, logged_admin, foo_group, user, moderator + testclient, logged_admin, foo_group, user, moderator, backend ): foo_group.members = [user, moderator] - foo_group.save() + backend.save(foo_group) res = testclient.get("/groups/foo") form = res.forms[f"deletegroupmemberform-{user.id}"] diff --git a/tests/core/test_models.py b/tests/core/test_models.py index 9d352164..8fc66247 100644 --- a/tests/core/test_models.py +++ b/tests/core/test_models.py @@ -13,13 +13,13 @@ def test_user_has_password(testclient, backend): user_name="temp", emails=["john@doe.com"], ) - user.save() + backend.save(user) assert user.password is None assert not user.has_password() user.password = "foobar" - user.save() + backend.save(user) assert user.password is not None assert user.has_password() diff --git a/tests/core/test_password_reset.py b/tests/core/test_password_reset.py index 3249a31d..dc7fec8f 100644 --- a/tests/core/test_password_reset.py +++ b/tests/core/test_password_reset.py @@ -25,7 +25,7 @@ def test_password_reset(testclient, user, backend): def test_password_reset_multiple_emails(testclient, user, backend): user.emails = ["foo@bar.com", "foo@baz.com"] - user.save() + backend.save(user) assert not backend.check_user_password(user, "foobarbaz")[0] hash = build_hash("user", "foo@baz.com", user.password) diff --git a/tests/core/test_profile_edition.py b/tests/core/test_profile_edition.py index e22fc50b..6f8cfc96 100644 --- a/tests/core/test_profile_edition.py +++ b/tests/core/test_profile_edition.py @@ -118,6 +118,7 @@ def test_edition( logged_user, admin, jpeg_photo, + backend, ): res = testclient.get("/profile/user", status=200) form = res.forms["baseform"] @@ -168,13 +169,14 @@ def test_edition( logged_user.emails = ["john@doe.com"] logged_user.given_name = None logged_user.photo = None - logged_user.save() + backend.save(logged_user) def test_edition_remove_fields( testclient, logged_user, admin, + backend, ): res = testclient.get("/profile/user", status=200) form = res.forms["baseform"] @@ -195,13 +197,13 @@ def test_edition_remove_fields( logged_user.emails = ["john@doe.com"] logged_user.given_name = None logged_user.photo = None - logged_user.save() + backend.save(logged_user) -def test_field_permissions_none(testclient, logged_user): +def test_field_permissions_none(testclient, logged_user, backend): testclient.get("/profile/user", status=200) logged_user.phone_numbers = ["555-666-777"] - logged_user.save() + backend.save(logged_user) testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"] = { "READ": ["user_name"], @@ -227,10 +229,10 @@ def test_field_permissions_none(testclient, logged_user): assert logged_user.phone_numbers == ["555-666-777"] -def test_field_permissions_read(testclient, logged_user): +def test_field_permissions_read(testclient, logged_user, backend): testclient.get("/profile/user", status=200) logged_user.phone_numbers = ["555-666-777"] - logged_user.save() + backend.save(logged_user) testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"] = { "READ": ["user_name", "phone_numbers"], @@ -255,10 +257,10 @@ def test_field_permissions_read(testclient, logged_user): assert logged_user.phone_numbers == ["555-666-777"] -def test_field_permissions_write(testclient, logged_user): +def test_field_permissions_write(testclient, logged_user, backend): testclient.get("/profile/user", status=200) logged_user.phone_numbers = ["555-666-777"] - logged_user.save() + backend.save(logged_user) testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"] = { "READ": ["user_name"], diff --git a/tests/core/test_profile_photo.py b/tests/core/test_profile_photo.py index d173565a..466c4bf6 100644 --- a/tests/core/test_profile_photo.py +++ b/tests/core/test_profile_photo.py @@ -5,9 +5,9 @@ from webtest import Upload from canaille.app import models -def test_photo(testclient, user, jpeg_photo): +def test_photo(testclient, user, jpeg_photo, backend): user.photo = jpeg_photo - user.save() + backend.save(user) user.reload() res = testclient.get("/profile/user/photo") diff --git a/tests/core/test_profile_settings.py b/tests/core/test_profile_settings.py index 6c159c34..547e77cf 100644 --- a/tests/core/test_profile_settings.py +++ b/tests/core/test_profile_settings.py @@ -34,13 +34,13 @@ def test_edition(testclient, logged_user, admin, foo_group, bar_group, backend): assert backend.check_user_password(logged_user, "correct horse battery staple")[0] logged_user.user_name = "user" - logged_user.save() + backend.save(logged_user) def test_group_removal(testclient, logged_admin, user, foo_group, backend): """Tests that one can remove a group from a user.""" foo_group.members = [user, logged_admin] - foo_group.save() + backend.save(foo_group) user.reload() assert foo_group in user.groups @@ -115,7 +115,7 @@ def test_edition_without_groups( assert backend.check_user_password(logged_user, "correct horse battery staple")[0] logged_user.user_name = "user" - logged_user.save() + backend.save(logged_user) def test_password_change(testclient, logged_user, backend): @@ -171,7 +171,7 @@ def test_password_initialization_mail(smtpd, testclient, backend, logged_admin): user_name="temp", emails=["john@doe.com"], ) - u.save() + backend.save(u) res = testclient.get("/profile/temp/settings", status=200) res.mustcontain("This user does not have a password yet") @@ -188,7 +188,7 @@ def test_password_initialization_mail(smtpd, testclient, backend, logged_admin): u.reload() u.password = "correct horse battery staple" - u.save() + backend.save(u) res = testclient.get("/profile/temp/settings", status=200) res.mustcontain(no="This user does not have a password yet") @@ -207,7 +207,7 @@ def test_password_initialization_mail_send_fail( user_name="temp", emails=["john@doe.com"], ) - u.save() + backend.save(u) res = testclient.get("/profile/temp/settings", status=200) res.mustcontain("This user does not have a password yet") @@ -272,7 +272,7 @@ def test_impersonate_locked_user(testclient, backend, logged_admin, user): user.lock_date = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta( days=1 ) - user.save() + backend.save(user) assert user.locked res = testclient.get("/profile/user/settings") @@ -295,7 +295,7 @@ def test_password_reset_email(smtpd, testclient, backend, logged_admin): emails=["john@doe.com"], password="correct horse battery staple", ) - u.save() + backend.save(u) res = testclient.get("/profile/temp/settings", status=200) res.mustcontain("If the user has forgotten his password") @@ -323,7 +323,7 @@ def test_password_reset_email_failed(SMTP, smtpd, testclient, backend, logged_ad emails=["john@doe.com"], password="correct horse battery staple", ) - u.save() + backend.save(u) res = testclient.get("/profile/temp/settings", status=200) res.mustcontain("If the user has forgotten his password") @@ -454,7 +454,7 @@ def test_empty_lock_date( second=0, microsecond=0 ) + datetime.timedelta(days=30) user.lock_date = expiration_datetime - user.save() + backend.save(user) res = testclient.get("/profile/user/settings", status=200) res.form["lock_date"] = "" diff --git a/tests/oidc/commands/test_clean.py b/tests/oidc/commands/test_clean.py index 69daf252..dba69479 100644 --- a/tests/oidc/commands/test_clean.py +++ b/tests/oidc/commands/test_clean.py @@ -21,7 +21,7 @@ def test_clean_command(testclient, backend, client, user): challenge="challenge", challenge_method="method", ) - valid_code.save() + backend.save(valid_code) expired_code = models.AuthorizationCode( authorization_code_id=gen_salt(48), code="my-expired-code", @@ -39,7 +39,7 @@ def test_clean_command(testclient, backend, client, user): challenge="challenge", challenge_method="method", ) - expired_code.save() + backend.save(expired_code) valid_token = models.Token( token_id=gen_salt(48), @@ -53,7 +53,7 @@ def test_clean_command(testclient, backend, client, user): ), lifetime=3600, ) - valid_token.save() + backend.save(valid_token) expired_token = models.Token( token_id=gen_salt(48), access_token="my-expired-token", @@ -67,7 +67,7 @@ def test_clean_command(testclient, backend, client, user): ), lifetime=3600, ) - expired_token.save() + backend.save(expired_token) assert backend.get(models.AuthorizationCode, code="my-expired-code") assert backend.get(models.Token, access_token="my-expired-token") diff --git a/tests/oidc/conftest.py b/tests/oidc/conftest.py index f4cc958d..c83baf6e 100644 --- a/tests/oidc/conftest.py +++ b/tests/oidc/conftest.py @@ -67,9 +67,9 @@ def client(testclient, trusted_client, backend): token_endpoint_auth_method="client_secret_basic", post_logout_redirect_uris=["https://mydomain.tld/disconnected"], ) - c.save() + backend.save(c) c.audience = [c, trusted_client] - c.save() + backend.save(c) yield c c.delete() @@ -105,9 +105,9 @@ def trusted_client(testclient, backend): post_logout_redirect_uris=["https://myotherdomain.tld/disconnected"], preconsent=True, ) - c.save() + backend.save(c) c.audience = [c] - c.save() + backend.save(c) yield c c.delete() @@ -129,7 +129,7 @@ def authorization(testclient, user, client, backend): challenge="challenge", challenge_method="method", ) - a.save() + backend.save(a) yield a a.delete() @@ -147,7 +147,7 @@ def token(testclient, client, user, backend): issue_date=datetime.datetime.now(datetime.timezone.utc), lifetime=3600, ) - t.save() + backend.save(t) yield t t.delete() @@ -171,6 +171,6 @@ def consent(testclient, client, user, backend): scope=["openid", "profile"], issue_date=datetime.datetime.now(datetime.timezone.utc), ) - t.save() + backend.save(t) yield t t.delete() diff --git a/tests/oidc/test_authorization_code_flow.py b/tests/oidc/test_authorization_code_flow.py index 9c8d30e5..bdee6ec7 100644 --- a/tests/oidc/test_authorization_code_flow.py +++ b/tests/oidc/test_authorization_code_flow.py @@ -167,7 +167,7 @@ def test_preconsented_client( assert not backend.query(models.Consent) client.preconsent = True - client.save() + backend.save(client) res = testclient.get( "/oauth/authorize", @@ -318,7 +318,7 @@ def test_code_challenge(testclient, logged_user, client, backend): assert not backend.query(models.Consent) client.token_endpoint_auth_method = "none" - client.save() + backend.save(client) code_verifier = gen_salt(48) code_challenge = create_s256_code_challenge(code_verifier) @@ -373,7 +373,7 @@ def test_code_challenge(testclient, logged_user, client, backend): assert res.json["name"] == "John (johnny) Doe" client.token_endpoint_auth_method = "client_secret_basic" - client.save() + backend.save(client) for consent in consents: consent.delete() @@ -565,7 +565,7 @@ def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client, b def test_request_scope_too_large(testclient, logged_user, keypair, client, backend): assert not backend.query(models.Consent) client.scope = ["openid", "profile", "groups"] - client.save() + backend.save(client) res = testclient.get( "/oauth/authorize", @@ -679,7 +679,7 @@ def test_code_with_invalid_user(testclient, admin, client, backend): emails=["temp@temp.com"], password="correct horse battery staple", ) - user.save() + backend.save(user) res = testclient.get( "/oauth/authorize", @@ -738,7 +738,7 @@ def test_locked_account( ) logged_user.lock_date = datetime.datetime.now(datetime.timezone.utc) - logged_user.save() + backend.save(logged_user) res = res.form.submit(name="answer", value="accept", status=302) diff --git a/tests/oidc/test_authorization_prompt.py b/tests/oidc/test_authorization_prompt.py index 4e20ecc9..05275eb9 100644 --- a/tests/oidc/test_authorization_prompt.py +++ b/tests/oidc/test_authorization_prompt.py @@ -14,7 +14,7 @@ from canaille.app import models from canaille.core.endpoints.account import RegistrationPayload -def test_prompt_none(testclient, logged_user, client): +def test_prompt_none(testclient, logged_user, client, backend): """Nominal case with prompt=none.""" consent = models.Consent( consent_id=str(uuid.uuid4()), @@ -22,7 +22,7 @@ def test_prompt_none(testclient, logged_user, client): subject=logged_user, scope=["openid", "profile"], ) - consent.save() + backend.save(consent) res = testclient.get( "/oauth/authorize", @@ -42,7 +42,7 @@ def test_prompt_none(testclient, logged_user, client): consent.delete() -def test_prompt_not_logged(testclient, user, client): +def test_prompt_not_logged(testclient, user, client, backend): """Prompt=none should return a login_required error when no user is logged in. @@ -58,7 +58,7 @@ def test_prompt_not_logged(testclient, user, client): subject=user, scope=["openid", "profile"], ) - consent.save() + backend.save(consent) res = testclient.get( "/oauth/authorize", @@ -100,7 +100,7 @@ def test_prompt_no_consent(testclient, logged_user, client): assert "consent_required" == res.json.get("error") -def test_prompt_create_logged(testclient, logged_user, client): +def test_prompt_create_logged(testclient, logged_user, client, backend): """If prompt=create and user is already logged in, then go straight to the consent page.""" testclient.app.config["CANAILLE"]["ENABLE_REGISTRATION"] = True @@ -111,7 +111,7 @@ def test_prompt_create_logged(testclient, logged_user, client): subject=logged_user, scope=["openid", "profile"], ) - consent.save() + backend.save(consent) res = testclient.get( "/oauth/authorize", diff --git a/tests/oidc/test_client_admin.py b/tests/oidc/test_client_admin.py index c42bf219..cbba9dc0 100644 --- a/tests/oidc/test_client_admin.py +++ b/tests/oidc/test_client_admin.py @@ -22,13 +22,15 @@ def test_client_list(testclient, client, logged_admin): res.mustcontain(client.client_name) -def test_client_list_pagination(testclient, logged_admin, client, trusted_client): +def test_client_list_pagination( + testclient, logged_admin, client, trusted_client, backend +): res = testclient.get("/admin/client") res.mustcontain("2 items") clients = [] for _ in range(25): client = models.Client(client_id=gen_salt(48), client_name=gen_salt(48)) - client.save() + backend.save(client) clients.append(client) res = testclient.get("/admin/client") @@ -216,22 +218,22 @@ def test_client_edit_missing_fields(testclient, client, logged_admin, trusted_cl def test_client_delete(testclient, logged_admin, backend): client = models.Client(client_id="client_id") - client.save() + backend.save(client) token = models.Token( token_id="id", client=client, subject=logged_admin, issue_date=datetime.datetime.now(datetime.timezone.utc), ) - token.save() + backend.save(token) consent = models.Consent( consent_id="consent_id", subject=logged_admin, client=client, scope=["openid"] ) - consent.save() + backend.save(consent) authorization_code = models.AuthorizationCode( authorization_code_id="id", client=client, subject=logged_admin ) - authorization_code.save() + backend.save(authorization_code) res = testclient.get("/admin/client/edit/" + client.client_id) res = res.forms["clientaddform"].submit(name="action", value="confirm-delete") diff --git a/tests/oidc/test_code_admin.py b/tests/oidc/test_code_admin.py index d41cb870..9a79ecb4 100644 --- a/tests/oidc/test_code_admin.py +++ b/tests/oidc/test_code_admin.py @@ -16,7 +16,7 @@ def test_authorizaton_list(testclient, authorization, logged_admin): res.mustcontain(authorization.authorization_code_id) -def test_authorization_list_pagination(testclient, logged_admin, client): +def test_authorization_list_pagination(testclient, logged_admin, client, backend): res = testclient.get("/admin/authorization") res.mustcontain("0 items") authorizations = [] @@ -24,7 +24,7 @@ def test_authorization_list_pagination(testclient, logged_admin, client): code = models.AuthorizationCode( authorization_code_id=gen_salt(48), client=client, subject=logged_admin ) - code.save() + backend.save(code) authorizations.append(code) res = testclient.get("/admin/authorization") @@ -64,18 +64,18 @@ def test_authorization_list_bad_pages(testclient, logged_admin): ) -def test_authorization_list_search(testclient, logged_admin, client): +def test_authorization_list_search(testclient, logged_admin, client, backend): id1 = gen_salt(48) auth1 = models.AuthorizationCode( authorization_code_id=id1, client=client, subject=logged_admin ) - auth1.save() + backend.save(auth1) id2 = gen_salt(48) auth2 = models.AuthorizationCode( authorization_code_id=id2, client=client, subject=logged_admin ) - auth2.save() + backend.save(auth2) res = testclient.get("/admin/authorization") res.mustcontain("2 items") diff --git a/tests/oidc/test_consent.py b/tests/oidc/test_consent.py index 16787021..8edc86ab 100644 --- a/tests/oidc/test_consent.py +++ b/tests/oidc/test_consent.py @@ -139,13 +139,15 @@ def test_oidc_authorization_after_revokation( assert token.subject == logged_user -def test_preconsented_client_appears_in_consent_list(testclient, client, logged_user): +def test_preconsented_client_appears_in_consent_list( + testclient, client, logged_user, backend +): assert not client.preconsent res = testclient.get("/consent/pre-consents") res.mustcontain(no=client.client_name) client.preconsent = True - client.save() + backend.save(client) res = testclient.get("/consent/pre-consents") res.mustcontain(client.client_name) @@ -153,7 +155,7 @@ def test_preconsented_client_appears_in_consent_list(testclient, client, logged_ def test_revoke_preconsented_client(testclient, client, logged_user, token, backend): client.preconsent = True - client.save() + backend.save(client) assert not backend.get(models.Consent) assert not token.revoked @@ -190,22 +192,22 @@ def test_revoke_invalid_preconsented_client(testclient, logged_user): def test_revoke_preconsented_client_with_manual_consent( - testclient, logged_user, client, consent + testclient, logged_user, client, consent, backend ): client.preconsent = True - client.save() + backend.save(client) res = testclient.get(f"/consent/revoke-preconsent/{client.client_id}", status=302) res = res.follow() assert ("success", "The access has been revoked") in res.flashes def test_revoke_preconsented_client_with_manual_revokation( - testclient, logged_user, client, consent + testclient, logged_user, client, consent, backend ): client.preconsent = True - client.save() + backend.save(client) consent.revoke() - consent.save() + backend.save(consent) res = testclient.get(f"/consent/revoke-preconsent/{client.client_id}", status=302) res = res.follow() diff --git a/tests/oidc/test_dynamic_client_registration_management.py b/tests/oidc/test_dynamic_client_registration_management.py index 3e6bd3f5..71934db5 100644 --- a/tests/oidc/test_dynamic_client_registration_management.py +++ b/tests/oidc/test_dynamic_client_registration_management.py @@ -146,7 +146,7 @@ def test_delete(testclient, backend, user): ] client = models.Client(client_id="foobar", client_name="Some client") - client.save() + backend.save(client) headers = {"Authorization": "Bearer static-token"} with warnings.catch_warnings(record=True): diff --git a/tests/oidc/test_implicit_flow.py b/tests/oidc/test_implicit_flow.py index ba1a08ab..dd3f5aba 100644 --- a/tests/oidc/test_implicit_flow.py +++ b/tests/oidc/test_implicit_flow.py @@ -10,7 +10,7 @@ def test_oauth_implicit(testclient, user, client, backend): client.grant_types = ["token"] client.token_endpoint_auth_method = "none" - client.save() + backend.save(client) res = testclient.get( "/oauth/authorize", @@ -48,14 +48,14 @@ def test_oauth_implicit(testclient, user, client, backend): client.grant_types = ["code"] client.token_endpoint_auth_method = "client_secret_basic" - client.save() + backend.save(client) def test_oidc_implicit(testclient, keypair, user, client, trusted_client, backend): client.grant_types = ["token id_token"] client.token_endpoint_auth_method = "none" - client.save() + backend.save(client) res = testclient.get( "/oauth/authorize", @@ -101,7 +101,7 @@ def test_oidc_implicit(testclient, keypair, user, client, trusted_client, backen client.grant_types = ["code"] client.token_endpoint_auth_method = "client_secret_basic" - client.save() + backend.save(client) def test_oidc_implicit_with_group( @@ -110,7 +110,7 @@ def test_oidc_implicit_with_group( client.grant_types = ["token id_token"] client.token_endpoint_auth_method = "none" - client.save() + backend.save(client) res = testclient.get( "/oauth/authorize", @@ -157,4 +157,4 @@ def test_oidc_implicit_with_group( client.grant_types = ["code"] client.token_endpoint_auth_method = "client_secret_basic" - client.save() + backend.save(client) diff --git a/tests/oidc/test_password_flow.py b/tests/oidc/test_password_flow.py index daaad5d3..725f1503 100644 --- a/tests/oidc/test_password_flow.py +++ b/tests/oidc/test_password_flow.py @@ -33,7 +33,7 @@ def test_password_flow_basic(testclient, user, client, backend): def test_password_flow_post(testclient, user, client, backend): client.token_endpoint_auth_method = "client_secret_post" - client.save() + backend.save(client) res = testclient.post( "/oauth/token", diff --git a/tests/oidc/test_refresh_token.py b/tests/oidc/test_refresh_token.py index 7a99bcd7..a0516a5d 100644 --- a/tests/oidc/test_refresh_token.py +++ b/tests/oidc/test_refresh_token.py @@ -82,7 +82,7 @@ def test_refresh_token_with_invalid_user(testclient, client, backend): emails=["temp@temp.com"], password="correct horse battery staple", ) - user.save() + backend.save(user) res = testclient.get( "/oauth/authorize", @@ -137,7 +137,9 @@ def test_refresh_token_with_invalid_user(testclient, client, backend): backend.get(models.Token, access_token=access_token).delete() -def test_cannot_refresh_token_for_locked_users(testclient, logged_user, client): +def test_cannot_refresh_token_for_locked_users( + testclient, logged_user, client, backend +): """Canaille should not issue new tokens for locked users.""" res = testclient.get( "/oauth/authorize", @@ -167,7 +169,7 @@ def test_cannot_refresh_token_for_locked_users(testclient, logged_user, client): ) logged_user.lock_date = datetime.datetime.now(datetime.timezone.utc) - logged_user.save() + backend.save(logged_user) res = testclient.post( "/oauth/token", diff --git a/tests/oidc/test_token_admin.py b/tests/oidc/test_token_admin.py index 7eca4e9c..24136136 100644 --- a/tests/oidc/test_token_admin.py +++ b/tests/oidc/test_token_admin.py @@ -18,7 +18,7 @@ def test_token_list(testclient, token, logged_admin): res.mustcontain(token.token_id) -def test_token_list_pagination(testclient, logged_admin, client): +def test_token_list_pagination(testclient, logged_admin, client, backend): res = testclient.get("/admin/token") res.mustcontain("0 items") tokens = [] @@ -36,7 +36,7 @@ def test_token_list_pagination(testclient, logged_admin, client): ), lifetime=3600, ) - token.save() + backend.save(token) tokens.append(token) res = testclient.get("/admin/token") @@ -74,7 +74,7 @@ def test_token_list_bad_pages(testclient, logged_admin): ) -def test_token_list_search(testclient, logged_admin, client): +def test_token_list_search(testclient, logged_admin, client, backend): token1 = models.Token( token_id=gen_salt(48), access_token="this-token-is-ok", @@ -88,7 +88,7 @@ def test_token_list_search(testclient, logged_admin, client): ), lifetime=3600, ) - token1.save() + backend.save(token1) token2 = models.Token( token_id=gen_salt(48), access_token="this-token-is-valid", @@ -102,7 +102,7 @@ def test_token_list_search(testclient, logged_admin, client): ), lifetime=3600, ) - token2.save() + backend.save(token2) res = testclient.get("/admin/token") res.mustcontain("2 items") diff --git a/tests/oidc/test_token_revocation.py b/tests/oidc/test_token_revocation.py index fe119f66..193613ae 100644 --- a/tests/oidc/test_token_revocation.py +++ b/tests/oidc/test_token_revocation.py @@ -63,11 +63,11 @@ def test_revoke_refresh_token_with_hint(testclient, user, client, token): assert token.revokation_date -def test_cannot_refresh_after_revocation(testclient, user, client, token): +def test_cannot_refresh_after_revocation(testclient, user, client, token, backend): token.revokation_date = datetime.datetime.now( datetime.timezone.utc ) - datetime.timedelta(days=7) - token.save() + backend.save(token) res = testclient.post( "/oauth/token", diff --git a/tests/oidc/test_userinfo.py b/tests/oidc/test_userinfo.py index f7744a89..5a3cfca1 100644 --- a/tests/oidc/test_userinfo.py +++ b/tests/oidc/test_userinfo.py @@ -146,9 +146,9 @@ def test_generate_user_claims(user, foo_group): } -def test_userinfo(testclient, token, user, foo_group): +def test_userinfo(testclient, token, user, foo_group, backend): token.scope = ["openid"] - token.save() + backend.save(token) testclient.get( "/oauth/userinfo", headers={"Authorization": f"Bearer {token.access_token}"}, @@ -156,7 +156,7 @@ def test_userinfo(testclient, token, user, foo_group): ) token.scope = ["openid", "profile"] - token.save() + backend.save(token) res = testclient.get( "/oauth/userinfo", headers={"Authorization": f"Bearer {token.access_token}"}, @@ -172,7 +172,7 @@ def test_userinfo(testclient, token, user, foo_group): } token.scope = ["openid", "profile", "email"] - token.save() + backend.save(token) res = testclient.get( "/oauth/userinfo", headers={"Authorization": f"Bearer {token.access_token}"}, @@ -189,7 +189,7 @@ def test_userinfo(testclient, token, user, foo_group): } token.scope = ["openid", "profile", "address"] - token.save() + backend.save(token) res = testclient.get( "/oauth/userinfo", headers={"Authorization": f"Bearer {token.access_token}"}, @@ -206,7 +206,7 @@ def test_userinfo(testclient, token, user, foo_group): } token.scope = ["openid", "profile", "phone"] - token.save() + backend.save(token) res = testclient.get( "/oauth/userinfo", headers={"Authorization": f"Bearer {token.access_token}"}, @@ -223,7 +223,7 @@ def test_userinfo(testclient, token, user, foo_group): } token.scope = ["openid", "profile", "groups"] - token.save() + backend.save(token) res = testclient.get( "/oauth/userinfo", headers={"Authorization": f"Bearer {token.access_token}"}, @@ -296,7 +296,7 @@ def test_claim_is_omitted_if_empty(testclient, backend, user): # According to https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse # it's better to not insert a null or empty string value user.emails = [] - user.save() + backend.save(user) default_jwt_mapping = JWTSettings().model_dump() data = generate_user_claims(user, STANDARD_CLAIMS, default_jwt_mapping)