refactor: BackendModel.get() is now Backend.instance

This commit is contained in:
Éloi Rivard 2024-04-10 15:59:25 +02:00
parent fa6488bcd1
commit ccde88b1bf
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184
21 changed files with 71 additions and 66 deletions

View file

@ -99,7 +99,7 @@ def setup_flask(app):
"has_oidc": "CANAILLE_OIDC" in app.config,
"has_password_recovery": app.config["CANAILLE"]["ENABLE_PASSWORD_RECOVERY"],
"has_registration": app.config["CANAILLE"]["ENABLE_REGISTRATION"],
"has_account_lockability": app.backend.get().has_account_lockability(),
"has_account_lockability": app.backend.instance.has_account_lockability(),
"logo_url": app.config["CANAILLE"]["LOGO"],
"favicon_url": app.config["CANAILLE"]["FAVICON"]
or app.config["CANAILLE"]["LOGO"],

View file

@ -12,7 +12,7 @@ def with_backendcontext(func):
@functools.wraps(func)
def _func(*args, **kwargs):
if not current_app.config["TESTING"]: # pragma: no cover
with BaseBackend.get().session():
with BaseBackend.instance.session():
result = func(*args, **kwargs)
else:

View file

@ -162,7 +162,7 @@ def validate(config, validate_remote=False):
from canaille.backends import BaseBackend
BaseBackend.get().validate(config)
BaseBackend.instance.validate(config)
validate_smtp_configuration(config["CANAILLE"]["SMTP"])

View file

@ -187,9 +187,11 @@ class TableForm(I18NFormMixin, FlaskForm):
filter = filter or {}
super().__init__(**kwargs)
if self.query.data:
self.items = BaseBackend.get().fuzzy(cls, self.query.data, fields, **filter)
self.items = BaseBackend.instance.fuzzy(
cls, self.query.data, fields, **filter
)
else:
self.items = BaseBackend.get().query(cls, **filter)
self.items = BaseBackend.instance.query(cls, **filter)
self.page_size = page_size
self.nb_items = len(self.items)

View file

@ -8,4 +8,4 @@ class InstallationException(Exception):
def install(config, debug=False):
install_oidc(config, debug=debug)
BaseBackend.get().install(config)
BaseBackend.instance.install(config)

View file

@ -4,6 +4,8 @@ from contextlib import contextmanager
from flask import g
from canaille.app import classproperty
class BaseBackend:
instance = None
@ -13,8 +15,8 @@ class BaseBackend:
BaseBackend.instance = self
self.register_models()
@classmethod
def get(cls):
@classproperty
def instance(cls):
return cls.instance
def init_app(self, app):

View file

