forked from Github-Mirrors/canaille
refactor: BackendModel.get() is now Backend.instance
This commit is contained in:
parent
fa6488bcd1
commit
ccde88b1bf
21 changed files with 71 additions and 66 deletions
|
@ -99,7 +99,7 @@ def setup_flask(app):
|
|||
"has_oidc": "CANAILLE_OIDC" in app.config,
|
||||
"has_password_recovery": app.config["CANAILLE"]["ENABLE_PASSWORD_RECOVERY"],
|
||||
"has_registration": app.config["CANAILLE"]["ENABLE_REGISTRATION"],
|
||||
"has_account_lockability": app.backend.get().has_account_lockability(),
|
||||
"has_account_lockability": app.backend.instance.has_account_lockability(),
|
||||
"logo_url": app.config["CANAILLE"]["LOGO"],
|
||||
"favicon_url": app.config["CANAILLE"]["FAVICON"]
|
||||
or app.config["CANAILLE"]["LOGO"],
|
||||
|
|
|
@ -12,7 +12,7 @@ def with_backendcontext(func):
|
|||
@functools.wraps(func)
|
||||
def _func(*args, **kwargs):
|
||||
if not current_app.config["TESTING"]: # pragma: no cover
|
||||
with BaseBackend.get().session():
|
||||
with BaseBackend.instance.session():
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
else:
|
||||
|
|
|
@ -162,7 +162,7 @@ def validate(config, validate_remote=False):
|
|||
|
||||
from canaille.backends import BaseBackend
|
||||
|
||||
BaseBackend.get().validate(config)
|
||||
BaseBackend.instance.validate(config)
|
||||
validate_smtp_configuration(config["CANAILLE"]["SMTP"])
|
||||
|
||||
|
||||
|
|
|
@ -187,9 +187,11 @@ class TableForm(I18NFormMixin, FlaskForm):
|
|||
filter = filter or {}
|
||||
super().__init__(**kwargs)
|
||||
if self.query.data:
|
||||
self.items = BaseBackend.get().fuzzy(cls, self.query.data, fields, **filter)
|
||||
self.items = BaseBackend.instance.fuzzy(
|
||||
cls, self.query.data, fields, **filter
|
||||
)
|
||||
else:
|
||||
self.items = BaseBackend.get().query(cls, **filter)
|
||||
self.items = BaseBackend.instance.query(cls, **filter)
|
||||
|
||||
self.page_size = page_size
|
||||
self.nb_items = len(self.items)
|
||||
|
|
|
@ -8,4 +8,4 @@ class InstallationException(Exception):
|
|||
|
||||
def install(config, debug=False):
|
||||
install_oidc(config, debug=debug)
|
||||
BaseBackend.get().install(config)
|
||||
BaseBackend.instance.install(config)
|
||||
|
|
|
@ -4,6 +4,8 @@ from contextlib import contextmanager
|
|||
|
||||
from flask import g
|
||||
|
||||
from canaille.app import classproperty
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
instance = None
|
||||
|
@ -13,8 +15,8 @@ class BaseBackend:
|
|||
BaseBackend.instance = self
|
||||
self.register_models()
|
||||
|
||||
@classmethod
|
||||
def get(cls):
|
||||
@classproperty
|
||||
def instance(cls):
|
||||
return cls.instance
|
||||
|
||||
def init_app(self, app):
|
||||
|
|
|
@ -162,7 +162,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
|||
|
||||
@classmethod
|
||||
def install(cls):
|
||||
conn = Backend.get().connection
|
||||
conn = Backend.instance.connection
|
||||
cls.ldap_object_classes(conn)
|
||||
cls.ldap_object_attributes(conn)
|
||||
|
||||
|
@ -187,7 +187,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
|||
if cls._object_class_by_name and not force:
|
||||
return cls._object_class_by_name
|
||||
|
||||
conn = Backend.get().connection
|
||||
conn = Backend.instance.connection
|
||||
|
||||
res = conn.search_s(
|
||||
"cn=subschema", ldap.SCOPE_BASE, "(objectclass=*)", ["*", "+"]
|
||||
|
@ -209,7 +209,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
|||
if cls._attribute_type_by_name and not force:
|
||||
return cls._attribute_type_by_name
|
||||
|
||||
conn = Backend.get().connection
|
||||
conn = Backend.instance.connection
|
||||
|
||||
res = conn.search_s(
|
||||
"cn=subschema", ldap.SCOPE_BASE, "(objectclass=*)", ["*", "+"]
|
||||
|
@ -229,7 +229,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
|||
@classmethod
|
||||
def get(cls, identifier=None, /, **kwargs):
|
||||
try:
|
||||
return BaseBackend.get().query(cls, identifier, **kwargs)[0]
|
||||
return BaseBackend.instance.query(cls, identifier, **kwargs)[0]
|
||||
except (IndexError, ldap.NO_SUCH_OBJECT):
|
||||
if identifier and cls.base:
|
||||
return (
|
||||
|
@ -274,13 +274,13 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
|||
return cls.attribute_map.get(name, name) if cls.attribute_map else None
|
||||
|
||||
def reload(self):
|
||||
conn = Backend.get().connection
|
||||
conn = Backend.instance.connection
|
||||
result = conn.search_s(self.dn, ldap.SCOPE_SUBTREE, None, ["+", "*"])
|
||||
self.changes = {}
|
||||
self.state = result[0][1]
|
||||
|
||||
def save(self):
|
||||
conn = Backend.get().connection
|
||||
conn = Backend.instance.connection
|
||||
|
||||
current_object_classes = self.get_ldap_attribute("objectClass") or []
|
||||
self.set_ldap_attribute(
|
||||
|
@ -338,7 +338,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
|||
self.changes = {}
|
||||
|
||||
def delete(self):
|
||||
conn = Backend.get().connection
|
||||
conn = Backend.instance.connection
|
||||
try:
|
||||
conn.delete_s(self.dn)
|
||||
except ldap.NO_SUCH_OBJECT:
|
||||
|
|
|
@ -38,7 +38,7 @@ class User(canaille.core.models.User, LDAPObject):
|
|||
|
||||
def match_filter(self, filter):
|
||||
if isinstance(filter, str):
|
||||
conn = Backend.get().connection
|
||||
conn = Backend.instance.connection
|
||||
return self.dn and conn.search_s(self.dn, ldap.SCOPE_SUBTREE, filter)
|
||||
|
||||
return super().match_filter(filter)
|
||||
|
|
|
@ -45,7 +45,7 @@ class MemoryModel(BackendModel):
|
|||
or None
|
||||
)
|
||||
|
||||
results = BaseBackend.get().query(cls, **kwargs)
|
||||
results = BaseBackend.instance.query(cls, **kwargs)
|
||||
return results[0] if results else None
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -74,8 +74,7 @@ class Backend(BaseBackend):
|
|||
for attribute_name, expected_value in kwargs.items()
|
||||
]
|
||||
return (
|
||||
Backend.get()
|
||||
.db_session.execute(select(model).filter(*filter))
|
||||
Backend.instance.db_session.execute(select(model).filter(*filter))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
|
|
@ -61,11 +61,9 @@ class SqlAlchemyModel(BackendModel):
|
|||
cls.attribute_filter(attribute_name, expected_value)
|
||||
for attribute_name, expected_value in kwargs.items()
|
||||
]
|
||||
return (
|
||||
Backend.get()
|
||||
.db_session.execute(select(cls).filter(*filter))
|
||||
.scalar_one_or_none()
|
||||
)
|
||||
return Backend.instance.db_session.execute(
|
||||
select(cls).filter(*filter)
|
||||
).scalar_one_or_none()
|
||||
|
||||
def save(self):
|
||||
self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
|
||||
|
@ -74,15 +72,15 @@ class SqlAlchemyModel(BackendModel):
|
|||
if not self.created:
|
||||
self.created = self.last_modified
|
||||
|
||||
Backend.get().db_session.add(self)
|
||||
Backend.get().db_session.commit()
|
||||
Backend.instance.db_session.add(self)
|
||||
Backend.instance.db_session.commit()
|
||||
|
||||
def delete(self):
|
||||
Backend.get().db_session.delete(self)
|
||||
Backend.get().db_session.commit()
|
||||
Backend.instance.db_session.delete(self)
|
||||
Backend.instance.db_session.commit()
|
||||
|
||||
def reload(self):
|
||||
Backend.get().db_session.refresh(self)
|
||||
Backend.instance.db_session.refresh(self)
|
||||
|
||||
|
||||
membership_association_table = Table(
|
||||
|
|
|
@ -86,7 +86,7 @@ def join():
|
|||
|
||||
form = JoinForm(request.form or None)
|
||||
if request.form and form.validate():
|
||||
if BaseBackend.get().query(models.User, emails=form.email.data):
|
||||
if BaseBackend.instance.query(models.User, emails=form.email.data):
|
||||
flash(
|
||||
_(
|
||||
"You will receive soon an email to continue the registration process."
|
||||
|
@ -297,7 +297,7 @@ def registration(data=None, hash=None):
|
|||
_("Groups"),
|
||||
choices=[
|
||||
(group, group.display_name)
|
||||
for group in BaseBackend.get().query(models.Group)
|
||||
for group in BaseBackend.instance.query(models.Group)
|
||||
],
|
||||
coerce=IDToModel("Group"),
|
||||
)
|
||||
|
@ -391,7 +391,7 @@ def email_confirmation(data, hash):
|
|||
)
|
||||
return redirect(url_for("core.account.index"))
|
||||
|
||||
if BaseBackend.get().query(models.User, emails=confirmation_obj.email):
|
||||
if BaseBackend.instance.query(models.User, emails=confirmation_obj.email):
|
||||
flash(
|
||||
_("This address email is already associated with another account."),
|
||||
"error",
|
||||
|
@ -458,7 +458,7 @@ def profile_create(current_app, form):
|
|||
user.save()
|
||||
|
||||
if form["password1"].data:
|
||||
BaseBackend.get().set_user_password(user, form["password1"].data)
|
||||
BaseBackend.instance.set_user_password(user, form["password1"].data)
|
||||
user.save()
|
||||
|
||||
return user
|
||||
|
@ -713,14 +713,14 @@ def profile_settings(user, edited_user):
|
|||
|
||||
if (
|
||||
request.form.get("action") == "confirm-lock"
|
||||
and BaseBackend.get().has_account_lockability()
|
||||
and BaseBackend.instance.has_account_lockability()
|
||||
and not edited_user.locked
|
||||
):
|
||||
return render_template("modals/lock-account.html", edited_user=edited_user)
|
||||
|
||||
if (
|
||||
request.form.get("action") == "lock"
|
||||
and BaseBackend.get().has_account_lockability()
|
||||
and BaseBackend.instance.has_account_lockability()
|
||||
and not edited_user.locked
|
||||
):
|
||||
flash(_("The account has been locked"), "success")
|
||||
|
@ -731,7 +731,7 @@ def profile_settings(user, edited_user):
|
|||
|
||||
if (
|
||||
request.form.get("action") == "unlock"
|
||||
and BaseBackend.get().has_account_lockability()
|
||||
and BaseBackend.instance.has_account_lockability()
|
||||
and edited_user.locked
|
||||
):
|
||||
flash(_("The account has been unlocked"), "success")
|
||||
|
@ -782,7 +782,9 @@ def profile_settings_edit(editor, edited_user):
|
|||
and form["password1"].data
|
||||
and request.form["action"] == "edit-settings"
|
||||
):
|
||||
BaseBackend.get().set_user_password(edited_user, form["password1"].data)
|
||||
BaseBackend.instance.set_user_password(
|
||||
edited_user, form["password1"].data
|
||||
)
|
||||
|
||||
edited_user.save()
|
||||
flash(_("Profile updated successfully."), "success")
|
||||
|
|
|
@ -43,12 +43,12 @@ def login():
|
|||
|
||||
form = LoginForm(request.form or None)
|
||||
form.render_field_macro_file = "partial/login_field.html"
|
||||
form["login"].render_kw["placeholder"] = BaseBackend.get().login_placeholder()
|
||||
form["login"].render_kw["placeholder"] = BaseBackend.instance.login_placeholder()
|
||||
|
||||
if not request.form or form.form_control():
|
||||
return render_template("login.html", form=form)
|
||||
|
||||
user = BaseBackend.get().get_user_from_login(form.login.data)
|
||||
user = BaseBackend.instance.get_user_from_login(form.login.data)
|
||||
if user and not user.has_password():
|
||||
return redirect(url_for("core.auth.firstlogin", user=user))
|
||||
|
||||
|
@ -80,7 +80,7 @@ def password():
|
|||
"password.html", form=form, username=session["attempt_login"]
|
||||
)
|
||||
|
||||
user = BaseBackend.get().get_user_from_login(session["attempt_login"])
|
||||
user = BaseBackend.instance.get_user_from_login(session["attempt_login"])
|
||||
if user and not user.has_password():
|
||||
return redirect(url_for("core.auth.firstlogin", user=user))
|
||||
|
||||
|
@ -91,7 +91,9 @@ def password():
|
|||
"password.html", form=form, username=session["attempt_login"]
|
||||
)
|
||||
|
||||
success, message = BaseBackend.get().check_user_password(user, form.password.data)
|
||||
success, message = BaseBackend.instance.check_user_password(
|
||||
user, form.password.data
|
||||
)
|
||||
request_ip = request.remote_addr or "unknown IP"
|
||||
if not success:
|
||||
logout_user()
|
||||
|
@ -175,7 +177,7 @@ def forgotten():
|
|||
flash(_("Could not send the password reset link."), "error")
|
||||
return render_template("forgotten-password.html", form=form)
|
||||
|
||||
user = BaseBackend.get().get_user_from_login(form.login.data)
|
||||
user = BaseBackend.instance.get_user_from_login(form.login.data)
|
||||
success_message = _(
|
||||
"A password reset link has been sent at your email address. "
|
||||
"You should receive it within a few minutes."
|
||||
|
@ -233,7 +235,7 @@ def reset(user, hash):
|
|||
return redirect(url_for("core.account.index"))
|
||||
|
||||
if request.form and form.validate():
|
||||
BaseBackend.get().set_user_password(user, form.password.data)
|
||||
BaseBackend.instance.set_user_password(user, form.password.data)
|
||||
login_user(user)
|
||||
|
||||
flash(_("Your password has been updated successfully"), "success")
|
||||
|
|
|
@ -51,7 +51,7 @@ def unique_group(form, field):
|
|||
def existing_login(form, field):
|
||||
if not current_app.config["CANAILLE"][
|
||||
"HIDE_INVALID_LOGINS"
|
||||
] and not BaseBackend.get().get_user_from_login(field.data):
|
||||
] and not BaseBackend.instance.get_user_from_login(field.data):
|
||||
raise wtforms.ValidationError(
|
||||
_("The login '{login}' does not exist").format(login=field.data)
|
||||
)
|
||||
|
@ -314,7 +314,7 @@ PROFILE_FORM_FIELDS = dict(
|
|||
default=[],
|
||||
choices=lambda: [
|
||||
(group, group.display_name)
|
||||
for group in BaseBackend.get().query(models.Group)
|
||||
for group in BaseBackend.instance.query(models.Group)
|
||||
],
|
||||
render_kw={"placeholder": _("users, admins …")},
|
||||
coerce=IDToModel("Group"),
|
||||
|
@ -336,10 +336,10 @@ def build_profile_form(write_field_names, readonly_field_names, user=None):
|
|||
if PROFILE_FORM_FIELDS.get(name)
|
||||
}
|
||||
|
||||
if "groups" in fields and not BaseBackend.get().query(models.Group):
|
||||
if "groups" in fields and not BaseBackend.instance.query(models.Group):
|
||||
del fields["groups"]
|
||||
|
||||
if current_app.backend.get().has_account_lockability(): # pragma: no branch
|
||||
if current_app.backend.instance.has_account_lockability(): # pragma: no branch
|
||||
fields["lock_date"] = DateTimeUTCField(
|
||||
_("Account expiration"),
|
||||
validators=[wtforms.validators.Optional()],
|
||||
|
@ -441,7 +441,7 @@ class InvitationForm(Form):
|
|||
_("Groups"),
|
||||
choices=lambda: [
|
||||
(group, group.display_name)
|
||||
for group in BaseBackend.get().query(models.Group)
|
||||
for group in BaseBackend.instance.query(models.Group)
|
||||
],
|
||||
render_kw={},
|
||||
coerce=IDToModel("Group"),
|
||||
|
|
|
@ -48,7 +48,7 @@ def fake_users(nb=1):
|
|||
|
||||
|
||||
def fake_groups(nb=1, nb_users_max=1):
|
||||
users = BaseBackend.get().query(models.User)
|
||||
users = BaseBackend.instance.query(models.User)
|
||||
groups = list()
|
||||
fake = faker.Faker(["en_US"])
|
||||
for _ in range(nb):
|
||||
|
|
|
@ -11,11 +11,11 @@ from canaille.backends import BaseBackend
|
|||
@with_backendcontext
|
||||
def clean():
|
||||
"""Remove expired tokens and authorization codes."""
|
||||
for t in BaseBackend.get().query(models.Token):
|
||||
for t in BaseBackend.instance.query(models.Token):
|
||||
if t.is_expired():
|
||||
t.delete()
|
||||
|
||||
for a in BaseBackend.get().query(models.AuthorizationCode):
|
||||
for a in BaseBackend.instance.query(models.AuthorizationCode):
|
||||
if a.is_expired():
|
||||
a.delete()
|
||||
|
||||
|
|
|
@ -20,13 +20,13 @@ bp = Blueprint("consents", __name__, url_prefix="/consent")
|
|||
@bp.route("/")
|
||||
@user_needed()
|
||||
def consents(user):
|
||||
consents = BaseBackend.get().query(models.Consent, subject=user)
|
||||
consents = BaseBackend.instance.query(models.Consent, subject=user)
|
||||
clients = {t.client for t in consents}
|
||||
|
||||
nb_consents = len(consents)
|
||||
nb_preconsents = sum(
|
||||
1
|
||||
for client in BaseBackend.get().query(models.Client)
|
||||
for client in BaseBackend.instance.query(models.Client)
|
||||
if client.preconsent and client not in clients
|
||||
)
|
||||
|
||||
|
@ -44,11 +44,11 @@ def consents(user):
|
|||
@bp.route("/pre-consents")
|
||||
@user_needed()
|
||||
def pre_consents(user):
|
||||
consents = BaseBackend.get().query(models.Consent, subject=user)
|
||||
consents = BaseBackend.instance.query(models.Consent, subject=user)
|
||||
clients = {t.client for t in consents}
|
||||
preconsented = [
|
||||
client
|
||||
for client in BaseBackend.get().query(models.Client)
|
||||
for client in BaseBackend.instance.query(models.Client)
|
||||
if client.preconsent and client not in clients
|
||||
]
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ class LogoutForm(Form):
|
|||
def client_audiences():
|
||||
return [
|
||||
(client, client.client_name)
|
||||
for client in BaseBackend.get().query(models.Client)
|
||||
for client in BaseBackend.instance.query(models.Client)
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -110,7 +110,7 @@ def authorize_login(user):
|
|||
def authorize_consent(client, user):
|
||||
requested_scopes = request.args.get("scope", "").split(" ")
|
||||
allowed_scopes = client.get_allowed_scope(requested_scopes).split(" ")
|
||||
consents = BaseBackend.get().query(
|
||||
consents = BaseBackend.instance.query(
|
||||
models.Consent,
|
||||
client=client,
|
||||
subject=user,
|
||||
|
|
|
@ -96,13 +96,13 @@ class Client(BaseClient, ClientMixin):
|
|||
return metadata
|
||||
|
||||
def delete(self):
|
||||
for consent in BaseBackend.get().query(models.Consent, client=self):
|
||||
for consent in BaseBackend.instance.query(models.Consent, client=self):
|
||||
consent.delete()
|
||||
|
||||
for code in BaseBackend.get().query(models.AuthorizationCode, client=self):
|
||||
for code in BaseBackend.instance.query(models.AuthorizationCode, client=self):
|
||||
code.delete()
|
||||
|
||||
for token in BaseBackend.get().query(models.Token, client=self):
|
||||
for token in BaseBackend.instance.query(models.Token, client=self):
|
||||
token.delete()
|
||||
|
||||
super().delete()
|
||||
|
@ -186,7 +186,7 @@ class Consent(BaseConsent):
|
|||
self.revokation_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
self.save()
|
||||
|
||||
tokens = BaseBackend.get().query(
|
||||
tokens = BaseBackend.instance.query(
|
||||
models.Token,
|
||||
client=self.client,
|
||||
subject=self.subject,
|
||||
|
|
|
@ -112,7 +112,7 @@ def openid_configuration():
|
|||
|
||||
def exists_nonce(nonce, req):
|
||||
client = models.Client.get(id=req.client_id)
|
||||
exists = BaseBackend.get().query(
|
||||
exists = BaseBackend.instance.query(
|
||||
models.AuthorizationCode, client=client, nonce=nonce
|
||||
)
|
||||
return bool(exists)
|
||||
|
@ -239,7 +239,7 @@ class AuthorizationCodeGrant(_AuthorizationCodeGrant):
|
|||
return save_authorization_code(code, request)
|
||||
|
||||
def query_authorization_code(self, code, client):
|
||||
item = BaseBackend.get().query(
|
||||
item = BaseBackend.instance.query(
|
||||
models.AuthorizationCode, code=code, client=client
|
||||
)
|
||||
if item and not item[0].is_expired():
|
||||
|
@ -272,11 +272,11 @@ class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant):
|
|||
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
|
||||
|
||||
def authenticate_user(self, username, password):
|
||||
user = BaseBackend.get().get_user_from_login(username)
|
||||
user = BaseBackend.instance.get_user_from_login(username)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
success, _ = BaseBackend.get().check_user_password(user, password)
|
||||
success, _ = BaseBackend.instance.check_user_password(user, password)
|
||||
if not success:
|
||||
return None
|
||||
|
||||
|
@ -287,7 +287,7 @@ class RefreshTokenGrant(_RefreshTokenGrant):
|
|||
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
|
||||
|
||||
def authenticate_refresh_token(self, refresh_token):
|
||||
token = BaseBackend.get().query(models.Token, refresh_token=refresh_token)
|
||||
token = BaseBackend.instance.query(models.Token, refresh_token=refresh_token)
|
||||
if token and token[0].is_refresh_token_active():
|
||||
return token[0]
|
||||
|
||||
|
|
Loading…
Reference in a new issue