diff --git a/canaille/app/flask.py b/canaille/app/flask.py index de96c000..134d656d 100644 --- a/canaille/app/flask.py +++ b/canaille/app/flask.py @@ -20,7 +20,7 @@ def current_user(): return g.user for user_id in session.get("user_id", [])[::-1]: - user = models.User.get(user_id) + user = current_app.backend.instance.get(models.User, user_id) if user and ( not current_app.backend.has_account_lockability() or not user.locked ): @@ -147,7 +147,7 @@ def model_converter(model): def to_python(self, identifier): current_app.backend.setup() - instance = model.get(identifier) + instance = current_app.backend.get(model, identifier) if self.required and not instance: abort(404) diff --git a/canaille/app/forms.py b/canaille/app/forms.py index 8d65d30b..65db468e 100644 --- a/canaille/app/forms.py +++ b/canaille/app/forms.py @@ -260,7 +260,9 @@ class IDToModel: def __call__(self, data): model = getattr(models, self.model_name) - instance = data if isinstance(data, model) else model.get(data) + instance = ( + data if isinstance(data, model) else BaseBackend.instance.get(model, data) + ) if instance: return instance diff --git a/canaille/backends/__init__.py b/canaille/backends/__init__.py index b2dd5e76..a7bfcad1 100644 --- a/canaille/backends/__init__.py +++ b/canaille/backends/__init__.py @@ -80,6 +80,11 @@ class BaseBackend: attribute values loosely be matched.""" raise NotImplementedError() + def get(self, model, identifier=None, **kwargs): + """Works like :meth:`~canaille.backends.BaseBackend.query` but return + only one element or :py:data:`None` if no item is matching.""" + raise NotImplementedError() + def check_user_password(self, user, password: str) -> bool: """Check if the password matches the user password in the database.""" raise NotImplementedError() diff --git a/canaille/backends/ldap/backend.py b/canaille/backends/ldap/backend.py index 9630d902..e48f51e1 100644 --- a/canaille/backends/ldap/backend.py +++ b/canaille/backends/ldap/backend.py @@ -202,7 +202,7 @@ class Backend(BaseBackend): if login else None ) - return User.get(filter=filter) + return self.get(User, filter=filter) def check_user_password(self, user, password): conn = ldap.initialize(current_app.config["CANAILLE_LDAP"]["URI"]) @@ -311,6 +311,19 @@ class Backend(BaseBackend): ) return self.query(model, filter=filter, **kwargs) + def get(self, model, identifier=None, /, **kwargs): + try: + return self.query(model, identifier, **kwargs)[0] + except (IndexError, ldap.NO_SUCH_OBJECT): + if identifier and model.base: + return ( + self.get(model, **{model.identifier_attribute: identifier}) + or self.get(model, id=identifier) + or None + ) + + return None + def setup_ldap_models(config): from canaille.app import models diff --git a/canaille/backends/ldap/ldapobject.py b/canaille/backends/ldap/ldapobject.py index 7e2688bf..2eacee7b 100644 --- a/canaille/backends/ldap/ldapobject.py +++ b/canaille/backends/ldap/ldapobject.py @@ -4,7 +4,6 @@ import ldap.dn import ldap.filter from ldap.controls.readentry import PostReadControl -from canaille.backends import BaseBackend from canaille.backends.models import BackendModel from .backend import Backend @@ -226,20 +225,6 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass): return cls._attribute_type_by_name - @classmethod - def get(cls, identifier=None, /, **kwargs): - try: - return BaseBackend.instance.query(cls, identifier, **kwargs)[0] - except (IndexError, ldap.NO_SUCH_OBJECT): - if identifier and cls.base: - return ( - cls.get(**{cls.identifier_attribute: identifier}) - or cls.get(id=identifier) - or None - ) - - return None - @classmethod def update_ldap_attributes(cls): all_object_classes = cls.ldap_object_classes() diff --git a/canaille/backends/ldap/utils.py b/canaille/backends/ldap/utils.py index 63186976..9cc13a61 100644 --- a/canaille/backends/ldap/utils.py +++ b/canaille/backends/ldap/utils.py @@ -1,6 +1,8 @@ import datetime from enum import Enum +from canaille.backends import BaseBackend + LDAP_NULL_DATE = "000001010000Z" @@ -50,7 +52,7 @@ def ldap_to_python(value, syntax): return value.decode("utf-8").upper() == "TRUE" if syntax == Syntax.DISTINGUISHED_NAME: - return LDAPObject.get(value.decode("utf-8")) + return BaseBackend.instance.get(LDAPObject, value.decode("utf-8")) return value.decode("utf-8") diff --git a/canaille/backends/memory/backend.py b/canaille/backends/memory/backend.py index 4b24da2e..fbeb3020 100644 --- a/canaille/backends/memory/backend.py +++ b/canaille/backends/memory/backend.py @@ -26,7 +26,7 @@ class Backend(BaseBackend): def get_user_from_login(self, login): from .models import User - return User.get(user_name=login) + return self.get(User, user_name=login) def check_user_password(self, user, password): if password != user.password: @@ -80,3 +80,14 @@ class Backend(BaseBackend): if isinstance(value, str) ) ] + + def get(self, model, identifier=None, /, **kwargs): + if identifier: + return ( + self.get(model, **{model.identifier_attribute: identifier}) + or self.get(model, id=identifier) + or None + ) + + results = self.query(model, **kwargs) + return results[0] if results else None diff --git a/canaille/backends/memory/models.py b/canaille/backends/memory/models.py index 59f8630e..41ebb8ce 100644 --- a/canaille/backends/memory/models.py +++ b/canaille/backends/memory/models.py @@ -36,18 +36,6 @@ class MemoryModel(BackendModel): class_name or cls.__name__, {} ).setdefault(attribute, {}) - @classmethod - def get(cls, identifier=None, /, **kwargs): - if identifier: - return ( - cls.get(**{cls.identifier_attribute: identifier}) - or cls.get(id=identifier) - or None - ) - - results = BaseBackend.instance.query(cls, **kwargs) - return results[0] if results else None - @classmethod def listify(cls, value): return value if isinstance(value, list) else [value] @@ -75,7 +63,7 @@ class MemoryModel(BackendModel): model, _ = cls.get_model_annotations(attribute_name) if model and not isinstance(value, model): backend_model = getattr(models, model.__name__) - return backend_model.get(id=value) + return BaseBackend.instance.get(backend_model, id=value) return value @@ -166,7 +154,7 @@ class MemoryModel(BackendModel): del self.index()[self.id] def reload(self): - self._state = self.__class__.get(id=self.id)._state + self._state = BaseBackend.instance.get(self.__class__, id=self.id)._state self._cache = {} def __eq__(self, other): @@ -174,7 +162,7 @@ class MemoryModel(BackendModel): return False if not isinstance(other, MemoryModel): - return self == self.__class__.get(id=other) + return self == BaseBackend.instance.get(self.__class__, id=other) return self._state == other._state diff --git a/canaille/backends/models.py b/canaille/backends/models.py index 3a954b0d..3e8998b6 100644 --- a/canaille/backends/models.py +++ b/canaille/backends/models.py @@ -11,6 +11,7 @@ from typing import get_type_hints from canaille.app import classproperty from canaille.app import models +from canaille.backends import BaseBackend class Model: @@ -87,12 +88,6 @@ class BackendModel: implemented for every model and for every backend. """ - @classmethod - def get(cls, identifier=None, **kwargs): - """Works like :meth:`~canaille.backends.BaseBackend.query` but return - only one element or :py:data:`None` if no item is matching.""" - raise NotImplementedError() - def save(self): """Validate the current modifications in the database.""" raise NotImplementedError() @@ -175,7 +170,7 @@ class BackendModel: backend_model = getattr(models, model.__name__) - if instance := backend_model.get(value): + if instance := BaseBackend.instance.get(backend_model, value): filter[attribute] = instance return all( diff --git a/canaille/backends/sql/backend.py b/canaille/backends/sql/backend.py index 76f867a5..39866386 100644 --- a/canaille/backends/sql/backend.py +++ b/canaille/backends/sql/backend.py @@ -53,7 +53,7 @@ class Backend(BaseBackend): def get_user_from_login(self, login): from .models import User - return User.get(user_name=login) + return self.get(User, user_name=login) def check_user_password(self, user, password): if password != user.password: @@ -90,3 +90,19 @@ class Backend(BaseBackend): ) return self.db_session.execute(select(model).filter(filter)).scalars().all() + + def get(self, model, identifier=None, /, **kwargs): + if identifier: + return ( + self.get(model, **{model.identifier_attribute: identifier}) + or self.get(model, id=identifier) + or None + ) + + filter = [ + model.attribute_filter(attribute_name, expected_value) + for attribute_name, expected_value in kwargs.items() + ] + return Backend.instance.db_session.execute( + select(model).filter(*filter) + ).scalar_one_or_none() diff --git a/canaille/backends/sql/models.py b/canaille/backends/sql/models.py index 9bd78445..bd14f85a 100644 --- a/canaille/backends/sql/models.py +++ b/canaille/backends/sql/models.py @@ -11,7 +11,6 @@ from sqlalchemy import LargeBinary from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import or_ -from sqlalchemy import select from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship @@ -48,23 +47,6 @@ class SqlAlchemyModel(BackendModel): return getattr(cls, name) == value - @classmethod - def get(cls, identifier=None, /, **kwargs): - if identifier: - return ( - cls.get(**{cls.identifier_attribute: identifier}) - or cls.get(id=identifier) - or None - ) - - filter = [ - cls.attribute_filter(attribute_name, expected_value) - for attribute_name, expected_value in kwargs.items() - ] - return Backend.instance.db_session.execute( - select(cls).filter(*filter) - ).scalar_one_or_none() - def save(self): self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace( microsecond=0 diff --git a/canaille/core/endpoints/account.py b/canaille/core/endpoints/account.py index 708a5dcd..a34af0da 100644 --- a/canaille/core/endpoints/account.py +++ b/canaille/core/endpoints/account.py @@ -257,7 +257,9 @@ def registration(data=None, hash=None): ) return redirect(url_for("core.account.index")) - if payload.user_name and models.User.get(user_name=payload.user_name): + if payload.user_name and BaseBackend.instance.get( + models.User, user_name=payload.user_name + ): flash( _("Your account has already been created."), "error", @@ -282,7 +284,10 @@ def registration(data=None, hash=None): data = { "user_name": payload.user_name, "emails": [payload.email], - "groups": [models.Group.get(id=group_id) for group_id in payload.groups], + "groups": [ + BaseBackend.instance.get(models.Group, id=group_id) + for group_id in payload.groups + ], } has_smtp = "SMTP" in current_app.config["CANAILLE"] @@ -376,7 +381,7 @@ def email_confirmation(data, hash): ) return redirect(url_for("core.account.index")) - user = models.User.get(confirmation_obj.identifier) + user = BaseBackend.instance.get(models.User, confirmation_obj.identifier) if not user: flash( _("The email confirmation link that brought you here is invalid."), diff --git a/canaille/core/endpoints/forms.py b/canaille/core/endpoints/forms.py index 5deb8711..26b6ae57 100644 --- a/canaille/core/endpoints/forms.py +++ b/canaille/core/endpoints/forms.py @@ -24,7 +24,7 @@ MINIMUM_PASSWORD_LENGTH = 8 def unique_user_name(form, field): - if models.User.get(user_name=field.data) and ( + if BaseBackend.instance.get(models.User, user_name=field.data) and ( not getattr(form, "user", None) or form.user.user_name != field.data ): raise wtforms.ValidationError( @@ -33,7 +33,7 @@ def unique_user_name(form, field): def unique_email(form, field): - if models.User.get(emails=field.data) and ( + if BaseBackend.instance.get(models.User, emails=field.data) and ( not getattr(form, "user", None) or field.data not in form.user.emails ): raise wtforms.ValidationError( @@ -42,7 +42,7 @@ def unique_email(form, field): def unique_group(form, field): - if models.Group.get(display_name=field.data): + if BaseBackend.instance.get(models.Group, display_name=field.data): raise wtforms.ValidationError( _("The group '{group}' already exists").format(group=field.data) ) diff --git a/canaille/core/models.py b/canaille/core/models.py index caad8ae9..1e2c0652 100644 --- a/canaille/core/models.py +++ b/canaille/core/models.py @@ -258,12 +258,12 @@ class User(Model): def preferred_email(self): return self.emails[0] if self.emails else None - def __getattr__(self, name): + def __getattribute__(self, name): prefix = "can_" if name.startswith(prefix) and name != "can_read": return self.can(name[len(prefix) :]) - return super().__getattr__(name) + return super().__getattribute__(name) def can(self, *permissions: Permission): """Wether or not the user has the diff --git a/canaille/oidc/endpoints/consents.py b/canaille/oidc/endpoints/consents.py index b5fdd35a..1dcff0be 100644 --- a/canaille/oidc/endpoints/consents.py +++ b/canaille/oidc/endpoints/consents.py @@ -108,7 +108,7 @@ def revoke_preconsent(user, client): flash(_("Could not revoke this access"), "error") return redirect(url_for("oidc.consents.consents")) - consent = models.Consent.get(client=client, subject=user) + consent = BaseBackend.instance.get(models.Consent, client=client, subject=user) if consent: return redirect(url_for("oidc.consents.revoke", consent=consent)) diff --git a/canaille/oidc/endpoints/oauth.py b/canaille/oidc/endpoints/oauth.py index af1bb400..9810993c 100644 --- a/canaille/oidc/endpoints/oauth.py +++ b/canaille/oidc/endpoints/oauth.py @@ -50,7 +50,9 @@ def authorize(): request.form.to_dict(flat=False), ) - client = models.Client.get(client_id=request.args["client_id"]) + client = BaseBackend.instance.get( + models.Client, client_id=request.args["client_id"] + ) user = current_user() if response := authorize_guards(client): @@ -276,7 +278,7 @@ def end_session(): valid_uris = [] if "client_id" in data: - client = models.Client.get(client_id=data["client_id"]) + client = BaseBackend.instance.get(models.Client, client_id=data["client_id"]) if client: valid_uris = client.post_logout_redirect_uris @@ -328,7 +330,7 @@ def end_session(): else [id_token["aud"]] ) for client_id in client_ids: - client = models.Client.get(client_id=client_id) + client = BaseBackend.instance.get(models.Client, client_id=client_id) if client: valid_uris.extend(client.post_logout_redirect_uris or []) diff --git a/canaille/oidc/oauth.py b/canaille/oidc/oauth.py index 1cc7a670..a07efb6e 100644 --- a/canaille/oidc/oauth.py +++ b/canaille/oidc/oauth.py @@ -111,7 +111,7 @@ def openid_configuration(): def exists_nonce(nonce, req): - client = models.Client.get(id=req.client_id) + client = BaseBackend.instance.get(models.Client, id=req.client_id) exists = BaseBackend.instance.query( models.AuthorizationCode, client=client, nonce=nonce ) @@ -334,7 +334,7 @@ class OpenIDHybridGrant(_OpenIDHybridGrant): def query_client(client_id): - return models.Client.get(client_id=client_id) + return BaseBackend.instance.get(models.Client, client_id=client_id) def save_token(token, request): @@ -356,20 +356,20 @@ def save_token(token, request): class BearerTokenValidator(_BearerTokenValidator): def authenticate_token(self, token_string): - return models.Token.get(access_token=token_string) + return BaseBackend.instance.get(models.Token, access_token=token_string) def query_token(token, token_type_hint): if token_type_hint == "access_token": - return models.Token.get(access_token=token) + return BaseBackend.instance.get(models.Token, access_token=token) elif token_type_hint == "refresh_token": - return models.Token.get(refresh_token=token) + return BaseBackend.instance.get(models.Token, refresh_token=token) - item = models.Token.get(access_token=token) + item = BaseBackend.instance.get(models.Token, access_token=token) if item: return item - item = models.Token.get(refresh_token=token) + item = BaseBackend.instance.get(models.Token, refresh_token=token) if item: return item @@ -472,7 +472,7 @@ class ClientRegistrationEndpoint(ClientManagementMixin, _ClientRegistrationEndpo class ClientConfigurationEndpoint(ClientManagementMixin, _ClientConfigurationEndpoint): def authenticate_client(self, request): client_id = request.uri.split("/")[-1] - return models.Client.get(client_id=client_id) + return BaseBackend.instance.get(models.Client, client_id=client_id) def revoke_access_token(self, request, token): pass diff --git a/tests/backends/ldap/test_object_class.py b/tests/backends/ldap/test_object_class.py index 84844c02..b7863459 100644 --- a/tests/backends/ldap/test_object_class.py +++ b/tests/backends/ldap/test_object_class.py @@ -7,7 +7,7 @@ def test_guess_object_from_dn(backend, testclient, foo_group): foo_group.members = [foo_group] foo_group.save() dn = foo_group.dn - g = LDAPObject.get(dn) + g = backend.get(LDAPObject, dn) assert isinstance(g, models.Group) assert g == foo_group assert g.display_name == foo_group.display_name @@ -21,9 +21,9 @@ def test_object_class_update(backend, testclient): user1.save() assert set(user1.get_ldap_attribute("objectClass")) == {"inetOrgPerson"} - assert set(models.User.get(id=user1.id).get_ldap_attribute("objectClass")) == { - "inetOrgPerson" - } + assert set( + backend.get(models.User, id=user1.id).get_ldap_attribute("objectClass") + ) == {"inetOrgPerson"} testclient.app.config["CANAILLE_LDAP"]["USER_CLASS"] = [ "inetOrgPerson", @@ -38,12 +38,14 @@ def test_object_class_update(backend, testclient): "inetOrgPerson", "extensibleObject", } - assert set(models.User.get(id=user2.id).get_ldap_attribute("objectClass")) == { + assert set( + backend.get(models.User, id=user2.id).get_ldap_attribute("objectClass") + ) == { "inetOrgPerson", "extensibleObject", } - user1 = models.User.get(id=user1.id) + user1 = backend.get(models.User, id=user1.id) assert user1.get_ldap_attribute("objectClass") == ["inetOrgPerson"] user1.save() @@ -51,7 +53,9 @@ def test_object_class_update(backend, testclient): "inetOrgPerson", "extensibleObject", } - assert set(models.User.get(id=user1.id).get_ldap_attribute("objectClass")) == { + assert set( + backend.get(models.User, id=user1.id).get_ldap_attribute("objectClass") + ) == { "inetOrgPerson", "extensibleObject", } diff --git a/tests/backends/ldap/test_utils.py b/tests/backends/ldap/test_utils.py index faec2a55..36b87084 100644 --- a/tests/backends/ldap/test_utils.py +++ b/tests/backends/ldap/test_utils.py @@ -27,7 +27,7 @@ def test_object_creation(app, backend): user.save() assert user.exists - user = models.User.get(id=user.id) + user = backend.get(models.User, id=user.id) assert user.exists user.delete() @@ -52,9 +52,9 @@ def test_dn_when_leading_space_in_id_attribute(testclient, backend): assert ldap.dn.is_dn(dn) assert ldap.dn.dn2str(ldap.dn.str2dn(dn)) == dn - assert user == models.User.get(user.user_name) - assert user == models.User.get(user_name=user.user_name) - assert user == models.User.get(dn) + assert user == backend.get(models.User, user.user_name) + assert user == backend.get(models.User, user_name=user.user_name) + assert user == backend.get(models.User, dn) user.delete() @@ -73,9 +73,9 @@ def test_special_chars_in_rdn(testclient, backend): assert ldap.dn.dn2str(ldap.dn.str2dn(dn)) == dn assert dn == "uid=\\#user,ou=users,dc=mydomain,dc=tld" - assert user == models.User.get(user.user_name) - assert user == models.User.get(user_name=user.user_name) - assert user == models.User.get(dn) + assert user == backend.get(models.User, user.user_name) + assert user == backend.get(models.User, user_name=user.user_name) + assert user == backend.get(models.User, dn) user.delete() diff --git a/tests/backends/test_models.py b/tests/backends/test_models.py index 0d5c6723..57086a9f 100644 --- a/tests/backends/test_models.py +++ b/tests/backends/test_models.py @@ -19,7 +19,7 @@ def test_model_comparison(testclient, backend): formatted_name="bar", ) bar.save() - foo2 = models.User.get(id=foo1.id) + foo2 = backend.get(models.User, id=foo1.id) assert foo1 == foo2 assert foo1 != bar @@ -39,14 +39,14 @@ def test_model_lifecycle(testclient, backend): assert not backend.query(models.User) assert not backend.query(models.User, id=user.id) assert not backend.query(models.User, id="invalid") - assert not models.User.get(id=user.id) + assert not backend.get(models.User, id=user.id) user.save() assert backend.query(models.User) == [user] assert backend.query(models.User, id=user.id) == [user] assert not backend.query(models.User, id="invalid") - assert models.User.get(id=user.id) == user + assert backend.get(models.User, id=user.id) == user user.family_name = "new_family_name" @@ -59,7 +59,7 @@ def test_model_lifecycle(testclient, backend): user.delete() assert not backend.query(models.User, id=user.id) - assert not models.User.get(id=user.id) + assert not backend.get(models.User, id=user.id) user.delete() @@ -78,7 +78,7 @@ def test_model_attribute_edition(testclient, backend): assert user.family_name == "family_name" assert user.emails == ["email1@user.com", "email2@user.com"] - user = models.User.get(id=user.id) + user = backend.get(models.User, id=user.id) assert user.user_name == "user_name" assert user.family_name == "family_name" assert user.emails == ["email1@user.com", "email2@user.com"] @@ -90,7 +90,7 @@ def test_model_attribute_edition(testclient, backend): assert user.family_name == "new_family_name" assert user.emails == ["email1@user.com"] - user = models.User.get(id=user.id) + user = backend.get(models.User, id=user.id) assert user.family_name == "new_family_name" assert user.emails == ["email1@user.com"] @@ -112,34 +112,34 @@ def test_model_indexation(testclient, backend): ) user.save() - assert models.User.get(family_name="family_name") == user - assert not models.User.get(family_name="new_family_name") - assert models.User.get(emails=["email1@user.com"]) == user - assert models.User.get(emails=["email2@user.com"]) == user - assert not models.User.get(emails=["email3@user.com"]) + assert backend.get(models.User, family_name="family_name") == user + assert not backend.get(models.User, family_name="new_family_name") + assert backend.get(models.User, emails=["email1@user.com"]) == user + assert backend.get(models.User, emails=["email2@user.com"]) == user + assert not backend.get(models.User, emails=["email3@user.com"]) user.family_name = "new_family_name" user.emails = ["email2@user.com"] - assert models.User.get(family_name="family_name") != user - assert models.User.get(emails=["email1@user.com"]) != user - assert not models.User.get(emails=["email3@user.com"]) + assert backend.get(models.User, family_name="family_name") != user + assert backend.get(models.User, emails=["email1@user.com"]) != user + assert not backend.get(models.User, emails=["email3@user.com"]) user.save() - assert not models.User.get(family_name="family_name") - assert models.User.get(family_name="new_family_name") == user - assert not models.User.get(emails=["email1@user.com"]) - assert models.User.get(emails=["email2@user.com"]) == user - assert not models.User.get(emails=["email3@user.com"]) + assert not backend.get(models.User, family_name="family_name") + assert backend.get(models.User, family_name="new_family_name") == user + assert not backend.get(models.User, emails=["email1@user.com"]) + assert backend.get(models.User, emails=["email2@user.com"]) == user + assert not backend.get(models.User, emails=["email3@user.com"]) user.delete() - assert not models.User.get(family_name="family_name") - assert not models.User.get(family_name="new_family_name") - assert not models.User.get(emails=["email1@user.com"]) - assert not models.User.get(emails=["email2@user.com"]) - assert not models.User.get(emails=["email3@user.com"]) + assert not backend.get(models.User, family_name="family_name") + assert not backend.get(models.User, family_name="new_family_name") + assert not backend.get(models.User, emails=["email1@user.com"]) + assert not backend.get(models.User, emails=["email2@user.com"]) + assert not backend.get(models.User, emails=["email3@user.com"]) def test_fuzzy_unique_attribute(user, moderator, admin, backend): diff --git a/tests/core/test_account.py b/tests/core/test_account.py index ff07c6a1..133d2374 100644 --- a/tests/core/test_account.py +++ b/tests/core/test_account.py @@ -78,7 +78,7 @@ def test_admin_self_deletion(testclient, backend): .follow(status=200) ) - assert models.User.get(user_name="temp") is None + assert backend.get(models.User, user_name="temp") is None with testclient.session_transaction() as sess: assert not sess.get("user_id") @@ -116,7 +116,7 @@ def test_user_self_deletion(testclient, backend): .follow(status=200) ) - assert models.User.get(user_name="temp") is None + assert backend.get(models.User, user_name="temp") is None with testclient.session_transaction() as sess: assert not sess.get("user_id") @@ -136,7 +136,7 @@ def test_account_locking(user, backend): assert user.locked user.save() assert user.locked - assert models.User.get(id=user.id).locked + assert backend.get(models.User, id=user.id).locked assert backend.check_user_password(user, "correct horse battery staple") == ( False, "Your account has been locked.", @@ -145,7 +145,7 @@ def test_account_locking(user, backend): user.lock_date = None user.save() assert not user.locked - assert not models.User.get(id=user.id).locked + assert not backend.get(models.User, id=user.id).locked assert backend.check_user_password(user, "correct horse battery staple") == ( True, None, @@ -165,7 +165,7 @@ def test_account_locking_past_date(user, backend): ) - datetime.timedelta(days=30) user.save() assert user.locked - assert models.User.get(id=user.id).locked + assert backend.get(models.User, id=user.id).locked assert backend.check_user_password(user, "correct horse battery staple") == ( False, "Your account has been locked.", @@ -185,7 +185,7 @@ def test_account_locking_future_date(user, backend): ) + datetime.timedelta(days=365 * 4) user.save() assert not user.locked - assert not models.User.get(id=user.id).locked + assert not backend.get(models.User, id=user.id).locked assert backend.check_user_password(user, "correct horse battery staple") == ( True, None, diff --git a/tests/core/test_groups.py b/tests/core/test_groups.py index f0087d27..bd9d6bb3 100644 --- a/tests/core/test_groups.py +++ b/tests/core/test_groups.py @@ -131,12 +131,12 @@ def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group): def test_moderator_can_create_edit_and_delete_group( - testclient, logged_moderator, foo_group + testclient, logged_moderator, foo_group, backend ): # The group does not exist res = testclient.get("/groups", status=200) - assert models.Group.get(display_name="bar") is None - assert models.Group.get(display_name="foo") == foo_group + assert backend.get(models.Group, display_name="bar") is None + assert backend.get(models.Group, display_name="foo") == foo_group res.mustcontain(no="bar") res.mustcontain("foo") @@ -150,7 +150,7 @@ def test_moderator_can_create_edit_and_delete_group( res = form.submit(status=302).follow(status=200) logged_moderator.reload() - bar_group = models.Group.get(display_name="bar") + bar_group = backend.get(models.Group, display_name="bar") assert bar_group.display_name == "bar" assert bar_group.description == "yolo" assert bar_group.members == [ @@ -168,10 +168,10 @@ def test_moderator_can_create_edit_and_delete_group( assert res.flashes == [("error", "Group edition failed.")] res.mustcontain("This field cannot be edited") - bar_group = models.Group.get(display_name="bar") + bar_group = backend.get(models.Group, display_name="bar") assert bar_group.display_name == "bar" assert bar_group.description == "yolo" - assert models.Group.get(display_name="bar2") is None + assert backend.get(models.Group, display_name="bar2") is None # Group description can be edited res = testclient.get("/groups/bar", status=200) @@ -182,14 +182,14 @@ def test_moderator_can_create_edit_and_delete_group( assert res.flashes == [("success", "The group bar has been sucessfully edited.")] res = res.follow() - bar_group = models.Group.get(display_name="bar") + bar_group = backend.get(models.Group, display_name="bar") assert bar_group.display_name == "bar" assert bar_group.description == "yolo2" # Group is deleted res = res.forms["editgroupform"].submit(name="action", value="confirm-delete") res = res.form.submit(name="action", value="delete", status=302) - assert models.Group.get(display_name="bar") is None + assert backend.get(models.Group, display_name="bar") is None assert ("success", "The group bar has been sucessfully deleted") in res.flashes diff --git a/tests/core/test_invitation.py b/tests/core/test_invitation.py index 0bc96033..537eb310 100644 --- a/tests/core/test_invitation.py +++ b/tests/core/test_invitation.py @@ -7,7 +7,7 @@ from canaille.core.endpoints.account import RegistrationPayload def test_invitation(testclient, logged_admin, foo_group, smtpd, backend): - assert models.User.get(user_name="someone") is None + assert backend.get(models.User, user_name="someone") is None res = testclient.get("/invite", status=200) @@ -46,7 +46,7 @@ def test_invitation(testclient, logged_admin, foo_group, smtpd, backend): assert ("success", "Your account has been created successfully.") in res.flashes res = res.follow(status=200) - user = models.User.get(user_name="someone") + user = backend.get(models.User, user_name="someone") foo_group.reload() assert backend.check_user_password(user, "whatever")[0] assert user.groups == [foo_group] @@ -62,8 +62,8 @@ def test_invitation(testclient, logged_admin, foo_group, smtpd, backend): def test_invitation_editable_user_name( testclient, logged_admin, foo_group, smtpd, backend ): - assert models.User.get(user_name="jackyjack") is None - assert models.User.get(user_name="djorje") is None + assert backend.get(models.User, user_name="jackyjack") is None + assert backend.get(models.User, user_name="djorje") is None res = testclient.get("/invite", status=200) @@ -102,7 +102,7 @@ def test_invitation_editable_user_name( assert ("success", "Your account has been created successfully.") in res.flashes res = res.follow(status=200) - user = models.User.get(user_name="djorje") + user = backend.get(models.User, user_name="djorje") foo_group.reload() assert backend.check_user_password(user, "whatever")[0] assert user.groups == [foo_group] @@ -114,7 +114,7 @@ def test_invitation_editable_user_name( def test_generate_link(testclient, logged_admin, foo_group, smtpd, backend): - assert models.User.get(user_name="sometwo") is None + assert backend.get(models.User, user_name="sometwo") is None res = testclient.get("/invite", status=200) @@ -149,7 +149,7 @@ def test_generate_link(testclient, logged_admin, foo_group, smtpd, backend): res = res.form.submit(status=302) res = res.follow(status=200) - user = models.User.get(user_name="sometwo") + user = backend.get(models.User, user_name="sometwo") foo_group.reload() assert backend.check_user_password(user, "whatever")[0] assert user.groups == [foo_group] @@ -245,7 +245,7 @@ def test_registration_more_than_48_hours_after_invitation(testclient, foo_group) testclient.get(f"/register/{b64}/{hash}", status=302) -def test_registration_no_password(testclient, foo_group): +def test_registration_no_password(testclient, foo_group, backend): payload = RegistrationPayload( datetime.datetime.now(datetime.timezone.utc).isoformat(), "someoneelse", @@ -264,7 +264,7 @@ def test_registration_no_password(testclient, foo_group): res = res.form.submit(status=200) res.mustcontain("This field is required.") - assert not models.User.get(user_name="someoneelse") + assert not backend.get(models.User, user_name="someoneelse") with testclient.session_transaction() as sess: assert "user_id" not in sess @@ -302,7 +302,7 @@ def test_unavailable_if_no_smtp(testclient, logged_admin): def test_groups_are_saved_even_when_user_does_not_have_read_permission( - testclient, foo_group + testclient, foo_group, backend ): testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["READ"] = [ "user_name" @@ -331,7 +331,7 @@ def test_groups_are_saved_even_when_user_does_not_have_read_permission( res = res.form.submit(status=302) res = res.follow(status=200) - user = models.User.get(user_name="someoneelse") + user = backend.get(models.User, user_name="someoneelse") foo_group.reload() assert user.groups == [foo_group] user.delete() diff --git a/tests/core/test_profile_creation.py b/tests/core/test_profile_creation.py index 5bb28f1d..c85ed61a 100644 --- a/tests/core/test_profile_creation.py +++ b/tests/core/test_profile_creation.py @@ -6,7 +6,7 @@ def test_user_creation_edition_and_deletion( ): # The user does not exist. res = testclient.get("/users", status=200) - assert models.User.get(user_name="george") is None + assert backend.get(models.User, user_name="george") is None res.mustcontain(no="george") # Fill the profile for a new user. @@ -24,7 +24,7 @@ def test_user_creation_edition_and_deletion( res = res.form.submit(name="action", value="create-profile", status=302) assert ("success", "User account creation succeed.") in res.flashes res = res.follow(status=200) - george = models.User.get(user_name="george") + george = backend.get(models.User, user_name="george") foo_group.reload() assert "George" == george.given_name assert george.groups == [foo_group] @@ -45,7 +45,7 @@ def test_user_creation_edition_and_deletion( res.form["groups"] = [foo_group.id, bar_group.id] res = res.form.submit(name="action", value="edit-settings").follow() - george = models.User.get(user_name="george") + george = backend.get(models.User, user_name="george") assert "Georgio" == george.given_name assert backend.check_user_password(george, "totoyolo")[0] @@ -62,7 +62,7 @@ def test_user_creation_edition_and_deletion( res = res.form.submit(name="action", value="confirm-delete", status=200) res = res.form.submit(name="action", value="delete", status=302) res = res.follow(status=200) - assert models.User.get(user_name="george") is None + assert backend.get(models.User, user_name="george") is None res.mustcontain(no="george") @@ -82,7 +82,7 @@ def test_profile_creation_dynamic_validation(testclient, logged_admin, user): res.mustcontain("The email 'john@doe.com' is already used") -def test_user_creation_without_password(testclient, logged_moderator): +def test_user_creation_without_password(testclient, logged_moderator, backend): res = testclient.get("/profile", status=200) res.form["user_name"] = "george" res.form["family_name"] = "Abitbol" @@ -91,7 +91,7 @@ def test_user_creation_without_password(testclient, logged_moderator): res = res.form.submit(name="action", value="create-profile", status=302) assert ("success", "User account creation succeed.") in res.flashes res = res.follow(status=200) - george = models.User.get(user_name="george") + george = backend.get(models.User, user_name="george") assert george.user_name == "george" assert not george.has_password() @@ -99,16 +99,16 @@ def test_user_creation_without_password(testclient, logged_moderator): def test_user_creation_form_validation_failed( - testclient, logged_moderator, foo_group, bar_group + testclient, logged_moderator, foo_group, bar_group, backend ): res = testclient.get("/users", status=200) - assert models.User.get(user_name="george") is None + assert backend.get(models.User, user_name="george") is None res.mustcontain(no="george") res = testclient.get("/profile", status=200) res = res.form.submit(name="action", value="create-profile") assert ("error", "User account creation failed.") in res.flashes - assert models.User.get(user_name="george") is None + assert backend.get(models.User, user_name="george") is None def test_username_already_taken( @@ -133,7 +133,7 @@ def test_email_already_taken(testclient, logged_moderator, user, foo_group, bar_ res.mustcontain("The email 'john@doe.com' is already used") -def test_cn_setting_with_given_name_and_surname(testclient, logged_moderator): +def test_cn_setting_with_given_name_and_surname(testclient, logged_moderator, backend): res = testclient.get("/profile", status=200) res.form["user_name"] = "george" res.form["given_name"] = "George" @@ -144,12 +144,12 @@ def test_cn_setting_with_given_name_and_surname(testclient, logged_moderator): status=200 ) - george = models.User.get(user_name="george") + george = backend.get(models.User, user_name="george") assert george.formatted_name == "George Abitbol" george.delete() -def test_cn_setting_with_surname_only(testclient, logged_moderator): +def test_cn_setting_with_surname_only(testclient, logged_moderator, backend): res = testclient.get("/profile", status=200) res.form["user_name"] = "george" res.form["family_name"] = "Abitbol" @@ -159,7 +159,7 @@ def test_cn_setting_with_surname_only(testclient, logged_moderator): status=200 ) - george = models.User.get(user_name="george") + george = backend.get(models.User, user_name="george") assert george.formatted_name == "Abitbol" george.delete() diff --git a/tests/core/test_profile_photo.py b/tests/core/test_profile_photo.py index bce81091..d173565a 100644 --- a/tests/core/test_profile_photo.py +++ b/tests/core/test_profile_photo.py @@ -104,9 +104,9 @@ def test_photo_on_profile_edition( assert logged_user.photo is None -def test_photo_on_profile_creation(testclient, jpeg_photo, logged_admin): +def test_photo_on_profile_creation(testclient, jpeg_photo, logged_admin, backend): res = testclient.get("/users", status=200) - assert models.User.get(user_name="foobar") is None + assert backend.get(models.User, user_name="foobar") is None res.mustcontain(no="foobar") res = testclient.get("/profile", status=200) @@ -119,14 +119,16 @@ def test_photo_on_profile_creation(testclient, jpeg_photo, logged_admin): status=200 ) - user = models.User.get(user_name="foobar") + user = backend.get(models.User, user_name="foobar") assert user.photo == jpeg_photo user.delete() -def test_photo_deleted_on_profile_creation(testclient, jpeg_photo, logged_admin): +def test_photo_deleted_on_profile_creation( + testclient, jpeg_photo, logged_admin, backend +): res = testclient.get("/users", status=200) - assert models.User.get(user_name="foobar") is None + assert backend.get(models.User, user_name="foobar") is None res.mustcontain(no="foobar") res = testclient.get("/profile", status=200) @@ -140,6 +142,6 @@ def test_photo_deleted_on_profile_creation(testclient, jpeg_photo, logged_admin) status=200 ) - user = models.User.get(user_name="foobar") + user = backend.get(models.User, user_name="foobar") assert user.photo is None user.delete() diff --git a/tests/core/test_profile_settings.py b/tests/core/test_profile_settings.py index 462c1e1b..6c159c34 100644 --- a/tests/core/test_profile_settings.py +++ b/tests/core/test_profile_settings.py @@ -381,7 +381,7 @@ def test_account_locking( res = res.form.submit(name="action", value="confirm-lock") res = res.form.submit(name="action", value="lock") - user = models.User.get(id=user.id) + user = backend.get(models.User, id=user.id) assert user.lock_date <= datetime.datetime.now(datetime.timezone.utc) assert user.locked res.mustcontain("The account has been locked") @@ -389,7 +389,7 @@ def test_account_locking( res.mustcontain("Unlock") res = res.form.submit(name="action", value="unlock") - user = models.User.get(id=user.id) + user = backend.get(models.User, id=user.id) assert not user.lock_date assert not user.locked res.mustcontain("The account has been unlocked") @@ -415,7 +415,7 @@ def test_past_lock_date( assert res.flashes == [("success", "Profile updated successfully.")] res = res.follow() - user = models.User.get(id=user.id) + user = backend.get(models.User, id=user.id) assert user.lock_date == expiration_datetime assert user.locked @@ -438,7 +438,7 @@ def test_future_lock_date( assert res.flashes == [("success", "Profile updated successfully.")] res = res.follow() - user = models.User.get(id=user.id) + user = backend.get(models.User, id=user.id) assert user.lock_date == expiration_datetime assert not user.locked assert res.form["lock_date"].value == expiration_datetime.strftime("%Y-%m-%d %H:%M") @@ -484,7 +484,7 @@ def test_account_limit_values( assert res.flashes == [("success", "Profile updated successfully.")] res = res.follow() - user = models.User.get(id=user.id) + user = backend.get(models.User, id=user.id) assert user.lock_date == expiration_datetime assert not user.locked diff --git a/tests/core/test_registration.py b/tests/core/test_registration.py index 7d41ef95..b445ca74 100644 --- a/tests/core/test_registration.py +++ b/tests/core/test_registration.py @@ -22,7 +22,7 @@ def test_registration_without_email_validation(testclient, backend, foo_group): res = res.form.submit() assert ("success", "Your account has been created successfully.") in res.flashes - user = models.User.get(user_name="newuser") + user = backend.get(models.User, user_name="newuser") assert user user.delete() @@ -73,7 +73,7 @@ def test_registration_with_email_validation(testclient, backend, smtpd, foo_grou ("success", "Your account has been created successfully."), ] - user = models.User.get(user_name="newuser") + user = backend.get(models.User, user_name="newuser") assert user user.delete() diff --git a/tests/oidc/commands/test_clean.py b/tests/oidc/commands/test_clean.py index 1dbecae7..69daf252 100644 --- a/tests/oidc/commands/test_clean.py +++ b/tests/oidc/commands/test_clean.py @@ -69,8 +69,8 @@ def test_clean_command(testclient, backend, client, user): ) expired_token.save() - assert models.AuthorizationCode.get(code="my-expired-code") - assert models.Token.get(access_token="my-expired-token") + assert backend.get(models.AuthorizationCode, code="my-expired-code") + assert backend.get(models.Token, access_token="my-expired-token") assert expired_code.is_expired() assert expired_token.is_expired() @@ -78,5 +78,5 @@ def test_clean_command(testclient, backend, client, user): res = runner.invoke(cli, ["clean"]) assert res.exit_code == 0, res.stdout - assert models.AuthorizationCode.get() == valid_code - assert models.Token.get() == valid_token + assert backend.get(models.AuthorizationCode) == valid_code + assert backend.get(models.Token) == valid_token diff --git a/tests/oidc/test_authorization_code_flow.py b/tests/oidc/test_authorization_code_flow.py index b9d5f45d..9c8d30e5 100644 --- a/tests/oidc/test_authorization_code_flow.py +++ b/tests/oidc/test_authorization_code_flow.py @@ -34,7 +34,7 @@ def test_nominal_case( assert res.location.startswith(client.redirect_uris[0]) params = parse_qs(urlsplit(res.location).query) code = params["code"][0] - authcode = models.AuthorizationCode.get(code=code) + authcode = backend.get(models.AuthorizationCode, code=code) assert authcode is not None assert set(authcode.scope) == { "openid", @@ -68,7 +68,7 @@ def test_nominal_case( ) access_token = res.json["access_token"] - token = models.Token.get(access_token=access_token) + token = backend.get(models.Token, access_token=access_token) assert token.client == client assert token.subject == logged_user assert set(token.scope) == { @@ -136,7 +136,7 @@ def test_redirect_uri( assert res.location.startswith(client.redirect_uris[1]) params = parse_qs(urlsplit(res.location).query) code = params["code"][0] - authcode = models.AuthorizationCode.get(code=code) + authcode = backend.get(models.AuthorizationCode, code=code) assert authcode is not None consents = backend.query(models.Consent, client=client, subject=logged_user) @@ -153,7 +153,7 @@ def test_redirect_uri( ) access_token = res.json["access_token"] - token = models.Token.get(access_token=access_token) + token = backend.get(models.Token, access_token=access_token) assert token.client == client assert token.subject == logged_user @@ -183,7 +183,7 @@ def test_preconsented_client( assert res.location.startswith(client.redirect_uris[0]) params = parse_qs(urlsplit(res.location).query) code = params["code"][0] - authcode = models.AuthorizationCode.get(code=code) + authcode = backend.get(models.AuthorizationCode, code=code) assert authcode is not None consents = backend.query(models.Consent, client=client, subject=logged_user) @@ -202,7 +202,7 @@ def test_preconsented_client( ) access_token = res.json["access_token"] - token = models.Token.get(access_token=access_token) + token = backend.get(models.Token, access_token=access_token) assert token.client == client assert token.subject == logged_user @@ -257,7 +257,7 @@ def test_logout_login(testclient, logged_user, client, backend): assert res.location.startswith(client.redirect_uris[0]) params = parse_qs(urlsplit(res.location).query) code = params["code"][0] - authcode = models.AuthorizationCode.get(code=code) + authcode = backend.get(models.AuthorizationCode, code=code) assert authcode is not None consents = backend.query(models.Consent, client=client, subject=logged_user) @@ -276,7 +276,7 @@ def test_logout_login(testclient, logged_user, client, backend): ) access_token = res.json["access_token"] - token = models.Token.get(access_token=access_token) + token = backend.get(models.Token, access_token=access_token) assert token.client == client assert token.subject == logged_user @@ -341,7 +341,7 @@ def test_code_challenge(testclient, logged_user, client, backend): assert res.location.startswith(client.redirect_uris[0]) params = parse_qs(urlsplit(res.location).query) code = params["code"][0] - authcode = models.AuthorizationCode.get(code=code) + authcode = backend.get(models.AuthorizationCode, code=code) assert authcode is not None consents = backend.query(models.Consent, client=client, subject=logged_user) @@ -361,7 +361,7 @@ def test_code_challenge(testclient, logged_user, client, backend): ) access_token = res.json["access_token"] - token = models.Token.get(access_token=access_token) + token = backend.get(models.Token, access_token=access_token) assert token.client == client assert token.subject == logged_user @@ -398,7 +398,7 @@ def test_consent_already_given(testclient, logged_user, client, backend): assert res.location.startswith(client.redirect_uris[0]) params = parse_qs(urlsplit(res.location).query) code = params["code"][0] - authcode = models.AuthorizationCode.get(code=code) + authcode = backend.get(models.AuthorizationCode, code=code) assert authcode is not None consents = backend.query(models.Consent, client=client, subject=logged_user) @@ -456,7 +456,7 @@ def test_when_consent_already_given_but_for_a_smaller_scope( assert res.location.startswith(client.redirect_uris[0]) params = parse_qs(urlsplit(res.location).query) code = params["code"][0] - authcode = models.AuthorizationCode.get(code=code) + authcode = backend.get(models.AuthorizationCode, code=code) assert authcode is not None consents = backend.query(models.Consent, client=client, subject=logged_user) @@ -492,7 +492,7 @@ def test_when_consent_already_given_but_for_a_smaller_scope( assert res.location.startswith(client.redirect_uris[0]) params = parse_qs(urlsplit(res.location).query) code = params["code"][0] - authcode = models.AuthorizationCode.get(code=code) + authcode = backend.get(models.AuthorizationCode, code=code) assert authcode is not None consents = backend.query(models.Consent, client=client, subject=logged_user) @@ -582,7 +582,7 @@ def test_request_scope_too_large(testclient, logged_user, keypair, client, backe params = parse_qs(urlsplit(res.location).query) code = params["code"][0] - authcode = models.AuthorizationCode.get(code=code) + authcode = backend.get(models.AuthorizationCode, code=code) assert set(authcode.scope) == { "openid", "profile", @@ -607,7 +607,7 @@ def test_request_scope_too_large(testclient, logged_user, keypair, client, backe ) access_token = res.json["access_token"] - token = models.Token.get(access_token=access_token) + token = backend.get(models.Token, access_token=access_token) assert token.client == client assert token.subject == logged_user assert set(token.scope) == { @@ -671,7 +671,7 @@ def test_code_expired(testclient, user, client): } -def test_code_with_invalid_user(testclient, admin, client): +def test_code_with_invalid_user(testclient, admin, client, backend): user = models.User( formatted_name="John Doe", family_name="Doe", @@ -699,7 +699,7 @@ def test_code_with_invalid_user(testclient, admin, client): res = res.form.submit(name="answer", value="accept", status=302) params = parse_qs(urlsplit(res.location).query) code = params["code"][0] - authcode = models.AuthorizationCode.get(code=code) + authcode = backend.get(models.AuthorizationCode, code=code) user.delete() @@ -721,7 +721,9 @@ def test_code_with_invalid_user(testclient, admin, client): authcode.delete() -def test_locked_account(testclient, logged_user, client, keypair, trusted_client): +def test_locked_account( + testclient, logged_user, client, keypair, trusted_client, backend +): """Users with a locked account should not be able to exchange code against tokens.""" res = testclient.get( @@ -743,7 +745,7 @@ def test_locked_account(testclient, logged_user, client, keypair, trusted_client assert res.location.startswith(client.redirect_uris[0]) params = parse_qs(urlsplit(res.location).query) code = params["code"][0] - authcode = models.AuthorizationCode.get(code=code) + authcode = backend.get(models.AuthorizationCode, code=code) assert authcode is not None res = testclient.post( diff --git a/tests/oidc/test_client_admin.py b/tests/oidc/test_client_admin.py index b5c4a3bd..c42bf219 100644 --- a/tests/oidc/test_client_admin.py +++ b/tests/oidc/test_client_admin.py @@ -83,7 +83,7 @@ def test_client_list_search(testclient, logged_admin, client, trusted_client): res.mustcontain(no=client.client_name) -def test_client_add(testclient, logged_admin): +def test_client_add(testclient, logged_admin, backend): res = testclient.get("/admin/client/add") data = { "client_name": "foobar", @@ -112,7 +112,7 @@ def test_client_add(testclient, logged_admin): res = res.follow(status=200) client_id = res.forms["readonly"]["client_id"].value - client = models.Client.get(client_id=client_id) + client = backend.get(models.Client, client_id=client_id) assert client.client_name == "foobar" assert client.contacts == ["foo@bar.com"] @@ -214,7 +214,7 @@ def test_client_edit_missing_fields(testclient, client, logged_admin, trusted_cl assert client.client_name -def test_client_delete(testclient, logged_admin): +def test_client_delete(testclient, logged_admin, backend): client = models.Client(client_id="client_id") client.save() token = models.Token( @@ -238,10 +238,10 @@ def test_client_delete(testclient, logged_admin): res = res.form.submit(name="action", value="delete") res = res.follow() - assert not models.Client.get() - assert not models.Token.get() - assert not models.AuthorizationCode.get() - assert not models.Consent.get() + assert not backend.get(models.Client) + assert not backend.get(models.Token) + assert not backend.get(models.AuthorizationCode) + assert not backend.get(models.Consent) def test_client_delete_invalid_client(testclient, logged_admin, client): diff --git a/tests/oidc/test_consent.py b/tests/oidc/test_consent.py index b4547a3a..16787021 100644 --- a/tests/oidc/test_consent.py +++ b/tests/oidc/test_consent.py @@ -134,7 +134,7 @@ def test_oidc_authorization_after_revokation( ) access_token = res.json["access_token"] - token = models.Token.get(access_token=access_token) + token = backend.get(models.Token, access_token=access_token) assert token.client == client assert token.subject == logged_user @@ -151,16 +151,16 @@ def test_preconsented_client_appears_in_consent_list(testclient, client, logged_ res.mustcontain(client.client_name) -def test_revoke_preconsented_client(testclient, client, logged_user, token): +def test_revoke_preconsented_client(testclient, client, logged_user, token, backend): client.preconsent = True client.save() - assert not models.Consent.get() + assert not backend.get(models.Consent) assert not token.revoked res = testclient.get(f"/consent/revoke-preconsent/{client.client_id}", status=302) assert ("success", "The access has been revoked") in res.flashes - consent = models.Consent.get() + consent = backend.get(models.Consent) assert consent.client == client assert consent.subject == logged_user assert consent.scope == ["openid", "email", "profile", "groups", "address", "phone"] diff --git a/tests/oidc/test_dynamic_client_registration.py b/tests/oidc/test_dynamic_client_registration.py index 9e9d336f..5754cb70 100644 --- a/tests/oidc/test_dynamic_client_registration.py +++ b/tests/oidc/test_dynamic_client_registration.py @@ -33,7 +33,7 @@ def test_client_registration_with_authentication_static_token( headers = {"Authorization": "Bearer static-token"} res = testclient.post_json("/oauth/register", payload, headers=headers, status=201) - client = models.Client.get(client_id=res.json["client_id"]) + client = backend.get(models.Client, client_id=res.json["client_id"]) assert res.json == { "client_id": client.client_id, @@ -154,7 +154,7 @@ def test_client_registration_with_software_statement(testclient, backend, keypai } res = testclient.post_json("/oauth/register", payload, status=201) - client = models.Client.get(client_id=res.json["client_id"]) + client = backend.get(models.Client, client_id=res.json["client_id"]) assert res.json == { "client_id": client.client_id, "client_secret": client.client_secret, @@ -205,7 +205,7 @@ def test_client_registration_without_authentication_ok(testclient, backend): res = testclient.post_json("/oauth/register", payload, status=201) - client = models.Client.get(client_id=res.json["client_id"]) + client = backend.get(models.Client, client_id=res.json["client_id"]) assert res.json == { "client_id": mock.ANY, "client_secret": mock.ANY, diff --git a/tests/oidc/test_dynamic_client_registration_management.py b/tests/oidc/test_dynamic_client_registration_management.py index 08594027..3e6bd3f5 100644 --- a/tests/oidc/test_dynamic_client_registration_management.py +++ b/tests/oidc/test_dynamic_client_registration_management.py @@ -95,7 +95,7 @@ def test_update(testclient, backend, client, user): res = testclient.put_json( f"/oauth/register/{client.client_id}", payload, headers=headers, status=200 ) - client = models.Client.get(client_id=res.json["client_id"]) + client = backend.get(models.Client, client_id=res.json["client_id"]) assert res.json == { "client_id": client.client_id, @@ -153,7 +153,7 @@ def test_delete(testclient, backend, user): testclient.delete( f"/oauth/register/{client.client_id}", headers=headers, status=204 ) - assert not models.Client.get(client_id=client.client_id) + assert not backend.get(models.Client, client_id=client.client_id) def test_invalid_client(testclient, backend, user): diff --git a/tests/oidc/test_forms.py b/tests/oidc/test_forms.py index dae00e7e..a3924e85 100644 --- a/tests/oidc/test_forms.py +++ b/tests/oidc/test_forms.py @@ -33,7 +33,7 @@ def test_fieldlist_add(testclient, logged_admin, backend): res = res.follow(status=200) client_id = res.forms["readonly"]["client_id"].value - client = models.Client.get(client_id=client_id) + client = backend.get(models.Client, client_id=client_id) assert client.redirect_uris == [ "https://foo.bar/callback", @@ -68,7 +68,7 @@ def test_fieldlist_delete(testclient, logged_admin, backend): res = res.follow(status=200) client_id = res.forms["readonly"]["client_id"].value - client = models.Client.get(client_id=client_id) + client = backend.get(models.Client, client_id=client_id) assert client.redirect_uris == [ "https://foo.bar/callback1", @@ -128,7 +128,7 @@ def test_fieldlist_duplicate_value(testclient, logged_admin, client): res.mustcontain("This value is a duplicate") -def test_fieldlist_empty_value(testclient, logged_admin): +def test_fieldlist_empty_value(testclient, logged_admin, backend): res = testclient.get("/admin/client/add") data = { "client_name": "foobar", @@ -145,7 +145,7 @@ def test_fieldlist_empty_value(testclient, logged_admin): status=200, name="fieldlist_add", value="post_logout_redirect_uris-0" ) res.form.submit(status=302, name="action", value="edit") - client = models.Client.get() + client = backend.get(models.Client) client.delete() diff --git a/tests/oidc/test_hybrid_flow.py b/tests/oidc/test_hybrid_flow.py index 92c15762..3f37a24b 100644 --- a/tests/oidc/test_hybrid_flow.py +++ b/tests/oidc/test_hybrid_flow.py @@ -32,11 +32,11 @@ def test_oauth_hybrid(testclient, backend, user, client): params = parse_qs(urlsplit(res.location).fragment) code = params["code"][0] - authcode = models.AuthorizationCode.get(code=code) + authcode = backend.get(models.AuthorizationCode, code=code) assert authcode is not None access_token = params["access_token"][0] - token = models.Token.get(access_token=access_token) + token = backend.get(models.Token, access_token=access_token) assert token is not None res = testclient.get( @@ -65,11 +65,11 @@ def test_oidc_hybrid(testclient, backend, logged_user, client, keypair, trusted_ params = parse_qs(urlsplit(res.location).fragment) code = params["code"][0] - authcode = models.AuthorizationCode.get(code=code) + authcode = backend.get(models.AuthorizationCode, code=code) assert authcode is not None access_token = params["access_token"][0] - token = models.Token.get(access_token=access_token) + token = backend.get(models.Token, access_token=access_token) assert token is not None id_token = params["id_token"][0] diff --git a/tests/oidc/test_implicit_flow.py b/tests/oidc/test_implicit_flow.py index 87f125bf..ba1a08ab 100644 --- a/tests/oidc/test_implicit_flow.py +++ b/tests/oidc/test_implicit_flow.py @@ -6,7 +6,7 @@ from authlib.jose import jwt from canaille.app import models -def test_oauth_implicit(testclient, user, client): +def test_oauth_implicit(testclient, user, client, backend): client.grant_types = ["token"] client.token_endpoint_auth_method = "none" @@ -37,7 +37,7 @@ def test_oauth_implicit(testclient, user, client): params = parse_qs(urlsplit(res.location).fragment) access_token = params["access_token"][0] - token = models.Token.get(access_token=access_token) + token = backend.get(models.Token, access_token=access_token) assert token is not None res = testclient.get( @@ -51,7 +51,7 @@ def test_oauth_implicit(testclient, user, client): client.save() -def test_oidc_implicit(testclient, keypair, user, client, trusted_client): +def test_oidc_implicit(testclient, keypair, user, client, trusted_client, backend): client.grant_types = ["token id_token"] client.token_endpoint_auth_method = "none" @@ -82,7 +82,7 @@ def test_oidc_implicit(testclient, keypair, user, client, trusted_client): params = parse_qs(urlsplit(res.location).fragment) access_token = params["access_token"][0] - token = models.Token.get(access_token=access_token) + token = backend.get(models.Token, access_token=access_token) assert token is not None id_token = params["id_token"][0] @@ -105,7 +105,7 @@ def test_oidc_implicit(testclient, keypair, user, client, trusted_client): def test_oidc_implicit_with_group( - testclient, keypair, user, client, foo_group, trusted_client + testclient, keypair, user, client, foo_group, trusted_client, backend ): client.grant_types = ["token id_token"] client.token_endpoint_auth_method = "none" @@ -137,7 +137,7 @@ def test_oidc_implicit_with_group( params = parse_qs(urlsplit(res.location).fragment) access_token = params["access_token"][0] - token = models.Token.get(access_token=access_token) + token = backend.get(models.Token, access_token=access_token) assert token is not None id_token = params["id_token"][0] diff --git a/tests/oidc/test_password_flow.py b/tests/oidc/test_password_flow.py index 47c49ca0..daaad5d3 100644 --- a/tests/oidc/test_password_flow.py +++ b/tests/oidc/test_password_flow.py @@ -3,7 +3,7 @@ from canaille.app import models from . import client_credentials -def test_password_flow_basic(testclient, user, client): +def test_password_flow_basic(testclient, user, client, backend): res = testclient.post( "/oauth/token", params=dict( @@ -20,7 +20,7 @@ def test_password_flow_basic(testclient, user, client): assert res.json["token_type"] == "Bearer" access_token = res.json["access_token"] - token = models.Token.get(access_token=access_token) + token = backend.get(models.Token, access_token=access_token) assert token is not None res = testclient.get( @@ -31,7 +31,7 @@ def test_password_flow_basic(testclient, user, client): assert res.json["name"] == "John (johnny) Doe" -def test_password_flow_post(testclient, user, client): +def test_password_flow_post(testclient, user, client, backend): client.token_endpoint_auth_method = "client_secret_post" client.save() @@ -52,7 +52,7 @@ def test_password_flow_post(testclient, user, client): assert res.json["token_type"] == "Bearer" access_token = res.json["access_token"] - token = models.Token.get(access_token=access_token) + token = backend.get(models.Token, access_token=access_token) assert token is not None res = testclient.get( diff --git a/tests/oidc/test_refresh_token.py b/tests/oidc/test_refresh_token.py index 548ecd42..7a99bcd7 100644 --- a/tests/oidc/test_refresh_token.py +++ b/tests/oidc/test_refresh_token.py @@ -24,7 +24,7 @@ def test_refresh_token(testclient, logged_user, client, backend): assert res.location.startswith(client.redirect_uris[0]) params = parse_qs(urlsplit(res.location).query) code = params["code"][0] - authcode = models.AuthorizationCode.get(code=code) + authcode = backend.get(models.AuthorizationCode, code=code) assert authcode is not None consents = backend.query(models.Consent, client=client, subject=logged_user) @@ -42,7 +42,7 @@ def test_refresh_token(testclient, logged_user, client, backend): status=200, ) access_token = res.json["access_token"] - old_token = models.Token.get(access_token=access_token) + old_token = backend.get(models.Token, access_token=access_token) assert old_token is not None assert not old_token.revokation_date @@ -56,7 +56,7 @@ def test_refresh_token(testclient, logged_user, client, backend): status=200, ) access_token = res.json["access_token"] - new_token = models.Token.get(access_token=access_token) + new_token = backend.get(models.Token, access_token=access_token) assert new_token is not None assert old_token.access_token != new_token.access_token @@ -74,7 +74,7 @@ def test_refresh_token(testclient, logged_user, client, backend): consent.delete() -def test_refresh_token_with_invalid_user(testclient, client): +def test_refresh_token_with_invalid_user(testclient, client, backend): user = models.User( formatted_name="John Doe", family_name="Doe", @@ -103,7 +103,7 @@ def test_refresh_token_with_invalid_user(testclient, client): params = parse_qs(urlsplit(res.location).query) code = params["code"][0] - models.AuthorizationCode.get(code=code) + backend.get(models.AuthorizationCode, code=code) res = testclient.post( "/oauth/token", @@ -134,7 +134,7 @@ def test_refresh_token_with_invalid_user(testclient, client): "error": "invalid_request", "error_description": 'There is no "user" for this token.', } - models.Token.get(access_token=access_token).delete() + backend.get(models.Token, access_token=access_token).delete() def test_cannot_refresh_token_for_locked_users(testclient, logged_user, client): diff --git a/tests/oidc/test_token_expiration.py b/tests/oidc/test_token_expiration.py index d4883dcd..dc212ca0 100644 --- a/tests/oidc/test_token_expiration.py +++ b/tests/oidc/test_token_expiration.py @@ -26,7 +26,7 @@ def test_token_default_expiration_date( res = res.form.submit(name="answer", value="accept", status=302) params = parse_qs(urlsplit(res.location).query) code = params["code"][0] - authcode = models.AuthorizationCode.get(code=code) + authcode = backend.get(models.AuthorizationCode, code=code) assert authcode.lifetime == 84400 res = testclient.post( @@ -44,7 +44,7 @@ def test_token_default_expiration_date( assert res.json["expires_in"] == 864000 access_token = res.json["access_token"] - token = models.Token.get(access_token=access_token) + token = backend.get(models.Token, access_token=access_token) assert token.lifetime == 864000 claims = jwt.decode(access_token, keypair[1]) @@ -86,7 +86,7 @@ def test_token_custom_expiration_date( res = res.form.submit(name="answer", value="accept", status=302) params = parse_qs(urlsplit(res.location).query) code = params["code"][0] - authcode = models.AuthorizationCode.get(code=code) + authcode = backend.get(models.AuthorizationCode, code=code) assert authcode.lifetime == 84400 res = testclient.post( @@ -104,7 +104,7 @@ def test_token_custom_expiration_date( assert res.json["expires_in"] == 1000 access_token = res.json["access_token"] - token = models.Token.get(access_token=access_token) + token = backend.get(models.Token, access_token=access_token) assert token.lifetime == 1000 claims = jwt.decode(access_token, keypair[1]) diff --git a/tests/oidc/test_token_introspection.py b/tests/oidc/test_token_introspection.py index c3334a6b..62a4984b 100644 --- a/tests/oidc/test_token_introspection.py +++ b/tests/oidc/test_token_introspection.py @@ -58,7 +58,7 @@ def test_token_invalid(testclient, client): assert {"active": False} == res.json -def test_full_flow(testclient, logged_user, client, user, trusted_client): +def test_full_flow(testclient, logged_user, client, user, trusted_client, backend): res = testclient.get( "/oauth/authorize", params=dict( @@ -75,7 +75,7 @@ def test_full_flow(testclient, logged_user, client, user, trusted_client): assert res.location.startswith(client.redirect_uris[0]) params = parse_qs(urlsplit(res.location).query) code = params["code"][0] - authcode = models.AuthorizationCode.get(code=code) + authcode = backend.get(models.AuthorizationCode, code=code) assert authcode is not None res = testclient.post( @@ -91,7 +91,7 @@ def test_full_flow(testclient, logged_user, client, user, trusted_client): ) access_token = res.json["access_token"] - token = models.Token.get(access_token=access_token) + token = backend.get(models.Token, access_token=access_token) assert token.client == client assert token.subject == logged_user