diff --git a/canaille/backends/__init__.py b/canaille/backends/__init__.py index 90351a44..a03f4416 100644 --- a/canaille/backends/__init__.py +++ b/canaille/backends/__init__.py @@ -55,6 +55,14 @@ class BaseBackend: """ raise NotImplementedError() + def check_user_password(self, user, password: str) -> bool: + """Checks if the password matches the user password in the database.""" + raise NotImplementedError() + + def set_user_password(self, user, password: str): + """Sets a password for the user.""" + raise NotImplementedError() + def has_account_lockability(self): """Indicates wether the backend supports locking user accounts.""" raise NotImplementedError() diff --git a/canaille/backends/ldap/backend.py b/canaille/backends/ldap/backend.py index 554ec935..b25050ae 100644 --- a/canaille/backends/ldap/backend.py +++ b/canaille/backends/ldap/backend.py @@ -6,6 +6,9 @@ from contextlib import contextmanager import ldap.modlist import ldif from flask import current_app +from ldap.controls import DecodeControlTuples +from ldap.controls.ppolicy import PasswordPolicyControl +from ldap.controls.ppolicy import PasswordPolicyError from canaille.app import models from canaille.app.configuration import ConfigurationException @@ -200,6 +203,53 @@ class Backend(BaseBackend): ) return User.get(filter=filter) + def check_user_password(self, user, password): + conn = ldap.initialize(current_app.config["CANAILLE_LDAP"]["URI"]) + + conn.set_option( + ldap.OPT_NETWORK_TIMEOUT, + current_app.config["CANAILLE_LDAP"]["TIMEOUT"], + ) + + message = None + try: + res = conn.simple_bind_s( + user.dn, password, serverctrls=[PasswordPolicyControl()] + ) + controls = res[3] + result = True + except ldap.INVALID_CREDENTIALS as exc: + controls = DecodeControlTuples(exc.args[0]["ctrls"]) + result = False + finally: + conn.unbind_s() + + for control in controls: + + def gettext(x): + return x + + if ( + control.controlType == PasswordPolicyControl.controlType + and control.error == PasswordPolicyError.namedValues["accountLocked"] + ): + message = gettext("Your account has been locked.") + elif ( + control.controlType == PasswordPolicyControl.controlType + and control.error == PasswordPolicyError.namedValues["changeAfterReset"] + ): + message = gettext("You should change your password.") + + return result, message + + def set_user_password(self, user, password): + conn = Backend.get().connection + conn.passwd_s( + user.dn, + None, + password.encode("utf-8"), + ) + def setup_ldap_models(config): from canaille.app import models diff --git a/canaille/backends/ldap/models.py b/canaille/backends/ldap/models.py index acf71154..c6771bc1 100644 --- a/canaille/backends/ldap/models.py +++ b/canaille/backends/ldap/models.py @@ -1,8 +1,4 @@ import ldap.filter -from flask import current_app -from ldap.controls import DecodeControlTuples -from ldap.controls.ppolicy import PasswordPolicyControl -from ldap.controls.ppolicy import PasswordPolicyError import canaille.core.models import canaille.oidc.models @@ -51,53 +47,6 @@ class User(canaille.core.models.User, LDAPObject): def identifier(self): return self.rdn_value - def check_password(self, password): - conn = ldap.initialize(current_app.config["CANAILLE_LDAP"]["URI"]) - - conn.set_option( - ldap.OPT_NETWORK_TIMEOUT, - current_app.config["CANAILLE_LDAP"]["TIMEOUT"], - ) - - message = None - try: - res = conn.simple_bind_s( - self.dn, password, serverctrls=[PasswordPolicyControl()] - ) - controls = res[3] - result = True - except ldap.INVALID_CREDENTIALS as exc: - controls = DecodeControlTuples(exc.args[0]["ctrls"]) - result = False - finally: - conn.unbind_s() - - for control in controls: - - def gettext(x): - return x - - if ( - control.controlType == PasswordPolicyControl.controlType - and control.error == PasswordPolicyError.namedValues["accountLocked"] - ): - message = gettext("Your account has been locked.") - elif ( - control.controlType == PasswordPolicyControl.controlType - and control.error == PasswordPolicyError.namedValues["changeAfterReset"] - ): - message = gettext("You should change your password.") - - return result, message - - def set_password(self, password): - conn = Backend.get().connection - conn.passwd_s( - self.dn, - None, - password.encode("utf-8"), - ) - def save(self, *args, **kwargs): group_attr = self.python_attribute_to_ldap("groups") new_groups = self.changes.get(group_attr) diff --git a/canaille/backends/memory/backend.py b/canaille/backends/memory/backend.py index 59e8e669..c0662094 100644 --- a/canaille/backends/memory/backend.py +++ b/canaille/backends/memory/backend.py @@ -27,3 +27,16 @@ class Backend(BaseBackend): from .models import User return User.get(user_name=login) + + def check_user_password(self, user, password): + if password != user.password: + return (False, None) + + if user.locked: + return (False, "Your account has been locked.") + + return (True, None) + + def set_user_password(self, user, password): + user.password = password + user.save() diff --git a/canaille/backends/memory/models.py b/canaille/backends/memory/models.py index 711872d1..ab63a2f1 100644 --- a/canaille/backends/memory/models.py +++ b/canaille/backends/memory/models.py @@ -246,19 +246,6 @@ class User(canaille.core.models.User, MemoryModel): "groups": ("Group", "members"), } - def check_password(self, password): - if password != self.password: - return (False, None) - - if self.locked: - return (False, "Your account has been locked.") - - return (True, None) - - def set_password(self, password): - self.password = password - self.save() - class Group(canaille.core.models.Group, MemoryModel): model_attributes: ClassVar[Dict[str, str]] = { diff --git a/canaille/backends/sql/backend.py b/canaille/backends/sql/backend.py index 91001d18..adf9e415 100644 --- a/canaille/backends/sql/backend.py +++ b/canaille/backends/sql/backend.py @@ -52,3 +52,16 @@ class Backend(BaseBackend): from .models import User return User.get(user_name=login) + + def check_user_password(self, user, password): + if password != user.password: + return (False, None) + + if user.locked: + return (False, "Your account has been locked.") + + return (True, None) + + def set_user_password(self, user, password): + user.password = password + user.save() diff --git a/canaille/backends/sql/models.py b/canaille/backends/sql/models.py index 2feeb678..06bbcc76 100644 --- a/canaille/backends/sql/models.py +++ b/canaille/backends/sql/models.py @@ -171,19 +171,6 @@ class User(canaille.core.models.User, Base, SqlAlchemyModel): TZDateTime(timezone=True), nullable=True ) - def check_password(self, password): - if password != self.password: - return (False, None) - - if self.locked: - return (False, "Your account has been locked.") - - return (True, None) - - def set_password(self, password): - self.password = password - self.save() - class Group(canaille.core.models.Group, Base, SqlAlchemyModel): __tablename__ = "group" diff --git a/canaille/core/endpoints/account.py b/canaille/core/endpoints/account.py index f33a8e85..0158ae7a 100644 --- a/canaille/core/endpoints/account.py +++ b/canaille/core/endpoints/account.py @@ -455,7 +455,7 @@ def profile_create(current_app, form): user.save() if form["password1"].data: - user.set_password(form["password1"].data) + BaseBackend.get().set_user_password(user, form["password1"].data) user.save() return user @@ -780,7 +780,7 @@ def profile_settings_edit(editor, edited_user): and form["password1"].data and request.form["action"] == "edit-settings" ): - edited_user.set_password(form["password1"].data) + BaseBackend.get().set_user_password(edited_user, form["password1"].data) edited_user.save() flash(_("Profile updated successfully."), "success") diff --git a/canaille/core/endpoints/auth.py b/canaille/core/endpoints/auth.py index 8cdfd742..1082f3a9 100644 --- a/canaille/core/endpoints/auth.py +++ b/canaille/core/endpoints/auth.py @@ -78,7 +78,7 @@ def password(): "password.html", form=form, username=session["attempt_login"] ) - success, message = user.check_password(form.password.data) + success, message = BaseBackend.get().check_user_password(user, form.password.data) if not success: logout_user() flash(message or _("Login failed, please check your information"), "error") @@ -210,7 +210,7 @@ def reset(user, hash): return redirect(url_for("core.account.index")) if request.form and form.validate(): - user.set_password(form.password.data) + BaseBackend.get().set_user_password(user, form.password.data) login_user(user) flash(_("Your password has been updated successfully"), "success") diff --git a/canaille/core/models.py b/canaille/core/models.py index 8155fe69..f1385f78 100644 --- a/canaille/core/models.py +++ b/canaille/core/models.py @@ -248,14 +248,6 @@ class User(Model): """Checks wether a password has been set for the user.""" return self.password is not None - def check_password(self, password: str) -> bool: - """Checks if the password matches the user password in the database.""" - raise NotImplementedError() - - def set_password(self, password: str): - """Sets a password for the user.""" - raise NotImplementedError() - def can_read(self, field: str): return field in self._readable_fields | self._writable_fields diff --git a/canaille/oidc/oauth.py b/canaille/oidc/oauth.py index 36b2c08f..180deb82 100644 --- a/canaille/oidc/oauth.py +++ b/canaille/oidc/oauth.py @@ -271,7 +271,7 @@ class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant): if not user: return None - success, _ = user.check_password(password) + success, _ = BaseBackend.get().check_user_password(user, password) if not success: return None diff --git a/tests/core/test_account.py b/tests/core/test_account.py index 1139d457..185648e7 100644 --- a/tests/core/test_account.py +++ b/tests/core/test_account.py @@ -372,14 +372,17 @@ def test_user_self_deletion(testclient, backend): def test_account_locking(user, backend): assert not user.locked assert not user.lock_date - assert user.check_password("correct horse battery staple") == (True, None) + assert backend.check_user_password(user, "correct horse battery staple") == ( + True, + None, + ) user.lock_date = datetime.datetime.now(datetime.timezone.utc) assert user.locked user.save() assert user.locked assert models.User.get(id=user.id).locked - assert user.check_password("correct horse battery staple") == ( + assert backend.check_user_password(user, "correct horse battery staple") == ( False, "Your account has been locked.", ) @@ -388,13 +391,19 @@ def test_account_locking(user, backend): user.save() assert not user.locked assert not models.User.get(id=user.id).locked - assert user.check_password("correct horse battery staple") == (True, None) + assert backend.check_user_password(user, "correct horse battery staple") == ( + True, + None, + ) def test_account_locking_past_date(user, backend): assert not user.locked assert not user.lock_date - assert user.check_password("correct horse battery staple") == (True, None) + assert backend.check_user_password(user, "correct horse battery staple") == ( + True, + None, + ) user.lock_date = datetime.datetime.now(datetime.timezone.utc).replace( microsecond=0 @@ -402,7 +411,7 @@ def test_account_locking_past_date(user, backend): user.save() assert user.locked assert models.User.get(id=user.id).locked - assert user.check_password("correct horse battery staple") == ( + assert backend.check_user_password(user, "correct horse battery staple") == ( False, "Your account has been locked.", ) @@ -411,7 +420,10 @@ def test_account_locking_past_date(user, backend): def test_account_locking_future_date(user, backend): assert not user.locked assert not user.lock_date - assert user.check_password("correct horse battery staple") == (True, None) + assert backend.check_user_password(user, "correct horse battery staple") == ( + True, + None, + ) user.lock_date = datetime.datetime.now(datetime.timezone.utc).replace( microsecond=0 @@ -419,7 +431,10 @@ def test_account_locking_future_date(user, backend): user.save() assert not user.locked assert not models.User.get(id=user.id).locked - assert user.check_password("correct horse battery staple") == (True, None) + assert backend.check_user_password(user, "correct horse battery staple") == ( + True, + None, + ) def test_signin_locked_account(testclient, user): diff --git a/tests/core/test_invitation.py b/tests/core/test_invitation.py index 8fcdfc03..0bc96033 100644 --- a/tests/core/test_invitation.py +++ b/tests/core/test_invitation.py @@ -6,7 +6,7 @@ from canaille.app import models from canaille.core.endpoints.account import RegistrationPayload -def test_invitation(testclient, logged_admin, foo_group, smtpd): +def test_invitation(testclient, logged_admin, foo_group, smtpd, backend): assert models.User.get(user_name="someone") is None res = testclient.get("/invite", status=200) @@ -48,7 +48,7 @@ def test_invitation(testclient, logged_admin, foo_group, smtpd): user = models.User.get(user_name="someone") foo_group.reload() - assert user.check_password("whatever")[0] + assert backend.check_user_password(user, "whatever")[0] assert user.groups == [foo_group] with testclient.session_transaction() as sess: @@ -59,7 +59,9 @@ def test_invitation(testclient, logged_admin, foo_group, smtpd): user.delete() -def test_invitation_editable_user_name(testclient, logged_admin, foo_group, smtpd): +def test_invitation_editable_user_name( + testclient, logged_admin, foo_group, smtpd, backend +): assert models.User.get(user_name="jackyjack") is None assert models.User.get(user_name="djorje") is None @@ -102,7 +104,7 @@ def test_invitation_editable_user_name(testclient, logged_admin, foo_group, smtp user = models.User.get(user_name="djorje") foo_group.reload() - assert user.check_password("whatever")[0] + assert backend.check_user_password(user, "whatever")[0] assert user.groups == [foo_group] with testclient.session_transaction() as sess: @@ -111,7 +113,7 @@ def test_invitation_editable_user_name(testclient, logged_admin, foo_group, smtp user.delete() -def test_generate_link(testclient, logged_admin, foo_group, smtpd): +def test_generate_link(testclient, logged_admin, foo_group, smtpd, backend): assert models.User.get(user_name="sometwo") is None res = testclient.get("/invite", status=200) @@ -149,7 +151,7 @@ def test_generate_link(testclient, logged_admin, foo_group, smtpd): user = models.User.get(user_name="sometwo") foo_group.reload() - assert user.check_password("whatever")[0] + assert backend.check_user_password(user, "whatever")[0] assert user.groups == [foo_group] with testclient.session_transaction() as sess: diff --git a/tests/core/test_models.py b/tests/core/test_models.py index 10af8577..9d352164 100644 --- a/tests/core/test_models.py +++ b/tests/core/test_models.py @@ -1,20 +1,4 @@ -import pytest - from canaille.app import models -from canaille.core.models import Group -from canaille.core.models import User - - -def test_required_methods(testclient, backend): - user = User() - - with pytest.raises(NotImplementedError): - user.check_password("password") - - with pytest.raises(NotImplementedError): - user.set_password("password") - - Group() def test_user_get_user_from_login(testclient, user, backend): @@ -44,10 +28,10 @@ def test_user_has_password(testclient, backend): def test_user_set_and_check_password(testclient, user, backend): - assert not user.check_password("something else")[0] - assert user.check_password("correct horse battery staple")[0] + assert not backend.check_user_password(user, "something else")[0] + assert backend.check_user_password(user, "correct horse battery staple")[0] - user.set_password("something else") + backend.set_user_password(user, "something else") - assert user.check_password("something else")[0] - assert not user.check_password("correct horse battery staple")[0] + assert backend.check_user_password(user, "something else")[0] + assert not backend.check_user_password(user, "correct horse battery staple")[0] diff --git a/tests/core/test_password_reset.py b/tests/core/test_password_reset.py index 142d9dbd..3249a31d 100644 --- a/tests/core/test_password_reset.py +++ b/tests/core/test_password_reset.py @@ -1,8 +1,8 @@ from canaille.core.endpoints.account import build_hash -def test_password_reset(testclient, user): - assert not user.check_password("foobarbaz")[0] +def test_password_reset(testclient, user, backend): + assert not backend.check_user_password(user, "foobarbaz")[0] hash = build_hash("user", user.preferred_email, user.password) res = testclient.get("/reset/user/" + hash, status=200) @@ -14,7 +14,7 @@ def test_password_reset(testclient, user): assert res.location == "/profile/user" user.reload() - assert user.check_password("foobarbaz")[0] + assert backend.check_user_password(user, "foobarbaz")[0] res = testclient.get("/reset/user/" + hash) assert ( @@ -23,11 +23,11 @@ def test_password_reset(testclient, user): ) in res.flashes -def test_password_reset_multiple_emails(testclient, user): +def test_password_reset_multiple_emails(testclient, user, backend): user.emails = ["foo@bar.com", "foo@baz.com"] user.save() - assert not user.check_password("foobarbaz")[0] + assert not backend.check_user_password(user, "foobarbaz")[0] hash = build_hash("user", "foo@baz.com", user.password) res = testclient.get("/reset/user/" + hash, status=200) @@ -38,7 +38,7 @@ def test_password_reset_multiple_emails(testclient, user): assert ("success", "Your password has been updated successfully") in res.flashes user.reload() - assert user.check_password("foobarbaz")[0] + assert backend.check_user_password(user, "foobarbaz")[0] res = testclient.get("/reset/user/" + hash) assert ( @@ -55,7 +55,7 @@ def test_password_reset_bad_link(testclient, user): ) in res.flashes -def test_password_reset_bad_password(testclient, user): +def test_password_reset_bad_password(testclient, user, backend): hash = build_hash("user", user.preferred_email, user.password) res = testclient.get("/reset/user/" + hash, status=200) @@ -64,7 +64,7 @@ def test_password_reset_bad_password(testclient, user): res.form["confirmation"] = "typo" res = res.form.submit(status=200) - assert user.check_password("correct horse battery staple")[0] + assert backend.check_user_password(user, "correct horse battery staple")[0] def test_unavailable_if_no_smtp(testclient, user): diff --git a/tests/core/test_profile_creation.py b/tests/core/test_profile_creation.py index acde32ea..5bb28f1d 100644 --- a/tests/core/test_profile_creation.py +++ b/tests/core/test_profile_creation.py @@ -2,7 +2,7 @@ from canaille.app import models def test_user_creation_edition_and_deletion( - testclient, logged_moderator, foo_group, bar_group + testclient, logged_moderator, foo_group, bar_group, backend ): # The user does not exist. res = testclient.get("/users", status=200) @@ -28,7 +28,7 @@ def test_user_creation_edition_and_deletion( foo_group.reload() assert "George" == george.given_name assert george.groups == [foo_group] - assert george.check_password("totoyolo")[0] + assert backend.check_user_password(george, "totoyolo")[0] res = testclient.get("/users", status=200) res.mustcontain("george") @@ -47,7 +47,7 @@ def test_user_creation_edition_and_deletion( george = models.User.get(user_name="george") assert "Georgio" == george.given_name - assert george.check_password("totoyolo")[0] + assert backend.check_user_password(george, "totoyolo")[0] foo_group.reload() bar_group.reload() diff --git a/tests/core/test_profile_settings.py b/tests/core/test_profile_settings.py index 766e10f9..e8e3ed99 100644 --- a/tests/core/test_profile_settings.py +++ b/tests/core/test_profile_settings.py @@ -6,13 +6,7 @@ from flask import g from canaille.app import models -def test_edition( - testclient, - logged_user, - admin, - foo_group, - bar_group, -): +def test_edition(testclient, logged_user, admin, foo_group, bar_group, backend): res = testclient.get("/profile/user/settings", status=200) assert set(res.form["groups"].options) == { (foo_group.id, True, "foo"), @@ -37,7 +31,7 @@ def test_edition( assert foo_group.members == [logged_user] assert bar_group.members == [admin] - assert logged_user.check_password("correct horse battery staple")[0] + assert backend.check_user_password(logged_user, "correct horse battery staple")[0] logged_user.user_name = "user" logged_user.save() @@ -63,6 +57,7 @@ def test_edition_without_groups( testclient, logged_user, admin, + backend, ): res = testclient.get("/profile/user/settings", status=200) testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["READ"] = [] @@ -74,13 +69,13 @@ def test_edition_without_groups( logged_user.reload() assert logged_user.user_name == "user" - assert logged_user.check_password("correct horse battery staple")[0] + assert backend.check_user_password(logged_user, "correct horse battery staple")[0] logged_user.user_name = "user" logged_user.save() -def test_password_change(testclient, logged_user): +def test_password_change(testclient, logged_user, backend): res = testclient.get("/profile/user/settings", status=200) res.form["password1"] = "new_password" @@ -89,7 +84,7 @@ def test_password_change(testclient, logged_user): res = res.form.submit(name="action", value="edit-settings").follow() logged_user.reload() - assert logged_user.check_password("new_password")[0] + assert backend.check_user_password(logged_user, "new_password")[0] res = testclient.get("/profile/user/settings", status=200) @@ -101,10 +96,10 @@ def test_password_change(testclient, logged_user): res = res.follow() logged_user.reload() - assert logged_user.check_password("correct horse battery staple")[0] + assert backend.check_user_password(logged_user, "correct horse battery staple")[0] -def test_password_change_fail(testclient, logged_user): +def test_password_change_fail(testclient, logged_user, backend): res = testclient.get("/profile/user/settings", status=200) res.form["password1"] = "new_password" @@ -113,7 +108,7 @@ def test_password_change_fail(testclient, logged_user): res = res.form.submit(name="action", value="edit-settings", status=200) logged_user.reload() - assert logged_user.check_password("correct horse battery staple")[0] + assert backend.check_user_password(logged_user, "correct horse battery staple")[0] res = testclient.get("/profile/user/settings", status=200) @@ -123,7 +118,7 @@ def test_password_change_fail(testclient, logged_user): res = res.form.submit(name="action", value="edit-settings", status=200) logged_user.reload() - assert logged_user.check_password("correct horse battery staple")[0] + assert backend.check_user_password(logged_user, "correct horse battery staple")[0] def test_password_initialization_mail(smtpd, testclient, backend, logged_admin):