diff --git a/canaille/app/forms.py b/canaille/app/forms.py index 6f40a7f5..ed8605c6 100644 --- a/canaille/app/forms.py +++ b/canaille/app/forms.py @@ -15,6 +15,7 @@ from canaille.app.i18n import DEFAULT_LANGUAGE_CODE from canaille.app.i18n import gettext as _ from canaille.app.i18n import locale_selector from canaille.app.i18n import timezone_selector +from canaille.backends import BaseBackend from . import validate_uri from .flask import request_is_htmx @@ -188,7 +189,7 @@ class TableForm(I18NFormMixin, FlaskForm): if self.query.data: self.items = cls.fuzzy(self.query.data, fields, **filter) else: - self.items = cls.query(**filter) + self.items = BaseBackend.get().query(cls, **filter) self.page_size = page_size self.nb_items = len(self.items) diff --git a/canaille/backends/__init__.py b/canaille/backends/__init__.py index 93908f34..0a3ef313 100644 --- a/canaille/backends/__init__.py +++ b/canaille/backends/__init__.py @@ -54,6 +54,25 @@ class BaseBackend: """ raise NotImplementedError() + def query(self, model, **kwargs): + """ + Perform a query on the database and return a collection of instances. + Parameters can be any valid attribute with the expected value: + + >>> backend.query(User, first_name="George") + + If several arguments are passed, the methods only returns the model + instances that return matches all the argument values: + + >>> backend.query(User, first_name="George", last_name="Abitbol") + + If the argument value is a collection, the methods will return the + models that matches any of the values: + + >>> backend.query(User, first_name=["George", "Jane"]) + """ + raise NotImplementedError() + def check_user_password(self, user, password: str) -> bool: """Check if the password matches the user password in the database.""" raise NotImplementedError() diff --git a/canaille/backends/ldap/backend.py b/canaille/backends/ldap/backend.py index b25050ae..8486808a 100644 --- a/canaille/backends/ldap/backend.py +++ b/canaille/backends/ldap/backend.py @@ -16,6 +16,7 @@ from canaille.app.i18n import gettext as _ from canaille.backends import BaseBackend from .utils import listify +from .utils import python_attrs_to_ldap @contextmanager @@ -243,13 +244,64 @@ class Backend(BaseBackend): return result, message def set_user_password(self, user, password): - conn = Backend.get().connection + conn = self.connection conn.passwd_s( user.dn, None, password.encode("utf-8"), ) + def query(self, model, dn=None, filter=None, **kwargs): + from .ldapobjectquery import LDAPObjectQuery + + base = dn + if dn is None: + base = f"{model.base},{model.root_dn}" + elif "=" not in base: + base = ldap.dn.escape_dn_chars(base) + base = f"{model.rdn_attribute}={base},{model.base},{model.root_dn}" + + class_filter = ( + "".join([f"(objectClass={oc})" for oc in model.ldap_object_class]) + if getattr(model, "ldap_object_class") + else "" + ) + if class_filter: + class_filter = f"(|{class_filter})" + + arg_filter = "" + ldap_args = python_attrs_to_ldap( + { + model.python_attribute_to_ldap(name): values + for name, values in kwargs.items() + if values is not None + }, + encode=False, + ) + for key, value in ldap_args.items(): + if len(value) == 1: + escaped_value = ldap.filter.escape_filter_chars(value[0]) + arg_filter += f"({key}={escaped_value})" + + else: + values = [ldap.filter.escape_filter_chars(v) for v in value] + arg_filter += ( + "(|" + "".join([f"({key}={value})" for value in values]) + ")" + ) + + if not filter: + filter = "" + + ldapfilter = f"(&{class_filter}{arg_filter}{filter})" + base = base or f"{model.base},{model.root_dn}" + try: + result = self.connection.search_s( + base, ldap.SCOPE_SUBTREE, ldapfilter or None, ["+", "*"] + ) + except ldap.NO_SUCH_OBJECT: + result = [] + return LDAPObjectQuery(model, result) + def setup_ldap_models(config): from canaille.app import models diff --git a/canaille/backends/ldap/ldapobject.py b/canaille/backends/ldap/ldapobject.py index fbd091ad..df41fabe 100644 --- a/canaille/backends/ldap/ldapobject.py +++ b/canaille/backends/ldap/ldapobject.py @@ -4,42 +4,15 @@ import ldap.dn import ldap.filter from ldap.controls.readentry import PostReadControl +from canaille.backends import BaseBackend from canaille.backends.models import BackendModel from .backend import Backend -from .ldapobjectquery import LDAPObjectQuery +from .utils import attribute_ldap_syntax from .utils import cardinalize_attribute from .utils import ldap_to_python from .utils import listify -from .utils import python_to_ldap - - -def python_attrs_to_ldap(attrs, encode=True, null_allowed=True): - formatted_attrs = { - name: [ - python_to_ldap(value, attribute_ldap_syntax(name), encode=encode) - for value in listify(values) - ] - for name, values in attrs.items() - } - if not null_allowed: - formatted_attrs = { - key: [value for value in values if value] - for key, values in formatted_attrs.items() - if values - } - return formatted_attrs - - -def attribute_ldap_syntax(attribute_name): - ldap_attrs = LDAPObject.ldap_object_attributes() - if attribute_name not in ldap_attrs: - return None - - if ldap_attrs[attribute_name].syntax: - return ldap_attrs[attribute_name].syntax - - return attribute_ldap_syntax(ldap_attrs[attribute_name].sup[0]) +from .utils import python_attrs_to_ldap class LDAPObjectMetaclass(type): @@ -256,7 +229,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass): @classmethod def get(cls, identifier=None, /, **kwargs): try: - return cls.query(identifier, **kwargs)[0] + return BaseBackend.get().query(cls, identifier, **kwargs)[0] except (IndexError, ldap.NO_SUCH_OBJECT): if identifier and cls.base: return ( @@ -267,58 +240,6 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass): return None - @classmethod - def query(cls, dn=None, filter=None, **kwargs): - conn = Backend.get().connection - - base = dn - if dn is None: - base = f"{cls.base},{cls.root_dn}" - elif "=" not in base: - base = ldap.dn.escape_dn_chars(base) - base = f"{cls.rdn_attribute}={base},{cls.base},{cls.root_dn}" - - class_filter = ( - "".join([f"(objectClass={oc})" for oc in cls.ldap_object_class]) - if getattr(cls, "ldap_object_class") - else "" - ) - if class_filter: - class_filter = f"(|{class_filter})" - - arg_filter = "" - ldap_args = python_attrs_to_ldap( - { - cls.python_attribute_to_ldap(name): values - for name, values in kwargs.items() - if values is not None - }, - encode=False, - ) - for key, value in ldap_args.items(): - if len(value) == 1: - escaped_value = ldap.filter.escape_filter_chars(value[0]) - arg_filter += f"({key}={escaped_value})" - - else: - values = [ldap.filter.escape_filter_chars(v) for v in value] - arg_filter += ( - "(|" + "".join([f"({key}={value})" for value in values]) + ")" - ) - - if not filter: - filter = "" - - ldapfilter = f"(&{class_filter}{arg_filter}{filter})" - base = base or f"{cls.base},{cls.root_dn}" - try: - result = conn.search_s( - base, ldap.SCOPE_SUBTREE, ldapfilter or None, ["+", "*"] - ) - except ldap.NO_SUCH_OBJECT: - result = [] - return LDAPObjectQuery(cls, result) - @classmethod def fuzzy(cls, query, attributes=None, **kwargs): query = ldap.filter.escape_filter_chars(query) @@ -327,7 +248,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass): filter = ( "(|" + "".join(f"({attribute}=*{query}*)" for attribute in attributes) + ")" ) - return cls.query(filter=filter, **kwargs) + return BaseBackend.get().query(cls, filter=filter, **kwargs) @classmethod def update_ldap_attributes(cls): diff --git a/canaille/backends/ldap/utils.py b/canaille/backends/ldap/utils.py index a21a4375..63186976 100644 --- a/canaille/backends/ldap/utils.py +++ b/canaille/backends/ldap/utils.py @@ -95,3 +95,33 @@ def cardinalize_attribute(python_unique, value): return value[0] return [v for v in value if v is not None] + + +def python_attrs_to_ldap(attrs, encode=True, null_allowed=True): + formatted_attrs = { + name: [ + python_to_ldap(value, attribute_ldap_syntax(name), encode=encode) + for value in listify(values) + ] + for name, values in attrs.items() + } + if not null_allowed: + formatted_attrs = { + key: [value for value in values if value] + for key, values in formatted_attrs.items() + if values + } + return formatted_attrs + + +def attribute_ldap_syntax(attribute_name): + from .ldapobject import LDAPObject + + ldap_attrs = LDAPObject.ldap_object_attributes() + if attribute_name not in ldap_attrs: + return None + + if ldap_attrs[attribute_name].syntax: + return ldap_attrs[attribute_name].syntax + + return attribute_ldap_syntax(ldap_attrs[attribute_name].sup[0]) diff --git a/canaille/backends/memory/backend.py b/canaille/backends/memory/backend.py index c0662094..0ba676d9 100644 --- a/canaille/backends/memory/backend.py +++ b/canaille/backends/memory/backend.py @@ -40,3 +40,28 @@ class Backend(BaseBackend): def set_user_password(self, user, password): user.password = password user.save() + + def query(self, model, **kwargs): + # if there is no filter, return all models + if not kwargs: + states = model.index().values() + return [model(**state) for state in states] + + # get the ids from the attribute indexes + ids = { + id + for attribute, values in kwargs.items() + for value in model.serialize(model.listify(values)) + for id in model.attribute_index(attribute).get(value, []) + } + + # get the states from the ids + states = [model.index()[id] for id in ids] + + # initialize instances from the states + instances = [model(**state) for state in states] + for instance in instances: + # TODO: maybe find a way to not initialize the cache in the first place? + instance._cache = {} + + return instances diff --git a/canaille/backends/memory/models.py b/canaille/backends/memory/models.py index 8c9317d0..ac772821 100644 --- a/canaille/backends/memory/models.py +++ b/canaille/backends/memory/models.py @@ -6,6 +6,7 @@ import uuid import canaille.core.models import canaille.oidc.models from canaille.app import models +from canaille.backends import BaseBackend from canaille.backends.models import BackendModel @@ -25,31 +26,6 @@ class MemoryModel(BackendModel): def __repr__(self): return f"<{self.__class__.__name__} id={self.id}>" - @classmethod - def query(cls, **kwargs): - # if there is no filter, return all models - if not kwargs: - states = cls.index().values() - return [cls(**state) for state in states] - - # get the ids from the attribute indexes - ids = { - id - for attribute, values in kwargs.items() - for value in cls.serialize(cls.listify(values)) - for id in cls.attribute_index(attribute).get(value, []) - } - - # get the states from the ids - states = [cls.index()[id] for id in ids] - - # initialize instances from the states - instances = [cls(**state) for state in states] - for instance in instances: - # TODO: maybe find a way to not initialize the cache in the first place? - instance._cache = {} - return instances - @classmethod def index(cls, class_name=None): return MemoryModel.indexes.setdefault(class_name or cls.__name__, {}) @@ -63,7 +39,7 @@ class MemoryModel(BackendModel): @classmethod def fuzzy(cls, query, attributes=None, **kwargs): attributes = attributes or cls.attributes - instances = cls.query(**kwargs) + instances = BaseBackend.get().query(cls, **kwargs) return [ instance @@ -85,7 +61,7 @@ class MemoryModel(BackendModel): or None ) - results = cls.query(**kwargs) + results = BaseBackend.get().query(cls, **kwargs) return results[0] if results else None @classmethod diff --git a/canaille/backends/models.py b/canaille/backends/models.py index c115dd3c..c3237e77 100644 --- a/canaille/backends/models.py +++ b/canaille/backends/models.py @@ -87,37 +87,16 @@ class BackendModel: implemented for every model and for every backend. """ - @classmethod - def query(cls, **kwargs): - """Perform a query on the database and return a collection of - instances. - - Parameters can be any valid attribute with the expected value: - - >>> User.query(first_name="George") - - If several arguments are passed, the methods only returns the model - instances that return matches all the argument values: - - >>> User.query(first_name="George", last_name="Abitbol") - - If the argument value is a collection, the methods will return the - models that matches any of the values: - - >>> User.query(first_name=["George", "Jane"]) - """ - raise NotImplementedError() - @classmethod def fuzzy(cls, query, attributes=None, **kwargs): - """Works like :meth:`~canaille.backends.models.BackendModel.query` but + """Works like :meth:`~canaille.backends.BaseBackend.query` but attribute values loosely be matched.""" raise NotImplementedError() @classmethod def get(cls, identifier=None, **kwargs): - """Works like :meth:`~canaille.backends.models.BackendModel.query` but - return only one element or :py:data:`None` if no item is matching.""" + """Works like :meth:`~canaille.backends.BaseBackend.query` but return + only one element or :py:data:`None` if no item is matching.""" raise NotImplementedError() def save(self): diff --git a/canaille/backends/sql/backend.py b/canaille/backends/sql/backend.py index adf9e415..310d9200 100644 --- a/canaille/backends/sql/backend.py +++ b/canaille/backends/sql/backend.py @@ -1,4 +1,5 @@ from sqlalchemy import create_engine +from sqlalchemy import select from sqlalchemy.orm import Session from sqlalchemy.orm import declarative_base @@ -65,3 +66,15 @@ class Backend(BaseBackend): def set_user_password(self, user, password): user.password = password user.save() + + def query(self, model, **kwargs): + filter = [ + model.attribute_filter(attribute_name, expected_value) + for attribute_name, expected_value in kwargs.items() + ] + return ( + Backend.get() + .db_session.execute(select(model).filter(*filter)) + .scalars() + .all() + ) diff --git a/canaille/backends/sql/models.py b/canaille/backends/sql/models.py index c50a7b5d..5968814e 100644 --- a/canaille/backends/sql/models.py +++ b/canaille/backends/sql/models.py @@ -36,19 +36,6 @@ class SqlAlchemyModel(BackendModel): f"<{self.__class__.__name__} {self.identifier_attribute}={self.identifier}>" ) - @classmethod - def query(cls, **kwargs): - filter = [ - cls.attribute_filter(attribute_name, expected_value) - for attribute_name, expected_value in kwargs.items() - ] - return ( - Backend.get() - .db_session.execute(select(cls).filter(*filter)) - .scalars() - .all() - ) - @classmethod def fuzzy(cls, query, attributes=None, **kwargs): attributes = attributes or cls.attributes diff --git a/canaille/core/endpoints/account.py b/canaille/core/endpoints/account.py index a10bb1e1..d597463f 100644 --- a/canaille/core/endpoints/account.py +++ b/canaille/core/endpoints/account.py @@ -86,7 +86,7 @@ def join(): form = JoinForm(request.form or None) if request.form and form.validate(): - if models.User.query(emails=form.email.data): + if BaseBackend.get().query(models.User, emails=form.email.data): flash( _( "You will receive soon an email to continue the registration process." @@ -295,7 +295,10 @@ def registration(data=None, hash=None): if "groups" not in form and payload and payload.groups: form["groups"] = wtforms.SelectMultipleField( _("Groups"), - choices=[(group, group.display_name) for group in models.Group.query()], + choices=[ + (group, group.display_name) + for group in BaseBackend.get().query(models.Group) + ], coerce=IDToModel("Group"), ) set_readonly(form["groups"]) @@ -388,7 +391,7 @@ def email_confirmation(data, hash): ) return redirect(url_for("core.account.index")) - if models.User.query(emails=confirmation_obj.email): + if BaseBackend.get().query(models.User, emails=confirmation_obj.email): flash( _("This address email is already associated with another account."), "error", diff --git a/canaille/core/endpoints/forms.py b/canaille/core/endpoints/forms.py index ad61d59d..375fa78e 100644 --- a/canaille/core/endpoints/forms.py +++ b/canaille/core/endpoints/forms.py @@ -312,7 +312,10 @@ PROFILE_FORM_FIELDS = dict( groups=wtforms.SelectMultipleField( _("Groups"), default=[], - choices=lambda: [(group, group.display_name) for group in models.Group.query()], + choices=lambda: [ + (group, group.display_name) + for group in BaseBackend.get().query(models.Group) + ], render_kw={"placeholder": _("users, admins …")}, coerce=IDToModel("Group"), validators=[non_empty_groups], @@ -333,7 +336,7 @@ def build_profile_form(write_field_names, readonly_field_names, user=None): if PROFILE_FORM_FIELDS.get(name) } - if "groups" in fields and not models.Group.query(): + if "groups" in fields and not BaseBackend.get().query(models.Group): del fields["groups"] if current_app.backend.get().has_account_lockability(): # pragma: no branch @@ -436,7 +439,10 @@ class InvitationForm(Form): ) groups = wtforms.SelectMultipleField( _("Groups"), - choices=lambda: [(group, group.display_name) for group in models.Group.query()], + choices=lambda: [ + (group, group.display_name) + for group in BaseBackend.get().query(models.Group) + ], render_kw={}, coerce=IDToModel("Group"), ) diff --git a/canaille/core/populate.py b/canaille/core/populate.py index f7c38f6d..631b8a0e 100644 --- a/canaille/core/populate.py +++ b/canaille/core/populate.py @@ -5,6 +5,7 @@ from faker.config import AVAILABLE_LOCALES from canaille.app import models from canaille.app.i18n import available_language_codes +from canaille.backends import BaseBackend def fake_users(nb=1): @@ -47,7 +48,7 @@ def fake_users(nb=1): def fake_groups(nb=1, nb_users_max=1): - users = models.User.query() + users = BaseBackend.get().query(models.User) groups = list() fake = faker.Faker(["en_US"]) for _ in range(nb): diff --git a/canaille/oidc/commands.py b/canaille/oidc/commands.py index b8285572..f2442cc9 100644 --- a/canaille/oidc/commands.py +++ b/canaille/oidc/commands.py @@ -3,6 +3,7 @@ from flask.cli import with_appcontext from canaille.app import models from canaille.app.commands import with_backendcontext +from canaille.backends import BaseBackend @click.command() @@ -10,11 +11,11 @@ from canaille.app.commands import with_backendcontext @with_backendcontext def clean(): """Remove expired tokens and authorization codes.""" - for t in models.Token.query(): + for t in BaseBackend.get().query(models.Token): if t.is_expired(): t.delete() - for a in models.AuthorizationCode.query(): + for a in BaseBackend.get().query(models.AuthorizationCode): if a.is_expired(): a.delete() diff --git a/canaille/oidc/endpoints/consents.py b/canaille/oidc/endpoints/consents.py index 1bf2db3b..f6c42d7f 100644 --- a/canaille/oidc/endpoints/consents.py +++ b/canaille/oidc/endpoints/consents.py @@ -10,6 +10,7 @@ from canaille.app import models from canaille.app.flask import user_needed from canaille.app.i18n import gettext as _ from canaille.app.themes import render_template +from canaille.backends import BaseBackend from ..utils import SCOPE_DETAILS @@ -19,13 +20,13 @@ bp = Blueprint("consents", __name__, url_prefix="/consent") @bp.route("/") @user_needed() def consents(user): - consents = models.Consent.query(subject=user) + consents = BaseBackend.get().query(models.Consent, subject=user) clients = {t.client for t in consents} nb_consents = len(consents) nb_preconsents = sum( 1 - for client in models.Client.query() + for client in BaseBackend.get().query(models.Client) if client.preconsent and client not in clients ) @@ -43,11 +44,11 @@ def consents(user): @bp.route("/pre-consents") @user_needed() def pre_consents(user): - consents = models.Consent.query(subject=user) + consents = BaseBackend.get().query(models.Consent, subject=user) clients = {t.client for t in consents} preconsented = [ client - for client in models.Client.query() + for client in BaseBackend.get().query(models.Client) if client.preconsent and client not in clients ] diff --git a/canaille/oidc/endpoints/forms.py b/canaille/oidc/endpoints/forms.py index 4de65a36..317fc991 100644 --- a/canaille/oidc/endpoints/forms.py +++ b/canaille/oidc/endpoints/forms.py @@ -7,6 +7,7 @@ from canaille.app.forms import email_validator from canaille.app.forms import is_uri from canaille.app.forms import unique_values from canaille.app.i18n import lazy_gettext as _ +from canaille.backends import BaseBackend class AuthorizeForm(Form): @@ -18,7 +19,10 @@ class LogoutForm(Form): def client_audiences(): - return [(client, client.client_name) for client in models.Client.query()] + return [ + (client, client.client_name) + for client in BaseBackend.get().query(models.Client) + ] class ClientAddForm(Form): diff --git a/canaille/oidc/endpoints/oauth.py b/canaille/oidc/endpoints/oauth.py index c3bce4b5..36b5e75f 100644 --- a/canaille/oidc/endpoints/oauth.py +++ b/canaille/oidc/endpoints/oauth.py @@ -23,6 +23,7 @@ from canaille.app.flask import logout_user from canaille.app.flask import set_parameter_in_url_query from canaille.app.i18n import gettext as _ from canaille.app.themes import render_template +from canaille.backends import BaseBackend from ..oauth import ClientConfigurationEndpoint from ..oauth import ClientRegistrationEndpoint @@ -109,7 +110,8 @@ def authorize_login(user): def authorize_consent(client, user): requested_scopes = request.args.get("scope", "").split(" ") allowed_scopes = client.get_allowed_scope(requested_scopes).split(" ") - consents = models.Consent.query( + consents = BaseBackend.get().query( + models.Consent, client=client, subject=user, ) diff --git a/canaille/oidc/models.py b/canaille/oidc/models.py index 0c7a84c0..4e32c5af 100644 --- a/canaille/oidc/models.py +++ b/canaille/oidc/models.py @@ -8,6 +8,7 @@ from authlib.oauth2.rfc6749 import TokenMixin from authlib.oauth2.rfc6749 import util from canaille.app import models +from canaille.backends import BaseBackend from .basemodels import AuthorizationCode as BaseAuthorizationCode from .basemodels import Client as BaseClient @@ -95,13 +96,13 @@ class Client(BaseClient, ClientMixin): return metadata def delete(self): - for consent in models.Consent.query(client=self): + for consent in BaseBackend.get().query(models.Consent, client=self): consent.delete() - for code in models.AuthorizationCode.query(client=self): + for code in BaseBackend.get().query(models.AuthorizationCode, client=self): code.delete() - for token in models.Token.query(client=self): + for token in BaseBackend.get().query(models.Token, client=self): token.delete() super().delete() @@ -185,7 +186,8 @@ class Consent(BaseConsent): self.revokation_date = datetime.datetime.now(datetime.timezone.utc) self.save() - tokens = models.Token.query( + tokens = BaseBackend.get().query( + models.Token, client=self.client, subject=self.subject, ) diff --git a/canaille/oidc/oauth.py b/canaille/oidc/oauth.py index b4031ad3..30b57294 100644 --- a/canaille/oidc/oauth.py +++ b/canaille/oidc/oauth.py @@ -112,7 +112,9 @@ def openid_configuration(): def exists_nonce(nonce, req): client = models.Client.get(id=req.client_id) - exists = models.AuthorizationCode.query(client=client, nonce=nonce) + exists = BaseBackend.get().query( + models.AuthorizationCode, client=client, nonce=nonce + ) return bool(exists) @@ -237,7 +239,9 @@ class AuthorizationCodeGrant(_AuthorizationCodeGrant): return save_authorization_code(code, request) def query_authorization_code(self, code, client): - item = models.AuthorizationCode.query(code=code, client=client) + item = BaseBackend.get().query( + models.AuthorizationCode, code=code, client=client + ) if item and not item[0].is_expired(): return item[0] @@ -283,7 +287,7 @@ class RefreshTokenGrant(_RefreshTokenGrant): TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"] def authenticate_refresh_token(self, refresh_token): - token = models.Token.query(refresh_token=refresh_token) + token = BaseBackend.get().query(models.Token, refresh_token=refresh_token) if token and token[0].is_refresh_token_active(): return token[0] diff --git a/tests/backends/ldap/test_utils.py b/tests/backends/ldap/test_utils.py index 83830b75..faec2a55 100644 --- a/tests/backends/ldap/test_utils.py +++ b/tests/backends/ldap/test_utils.py @@ -81,15 +81,15 @@ def test_special_chars_in_rdn(testclient, backend): def test_filter(backend, foo_group, bar_group): - assert models.Group.query(display_name="foo") == [foo_group] - assert models.Group.query(display_name="bar") == [bar_group] + assert backend.query(models.Group, display_name="foo") == [foo_group] + assert backend.query(models.Group, display_name="bar") == [bar_group] - assert models.Group.query(display_name="foo") != 3 + assert backend.query(models.Group, display_name="foo") != 3 - assert models.Group.query(display_name=["foo"]) == [foo_group] - assert models.Group.query(display_name=["bar"]) == [bar_group] + assert backend.query(models.Group, display_name=["foo"]) == [foo_group] + assert backend.query(models.Group, display_name=["bar"]) == [bar_group] - assert set(models.Group.query(display_name=["foo", "bar"])) == { + assert set(backend.query(models.Group, display_name=["foo", "bar"])) == { foo_group, bar_group, } diff --git a/tests/backends/test_models.py b/tests/backends/test_models.py index ef8c1e0d..7ca216e0 100644 --- a/tests/backends/test_models.py +++ b/tests/backends/test_models.py @@ -36,16 +36,16 @@ def test_model_lifecycle(testclient, backend): ) assert not user.id - assert not models.User.query() - assert not models.User.query(id=user.id) - assert not models.User.query(id="invalid") + assert not backend.query(models.User) + assert not backend.query(models.User, id=user.id) + assert not backend.query(models.User, id="invalid") assert not models.User.get(id=user.id) user.save() - assert models.User.query() == [user] - assert models.User.query(id=user.id) == [user] - assert not models.User.query(id="invalid") + assert backend.query(models.User) == [user] + assert backend.query(models.User, id=user.id) == [user] + assert not backend.query(models.User, id="invalid") assert models.User.get(id=user.id) == user user.family_name = "new_family_name" @@ -58,7 +58,7 @@ def test_model_lifecycle(testclient, backend): user.delete() - assert not models.User.query(id=user.id) + assert not backend.query(models.User, id=user.id) assert not models.User.get(id=user.id) user.delete() @@ -143,7 +143,7 @@ def test_model_indexation(testclient, backend): def test_fuzzy_unique_attribute(user, moderator, admin, backend): - assert set(models.User.query()) == {user, moderator, admin} + assert set(backend.query(models.User)) == {user, moderator, admin} assert set(models.User.fuzzy("Jack")) == {moderator} assert set(models.User.fuzzy("Jack", ["formatted_name"])) == {moderator} assert set(models.User.fuzzy("Jack", ["user_name"])) == set() @@ -157,7 +157,7 @@ def test_fuzzy_unique_attribute(user, moderator, admin, backend): def test_fuzzy_multiple_attribute(user, moderator, admin, backend): - assert set(models.User.query()) == {user, moderator, admin} + assert set(backend.query(models.User)) == {user, moderator, admin} assert set(models.User.fuzzy("jack@doe.com")) == {moderator} assert set(models.User.fuzzy("jack@doe.com", ["emails"])) == {moderator} assert set(models.User.fuzzy("jack@doe.com", ["formatted_name"])) == set() @@ -171,8 +171,8 @@ def test_fuzzy_multiple_attribute(user, moderator, admin, backend): def test_model_references(testclient, user, foo_group, admin, bar_group, backend): assert foo_group.members == [user] assert user.groups == [foo_group] - assert foo_group in models.Group.query(members=user) - assert user in models.User.query(groups=foo_group) + assert foo_group in backend.query(models.Group, members=user) + assert user in backend.query(models.User, groups=foo_group) assert user not in bar_group.members assert bar_group not in user.groups diff --git a/tests/core/commands/test_populate.py b/tests/core/commands/test_populate.py index 26989680..93fad3e5 100644 --- a/tests/core/commands/test_populate.py +++ b/tests/core/commands/test_populate.py @@ -6,11 +6,11 @@ from canaille.core.populate import fake_users def test_populate_users(testclient, backend): runner = testclient.app.test_cli_runner() - assert len(models.User.query()) == 0 + assert len(backend.query(models.User)) == 0 res = runner.invoke(cli, ["populate", "--nb", "10", "users"]) assert res.exit_code == 0, res.stdout - assert len(models.User.query()) == 10 - for user in models.User.query(): + assert len(backend.query(models.User)) == 10 + for user in backend.query(models.User): user.delete() @@ -18,13 +18,13 @@ def test_populate_groups(testclient, backend): fake_users(10) runner = testclient.app.test_cli_runner() - assert len(models.Group.query()) == 0 + assert len(backend.query(models.Group)) == 0 res = runner.invoke(cli, ["populate", "--nb", "10", "groups"]) assert res.exit_code == 0, res.stdout - assert len(models.Group.query()) == 10 + assert len(backend.query(models.Group)) == 10 - for group in models.Group.query(): + for group in backend.query(models.Group): group.delete() - for user in models.User.query(): + for user in backend.query(models.User): user.delete() diff --git a/tests/core/test_groups.py b/tests/core/test_groups.py index c8f24744..f0087d27 100644 --- a/tests/core/test_groups.py +++ b/tests/core/test_groups.py @@ -4,7 +4,7 @@ from canaille.core.populate import fake_users def test_no_group(app, backend): - assert models.Group.query() == [] + assert backend.query(models.Group) == [] def test_group_list_pagination(testclient, logged_admin, foo_group): diff --git a/tests/core/test_registration.py b/tests/core/test_registration.py index b989db4b..7d41ef95 100644 --- a/tests/core/test_registration.py +++ b/tests/core/test_registration.py @@ -12,7 +12,7 @@ def test_registration_without_email_validation(testclient, backend, foo_group): testclient.app.config["CANAILLE"]["ENABLE_REGISTRATION"] = True testclient.app.config["CANAILLE"]["EMAIL_CONFIRMATION"] = False - assert not models.User.query(user_name="newuser") + assert not backend.query(models.User, user_name="newuser") res = testclient.get(url_for("core.account.registration"), status=200) res.form["user_name"] = "newuser" res.form["password1"] = "password" @@ -60,7 +60,7 @@ def test_registration_with_email_validation(testclient, backend, smtpd, foo_grou text_mail = smtpd.messages[0].get_payload()[0].get_payload(decode=True).decode() assert registration_url in text_mail - assert not models.User.query(user_name="newuser") + assert not backend.query(models.User, user_name="newuser") with time_machine.travel("2020-01-01 02:01:00+00:00", tick=False): res = testclient.get(registration_url, status=200) res.form["user_name"] = "newuser" diff --git a/tests/oidc/test_authorization_code_flow.py b/tests/oidc/test_authorization_code_flow.py index 6b101e7e..b9d5f45d 100644 --- a/tests/oidc/test_authorization_code_flow.py +++ b/tests/oidc/test_authorization_code_flow.py @@ -13,8 +13,10 @@ from canaille.app import models from . import client_credentials -def test_nominal_case(testclient, logged_user, client, keypair, trusted_client): - assert not models.Consent.query() +def test_nominal_case( + testclient, logged_user, client, keypair, trusted_client, backend +): + assert not backend.query(models.Consent) res = testclient.get( "/oauth/authorize", @@ -43,7 +45,7 @@ def test_nominal_case(testclient, logged_user, client, keypair, trusted_client): "phone", } - consents = models.Consent.query(client=client, subject=logged_user) + consents = backend.query(models.Consent, client=client, subject=logged_user) assert set(consents[0].scope) == { "openid", "profile", @@ -112,8 +114,10 @@ def test_invalid_client(testclient, logged_user, keypair): ) -def test_redirect_uri(testclient, logged_user, client, keypair, trusted_client): - assert not models.Consent.query() +def test_redirect_uri( + testclient, logged_user, client, keypair, trusted_client, backend +): + assert not backend.query(models.Consent) res = testclient.get( "/oauth/authorize", @@ -134,7 +138,7 @@ def test_redirect_uri(testclient, logged_user, client, keypair, trusted_client): code = params["code"][0] authcode = models.AuthorizationCode.get(code=code) assert authcode is not None - consents = models.Consent.query(client=client, subject=logged_user) + consents = backend.query(models.Consent, client=client, subject=logged_user) res = testclient.post( "/oauth/token", @@ -157,8 +161,10 @@ def test_redirect_uri(testclient, logged_user, client, keypair, trusted_client): consent.delete() -def test_preconsented_client(testclient, logged_user, client, keypair, trusted_client): - assert not models.Consent.query() +def test_preconsented_client( + testclient, logged_user, client, keypair, trusted_client, backend +): + assert not backend.query(models.Consent) client.preconsent = True client.save() @@ -180,7 +186,7 @@ def test_preconsented_client(testclient, logged_user, client, keypair, trusted_c authcode = models.AuthorizationCode.get(code=code) assert authcode is not None - consents = models.Consent.query(client=client, subject=logged_user) + consents = backend.query(models.Consent, client=client, subject=logged_user) assert not consents res = testclient.post( @@ -214,8 +220,8 @@ def test_preconsented_client(testclient, logged_user, client, keypair, trusted_c assert res.json["name"] == "John (johnny) Doe" -def test_logout_login(testclient, logged_user, client): - assert not models.Consent.query() +def test_logout_login(testclient, logged_user, client, backend): + assert not backend.query(models.Consent) res = testclient.get( "/oauth/authorize", @@ -254,7 +260,7 @@ def test_logout_login(testclient, logged_user, client): authcode = models.AuthorizationCode.get(code=code) assert authcode is not None - consents = models.Consent.query(client=client, subject=logged_user) + consents = backend.query(models.Consent, client=client, subject=logged_user) assert "profile" in consents[0].scope res = testclient.post( @@ -285,8 +291,8 @@ def test_logout_login(testclient, logged_user, client): consent.delete() -def test_deny(testclient, logged_user, client): - assert not models.Consent.query() +def test_deny(testclient, logged_user, client, backend): + assert not backend.query(models.Consent) res = testclient.get( "/oauth/authorize", @@ -305,11 +311,11 @@ def test_deny(testclient, logged_user, client): error = params["error"][0] assert error == "access_denied" - assert not models.Consent.query() + assert not backend.query(models.Consent) -def test_code_challenge(testclient, logged_user, client): - assert not models.Consent.query() +def test_code_challenge(testclient, logged_user, client, backend): + assert not backend.query(models.Consent) client.token_endpoint_auth_method = "none" client.save() @@ -338,7 +344,7 @@ def test_code_challenge(testclient, logged_user, client): authcode = models.AuthorizationCode.get(code=code) assert authcode is not None - consents = models.Consent.query(client=client, subject=logged_user) + consents = backend.query(models.Consent, client=client, subject=logged_user) assert "profile" in consents[0].scope res = testclient.post( @@ -373,8 +379,8 @@ def test_code_challenge(testclient, logged_user, client): consent.delete() -def test_consent_already_given(testclient, logged_user, client): - assert not models.Consent.query() +def test_consent_already_given(testclient, logged_user, client, backend): + assert not backend.query(models.Consent) res = testclient.get( "/oauth/authorize", @@ -395,7 +401,7 @@ def test_consent_already_given(testclient, logged_user, client): authcode = models.AuthorizationCode.get(code=code) assert authcode is not None - consents = models.Consent.query(client=client, subject=logged_user) + consents = backend.query(models.Consent, client=client, subject=logged_user) assert "profile" in consents[0].scope res = testclient.post( @@ -430,9 +436,9 @@ def test_consent_already_given(testclient, logged_user, client): def test_when_consent_already_given_but_for_a_smaller_scope( - testclient, logged_user, client + testclient, logged_user, client, backend ): - assert not models.Consent.query() + assert not backend.query(models.Consent) res = testclient.get( "/oauth/authorize", @@ -453,7 +459,7 @@ def test_when_consent_already_given_but_for_a_smaller_scope( authcode = models.AuthorizationCode.get(code=code) assert authcode is not None - consents = models.Consent.query(client=client, subject=logged_user) + consents = backend.query(models.Consent, client=client, subject=logged_user) assert "profile" in consents[0].scope assert "groups" not in consents[0].scope @@ -489,7 +495,7 @@ def test_when_consent_already_given_but_for_a_smaller_scope( authcode = models.AuthorizationCode.get(code=code) assert authcode is not None - consents = models.Consent.query(client=client, subject=logged_user) + consents = backend.query(models.Consent, client=client, subject=logged_user) assert "profile" in consents[0].scope assert "groups" in consents[0].scope @@ -535,8 +541,8 @@ def test_nonce_required_in_oidc_requests(testclient, logged_user, client): assert res.json.get("error") == "invalid_request" -def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client): - assert not models.Consent.query() +def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client, backend): + assert not backend.query(models.Consent) testclient.app.config["CANAILLE_OIDC"]["REQUIRE_NONCE"] = False res = testclient.get( @@ -552,12 +558,12 @@ def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client): res = res.form.submit(name="answer", value="accept", status=302) assert res.location.startswith(client.redirect_uris[0]) - for consent in models.Consent.query(): + for consent in backend.query(models.Consent): consent.delete() -def test_request_scope_too_large(testclient, logged_user, keypair, client): - assert not models.Consent.query() +def test_request_scope_too_large(testclient, logged_user, keypair, client, backend): + assert not backend.query(models.Consent) client.scope = ["openid", "profile", "groups"] client.save() @@ -582,7 +588,7 @@ def test_request_scope_too_large(testclient, logged_user, keypair, client): "profile", } - consents = models.Consent.query(client=client, subject=logged_user) + consents = backend.query(models.Consent, client=client, subject=logged_user) assert set(consents[0].scope) == { "openid", "profile", diff --git a/tests/oidc/test_consent.py b/tests/oidc/test_consent.py index fd1f2c19..b4547a3a 100644 --- a/tests/oidc/test_consent.py +++ b/tests/oidc/test_consent.py @@ -95,7 +95,7 @@ def test_someone_else_consent_restoration( def test_oidc_authorization_after_revokation( - testclient, logged_user, client, keypair, consent + testclient, logged_user, client, keypair, consent, backend ): consent.revoke() @@ -114,7 +114,7 @@ def test_oidc_authorization_after_revokation( res = res.form.submit(name="answer", value="accept", status=302) - consents = models.Consent.query(client=client, subject=logged_user) + consents = backend.query(models.Consent, client=client, subject=logged_user) consent.reload() assert consents[0] == consent assert not consent.revoked diff --git a/tests/oidc/test_forms.py b/tests/oidc/test_forms.py index ce76aeac..dae00e7e 100644 --- a/tests/oidc/test_forms.py +++ b/tests/oidc/test_forms.py @@ -5,8 +5,8 @@ from canaille.app import models # forms. -def test_fieldlist_add(testclient, logged_admin): - assert not models.Client.query() +def test_fieldlist_add(testclient, logged_admin, backend): + assert not backend.query(models.Client) res = testclient.get("/admin/client/add") assert "redirect_uris-1" not in res.form.fields @@ -23,7 +23,7 @@ def test_fieldlist_add(testclient, logged_admin): res.form[k].force_value(v) res = res.form.submit(status=200, name="fieldlist_add", value="redirect_uris-0") - assert not models.Client.query() + assert not backend.query(models.Client) data["redirect_uris-1"] = "https://foo.bar/callback2" for k, v in data.items(): @@ -43,8 +43,8 @@ def test_fieldlist_add(testclient, logged_admin): client.delete() -def test_fieldlist_delete(testclient, logged_admin): - assert not models.Client.query() +def test_fieldlist_delete(testclient, logged_admin, backend): + assert not backend.query(models.Client) res = testclient.get("/admin/client/add") data = { @@ -61,7 +61,7 @@ def test_fieldlist_delete(testclient, logged_admin): res.form["redirect_uris-1"] = "https://foo.bar/callback2" res = res.form.submit(status=200, name="fieldlist_remove", value="redirect_uris-1") - assert not models.Client.query() + assert not backend.query(models.Client) assert "redirect_uris-1" not in res.form.fields res = res.form.submit(status=302, name="action", value="edit") @@ -92,8 +92,8 @@ def test_fieldlist_add_invalid_field(testclient, logged_admin): testclient.post("/admin/client/add", data, status=400) -def test_fieldlist_delete_invalid_field(testclient, logged_admin): - assert not models.Client.query() +def test_fieldlist_delete_invalid_field(testclient, logged_admin, backend): + assert not backend.query(models.Client) res = testclient.get("/admin/client/add") data = { diff --git a/tests/oidc/test_refresh_token.py b/tests/oidc/test_refresh_token.py index 3e17bd7e..548ecd42 100644 --- a/tests/oidc/test_refresh_token.py +++ b/tests/oidc/test_refresh_token.py @@ -7,8 +7,8 @@ from canaille.app import models from . import client_credentials -def test_refresh_token(testclient, logged_user, client): - assert not models.Consent.query() +def test_refresh_token(testclient, logged_user, client, backend): + assert not backend.query(models.Consent) res = testclient.get( "/oauth/authorize", @@ -27,7 +27,7 @@ def test_refresh_token(testclient, logged_user, client): authcode = models.AuthorizationCode.get(code=code) assert authcode is not None - consents = models.Consent.query(client=client, subject=logged_user) + consents = backend.query(models.Consent, client=client, subject=logged_user) assert "profile" in consents[0].scope res = testclient.post( diff --git a/tests/oidc/test_token_expiration.py b/tests/oidc/test_token_expiration.py index b92b7afe..d4883dcd 100644 --- a/tests/oidc/test_token_expiration.py +++ b/tests/oidc/test_token_expiration.py @@ -9,7 +9,9 @@ from canaille.oidc.oauth import setup_oauth from . import client_credentials -def test_token_default_expiration_date(testclient, logged_user, client, keypair): +def test_token_default_expiration_date( + testclient, logged_user, client, keypair, backend +): res = testclient.get( "/oauth/authorize", params=dict( @@ -52,12 +54,14 @@ def test_token_default_expiration_date(testclient, logged_user, client, keypair) claims = jwt.decode(id_token, keypair[1]) assert claims["exp"] - claims["iat"] == 3600 - consents = models.Consent.query(client=client, subject=logged_user) + consents = backend.query(models.Consent, client=client, subject=logged_user) for consent in consents: consent.delete() -def test_token_custom_expiration_date(testclient, logged_user, client, keypair): +def test_token_custom_expiration_date( + testclient, logged_user, client, keypair, backend +): testclient.app.config["OAUTH2_TOKEN_EXPIRES_IN"] = { "authorization_code": 1000, "implicit": 2000, @@ -110,6 +114,6 @@ def test_token_custom_expiration_date(testclient, logged_user, client, keypair): claims = jwt.decode(id_token, keypair[1]) assert claims["exp"] - claims["iat"] == 6000 - consents = models.Consent.query(client=client, subject=logged_user) + consents = backend.query(models.Consent, client=client, subject=logged_user) for consent in consents: consent.delete()