@ -162,7 +162,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
@classmethod
def install(cls):
conn = Backend.get().connection
conn = Backend.instance.connection
cls.ldap_object_classes(conn)
cls.ldap_object_attributes(conn)
@ -187,7 +187,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
if cls._object_class_by_name and not force:
return cls._object_class_by_name
conn = Backend.get().connection
conn = Backend.instance.connection
res = conn.search_s(
"cn=subschema", ldap.SCOPE_BASE, "(objectclass=*)", ["*", "+"]
@ -209,7 +209,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
if cls._attribute_type_by_name and not force:
return cls._attribute_type_by_name
conn = Backend.get().connection
conn = Backend.instance.connection
res = conn.search_s(
"cn=subschema", ldap.SCOPE_BASE, "(objectclass=*)", ["*", "+"]
@ -229,7 +229,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
@classmethod
def get(cls, identifier=None, /, **kwargs):
try:
return BaseBackend.get().query(cls, identifier, **kwargs)[0]
return BaseBackend.instance.query(cls, identifier, **kwargs)[0]
except (IndexError, ldap.NO_SUCH_OBJECT):
if identifier and cls.base:
return (
@ -274,13 +274,13 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
return cls.attribute_map.get(name, name) if cls.attribute_map else None
def reload(self):
conn = Backend.get().connection
conn = Backend.instance.connection
result = conn.search_s(self.dn, ldap.SCOPE_SUBTREE, None, ["+", "*"])
self.changes = {}
self.state = result[0][1]
def save(self):
conn = Backend.get().connection
conn = Backend.instance.connection
current_object_classes = self.get_ldap_attribute("objectClass") or []
self.set_ldap_attribute(
@ -338,7 +338,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
self.changes = {}
def delete(self):
conn = Backend.get().connection
conn = Backend.instance.connection
try:
conn.delete_s(self.dn)
except ldap.NO_SUCH_OBJECT:

View file

@ -38,7 +38,7 @@ class User(canaille.core.models.User, LDAPObject):
def match_filter(self, filter):
if isinstance(filter, str):
conn = Backend.get().connection
conn = Backend.instance.connection
return self.dn and conn.search_s(self.dn, ldap.SCOPE_SUBTREE, filter)
return super().match_filter(filter)

View file

@ -45,7 +45,7 @@ class MemoryModel(BackendModel):
or None
)
results = BaseBackend.get().query(cls, **kwargs)
results = BaseBackend.instance.query(cls, **kwargs)
return results[0] if results else None
@classmethod

View file

@ -74,8 +74,7 @@ class Backend(BaseBackend):
for attribute_name, expected_value in kwargs.items()
]
return (
Backend.get()
.db_session.execute(select(model).filter(*filter))
Backend.instance.db_session.execute(select(model).filter(*filter))
.scalars()
.all()
)

View file

@ -61,11 +61,9 @@ class SqlAlchemyModel(BackendModel):
cls.attribute_filter(attribute_name, expected_value)
for attribute_name, expected_value in kwargs.items()
]
return (
Backend.get()
.db_session.execute(select(cls).filter(*filter))
.scalar_one_or_none()
)
return Backend.instance.db_session.execute(
select(cls).filter(*filter)
).scalar_one_or_none()
def save(self):
self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
@ -74,15 +72,15 @@ class SqlAlchemyModel(BackendModel):
if not self.created:
self.created = self.last_modified
Backend.get().db_session.add(self)
Backend.get().db_session.commit()
Backend.instance.db_session.add(self)
Backend.instance.db_session.commit()
def delete(self):
Backend.get().db_session.delete(self)
Backend.get().db_session.commit()
Backend.instance.db_session.delete(self)
Backend.instance.db_session.commit()
def reload(self):
Backend.get().db_session.refresh(self)
Backend.instance.db_session.refresh(self)
membership_association_table = Table(

View file

@ -86,7 +86,7 @@ def join():
form = JoinForm(request.form or None)
if request.form and form.validate():
if BaseBackend.get().query(models.User, emails=form.email.data):
if BaseBackend.instance.query(models.User, emails=form.email.data):
flash(
_(
"You will receive soon an email to continue the registration process."
@ -297,7 +297,7 @@ def registration(data=None, hash=None):
_("Groups"),
choices=[
(group, group.display_name)
for group in BaseBackend.get().query(models.Group)
for group in BaseBackend.instance.query(models.Group)
],
coerce=IDToModel("Group"),
)
@ -391,7 +391,7 @@ def email_confirmation(data, hash):
)
return redirect(url_for("core.account.index"))
if BaseBackend.get().query(models.User, emails=confirmation_obj.email):
if BaseBackend.instance.query(models.User, emails=confirmation_obj.email):
flash(
_("This address email is already associated with another account."),
"error",
@ -458,7 +458,7 @@ def profile_create(current_app, form):
user.save()
if form["password1"].data:
BaseBackend.get().set_user_password(user, form["password1"].data)
BaseBackend.instance.set_user_password(user, form["password1"].data)
user.save()
return user
@ -713,14 +713,14 @@ def profile_settings(user, edited_user):
if (
request.form.get("action") == "confirm-lock"
and BaseBackend.get().has_account_lockability()
and BaseBackend.instance.has_account_lockability()
and not edited_user.locked
):
return render_template("modals/lock-account.html", edited_user=edited_user)
if (
request.form.get("action") == "lock"
and BaseBackend.get().has_account_lockability()
and BaseBackend.instance.has_account_lockability()
and not edited_user.locked
):
flash(_("The account has been locked"), "success")
@ -731,7 +731,7 @@ def profile_settings(user, edited_user):
if (
request.form.get("action") == "unlock"
and BaseBackend.get().has_account_lockability()
and BaseBackend.instance.has_account_lockability()
and edited_user.locked
):
flash(_("The account has been unlocked"), "success")
@ -782,7 +782,9 @@ def profile_settings_edit(editor, edited_user):
and form["password1"].data
and request.form["action"] == "edit-settings"
):
BaseBackend.get().set_user_password(edited_user, form["password1"].data)
BaseBackend.instance.set_user_password(
edited_user, form["password1"].data
)
edited_user.save()
flash(_("Profile updated successfully."), "success")

View file

@ -43,12 +43,12 @@ def login():
form = LoginForm(request.form or None)
form.render_field_macro_file = "partial/login_field.html"
form["login"].render_kw["placeholder"] = BaseBackend.get().login_placeholder()
form["login"].render_kw["placeholder"] = BaseBackend.instance.login_placeholder()
if not request.form or form.form_control():
return render_template("login.html", form=form)
user = BaseBackend.get().get_user_from_login(form.login.data)
user = BaseBackend.instance.get_user_from_login(form.login.data)
if user and not user.has_password():
return redirect(url_for("core.auth.firstlogin", user=user))
@ -80,7 +80,7 @@ def password():
"password.html", form=form, username=session["attempt_login"]
)
user = BaseBackend.get().get_user_from_login(session["attempt_login"])
user = BaseBackend.instance.get_user_from_login(session["attempt_login"])
if user and not user.has_password():
return redirect(url_for("core.auth.firstlogin", user=user))
@ -91,7 +91,9 @@ def password():
"password.html", form=form, username=session["attempt_login"]
)
success, message = BaseBackend.get().check_user_password(user, form.password.data)
success, message = BaseBackend.instance.check_user_password(
user, form.password.data
)
request_ip = request.remote_addr or "unknown IP"
if not success:
logout_user()
@ -175,7 +177,7 @@ def forgotten():
flash(_("Could not send the password reset link."), "error")
return render_template("forgotten-password.html", form=form)
user = BaseBackend.get().get_user_from_login(form.login.data)
user = BaseBackend.instance.get_user_from_login(form.login.data)
success_message = _(
"A password reset link has been sent at your email address. "
"You should receive it within a few minutes."
@ -233,7 +235,7 @@ def reset(user, hash):
return redirect(url_for("core.account.index"))
if request.form and form.validate():
BaseBackend.get().set_user_password(user, form.password.data)
BaseBackend.instance.set_user_password(user, form.password.data)
login_user(user)
flash(_("Your password has been updated successfully"), "success")

View file

@ -51,7 +51,7 @@ def unique_group(form, field):
def existing_login(form, field):
if not current_app.config["CANAILLE"][
"HIDE_INVALID_LOGINS"
] and not BaseBackend.get().get_user_from_login(field.data):
] and not BaseBackend.instance.get_user_from_login(field.data):
raise wtforms.ValidationError(
_("The login '{login}' does not exist").format(login=field.data)
)
@ -314,7 +314,7 @@ PROFILE_FORM_FIELDS = dict(
default=[],
choices=lambda: [
(group, group.display_name)
for group in BaseBackend.get().query(models.Group)
for group in BaseBackend.instance.query(models.Group)
],
render_kw={"placeholder": _("users, admins …")},
coerce=IDToModel("Group"),
@ -336,10 +336,10 @@ def build_profile_form(write_field_names, readonly_field_names, user=None):
if PROFILE_FORM_FIELDS.get(name)
}
if "groups" in fields and not BaseBackend.get().query(models.Group):
if "groups" in fields and not BaseBackend.instance.query(models.Group):
del fields["groups"]
if current_app.backend.get().has_account_lockability(): # pragma: no branch
if current_app.backend.instance.has_account_lockability(): # pragma: no branch
fields["lock_date"] = DateTimeUTCField(
_("Account expiration"),
validators=[wtforms.validators.Optional()],
@ -441,7 +441,7 @@ class InvitationForm(Form):
_("Groups"),
choices=lambda: [
(group, group.display_name)
for group in BaseBackend.get().query(models.Group)
for group in BaseBackend.instance.query(models.Group)
],
render_kw={},
coerce=IDToModel("Group"),

View file

@ -48,7 +48,7 @@ def fake_users(nb=1):
def fake_groups(nb=1, nb_users_max=1):
users = BaseBackend.get().query(models.User)
users = BaseBackend.instance.query(models.User)
groups = list()
fake = faker.Faker(["en_US"])
for _ in range(nb):

View file

@ -11,11 +11,11 @@ from canaille.backends import BaseBackend
@with_backendcontext
def clean():
"""Remove expired tokens and authorization codes."""
for t in BaseBackend.get().query(models.Token):
for t in BaseBackend.instance.query(models.Token):
if t.is_expired():
t.delete()
for a in BaseBackend.get().query(models.AuthorizationCode):
for a in BaseBackend.instance.query(models.AuthorizationCode):
if a.is_expired():
a.delete()

View file

@ -20,13 +20,13 @@ bp = Blueprint("consents", __name__, url_prefix="/consent")
@bp.route("/")
@user_needed()
def consents(user):
consents = BaseBackend.get().query(models.Consent, subject=user)
consents = BaseBackend.instance.query(models.Consent, subject=user)
clients = {t.client for t in consents}
nb_consents = len(consents)
nb_preconsents = sum(
1
for client in BaseBackend.get().query(models.Client)
for client in BaseBackend.instance.query(models.Client)
if client.preconsent and client not in clients
)
@ -44,11 +44,11 @@ def consents(user):
@bp.route("/pre-consents")
@user_needed()
def pre_consents(user):
consents = BaseBackend.get().query(models.Consent, subject=user)
consents = BaseBackend.instance.query(models.Consent, subject=user)
clients = {t.client for t in consents}
preconsented = [
client
for client in BaseBackend.get().query(models.Client)
for client in BaseBackend.instance.query(models.Client)
if client.preconsent and client not in clients
]

View file

@ -21,7 +21,7 @@ class LogoutForm(Form):
def client_audiences():
return [
(client, client.client_name)
for client in BaseBackend.get().query(models.Client)
for client in BaseBackend.instance.query(models.Client)
]

View file

@ -110,7 +110,7 @@ def authorize_login(user):
def authorize_consent(client, user):
requested_scopes = request.args.get("scope", "").split(" ")
allowed_scopes = client.get_allowed_scope(requested_scopes).split(" ")
consents = BaseBackend.get().query(
consents = BaseBackend.instance.query(
models.Consent,
client=client,
subject=user,

View file

@ -96,13 +96,13 @@ class Client(BaseClient, ClientMixin):
return metadata
def delete(self):
for consent in BaseBackend.get().query(models.Consent, client=self):
for consent in BaseBackend.instance.query(models.Consent, client=self):
consent.delete()
for code in BaseBackend.get().query(models.AuthorizationCode, client=self):
for code in BaseBackend.instance.query(models.AuthorizationCode, client=self):
code.delete()
for token in BaseBackend.get().query(models.Token, client=self):
for token in BaseBackend.instance.query(models.Token, client=self):
token.delete()
super().delete()
@ -186,7 +186,7 @@ class Consent(BaseConsent):
self.revokation_date = datetime.datetime.now(datetime.timezone.utc)
self.save()
tokens = BaseBackend.get().query(
tokens = BaseBackend.instance.query(
models.Token,
client=self.client,
subject=self.subject,

View file

@ -112,7 +112,7 @@ def openid_configuration():
def exists_nonce(nonce, req):
client = models.Client.get(id=req.client_id)
exists = BaseBackend.get().query(
exists = BaseBackend.instance.query(
models.AuthorizationCode, client=client, nonce=nonce
)
return bool(exists)
@ -239,7 +239,7 @@ class AuthorizationCodeGrant(_AuthorizationCodeGrant):
return save_authorization_code(code, request)
def query_authorization_code(self, code, client):
item = BaseBackend.get().query(
item = BaseBackend.instance.query(
models.AuthorizationCode, code=code, client=client
)
if item and not item[0].is_expired():
@ -272,11 +272,11 @@ class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant):
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
def authenticate_user(self, username, password):
user = BaseBackend.get().get_user_from_login(username)
user = BaseBackend.instance.get_user_from_login(username)
if not user:
return None
success, _ = BaseBackend.get().check_user_password(user, password)
success, _ = BaseBackend.instance.check_user_password(user, password)
if not success:
return None
@ -287,7 +287,7 @@ class RefreshTokenGrant(_RefreshTokenGrant):
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
def authenticate_refresh_token(self, refresh_token):
token = BaseBackend.get().query(models.Token, refresh_token=refresh_token)
token = BaseBackend.instance.query(models.Token, refresh_token=refresh_token)
if token and token[0].is_refresh_token_active():
return token[0]