refactor: move User.get_from_login method to Backend

This commit is contained in:
Éloi Rivard 2024-04-07 19:56:52 +02:00
parent 2082e19480
commit 5a6ce24074
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184
11 changed files with 36 additions and 39 deletions

View file

@ -185,6 +185,21 @@ class Backend(BaseBackend):
except ldap.SERVER_DOWN: # pragma: no cover except ldap.SERVER_DOWN: # pragma: no cover
return False return False
def get_user_from_login(self, login=None):
from .models import User
raw_filter = current_app.config["CANAILLE_LDAP"]["USER_FILTER"]
filter = (
(
current_app.jinja_env.from_string(raw_filter).render(
login=ldap.filter.escape_filter_chars(login)
)
)
if login
else None
)
return User.get(filter=filter)
def setup_ldap_models(config): def setup_ldap_models(config):
from canaille.app import models from canaille.app import models

View file

@ -40,20 +40,6 @@ class User(canaille.core.models.User, LDAPObject):
"lock_date": "pwdEndTime", "lock_date": "pwdEndTime",
} }
@classmethod
def get_from_login(cls, login=None, **kwargs):
raw_filter = current_app.config["CANAILLE_LDAP"]["USER_FILTER"]
filter = (
(
current_app.jinja_env.from_string(raw_filter).render(
login=ldap.filter.escape_filter_chars(login)
)
)
if login
else None
)
return cls.get(filter=filter, **kwargs)
def match_filter(self, filter): def match_filter(self, filter):
if isinstance(filter, str): if isinstance(filter, str):
conn = Backend.get().connection conn = Backend.get().connection

View file

@ -22,3 +22,8 @@ class Backend(BaseBackend):
def has_account_lockability(self): def has_account_lockability(self):
return True return True
def get_user_from_login(self, login):
from .models import User
return User.get(user_name=login)

View file

@ -246,10 +246,6 @@ class User(canaille.core.models.User, MemoryModel):
"groups": ("Group", "members"), "groups": ("Group", "members"),
} }
@classmethod
def get_from_login(cls, login=None, **kwargs):
return User.get(user_name=login)
def check_password(self, password): def check_password(self, password):
if password != self.password: if password != self.password:
return (False, None) return (False, None)

View file

@ -47,3 +47,8 @@ class Backend(BaseBackend):
def has_account_lockability(self): def has_account_lockability(self):
return True return True
def get_user_from_login(self, login):
from .models import User
return User.get(user_name=login)

View file

@ -171,10 +171,6 @@ class User(canaille.core.models.User, Base, SqlAlchemyModel):
TZDateTime(timezone=True), nullable=True TZDateTime(timezone=True), nullable=True
) )
@classmethod
def get_from_login(cls, login=None, **kwargs):
return User.get(user_name=login)
def check_password(self, password): def check_password(self, password):
if password != self.password: if password != self.password:
return (False, None) return (False, None)

View file

@ -8,7 +8,6 @@ from flask import session
from flask import url_for from flask import url_for
from canaille.app import build_hash from canaille.app import build_hash
from canaille.app import models
from canaille.app.flask import current_user from canaille.app.flask import current_user
from canaille.app.flask import login_user from canaille.app.flask import login_user
from canaille.app.flask import logout_user from canaille.app.flask import logout_user
@ -42,7 +41,7 @@ def login():
if not request.form or form.form_control(): if not request.form or form.form_control():
return render_template("login.html", form=form) return render_template("login.html", form=form)
user = models.User.get_from_login(form.login.data) user = BaseBackend.get().get_user_from_login(form.login.data)
if user and not user.has_password(): if user and not user.has_password():
return redirect(url_for("core.auth.firstlogin", user=user)) return redirect(url_for("core.auth.firstlogin", user=user))
@ -68,7 +67,7 @@ def password():
"password.html", form=form, username=session["attempt_login"] "password.html", form=form, username=session["attempt_login"]
) )
user = models.User.get_from_login(session["attempt_login"]) user = BaseBackend.get().get_user_from_login(session["attempt_login"])
if user and not user.has_password(): if user and not user.has_password():
return redirect(url_for("core.auth.firstlogin", user=user)) return redirect(url_for("core.auth.firstlogin", user=user))
@ -153,7 +152,7 @@ def forgotten():
flash(_("Could not send the password reset link."), "error") flash(_("Could not send the password reset link."), "error")
return render_template("forgotten-password.html", form=form) return render_template("forgotten-password.html", form=form)
user = models.User.get_from_login(form.login.data) user = BaseBackend.get().get_user_from_login(form.login.data)
success_message = _( success_message = _(
"A password reset link has been sent at your email address. " "A password reset link has been sent at your email address. "
"You should receive it within a few minutes." "You should receive it within a few minutes."

View file

@ -17,6 +17,7 @@ from canaille.app.forms import set_readonly
from canaille.app.forms import unique_values from canaille.app.forms import unique_values
from canaille.app.i18n import lazy_gettext as _ from canaille.app.i18n import lazy_gettext as _
from canaille.app.i18n import native_language_name_from_code from canaille.app.i18n import native_language_name_from_code
from canaille.backends import BaseBackend
MINIMUM_PASSWORD_LENGTH = 8 MINIMUM_PASSWORD_LENGTH = 8
@ -49,7 +50,7 @@ def unique_group(form, field):
def existing_login(form, field): def existing_login(form, field):
if not current_app.config["CANAILLE"][ if not current_app.config["CANAILLE"][
"HIDE_INVALID_LOGINS" "HIDE_INVALID_LOGINS"
] and not models.User.get_from_login(field.data): ] and not BaseBackend.get().get_user_from_login(field.data):
raise wtforms.ValidationError( raise wtforms.ValidationError(
_("The login '{login}' does not exist").format(login=field.data) _("The login '{login}' does not exist").format(login=field.data)
) )

View file

@ -244,10 +244,6 @@ class User(Model):
_writable_fields = None _writable_fields = None
_permissions = None _permissions = None
@classmethod
def get_from_login(cls, login=None, **kwargs) -> Optional["User"]:
raise NotImplementedError()
def has_password(self) -> bool: def has_password(self) -> bool:
"""Checks wether a password has been set for the user.""" """Checks wether a password has been set for the user."""
return self.password is not None return self.password is not None

View file

@ -34,6 +34,7 @@ from flask import url_for
from werkzeug.security import gen_salt from werkzeug.security import gen_salt
from canaille.app import models from canaille.app import models
from canaille.backends import BaseBackend
AUTHORIZATION_CODE_LIFETIME = 84400 AUTHORIZATION_CODE_LIFETIME = 84400
@ -266,7 +267,7 @@ class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant):
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"] TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
def authenticate_user(self, username, password): def authenticate_user(self, username, password):
user = models.User.get_from_login(username) user = BaseBackend.get().get_user_from_login(username)
if not user: if not user:
return None return None

View file

@ -5,10 +5,7 @@ from canaille.core.models import Group
from canaille.core.models import User from canaille.core.models import User
def test_required_methods(testclient): def test_required_methods(testclient, backend):
with pytest.raises(NotImplementedError):
User.get_from_login()
user = User() user = User()
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
@ -20,9 +17,9 @@ def test_required_methods(testclient):
Group() Group()
def test_user_get_from_login(testclient, user, backend): def test_user_get_user_from_login(testclient, user, backend):
assert models.User.get_from_login(login="invalid") is None assert backend.get_user_from_login(login="invalid") is None
assert models.User.get_from_login(login="user") == user assert backend.get_user_from_login(login="user") == user
def test_user_has_password(testclient, backend): def test_user_has_password(testclient, backend):