refactor: Rename BaseBackend in Backend

This commit is contained in:
Éloi Rivard 2024-04-16 22:42:29 +02:00
parent 6ff591b91c
commit 16c3021a8f
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184
31 changed files with 186 additions and 195 deletions

View file

@ -5,14 +5,14 @@ import click
from flask import current_app
from flask.cli import with_appcontext
from canaille.backends import BaseBackend
from canaille.backends import Backend
def with_backendcontext(func):
@functools.wraps(func)
def _func(*args, **kwargs):
if not current_app.config["TESTING"]: # pragma: no cover
with BaseBackend.instance.session():
with Backend.instance.session():
result = func(*args, **kwargs)
else:

View file

@ -160,9 +160,9 @@ def validate(config, validate_remote=False):
if not validate_remote:
return
from canaille.backends import BaseBackend
from canaille.backends import Backend
BaseBackend.instance.validate(config)
Backend.instance.validate(config)
validate_smtp_configuration(config["CANAILLE"]["SMTP"])

View file

@ -15,7 +15,7 @@ from canaille.app.i18n import DEFAULT_LANGUAGE_CODE
from canaille.app.i18n import gettext as _
from canaille.app.i18n import locale_selector
from canaille.app.i18n import timezone_selector
from canaille.backends import BaseBackend
from canaille.backends import Backend
from . import validate_uri
from .flask import request_is_htmx
@ -187,11 +187,9 @@ class TableForm(I18NFormMixin, FlaskForm):
filter = filter or {}
super().__init__(**kwargs)
if self.query.data:
self.items = BaseBackend.instance.fuzzy(
cls, self.query.data, fields, **filter
)
self.items = Backend.instance.fuzzy(cls, self.query.data, fields, **filter)
else:
self.items = BaseBackend.instance.query(cls, **filter)
self.items = Backend.instance.query(cls, **filter)
self.page_size = page_size
self.nb_items = len(self.items)
@ -261,7 +259,7 @@ class IDToModel:
def __call__(self, data):
model = getattr(models, self.model_name)
instance = (
data if isinstance(data, model) else BaseBackend.instance.get(model, data)
data if isinstance(data, model) else Backend.instance.get(model, data)
)
if instance:
return instance

View file

@ -1,4 +1,4 @@
from canaille.backends import BaseBackend
from canaille.backends import Backend
from canaille.oidc.installation import install as install_oidc
@ -8,4 +8,4 @@ class InstallationException(Exception):
def install(config, debug=False):
install_oidc(config, debug=debug)
BaseBackend.instance.install(config)
Backend.instance.install(config)

View file

@ -7,12 +7,12 @@ from flask import g
from canaille.app import classproperty
class BaseBackend:
class Backend:
instance = None
def __init__(self, config):
self.config = config
BaseBackend.instance = self
Backend.instance = self
self.register_models()
@classproperty
@ -76,13 +76,13 @@ class BaseBackend:
raise NotImplementedError()
def fuzzy(self, model, query, attributes=None, **kwargs):
"""Works like :meth:`~canaille.backends.BaseBackend.query` but
attribute values loosely be matched."""
"""Works like :meth:`~canaille.backends.Backend.query` but attribute
values loosely be matched."""
raise NotImplementedError()
def get(self, model, identifier=None, **kwargs):
"""Works like :meth:`~canaille.backends.BaseBackend.query` but return
only one element or :py:data:`None` if no item is matching."""
"""Works like :meth:`~canaille.backends.Backend.query` but return only
one element or :py:data:`None` if no item is matching."""
raise NotImplementedError()
def save(self, instance):
@ -102,7 +102,7 @@ class BaseBackend:
>>> user.display_name = "Jane"
>>> user.display_name
Jane
>>> BaseBackend.instance.reload(user)
>>> Backend.instance.reload(user)
>>> user.display_name
George
"""
@ -176,7 +176,9 @@ def setup_backend(app, backend=None):
else "memory"
)
module = importlib.import_module(f"canaille.backends.{backend_name}.backend")
backend_class = getattr(module, "Backend")
backend_class = getattr(
module, f"{backend_name.title()}Backend", None
) or getattr(module, f"{backend_name.upper()}Backend", None)
backend = backend_class(app.config)
backend.init_app(app)

View file

@ -14,7 +14,7 @@ from ldap.controls.readentry import PostReadControl
from canaille.app import models
from canaille.app.configuration import ConfigurationException
from canaille.app.i18n import gettext as _
from canaille.backends import BaseBackend
from canaille.backends import Backend
from .utils import listify
from .utils import python_attrs_to_ldap
@ -53,7 +53,7 @@ def install_schema(config, schema_path):
) from exc
class Backend(BaseBackend):
class LDAPBackend(Backend):
def __init__(self, config):
super().__init__(config)
self._connection = None
@ -129,8 +129,8 @@ class Backend(BaseBackend):
emails=f"canaille_{uuid.uuid4()}@mydomain.tld",
password="correct horse battery staple",
)
BaseBackend.instance.save(user)
BaseBackend.instance.delete(user)
Backend.instance.save(user)
Backend.instance.delete(user)
except ldap.INSUFFICIENT_ACCESS as exc:
raise ConfigurationException(
@ -148,14 +148,14 @@ class Backend(BaseBackend):
emails=f"canaille_{uuid.uuid4()}@mydomain.tld",
password="correct horse battery staple",
)
BaseBackend.instance.save(user)
Backend.instance.save(user)
group = models.Group(
display_name=f"canaille_{uuid.uuid4()}",
members=[user],
)
BaseBackend.instance.save(group)
BaseBackend.instance.delete(group)
Backend.instance.save(group)
Backend.instance.delete(group)
except ldap.INSUFFICIENT_ACCESS as exc:
raise ConfigurationException(
@ -164,7 +164,7 @@ class Backend(BaseBackend):
) from exc
finally:
BaseBackend.instance.delete(user)
Backend.instance.delete(user)
@classmethod
def login_placeholder(cls):

View file

@ -5,7 +5,7 @@ import ldap.filter
from canaille.backends.models import BackendModel
from .backend import Backend
from .backend import LDAPBackend
from .utils import attribute_ldap_syntax
from .utils import cardinalize_attribute
from .utils import ldap_to_python
@ -160,7 +160,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
@classmethod
def install(cls):
conn = Backend.instance.connection
conn = LDAPBackend.instance.connection
cls.ldap_object_classes(conn)
cls.ldap_object_attributes(conn)
@ -185,7 +185,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
if cls._object_class_by_name and not force:
return cls._object_class_by_name
conn = Backend.instance.connection
conn = LDAPBackend.instance.connection
res = conn.search_s(
"cn=subschema", ldap.SCOPE_BASE, "(objectclass=*)", ["*", "+"]
@ -207,7 +207,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
if cls._attribute_type_by_name and not force:
return cls._attribute_type_by_name
conn = Backend.instance.connection
conn = LDAPBackend.instance.connection
res = conn.search_s(
"cn=subschema", ldap.SCOPE_BASE, "(objectclass=*)", ["*", "+"]

View file

@ -3,7 +3,7 @@ import ldap.filter
import canaille.core.models
import canaille.oidc.models
from .backend import Backend
from .backend import LDAPBackend
from .ldapobject import LDAPObject
@ -38,7 +38,7 @@ class User(canaille.core.models.User, LDAPObject):
def match_filter(self, filter):
if isinstance(filter, str):
conn = Backend.instance.connection
conn = LDAPBackend.instance.connection
return self.dn and conn.search_s(self.dn, ldap.SCOPE_SUBTREE, filter)
return super().match_filter(filter)
@ -64,7 +64,7 @@ class User(canaille.core.models.User, LDAPObject):
for group in to_add:
group.members = group.members + [self]
Backend.instance.save(group)
LDAPBackend.instance.save(group)
for group in to_del:
# LDAP groups cannot be empty because groupOfNames.member
@ -73,7 +73,7 @@ class User(canaille.core.models.User, LDAPObject):
# TODO: properly manage the situation where one wants to
# remove the last member of a group
group.members = [member for member in group.members if member != self]
Backend.instance.save(group)
LDAPBackend.instance.save(group)
self.state[group_attr] = new_groups

View file

@ -1,7 +1,7 @@
import datetime
from enum import Enum
from canaille.backends import BaseBackend
from canaille.backends import Backend
LDAP_NULL_DATE = "000001010000Z"
@ -52,7 +52,7 @@ def ldap_to_python(value, syntax):
return value.decode("utf-8").upper() == "TRUE"
if syntax == Syntax.DISTINGUISHED_NAME:
return BaseBackend.instance.get(LDAPObject, value.decode("utf-8"))
return Backend.instance.get(LDAPObject, value.decode("utf-8"))
return value.decode("utf-8")

View file

@ -1,10 +1,10 @@
import datetime
import uuid
from canaille.backends import BaseBackend
from canaille.backends import Backend
class Backend(BaseBackend):
class MemoryBackend(Backend):
@classmethod
def install(cls, config):
pass
@ -124,7 +124,7 @@ class Backend(BaseBackend):
reload_callback = instance.reload() if hasattr(instance, "reload") else iter([])
next(reload_callback, None)
instance._state = BaseBackend.instance.get(
instance._state = Backend.instance.get(
instance.__class__, id=instance.id
)._state
instance._cache = {}

View file

@ -4,7 +4,7 @@ import typing
import canaille.core.models
import canaille.oidc.models
from canaille.app import models
from canaille.backends import BaseBackend
from canaille.backends import Backend
from canaille.backends.models import BackendModel
@ -61,7 +61,7 @@ class MemoryModel(BackendModel):
model, _ = cls.get_model_annotations(attribute_name)
if model and not isinstance(value, model):
backend_model = getattr(models, model.__name__)
return BaseBackend.instance.get(backend_model, id=value)
return Backend.instance.get(backend_model, id=value)
return value
@ -139,7 +139,7 @@ class MemoryModel(BackendModel):
return False
if not isinstance(other, MemoryModel):
return self == BaseBackend.instance.get(self.__class__, id=other)
return self == Backend.instance.get(self.__class__, id=other)
return self._state == other._state

View file

@ -11,7 +11,7 @@ from typing import get_type_hints
from canaille.app import classproperty
from canaille.app import models
from canaille.backends import BaseBackend
from canaille.backends import Backend
class Model:
@ -133,7 +133,7 @@ class BackendModel:
backend_model = getattr(models, model.__name__)
if instance := BaseBackend.instance.get(backend_model, value):
if instance := Backend.instance.get(backend_model, value):
filter[attribute] = instance
return all(

View file

@ -6,7 +6,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import declarative_base
from canaille.backends import BaseBackend
from canaille.backends import Backend
Base = declarative_base()
@ -19,7 +19,7 @@ def db_session(db_uri=None, init=False):
return session
class Backend(BaseBackend):
class SQLBackend(Backend):
db_session = None
@classmethod
@ -76,7 +76,7 @@ class Backend(BaseBackend):
for attribute_name, expected_value in kwargs.items()
]
return (
Backend.instance.db_session.execute(select(model).filter(*filter))
SQLBackend.instance.db_session.execute(select(model).filter(*filter))
.scalars()
.all()
)
@ -105,7 +105,7 @@ class Backend(BaseBackend):
model.attribute_filter(attribute_name, expected_value)
for attribute_name, expected_value in kwargs.items()
]
return Backend.instance.db_session.execute(
return SQLBackend.instance.db_session.execute(
select(model).filter(*filter)
).scalar_one_or_none()
@ -116,16 +116,16 @@ class Backend(BaseBackend):
if not instance.created:
instance.created = instance.last_modified
Backend.instance.db_session.add(instance)
Backend.instance.db_session.commit()
SQLBackend.instance.db_session.add(instance)
SQLBackend.instance.db_session.commit()
def delete(self, instance):
# run the instance delete callback if existing
save_callback = instance.delete() if hasattr(instance, "delete") else iter([])
next(save_callback, None)
Backend.instance.db_session.delete(instance)
Backend.instance.db_session.commit()
SQLBackend.instance.db_session.delete(instance)
SQLBackend.instance.db_session.commit()
# run the instance delete callback again if existing
next(save_callback, None)
@ -135,7 +135,7 @@ class Backend(BaseBackend):
reload_callback = instance.reload() if hasattr(instance, "reload") else iter([])
next(reload_callback, None)
Backend.instance.db_session.refresh(instance)
SQLBackend.instance.db_session.refresh(instance)
# run the instance reload callback again if existing
next(reload_callback, None)

View file

@ -41,7 +41,7 @@ from canaille.app.forms import set_writable
from canaille.app.i18n import gettext as _
from canaille.app.i18n import reload_translations
from canaille.app.themes import render_template
from canaille.backends import BaseBackend
from canaille.backends import Backend
from ..mails import send_confirmation_email
from ..mails import send_invitation_mail
@ -86,7 +86,7 @@ def join():
form = JoinForm(request.form or None)
if request.form and form.validate():
if BaseBackend.instance.query(models.User, emails=form.email.data):
if Backend.instance.query(models.User, emails=form.email.data):
flash(
_(
"You will receive soon an email to continue the registration process."
@ -257,7 +257,7 @@ def registration(data=None, hash=None):
)
return redirect(url_for("core.account.index"))
if payload.user_name and BaseBackend.instance.get(
if payload.user_name and Backend.instance.get(
models.User, user_name=payload.user_name
):
flash(
@ -285,7 +285,7 @@ def registration(data=None, hash=None):
"user_name": payload.user_name,
"emails": [payload.email],
"groups": [
BaseBackend.instance.get(models.Group, id=group_id)
Backend.instance.get(models.Group, id=group_id)
for group_id in payload.groups
],
}
@ -302,7 +302,7 @@ def registration(data=None, hash=None):
_("Groups"),
choices=[
(group, group.display_name)
for group in BaseBackend.instance.query(models.Group)
for group in Backend.instance.query(models.Group)
],
coerce=IDToModel("Group"),
)
@ -381,7 +381,7 @@ def email_confirmation(data, hash):
)
return redirect(url_for("core.account.index"))
user = BaseBackend.instance.get(models.User, confirmation_obj.identifier)
user = Backend.instance.get(models.User, confirmation_obj.identifier)
if not user:
flash(
_("The email confirmation link that brought you here is invalid."),
@ -396,7 +396,7 @@ def email_confirmation(data, hash):
)
return redirect(url_for("core.account.index"))
if BaseBackend.instance.query(models.User, emails=confirmation_obj.email):
if Backend.instance.query(models.User, emails=confirmation_obj.email):
flash(
_("This address email is already associated with another account."),
"error",
@ -404,7 +404,7 @@ def email_confirmation(data, hash):
return redirect(url_for("core.account.index"))
user.emails = user.emails + [confirmation_obj.email]
BaseBackend.instance.save(user)
Backend.instance.save(user)
flash(_("Your email address have been confirmed."), "success")
return redirect(url_for("core.account.index"))
@ -460,11 +460,11 @@ def profile_create(current_app, form):
given_name = user.given_name if user.given_name else ""
family_name = user.family_name if user.family_name else ""
user.formatted_name = f"{given_name} {family_name}".strip()
BaseBackend.instance.save(user)
Backend.instance.save(user)
if form["password1"].data:
BaseBackend.instance.set_user_password(user, form["password1"].data)
BaseBackend.instance.save(user)
Backend.instance.set_user_password(user, form["password1"].data)
Backend.instance.save(user)
return user
@ -536,8 +536,8 @@ def profile_edition_main_form_validation(user, edited_user, profile_form):
if profile_form["preferred_language"].data == "auto":
edited_user.preferred_language = None
BaseBackend.instance.save(edited_user)
BaseBackend.instance.reload(g.user)
Backend.instance.save(edited_user)
Backend.instance.reload(g.user)
def profile_edition_emails_form(user, edited_user, has_smtp):
@ -574,7 +574,7 @@ def profile_edition_remove_email(user, edited_user, email):
return False
edited_user.emails = [m for m in edited_user.emails if m != email]
BaseBackend.instance.save(edited_user)
Backend.instance.save(edited_user)
return True
@ -718,30 +718,30 @@ def profile_settings(user, edited_user):
if (
request.form.get("action") == "confirm-lock"
and BaseBackend.instance.has_account_lockability()
and Backend.instance.has_account_lockability()
and not edited_user.locked
):
return render_template("modals/lock-account.html", edited_user=edited_user)
if (
request.form.get("action") == "lock"
and BaseBackend.instance.has_account_lockability()
and Backend.instance.has_account_lockability()
and not edited_user.locked
):
flash(_("The account has been locked"), "success")
edited_user.lock_date = datetime.datetime.now(datetime.timezone.utc)
BaseBackend.instance.save(edited_user)
Backend.instance.save(edited_user)
return profile_settings_edit(user, edited_user)
if (
request.form.get("action") == "unlock"
and BaseBackend.instance.has_account_lockability()
and Backend.instance.has_account_lockability()
and edited_user.locked
):
flash(_("The account has been unlocked"), "success")
edited_user.lock_date = None
BaseBackend.instance.save(edited_user)
Backend.instance.save(edited_user)
return profile_settings_edit(user, edited_user)
@ -787,11 +787,9 @@ def profile_settings_edit(editor, edited_user):
and form["password1"].data
and request.form["action"] == "edit-settings"
):
BaseBackend.instance.set_user_password(
edited_user, form["password1"].data
)
Backend.instance.set_user_password(edited_user, form["password1"].data)
BaseBackend.instance.save(edited_user)
Backend.instance.save(edited_user)
flash(_("Profile updated successfully."), "success")
return redirect(
url_for("core.account.profile_settings", edited_user=edited_user)
@ -818,7 +816,7 @@ def profile_delete(user, edited_user):
),
"success",
)
BaseBackend.instance.delete(edited_user)
Backend.instance.delete(edited_user)
if self_deletion:
return redirect(url_for("core.account.index"))

View file

@ -14,7 +14,7 @@ from canaille.app.flask import logout_user
from canaille.app.flask import smtp_needed
from canaille.app.i18n import gettext as _
from canaille.app.themes import render_template
from canaille.backends import BaseBackend
from canaille.backends import Backend
from ..mails import send_password_initialization_mail
from ..mails import send_password_reset_mail
@ -43,12 +43,12 @@ def login():
form = LoginForm(request.form or None)
form.render_field_macro_file = "partial/login_field.html"
form["login"].render_kw["placeholder"] = BaseBackend.instance.login_placeholder()
form["login"].render_kw["placeholder"] = Backend.instance.login_placeholder()
if not request.form or form.form_control():
return render_template("login.html", form=form)
user = BaseBackend.instance.get_user_from_login(form.login.data)
user = Backend.instance.get_user_from_login(form.login.data)
if user and not user.has_password():
return redirect(url_for("core.auth.firstlogin", user=user))
@ -80,7 +80,7 @@ def password():
"password.html", form=form, username=session["attempt_login"]
)
user = BaseBackend.instance.get_user_from_login(session["attempt_login"])
user = Backend.instance.get_user_from_login(session["attempt_login"])
if user and not user.has_password():
return redirect(url_for("core.auth.firstlogin", user=user))
@ -91,9 +91,7 @@ def password():
"password.html", form=form, username=session["attempt_login"]
)
success, message = BaseBackend.instance.check_user_password(
user, form.password.data
)
success, message = Backend.instance.check_user_password(user, form.password.data)
request_ip = request.remote_addr or "unknown IP"
if not success:
logout_user()
@ -177,7 +175,7 @@ def forgotten():
flash(_("Could not send the password reset link."), "error")
return render_template("forgotten-password.html", form=form)
user = BaseBackend.instance.get_user_from_login(form.login.data)
user = Backend.instance.get_user_from_login(form.login.data)
success_message = _(
"A password reset link has been sent at your email address. "
"You should receive it within a few minutes."
@ -235,7 +233,7 @@ def reset(user, hash):
return redirect(url_for("core.account.index"))
if request.form and form.validate():
BaseBackend.instance.set_user_password(user, form.password.data)
Backend.instance.set_user_password(user, form.password.data)
login_user(user)
flash(_("Your password has been updated successfully"), "success")

View file

@ -18,13 +18,13 @@ from canaille.app.forms import unique_values
from canaille.app.i18n import gettext
from canaille.app.i18n import lazy_gettext as _
from canaille.app.i18n import native_language_name_from_code
from canaille.backends import BaseBackend
from canaille.backends import Backend
MINIMUM_PASSWORD_LENGTH = 8
def unique_user_name(form, field):
if BaseBackend.instance.get(models.User, user_name=field.data) and (
if Backend.instance.get(models.User, user_name=field.data) and (
not getattr(form, "user", None) or form.user.user_name != field.data
):
raise wtforms.ValidationError(
@ -33,7 +33,7 @@ def unique_user_name(form, field):
def unique_email(form, field):
if BaseBackend.instance.get(models.User, emails=field.data) and (
if Backend.instance.get(models.User, emails=field.data) and (
not getattr(form, "user", None) or field.data not in form.user.emails
):
raise wtforms.ValidationError(
@ -42,7 +42,7 @@ def unique_email(form, field):
def unique_group(form, field):
if BaseBackend.instance.get(models.Group, display_name=field.data):
if Backend.instance.get(models.Group, display_name=field.data):
raise wtforms.ValidationError(
_("The group '{group}' already exists").format(group=field.data)
)
@ -51,7 +51,7 @@ def unique_group(form, field):
def existing_login(form, field):
if not current_app.config["CANAILLE"][
"HIDE_INVALID_LOGINS"
] and not BaseBackend.instance.get_user_from_login(field.data):
] and not Backend.instance.get_user_from_login(field.data):
raise wtforms.ValidationError(
_("The login '{login}' does not exist").format(login=field.data)
)
@ -314,7 +314,7 @@ PROFILE_FORM_FIELDS = dict(
default=[],
choices=lambda: [
(group, group.display_name)
for group in BaseBackend.instance.query(models.Group)
for group in Backend.instance.query(models.Group)
],
render_kw={"placeholder": _("users, admins …")},
coerce=IDToModel("Group"),
@ -336,7 +336,7 @@ def build_profile_form(write_field_names, readonly_field_names, user=None):
if PROFILE_FORM_FIELDS.get(name)
}
if "groups" in fields and not BaseBackend.instance.query(models.Group):
if "groups" in fields and not Backend.instance.query(models.Group):
del fields["groups"]
if current_app.backend.instance.has_account_lockability(): # pragma: no branch
@ -441,7 +441,7 @@ class InvitationForm(Form):
_("Groups"),
choices=lambda: [
(group, group.display_name)
for group in BaseBackend.instance.query(models.Group)
for group in Backend.instance.query(models.Group)
],
render_kw={},
coerce=IDToModel("Group"),

View file

@ -11,7 +11,7 @@ from canaille.app.flask import render_htmx_template
from canaille.app.forms import TableForm
from canaille.app.i18n import gettext as _
from canaille.app.themes import render_template
from canaille.backends import BaseBackend
from canaille.backends import Backend
from .forms import CreateGroupForm
from .forms import DeleteGroupMemberForm
@ -43,7 +43,7 @@ def create_group(user):
group.members = [user]
group.display_name = form.display_name.data
group.description = form.description.data
BaseBackend.instance.save(group)
Backend.instance.save(group)
flash(
_(
"The group %(group)s has been sucessfully created",
@ -103,7 +103,7 @@ def edit_group(group):
):
if form.validate():
group.description = form.description.data
BaseBackend.instance.save(group)
Backend.instance.save(group)
flash(
_(
"The group %(group)s has been sucessfully edited.",
@ -162,5 +162,5 @@ def delete_group(group):
_("The group %(group)s has been sucessfully deleted", group=group.display_name),
"success",
)
BaseBackend.instance.delete(group)
Backend.instance.delete(group)
return redirect(url_for("core.groups.groups"))

View file

@ -5,7 +5,7 @@ from faker.config import AVAILABLE_LOCALES
from canaille.app import models
from canaille.app.i18n import available_language_codes
from canaille.backends import BaseBackend
from canaille.backends import Backend
def fake_users(nb=1):
@ -40,7 +40,7 @@ def fake_users(nb=1):
password=fake.password(),
preferred_language=fake._locales[0],
)
BaseBackend.instance.save(user)
Backend.instance.save(user)
users.append(user)
except Exception: # pragma: no cover
pass
@ -48,7 +48,7 @@ def fake_users(nb=1):
def fake_groups(nb=1, nb_users_max=1):
users = BaseBackend.instance.query(models.User)
users = Backend.instance.query(models.User)
groups = list()
fake = faker.Faker(["en_US"])
for _ in range(nb):
@ -59,7 +59,7 @@ def fake_groups(nb=1, nb_users_max=1):
)
nb_users = random.randrange(1, nb_users_max + 1)
group.members = list({random.choice(users) for _ in range(nb_users)})
BaseBackend.instance.save(group)
Backend.instance.save(group)
groups.append(group)
except Exception: # pragma: no cover
pass

View file

@ -3,7 +3,7 @@ from flask.cli import with_appcontext
from canaille.app import models
from canaille.app.commands import with_backendcontext
from canaille.backends import BaseBackend
from canaille.backends import Backend
@click.command()
@ -11,13 +11,13 @@ from canaille.backends import BaseBackend
@with_backendcontext
def clean():
"""Remove expired tokens and authorization codes."""
for t in BaseBackend.instance.query(models.Token):
for t in Backend.instance.query(models.Token):
if t.is_expired():
BaseBackend.instance.delete(t)
Backend.instance.delete(t)
for a in BaseBackend.instance.query(models.AuthorizationCode):
for a in Backend.instance.query(models.AuthorizationCode):
if a.is_expired():
BaseBackend.instance.delete(a)
Backend.instance.delete(a)
def register(cli):

View file

@ -14,7 +14,7 @@ from canaille.app.flask import render_htmx_template
from canaille.app.forms import TableForm
from canaille.app.i18n import gettext as _
from canaille.app.themes import render_template
from canaille.backends import BaseBackend
from canaille.backends import Backend
from .forms import ClientAddForm
@ -74,9 +74,9 @@ def add(user):
if form["token_endpoint_auth_method"].data == "none"
else gen_salt(48),
)
BaseBackend.instance.save(client)
Backend.instance.save(client)
client.audience = [client]
BaseBackend.instance.save(client)
Backend.instance.save(client)
flash(
_("The client has been created."),
"success",
@ -118,7 +118,7 @@ def client_edit(client):
"client_edit.html", form=form, client=client, menuitem="admin"
)
BaseBackend.instance.update(
Backend.instance.update(
client,
client_name=form["client_name"].data,
contacts=form["contacts"].data,
@ -139,7 +139,7 @@ def client_edit(client):
audience=form["audience"].data,
preconsent=form["preconsent"].data,
)
BaseBackend.instance.save(client)
Backend.instance.save(client)
flash(
_("The client has been edited."),
"success",
@ -152,5 +152,5 @@ def client_delete(client):
_("The client has been deleted."),
"success",
)
BaseBackend.instance.delete(client)
Backend.instance.delete(client)
return redirect(url_for("oidc.clients.index"))

View file

@ -10,7 +10,7 @@ from canaille.app import models
from canaille.app.flask import user_needed
from canaille.app.i18n import gettext as _
from canaille.app.themes import render_template
from canaille.backends import BaseBackend
from canaille.backends import Backend
from ..utils import SCOPE_DETAILS
@ -20,13 +20,13 @@ bp = Blueprint("consents", __name__, url_prefix="/consent")
@bp.route("/")
@user_needed()
def consents(user):
consents = BaseBackend.instance.query(models.Consent, subject=user)
consents = Backend.instance.query(models.Consent, subject=user)
clients = {t.client for t in consents}
nb_consents = len(consents)
nb_preconsents = sum(
1
for client in BaseBackend.instance.query(models.Client)
for client in Backend.instance.query(models.Client)
if client.preconsent and client not in clients
)
@ -44,11 +44,11 @@ def consents(user):
@bp.route("/pre-consents")
@user_needed()
def pre_consents(user):
consents = BaseBackend.instance.query(models.Consent, subject=user)
consents = Backend.instance.query(models.Consent, subject=user)
clients = {t.client for t in consents}
preconsented = [
client
for client in BaseBackend.instance.query(models.Client)
for client in Backend.instance.query(models.Client)
if client.preconsent and client not in clients
]
@ -95,7 +95,7 @@ def restore(user, consent):
consent.restore()
if not consent.issue_date:
consent.issue_date = datetime.datetime.now(datetime.timezone.utc)
BaseBackend.instance.save(consent)
Backend.instance.save(consent)
flash(_("The access has been restored"), "success")
return redirect(url_for("oidc.consents.consents"))
@ -108,7 +108,7 @@ def revoke_preconsent(user, client):
flash(_("Could not revoke this access"), "error")
return redirect(url_for("oidc.consents.consents"))
consent = BaseBackend.instance.get(models.Consent, client=client, subject=user)
consent = Backend.instance.get(models.Consent, client=client, subject=user)
if consent:
return redirect(url_for("oidc.consents.revoke", consent=consent))
@ -119,7 +119,7 @@ def revoke_preconsent(user, client):
scope=client.scope,
)
consent.revoke()
BaseBackend.instance.save(consent)
Backend.instance.save(consent)
flash(_("The access has been revoked"), "success")
return redirect(url_for("oidc.consents.consents"))

View file

@ -7,7 +7,7 @@ from canaille.app.forms import email_validator
from canaille.app.forms import is_uri
from canaille.app.forms import unique_values
from canaille.app.i18n import lazy_gettext as _
from canaille.backends import BaseBackend
from canaille.backends import Backend
class AuthorizeForm(Form):
@ -20,8 +20,7 @@ class LogoutForm(Form):
def client_audiences():
return [
(client, client.client_name)
for client in BaseBackend.instance.query(models.Client)
(client, client.client_name) for client in Backend.instance.query(models.Client)
]

View file

@ -23,7 +23,7 @@ from canaille.app.flask import logout_user
from canaille.app.flask import set_parameter_in_url_query
from canaille.app.i18n import gettext as _
from canaille.app.themes import render_template
from canaille.backends import BaseBackend
from canaille.backends import Backend
from ..oauth import ClientConfigurationEndpoint
from ..oauth import ClientRegistrationEndpoint
@ -50,9 +50,7 @@ def authorize():
request.form.to_dict(flat=False),
)
client = BaseBackend.instance.get(
models.Client, client_id=request.args["client_id"]
)
client = Backend.instance.get(models.Client, client_id=request.args["client_id"])
user = current_user()
if response := authorize_guards(client):
@ -112,7 +110,7 @@ def authorize_login(user):
def authorize_consent(client, user):
requested_scopes = request.args.get("scope", "").split(" ")
allowed_scopes = client.get_allowed_scope(requested_scopes).split(" ")
consents = BaseBackend.instance.query(
consents = Backend.instance.query(
models.Consent,
client=client,
subject=user,
@ -177,7 +175,7 @@ def authorize_consent(client, user):
scope=allowed_scopes,
issue_date=datetime.datetime.now(datetime.timezone.utc),
)
BaseBackend.instance.save(consent)
Backend.instance.save(consent)
response = authorization.create_authorization_response(grant_user=grant_user)
current_app.logger.debug("authorization endpoint response: %s", response.location)
@ -278,7 +276,7 @@ def end_session():
valid_uris = []
if "client_id" in data:
client = BaseBackend.instance.get(models.Client, client_id=data["client_id"])
client = Backend.instance.get(models.Client, client_id=data["client_id"])
if client:
valid_uris = client.post_logout_redirect_uris
@ -330,7 +328,7 @@ def end_session():
else [id_token["aud"]]
)
for client_id in client_ids:
client = BaseBackend.instance.get(models.Client, client_id=client_id)
client = Backend.instance.get(models.Client, client_id=client_id)
if client:
valid_uris.extend(client.post_logout_redirect_uris or [])

View file

@ -11,7 +11,7 @@ from canaille.app.flask import render_htmx_template
from canaille.app.forms import TableForm
from canaille.app.i18n import gettext as _
from canaille.app.themes import render_template
from canaille.backends import BaseBackend
from canaille.backends import Backend
from .forms import TokenRevokationForm
@ -41,7 +41,7 @@ def view(user, token):
elif request.form.get("action") == "revoke":
token.revokation_date = datetime.datetime.now(datetime.timezone.utc)
BaseBackend.instance.save(token)
Backend.instance.save(token)
flash(_("The token has successfully been revoked."), "success")
else:

View file

@ -8,7 +8,7 @@ from authlib.oauth2.rfc6749 import TokenMixin
from authlib.oauth2.rfc6749 import util
from canaille.app import models
from canaille.backends import BaseBackend
from canaille.backends import Backend
from .basemodels import AuthorizationCode as BaseAuthorizationCode
from .basemodels import Client as BaseClient
@ -96,14 +96,14 @@ class Client(BaseClient, ClientMixin):
return metadata
def delete(self):
for consent in BaseBackend.instance.query(models.Consent, client=self):
BaseBackend.instance.delete(consent)
for consent in Backend.instance.query(models.Consent, client=self):
Backend.instance.delete(consent)
for code in BaseBackend.instance.query(models.AuthorizationCode, client=self):
BaseBackend.instance.delete(code)
for code in Backend.instance.query(models.AuthorizationCode, client=self):
Backend.instance.delete(code)
for token in BaseBackend.instance.query(models.Token, client=self):
BaseBackend.instance.delete(token)
for token in Backend.instance.query(models.Token, client=self):
Backend.instance.delete(token)
yield
@ -184,9 +184,9 @@ class Consent(BaseConsent):
def revoke(self):
self.revokation_date = datetime.datetime.now(datetime.timezone.utc)
BaseBackend.instance.save(self)
Backend.instance.save(self)
tokens = BaseBackend.instance.query(
tokens = Backend.instance.query(
models.Token,
client=self.client,
subject=self.subject,
@ -194,8 +194,8 @@ class Consent(BaseConsent):
tokens = [token for token in tokens if not token.revoked]
for t in tokens:
t.revokation_date = self.revokation_date
BaseBackend.instance.save(t)
Backend.instance.save(t)
def restore(self):
self.revokation_date = None
BaseBackend.instance.save(self)
Backend.instance.save(self)

View file

@ -34,7 +34,7 @@ from flask import url_for
from werkzeug.security import gen_salt
from canaille.app import models
from canaille.backends import BaseBackend
from canaille.backends import Backend
AUTHORIZATION_CODE_LIFETIME = 84400
@ -111,8 +111,8 @@ def openid_configuration():
def exists_nonce(nonce, req):
client = BaseBackend.instance.get(models.Client, id=req.client_id)
exists = BaseBackend.instance.query(
client = Backend.instance.get(models.Client, id=req.client_id)
exists = Backend.instance.query(
models.AuthorizationCode, client=client, nonce=nonce
)
return bool(exists)
@ -228,7 +228,7 @@ def save_authorization_code(code, request):
challenge=request.data.get("code_challenge"),
challenge_method=request.data.get("code_challenge_method"),
)
BaseBackend.instance.save(code)
Backend.instance.save(code)
return code.code
@ -239,14 +239,14 @@ class AuthorizationCodeGrant(_AuthorizationCodeGrant):
return save_authorization_code(code, request)
def query_authorization_code(self, code, client):
item = BaseBackend.instance.query(
item = Backend.instance.query(
models.AuthorizationCode, code=code, client=client
)
if item and not item[0].is_expired():
return item[0]
def delete_authorization_code(self, authorization_code):
BaseBackend.instance.delete(authorization_code)
Backend.instance.delete(authorization_code)
def authenticate_user(self, authorization_code):
if authorization_code.subject and not authorization_code.subject.locked:
@ -272,11 +272,11 @@ class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant):
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
def authenticate_user(self, username, password):
user = BaseBackend.instance.get_user_from_login(username)
user = Backend.instance.get_user_from_login(username)
if not user:
return None
success, _ = BaseBackend.instance.check_user_password(user, password)
success, _ = Backend.instance.check_user_password(user, password)
if not success:
return None
@ -287,7 +287,7 @@ class RefreshTokenGrant(_RefreshTokenGrant):
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
def authenticate_refresh_token(self, refresh_token):
token = BaseBackend.instance.query(models.Token, refresh_token=refresh_token)
token = Backend.instance.query(models.Token, refresh_token=refresh_token)
if token and token[0].is_refresh_token_active():
return token[0]
@ -297,7 +297,7 @@ class RefreshTokenGrant(_RefreshTokenGrant):
def revoke_old_credential(self, credential):
credential.revokation_date = datetime.datetime.now(datetime.timezone.utc)
BaseBackend.instance.save(credential)
Backend.instance.save(credential)
class OpenIDImplicitGrant(_OpenIDImplicitGrant):
@ -334,7 +334,7 @@ class OpenIDHybridGrant(_OpenIDHybridGrant):
def query_client(client_id):
return BaseBackend.instance.get(models.Client, client_id=client_id)
return Backend.instance.get(models.Client, client_id=client_id)
def save_token(token, request):
@ -351,25 +351,25 @@ def save_token(token, request):
subject=request.user,
audience=request.client.audience,
)
BaseBackend.instance.save(t)
Backend.instance.save(t)
class BearerTokenValidator(_BearerTokenValidator):
def authenticate_token(self, token_string):
return BaseBackend.instance.get(models.Token, access_token=token_string)
return Backend.instance.get(models.Token, access_token=token_string)
def query_token(token, token_type_hint):
if token_type_hint == "access_token":
return BaseBackend.instance.get(models.Token, access_token=token)
return Backend.instance.get(models.Token, access_token=token)
elif token_type_hint == "refresh_token":
return BaseBackend.instance.get(models.Token, refresh_token=token)
return Backend.instance.get(models.Token, refresh_token=token)
item = BaseBackend.instance.get(models.Token, access_token=token)
item = Backend.instance.get(models.Token, access_token=token)
if item:
return item
item = BaseBackend.instance.get(models.Token, refresh_token=token)
item = Backend.instance.get(models.Token, refresh_token=token)
if item:
return item
@ -382,7 +382,7 @@ class RevocationEndpoint(_RevocationEndpoint):
def revoke_token(self, token, request):
token.revokation_date = datetime.datetime.now(datetime.timezone.utc)
BaseBackend.instance.save(token)
Backend.instance.save(token)
class IntrospectionEndpoint(_IntrospectionEndpoint):
@ -463,16 +463,16 @@ class ClientRegistrationEndpoint(ClientManagementMixin, _ClientRegistrationEndpo
post_logout_redirect_uris=request.data.get("post_logout_redirect_uris"),
**self.client_convert_data(**client_info, **client_metadata),
)
BaseBackend.instance.save(client)
Backend.instance.save(client)
client.audience = [client]
BaseBackend.instance.save(client)
Backend.instance.save(client)
return client
class ClientConfigurationEndpoint(ClientManagementMixin, _ClientConfigurationEndpoint):
def authenticate_client(self, request):
client_id = request.uri.split("/")[-1]
return BaseBackend.instance.get(models.Client, client_id=client_id)
return Backend.instance.get(models.Client, client_id=client_id)
def revoke_access_token(self, request, token):
pass
@ -481,13 +481,11 @@ class ClientConfigurationEndpoint(ClientManagementMixin, _ClientConfigurationEnd
return True
def delete_client(self, client, request):
BaseBackend.instance.delete(client)
Backend.instance.delete(client)
def update_client(self, client, client_metadata, request):
BaseBackend.instance.update(
client, **self.client_convert_data(**client_metadata)
)
BaseBackend.instance.save(client)
Backend.instance.update(client, **self.client_convert_data(**client_metadata))
Backend.instance.save(client)
return client
def generate_client_registration_info(self, client, request):

View file

@ -1,7 +1,7 @@
import pytest
from canaille.app.configuration import settings_factory
from canaille.backends.ldap.backend import Backend
from canaille.backends.ldap.backend import LDAPBackend
from tests.backends.ldap import CustomSlapdObject
@ -47,6 +47,6 @@ def ldap_configuration(configuration, slapd_server):
def ldap_backend(slapd_server, ldap_configuration):
config_obj = settings_factory(ldap_configuration)
config_dict = config_obj.model_dump()
backend = Backend(config_dict)
backend = LDAPBackend(config_dict)
with backend.session():
yield backend

View file

@ -4,7 +4,7 @@ from flask_webtest import TestApp
from canaille import create_app
from canaille.app.configuration import settings_factory
from canaille.app.installation import InstallationException
from canaille.backends.ldap.backend import Backend
from canaille.backends.ldap.backend import LDAPBackend
from canaille.backends.ldap.ldapobject import LDAPObject
from canaille.commands import cli
@ -54,12 +54,12 @@ def test_install_schemas(configuration, slapd_server):
config_obj = settings_factory(configuration)
config_dict = config_obj.model_dump()
with Backend(config_dict).session():
with LDAPBackend(config_dict).session():
assert "oauthClient" not in LDAPObject.ldap_object_classes(force=True)
Backend.setup_schemas(config_dict)
LDAPBackend.setup_schemas(config_dict)
with Backend(config_dict).session():
with LDAPBackend(config_dict).session():
assert "oauthClient" in LDAPObject.ldap_object_classes(force=True)
@ -71,15 +71,15 @@ def test_install_schemas_twice(configuration, slapd_server):
config_obj = settings_factory(configuration)
config_dict = config_obj.model_dump()
with Backend(config_dict).session():
with LDAPBackend(config_dict).session():
assert "oauthClient" not in LDAPObject.ldap_object_classes(force=True)
Backend.setup_schemas(config_dict)
LDAPBackend.setup_schemas(config_dict)
with Backend(config_dict).session():
with LDAPBackend(config_dict).session():
assert "oauthClient" in LDAPObject.ldap_object_classes(force=True)
Backend.setup_schemas(config_dict)
LDAPBackend.setup_schemas(config_dict)
def test_install_no_permissions_to_install_schemas(configuration, slapd_server):
@ -90,11 +90,11 @@ def test_install_no_permissions_to_install_schemas(configuration, slapd_server):
config_obj = settings_factory(configuration)
config_dict = config_obj.model_dump()
with Backend(config_dict).session():
with LDAPBackend(config_dict).session():
assert "oauthClient" not in LDAPObject.ldap_object_classes(force=True)
with pytest.raises(InstallationException):
Backend.setup_schemas(config_dict)
LDAPBackend.setup_schemas(config_dict)
assert "oauthClient" not in LDAPObject.ldap_object_classes(force=True)
@ -107,7 +107,7 @@ def test_install_schemas_command(configuration, slapd_server):
config_obj = settings_factory(configuration)
config_dict = config_obj.model_dump()
with Backend(config_dict).session():
with LDAPBackend(config_dict).session():
assert "oauthClient" not in LDAPObject.ldap_object_classes(force=True)
testclient = TestApp(create_app(configuration, validate=False))
@ -115,5 +115,5 @@ def test_install_schemas_command(configuration, slapd_server):
res = runner.invoke(cli, ["install"])
assert res.exit_code == 0, res.stdout
with Backend(config_dict).session():
with LDAPBackend(config_dict).session():
assert "oauthClient" in LDAPObject.ldap_object_classes(force=True)

View file

@ -1,10 +1,10 @@
import pytest
from canaille.backends.memory.backend import Backend
from canaille.backends.memory.backend import MemoryBackend
@pytest.fixture
def memory_backend(configuration):
backend = Backend(configuration)
backend = MemoryBackend(configuration)
with backend.session():
yield backend

View file

@ -1,7 +1,7 @@
import pytest
from canaille.app.configuration import settings_factory
from canaille.backends.sql.backend import Backend
from canaille.backends.sql.backend import SQLBackend
@pytest.fixture
@ -15,6 +15,6 @@ def sqlalchemy_configuration(configuration):
def sql_backend(sqlalchemy_configuration):
config_obj = settings_factory(sqlalchemy_configuration)
config_dict = config_obj.model_dump()
backend = Backend(config_dict)
backend = SQLBackend(config_dict)
with backend.session(init=True):
yield backend

View file

@ -1,16 +1,16 @@
import pytest
from canaille.backends import BaseBackend
from canaille.backends import Backend
def test_required_methods(testclient):
with pytest.raises(NotImplementedError):
BaseBackend.install(config=None)
Backend.install(config=None)
with pytest.raises(NotImplementedError):
BaseBackend.validate({})
Backend.validate({})
backend = BaseBackend(testclient.app.config["CANAILLE"])
backend = Backend(testclient.app.config["CANAILLE"])
with pytest.raises(NotImplementedError):
backend.has_account_lockability()