refactor: move BackendModel.query to Backend.query

This commit is contained in:
Éloi Rivard 2024-04-10 15:44:11 +02:00
parent 93fa708b1c
commit 8425b2a3b8
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184
29 changed files with 284 additions and 247 deletions

View file

@ -15,6 +15,7 @@ from canaille.app.i18n import DEFAULT_LANGUAGE_CODE
from canaille.app.i18n import gettext as _
from canaille.app.i18n import locale_selector
from canaille.app.i18n import timezone_selector
from canaille.backends import BaseBackend
from . import validate_uri
from .flask import request_is_htmx
@ -188,7 +189,7 @@ class TableForm(I18NFormMixin, FlaskForm):
if self.query.data:
self.items = cls.fuzzy(self.query.data, fields, **filter)
else:
self.items = cls.query(**filter)
self.items = BaseBackend.get().query(cls, **filter)
self.page_size = page_size
self.nb_items = len(self.items)

View file

@ -54,6 +54,25 @@ class BaseBackend:
"""
raise NotImplementedError()
def query(self, model, **kwargs):
"""
Perform a query on the database and return a collection of instances.
Parameters can be any valid attribute with the expected value:
>>> backend.query(User, first_name="George")
If several arguments are passed, the methods only returns the model
instances that return matches all the argument values:
>>> backend.query(User, first_name="George", last_name="Abitbol")
If the argument value is a collection, the methods will return the
models that matches any of the values:
>>> backend.query(User, first_name=["George", "Jane"])
"""
raise NotImplementedError()
def check_user_password(self, user, password: str) -> bool:
"""Check if the password matches the user password in the database."""
raise NotImplementedError()

View file

@ -16,6 +16,7 @@ from canaille.app.i18n import gettext as _
from canaille.backends import BaseBackend
from .utils import listify
from .utils import python_attrs_to_ldap
@contextmanager
@ -243,13 +244,64 @@ class Backend(BaseBackend):
return result, message
def set_user_password(self, user, password):
conn = Backend.get().connection
conn = self.connection
conn.passwd_s(
user.dn,
None,
password.encode("utf-8"),
)
def query(self, model, dn=None, filter=None, **kwargs):
from .ldapobjectquery import LDAPObjectQuery
base = dn
if dn is None:
base = f"{model.base},{model.root_dn}"
elif "=" not in base:
base = ldap.dn.escape_dn_chars(base)
base = f"{model.rdn_attribute}={base},{model.base},{model.root_dn}"
class_filter = (
"".join([f"(objectClass={oc})" for oc in model.ldap_object_class])
if getattr(model, "ldap_object_class")
else ""
)
if class_filter:
class_filter = f"(|{class_filter})"
arg_filter = ""
ldap_args = python_attrs_to_ldap(
{
model.python_attribute_to_ldap(name): values
for name, values in kwargs.items()
if values is not None
},
encode=False,
)
for key, value in ldap_args.items():
if len(value) == 1:
escaped_value = ldap.filter.escape_filter_chars(value[0])
arg_filter += f"({key}={escaped_value})"
else:
values = [ldap.filter.escape_filter_chars(v) for v in value]
arg_filter += (
"(|" + "".join([f"({key}={value})" for value in values]) + ")"
)
if not filter:
filter = ""
ldapfilter = f"(&{class_filter}{arg_filter}{filter})"
base = base or f"{model.base},{model.root_dn}"
try:
result = self.connection.search_s(
base, ldap.SCOPE_SUBTREE, ldapfilter or None, ["+", "*"]
)
except ldap.NO_SUCH_OBJECT:
result = []
return LDAPObjectQuery(model, result)
def setup_ldap_models(config):
from canaille.app import models

View file

@ -4,42 +4,15 @@ import ldap.dn
import ldap.filter
from ldap.controls.readentry import PostReadControl
from canaille.backends import BaseBackend
from canaille.backends.models import BackendModel
from .backend import Backend
from .ldapobjectquery import LDAPObjectQuery
from .utils import attribute_ldap_syntax
from .utils import cardinalize_attribute
from .utils import ldap_to_python
from .utils import listify
from .utils import python_to_ldap
def python_attrs_to_ldap(attrs, encode=True, null_allowed=True):
formatted_attrs = {
name: [
python_to_ldap(value, attribute_ldap_syntax(name), encode=encode)
for value in listify(values)
]
for name, values in attrs.items()
}
if not null_allowed:
formatted_attrs = {
key: [value for value in values if value]
for key, values in formatted_attrs.items()
if values
}
return formatted_attrs
def attribute_ldap_syntax(attribute_name):
ldap_attrs = LDAPObject.ldap_object_attributes()
if attribute_name not in ldap_attrs:
return None
if ldap_attrs[attribute_name].syntax:
return ldap_attrs[attribute_name].syntax
return attribute_ldap_syntax(ldap_attrs[attribute_name].sup[0])
from .utils import python_attrs_to_ldap
class LDAPObjectMetaclass(type):
@ -256,7 +229,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
@classmethod
def get(cls, identifier=None, /, **kwargs):
try:
return cls.query(identifier, **kwargs)[0]
return BaseBackend.get().query(cls, identifier, **kwargs)[0]
except (IndexError, ldap.NO_SUCH_OBJECT):
if identifier and cls.base:
return (
@ -267,58 +240,6 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
return None
@classmethod
def query(cls, dn=None, filter=None, **kwargs):
conn = Backend.get().connection
base = dn
if dn is None:
base = f"{cls.base},{cls.root_dn}"
elif "=" not in base:
base = ldap.dn.escape_dn_chars(base)
base = f"{cls.rdn_attribute}={base},{cls.base},{cls.root_dn}"
class_filter = (
"".join([f"(objectClass={oc})" for oc in cls.ldap_object_class])
if getattr(cls, "ldap_object_class")
else ""
)
if class_filter:
class_filter = f"(|{class_filter})"
arg_filter = ""
ldap_args = python_attrs_to_ldap(
{
cls.python_attribute_to_ldap(name): values
for name, values in kwargs.items()
if values is not None
},
encode=False,
)
for key, value in ldap_args.items():
if len(value) == 1:
escaped_value = ldap.filter.escape_filter_chars(value[0])
arg_filter += f"({key}={escaped_value})"
else:
values = [ldap.filter.escape_filter_chars(v) for v in value]
arg_filter += (
"(|" + "".join([f"({key}={value})" for value in values]) + ")"
)
if not filter:
filter = ""
ldapfilter = f"(&{class_filter}{arg_filter}{filter})"
base = base or f"{cls.base},{cls.root_dn}"
try:
result = conn.search_s(
base, ldap.SCOPE_SUBTREE, ldapfilter or None, ["+", "*"]
)
except ldap.NO_SUCH_OBJECT:
result = []
return LDAPObjectQuery(cls, result)
@classmethod
def fuzzy(cls, query, attributes=None, **kwargs):
query = ldap.filter.escape_filter_chars(query)
@ -327,7 +248,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
filter = (
"(|" + "".join(f"({attribute}=*{query}*)" for attribute in attributes) + ")"
)
return cls.query(filter=filter, **kwargs)
return BaseBackend.get().query(cls, filter=filter, **kwargs)
@classmethod
def update_ldap_attributes(cls):

View file

@ -95,3 +95,33 @@ def cardinalize_attribute(python_unique, value):
return value[0]
return [v for v in value if v is not None]
def python_attrs_to_ldap(attrs, encode=True, null_allowed=True):
formatted_attrs = {
name: [
python_to_ldap(value, attribute_ldap_syntax(name), encode=encode)
for value in listify(values)
]
for name, values in attrs.items()
}
if not null_allowed:
formatted_attrs = {
key: [value for value in values if value]
for key, values in formatted_attrs.items()
if values
}
return formatted_attrs
def attribute_ldap_syntax(attribute_name):
from .ldapobject import LDAPObject
ldap_attrs = LDAPObject.ldap_object_attributes()
if attribute_name not in ldap_attrs:
return None
if ldap_attrs[attribute_name].syntax:
return ldap_attrs[attribute_name].syntax
return attribute_ldap_syntax(ldap_attrs[attribute_name].sup[0])

View file

@ -40,3 +40,28 @@ class Backend(BaseBackend):
def set_user_password(self, user, password):
user.password = password
user.save()
def query(self, model, **kwargs):
# if there is no filter, return all models
if not kwargs:
states = model.index().values()
return [model(**state) for state in states]
# get the ids from the attribute indexes
ids = {
id
for attribute, values in kwargs.items()
for value in model.serialize(model.listify(values))
for id in model.attribute_index(attribute).get(value, [])
}
# get the states from the ids
states = [model.index()[id] for id in ids]
# initialize instances from the states
instances = [model(**state) for state in states]
for instance in instances:
# TODO: maybe find a way to not initialize the cache in the first place?
instance._cache = {}
return instances

View file

@ -6,6 +6,7 @@ import uuid
import canaille.core.models
import canaille.oidc.models
from canaille.app import models
from canaille.backends import BaseBackend
from canaille.backends.models import BackendModel
@ -25,31 +26,6 @@ class MemoryModel(BackendModel):
def __repr__(self):
return f"<{self.__class__.__name__} id={self.id}>"
@classmethod
def query(cls, **kwargs):
# if there is no filter, return all models
if not kwargs:
states = cls.index().values()
return [cls(**state) for state in states]
# get the ids from the attribute indexes
ids = {
id
for attribute, values in kwargs.items()
for value in cls.serialize(cls.listify(values))
for id in cls.attribute_index(attribute).get(value, [])
}
# get the states from the ids
states = [cls.index()[id] for id in ids]
# initialize instances from the states
instances = [cls(**state) for state in states]
for instance in instances:
# TODO: maybe find a way to not initialize the cache in the first place?
instance._cache = {}
return instances
@classmethod
def index(cls, class_name=None):
return MemoryModel.indexes.setdefault(class_name or cls.__name__, {})
@ -63,7 +39,7 @@ class MemoryModel(BackendModel):
@classmethod
def fuzzy(cls, query, attributes=None, **kwargs):
attributes = attributes or cls.attributes
instances = cls.query(**kwargs)
instances = BaseBackend.get().query(cls, **kwargs)
return [
instance
@ -85,7 +61,7 @@ class MemoryModel(BackendModel):
or None
)
results = cls.query(**kwargs)
results = BaseBackend.get().query(cls, **kwargs)
return results[0] if results else None
@classmethod

View file

@ -87,37 +87,16 @@ class BackendModel:
implemented for every model and for every backend.
"""
@classmethod
def query(cls, **kwargs):
"""Perform a query on the database and return a collection of
instances.
Parameters can be any valid attribute with the expected value:
>>> User.query(first_name="George")
If several arguments are passed, the methods only returns the model
instances that return matches all the argument values:
>>> User.query(first_name="George", last_name="Abitbol")
If the argument value is a collection, the methods will return the
models that matches any of the values:
>>> User.query(first_name=["George", "Jane"])
"""
raise NotImplementedError()
@classmethod
def fuzzy(cls, query, attributes=None, **kwargs):
"""Works like :meth:`~canaille.backends.models.BackendModel.query` but
"""Works like :meth:`~canaille.backends.BaseBackend.query` but
attribute values loosely be matched."""
raise NotImplementedError()
@classmethod
def get(cls, identifier=None, **kwargs):
"""Works like :meth:`~canaille.backends.models.BackendModel.query` but
return only one element or :py:data:`None` if no item is matching."""
"""Works like :meth:`~canaille.backends.BaseBackend.query` but return
only one element or :py:data:`None` if no item is matching."""
raise NotImplementedError()
def save(self):

View file

@ -1,4 +1,5 @@
from sqlalchemy import create_engine
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import declarative_base
@ -65,3 +66,15 @@ class Backend(BaseBackend):
def set_user_password(self, user, password):
user.password = password
user.save()
def query(self, model, **kwargs):
filter = [
model.attribute_filter(attribute_name, expected_value)
for attribute_name, expected_value in kwargs.items()
]
return (
Backend.get()
.db_session.execute(select(model).filter(*filter))
.scalars()
.all()
)

View file

@ -36,19 +36,6 @@ class SqlAlchemyModel(BackendModel):
f"<{self.__class__.__name__} {self.identifier_attribute}={self.identifier}>"
)
@classmethod
def query(cls, **kwargs):
filter = [
cls.attribute_filter(attribute_name, expected_value)
for attribute_name, expected_value in kwargs.items()
]
return (
Backend.get()
.db_session.execute(select(cls).filter(*filter))
.scalars()
.all()
)
@classmethod
def fuzzy(cls, query, attributes=None, **kwargs):
attributes = attributes or cls.attributes

View file

@ -86,7 +86,7 @@ def join():
form = JoinForm(request.form or None)
if request.form and form.validate():
if models.User.query(emails=form.email.data):
if BaseBackend.get().query(models.User, emails=form.email.data):
flash(
_(
"You will receive soon an email to continue the registration process."
@ -295,7 +295,10 @@ def registration(data=None, hash=None):
if "groups" not in form and payload and payload.groups:
form["groups"] = wtforms.SelectMultipleField(
_("Groups"),
choices=[(group, group.display_name) for group in models.Group.query()],
choices=[
(group, group.display_name)
for group in BaseBackend.get().query(models.Group)
],
coerce=IDToModel("Group"),
)
set_readonly(form["groups"])
@ -388,7 +391,7 @@ def email_confirmation(data, hash):
)
return redirect(url_for("core.account.index"))
if models.User.query(emails=confirmation_obj.email):
if BaseBackend.get().query(models.User, emails=confirmation_obj.email):
flash(
_("This address email is already associated with another account."),
"error",

View file

@ -312,7 +312,10 @@ PROFILE_FORM_FIELDS = dict(
groups=wtforms.SelectMultipleField(
_("Groups"),
default=[],
choices=lambda: [(group, group.display_name) for group in models.Group.query()],
choices=lambda: [
(group, group.display_name)
for group in BaseBackend.get().query(models.Group)
],
render_kw={"placeholder": _("users, admins …")},
coerce=IDToModel("Group"),
validators=[non_empty_groups],
@ -333,7 +336,7 @@ def build_profile_form(write_field_names, readonly_field_names, user=None):
if PROFILE_FORM_FIELDS.get(name)
}
if "groups" in fields and not models.Group.query():
if "groups" in fields and not BaseBackend.get().query(models.Group):
del fields["groups"]
if current_app.backend.get().has_account_lockability(): # pragma: no branch
@ -436,7 +439,10 @@ class InvitationForm(Form):
)
groups = wtforms.SelectMultipleField(
_("Groups"),
choices=lambda: [(group, group.display_name) for group in models.Group.query()],
choices=lambda: [
(group, group.display_name)
for group in BaseBackend.get().query(models.Group)
],
render_kw={},
coerce=IDToModel("Group"),
)

View file

@ -5,6 +5,7 @@ from faker.config import AVAILABLE_LOCALES
from canaille.app import models
from canaille.app.i18n import available_language_codes
from canaille.backends import BaseBackend
def fake_users(nb=1):
@ -47,7 +48,7 @@ def fake_users(nb=1):
def fake_groups(nb=1, nb_users_max=1):
users = models.User.query()
users = BaseBackend.get().query(models.User)
groups = list()
fake = faker.Faker(["en_US"])
for _ in range(nb):

View file

@ -3,6 +3,7 @@ from flask.cli import with_appcontext
from canaille.app import models
from canaille.app.commands import with_backendcontext
from canaille.backends import BaseBackend
@click.command()
@ -10,11 +11,11 @@ from canaille.app.commands import with_backendcontext
@with_backendcontext
def clean():
"""Remove expired tokens and authorization codes."""
for t in models.Token.query():
for t in BaseBackend.get().query(models.Token):
if t.is_expired():
t.delete()
for a in models.AuthorizationCode.query():
for a in BaseBackend.get().query(models.AuthorizationCode):
if a.is_expired():
a.delete()

View file

@ -10,6 +10,7 @@ from canaille.app import models
from canaille.app.flask import user_needed
from canaille.app.i18n import gettext as _
from canaille.app.themes import render_template
from canaille.backends import BaseBackend
from ..utils import SCOPE_DETAILS
@ -19,13 +20,13 @@ bp = Blueprint("consents", __name__, url_prefix="/consent")
@bp.route("/")
@user_needed()
def consents(user):
consents = models.Consent.query(subject=user)
consents = BaseBackend.get().query(models.Consent, subject=user)
clients = {t.client for t in consents}
nb_consents = len(consents)
nb_preconsents = sum(
1
for client in models.Client.query()
for client in BaseBackend.get().query(models.Client)
if client.preconsent and client not in clients
)
@ -43,11 +44,11 @@ def consents(user):
@bp.route("/pre-consents")
@user_needed()
def pre_consents(user):
consents = models.Consent.query(subject=user)
consents = BaseBackend.get().query(models.Consent, subject=user)
clients = {t.client for t in consents}
preconsented = [
client
for client in models.Client.query()
for client in BaseBackend.get().query(models.Client)
if client.preconsent and client not in clients
]

View file

@ -7,6 +7,7 @@ from canaille.app.forms import email_validator
from canaille.app.forms import is_uri
from canaille.app.forms import unique_values
from canaille.app.i18n import lazy_gettext as _
from canaille.backends import BaseBackend
class AuthorizeForm(Form):
@ -18,7 +19,10 @@ class LogoutForm(Form):
def client_audiences():
return [(client, client.client_name) for client in models.Client.query()]
return [
(client, client.client_name)
for client in BaseBackend.get().query(models.Client)
]
class ClientAddForm(Form):

View file

@ -23,6 +23,7 @@ from canaille.app.flask import logout_user
from canaille.app.flask import set_parameter_in_url_query
from canaille.app.i18n import gettext as _
from canaille.app.themes import render_template
from canaille.backends import BaseBackend
from ..oauth import ClientConfigurationEndpoint
from ..oauth import ClientRegistrationEndpoint
@ -109,7 +110,8 @@ def authorize_login(user):
def authorize_consent(client, user):
requested_scopes = request.args.get("scope", "").split(" ")
allowed_scopes = client.get_allowed_scope(requested_scopes).split(" ")
consents = models.Consent.query(
consents = BaseBackend.get().query(
models.Consent,
client=client,
subject=user,
)

View file

@ -8,6 +8,7 @@ from authlib.oauth2.rfc6749 import TokenMixin
from authlib.oauth2.rfc6749 import util
from canaille.app import models
from canaille.backends import BaseBackend
from .basemodels import AuthorizationCode as BaseAuthorizationCode
from .basemodels import Client as BaseClient
@ -95,13 +96,13 @@ class Client(BaseClient, ClientMixin):
return metadata
def delete(self):
for consent in models.Consent.query(client=self):
for consent in BaseBackend.get().query(models.Consent, client=self):
consent.delete()
for code in models.AuthorizationCode.query(client=self):
for code in BaseBackend.get().query(models.AuthorizationCode, client=self):
code.delete()
for token in models.Token.query(client=self):
for token in BaseBackend.get().query(models.Token, client=self):
token.delete()
super().delete()
@ -185,7 +186,8 @@ class Consent(BaseConsent):
self.revokation_date = datetime.datetime.now(datetime.timezone.utc)
self.save()
tokens = models.Token.query(
tokens = BaseBackend.get().query(
models.Token,
client=self.client,
subject=self.subject,
)

View file

@ -112,7 +112,9 @@ def openid_configuration():
def exists_nonce(nonce, req):
client = models.Client.get(id=req.client_id)
exists = models.AuthorizationCode.query(client=client, nonce=nonce)
exists = BaseBackend.get().query(
models.AuthorizationCode, client=client, nonce=nonce
)
return bool(exists)
@ -237,7 +239,9 @@ class AuthorizationCodeGrant(_AuthorizationCodeGrant):
return save_authorization_code(code, request)
def query_authorization_code(self, code, client):
item = models.AuthorizationCode.query(code=code, client=client)
item = BaseBackend.get().query(
models.AuthorizationCode, code=code, client=client
)
if item and not item[0].is_expired():
return item[0]
@ -283,7 +287,7 @@ class RefreshTokenGrant(_RefreshTokenGrant):
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
def authenticate_refresh_token(self, refresh_token):
token = models.Token.query(refresh_token=refresh_token)
token = BaseBackend.get().query(models.Token, refresh_token=refresh_token)
if token and token[0].is_refresh_token_active():
return token[0]

View file

@ -81,15 +81,15 @@ def test_special_chars_in_rdn(testclient, backend):
def test_filter(backend, foo_group, bar_group):
assert models.Group.query(display_name="foo") == [foo_group]
assert models.Group.query(display_name="bar") == [bar_group]
assert backend.query(models.Group, display_name="foo") == [foo_group]
assert backend.query(models.Group, display_name="bar") == [bar_group]
assert models.Group.query(display_name="foo") != 3
assert backend.query(models.Group, display_name="foo") != 3
assert models.Group.query(display_name=["foo"]) == [foo_group]
assert models.Group.query(display_name=["bar"]) == [bar_group]
assert backend.query(models.Group, display_name=["foo"]) == [foo_group]
assert backend.query(models.Group, display_name=["bar"]) == [bar_group]
assert set(models.Group.query(display_name=["foo", "bar"])) == {
assert set(backend.query(models.Group, display_name=["foo", "bar"])) == {
foo_group,
bar_group,
}

View file

@ -36,16 +36,16 @@ def test_model_lifecycle(testclient, backend):
)
assert not user.id
assert not models.User.query()
assert not models.User.query(id=user.id)
assert not models.User.query(id="invalid")
assert not backend.query(models.User)
assert not backend.query(models.User, id=user.id)
assert not backend.query(models.User, id="invalid")
assert not models.User.get(id=user.id)
user.save()
assert models.User.query() == [user]
assert models.User.query(id=user.id) == [user]
assert not models.User.query(id="invalid")
assert backend.query(models.User) == [user]
assert backend.query(models.User, id=user.id) == [user]
assert not backend.query(models.User, id="invalid")
assert models.User.get(id=user.id) == user
user.family_name = "new_family_name"
@ -58,7 +58,7 @@ def test_model_lifecycle(testclient, backend):
user.delete()
assert not models.User.query(id=user.id)
assert not backend.query(models.User, id=user.id)
assert not models.User.get(id=user.id)
user.delete()
@ -143,7 +143,7 @@ def test_model_indexation(testclient, backend):
def test_fuzzy_unique_attribute(user, moderator, admin, backend):
assert set(models.User.query()) == {user, moderator, admin}
assert set(backend.query(models.User)) == {user, moderator, admin}
assert set(models.User.fuzzy("Jack")) == {moderator}
assert set(models.User.fuzzy("Jack", ["formatted_name"])) == {moderator}
assert set(models.User.fuzzy("Jack", ["user_name"])) == set()
@ -157,7 +157,7 @@ def test_fuzzy_unique_attribute(user, moderator, admin, backend):
def test_fuzzy_multiple_attribute(user, moderator, admin, backend):
assert set(models.User.query()) == {user, moderator, admin}
assert set(backend.query(models.User)) == {user, moderator, admin}
assert set(models.User.fuzzy("jack@doe.com")) == {moderator}
assert set(models.User.fuzzy("jack@doe.com", ["emails"])) == {moderator}
assert set(models.User.fuzzy("jack@doe.com", ["formatted_name"])) == set()
@ -171,8 +171,8 @@ def test_fuzzy_multiple_attribute(user, moderator, admin, backend):
def test_model_references(testclient, user, foo_group, admin, bar_group, backend):
assert foo_group.members == [user]
assert user.groups == [foo_group]
assert foo_group in models.Group.query(members=user)
assert user in models.User.query(groups=foo_group)
assert foo_group in backend.query(models.Group, members=user)
assert user in backend.query(models.User, groups=foo_group)
assert user not in bar_group.members
assert bar_group not in user.groups

View file

@ -6,11 +6,11 @@ from canaille.core.populate import fake_users
def test_populate_users(testclient, backend):
runner = testclient.app.test_cli_runner()
assert len(models.User.query()) == 0
assert len(backend.query(models.User)) == 0
res = runner.invoke(cli, ["populate", "--nb", "10", "users"])
assert res.exit_code == 0, res.stdout
assert len(models.User.query()) == 10
for user in models.User.query():
assert len(backend.query(models.User)) == 10
for user in backend.query(models.User):
user.delete()
@ -18,13 +18,13 @@ def test_populate_groups(testclient, backend):
fake_users(10)
runner = testclient.app.test_cli_runner()
assert len(models.Group.query()) == 0
assert len(backend.query(models.Group)) == 0
res = runner.invoke(cli, ["populate", "--nb", "10", "groups"])
assert res.exit_code == 0, res.stdout
assert len(models.Group.query()) == 10
assert len(backend.query(models.Group)) == 10
for group in models.Group.query():
for group in backend.query(models.Group):
group.delete()
for user in models.User.query():
for user in backend.query(models.User):
user.delete()

View file

@ -4,7 +4,7 @@ from canaille.core.populate import fake_users
def test_no_group(app, backend):
assert models.Group.query() == []
assert backend.query(models.Group) == []
def test_group_list_pagination(testclient, logged_admin, foo_group):

View file

@ -12,7 +12,7 @@ def test_registration_without_email_validation(testclient, backend, foo_group):
testclient.app.config["CANAILLE"]["ENABLE_REGISTRATION"] = True
testclient.app.config["CANAILLE"]["EMAIL_CONFIRMATION"] = False
assert not models.User.query(user_name="newuser")
assert not backend.query(models.User, user_name="newuser")
res = testclient.get(url_for("core.account.registration"), status=200)
res.form["user_name"] = "newuser"
res.form["password1"] = "password"
@ -60,7 +60,7 @@ def test_registration_with_email_validation(testclient, backend, smtpd, foo_grou
text_mail = smtpd.messages[0].get_payload()[0].get_payload(decode=True).decode()
assert registration_url in text_mail
assert not models.User.query(user_name="newuser")
assert not backend.query(models.User, user_name="newuser")
with time_machine.travel("2020-01-01 02:01:00+00:00", tick=False):
res = testclient.get(registration_url, status=200)
res.form["user_name"] = "newuser"

View file

@ -13,8 +13,10 @@ from canaille.app import models
from . import client_credentials
def test_nominal_case(testclient, logged_user, client, keypair, trusted_client):
assert not models.Consent.query()
def test_nominal_case(
testclient, logged_user, client, keypair, trusted_client, backend
):
assert not backend.query(models.Consent)
res = testclient.get(
"/oauth/authorize",
@ -43,7 +45,7 @@ def test_nominal_case(testclient, logged_user, client, keypair, trusted_client):
"phone",
}
consents = models.Consent.query(client=client, subject=logged_user)
consents = backend.query(models.Consent, client=client, subject=logged_user)
assert set(consents[0].scope) == {
"openid",
"profile",
@ -112,8 +114,10 @@ def test_invalid_client(testclient, logged_user, keypair):
)
def test_redirect_uri(testclient, logged_user, client, keypair, trusted_client):
assert not models.Consent.query()
def test_redirect_uri(
testclient, logged_user, client, keypair, trusted_client, backend
):
assert not backend.query(models.Consent)
res = testclient.get(
"/oauth/authorize",
@ -134,7 +138,7 @@ def test_redirect_uri(testclient, logged_user, client, keypair, trusted_client):
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
consents = models.Consent.query(client=client, subject=logged_user)
consents = backend.query(models.Consent, client=client, subject=logged_user)
res = testclient.post(
"/oauth/token",
@ -157,8 +161,10 @@ def test_redirect_uri(testclient, logged_user, client, keypair, trusted_client):
consent.delete()
def test_preconsented_client(testclient, logged_user, client, keypair, trusted_client):
assert not models.Consent.query()
def test_preconsented_client(
testclient, logged_user, client, keypair, trusted_client, backend
):
assert not backend.query(models.Consent)
client.preconsent = True
client.save()
@ -180,7 +186,7 @@ def test_preconsented_client(testclient, logged_user, client, keypair, trusted_c
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
consents = models.Consent.query(client=client, subject=logged_user)
consents = backend.query(models.Consent, client=client, subject=logged_user)
assert not consents
res = testclient.post(
@ -214,8 +220,8 @@ def test_preconsented_client(testclient, logged_user, client, keypair, trusted_c
assert res.json["name"] == "John (johnny) Doe"
def test_logout_login(testclient, logged_user, client):
assert not models.Consent.query()
def test_logout_login(testclient, logged_user, client, backend):
assert not backend.query(models.Consent)
res = testclient.get(
"/oauth/authorize",
@ -254,7 +260,7 @@ def test_logout_login(testclient, logged_user, client):
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
consents = models.Consent.query(client=client, subject=logged_user)
consents = backend.query(models.Consent, client=client, subject=logged_user)
assert "profile" in consents[0].scope
res = testclient.post(
@ -285,8 +291,8 @@ def test_logout_login(testclient, logged_user, client):
consent.delete()
def test_deny(testclient, logged_user, client):
assert not models.Consent.query()
def test_deny(testclient, logged_user, client, backend):
assert not backend.query(models.Consent)
res = testclient.get(
"/oauth/authorize",
@ -305,11 +311,11 @@ def test_deny(testclient, logged_user, client):
error = params["error"][0]
assert error == "access_denied"
assert not models.Consent.query()
assert not backend.query(models.Consent)
def test_code_challenge(testclient, logged_user, client):
assert not models.Consent.query()
def test_code_challenge(testclient, logged_user, client, backend):
assert not backend.query(models.Consent)
client.token_endpoint_auth_method = "none"
client.save()
@ -338,7 +344,7 @@ def test_code_challenge(testclient, logged_user, client):
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
consents = models.Consent.query(client=client, subject=logged_user)
consents = backend.query(models.Consent, client=client, subject=logged_user)
assert "profile" in consents[0].scope
res = testclient.post(
@ -373,8 +379,8 @@ def test_code_challenge(testclient, logged_user, client):
consent.delete()
def test_consent_already_given(testclient, logged_user, client):
assert not models.Consent.query()
def test_consent_already_given(testclient, logged_user, client, backend):
assert not backend.query(models.Consent)
res = testclient.get(
"/oauth/authorize",
@ -395,7 +401,7 @@ def test_consent_already_given(testclient, logged_user, client):
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
consents = models.Consent.query(client=client, subject=logged_user)
consents = backend.query(models.Consent, client=client, subject=logged_user)
assert "profile" in consents[0].scope
res = testclient.post(
@ -430,9 +436,9 @@ def test_consent_already_given(testclient, logged_user, client):
def test_when_consent_already_given_but_for_a_smaller_scope(
testclient, logged_user, client
testclient, logged_user, client, backend
):
assert not models.Consent.query()
assert not backend.query(models.Consent)
res = testclient.get(
"/oauth/authorize",
@ -453,7 +459,7 @@ def test_when_consent_already_given_but_for_a_smaller_scope(
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
consents = models.Consent.query(client=client, subject=logged_user)
consents = backend.query(models.Consent, client=client, subject=logged_user)
assert "profile" in consents[0].scope
assert "groups" not in consents[0].scope
@ -489,7 +495,7 @@ def test_when_consent_already_given_but_for_a_smaller_scope(
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
consents = models.Consent.query(client=client, subject=logged_user)
consents = backend.query(models.Consent, client=client, subject=logged_user)
assert "profile" in consents[0].scope
assert "groups" in consents[0].scope
@ -535,8 +541,8 @@ def test_nonce_required_in_oidc_requests(testclient, logged_user, client):
assert res.json.get("error") == "invalid_request"
def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client):
assert not models.Consent.query()
def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client, backend):
assert not backend.query(models.Consent)
testclient.app.config["CANAILLE_OIDC"]["REQUIRE_NONCE"] = False
res = testclient.get(
@ -552,12 +558,12 @@ def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client):
res = res.form.submit(name="answer", value="accept", status=302)
assert res.location.startswith(client.redirect_uris[0])
for consent in models.Consent.query():
for consent in backend.query(models.Consent):
consent.delete()
def test_request_scope_too_large(testclient, logged_user, keypair, client):
assert not models.Consent.query()
def test_request_scope_too_large(testclient, logged_user, keypair, client, backend):
assert not backend.query(models.Consent)
client.scope = ["openid", "profile", "groups"]
client.save()
@ -582,7 +588,7 @@ def test_request_scope_too_large(testclient, logged_user, keypair, client):
"profile",
}
consents = models.Consent.query(client=client, subject=logged_user)
consents = backend.query(models.Consent, client=client, subject=logged_user)
assert set(consents[0].scope) == {
"openid",
"profile",

View file

@ -95,7 +95,7 @@ def test_someone_else_consent_restoration(
def test_oidc_authorization_after_revokation(
testclient, logged_user, client, keypair, consent
testclient, logged_user, client, keypair, consent, backend
):
consent.revoke()
@ -114,7 +114,7 @@ def test_oidc_authorization_after_revokation(
res = res.form.submit(name="answer", value="accept", status=302)
consents = models.Consent.query(client=client, subject=logged_user)
consents = backend.query(models.Consent, client=client, subject=logged_user)
consent.reload()
assert consents[0] == consent
assert not consent.revoked

View file

@ -5,8 +5,8 @@ from canaille.app import models
# forms.
def test_fieldlist_add(testclient, logged_admin):
assert not models.Client.query()
def test_fieldlist_add(testclient, logged_admin, backend):
assert not backend.query(models.Client)
res = testclient.get("/admin/client/add")
assert "redirect_uris-1" not in res.form.fields
@ -23,7 +23,7 @@ def test_fieldlist_add(testclient, logged_admin):
res.form[k].force_value(v)
res = res.form.submit(status=200, name="fieldlist_add", value="redirect_uris-0")
assert not models.Client.query()
assert not backend.query(models.Client)
data["redirect_uris-1"] = "https://foo.bar/callback2"
for k, v in data.items():
@ -43,8 +43,8 @@ def test_fieldlist_add(testclient, logged_admin):
client.delete()
def test_fieldlist_delete(testclient, logged_admin):
assert not models.Client.query()
def test_fieldlist_delete(testclient, logged_admin, backend):
assert not backend.query(models.Client)
res = testclient.get("/admin/client/add")
data = {
@ -61,7 +61,7 @@ def test_fieldlist_delete(testclient, logged_admin):
res.form["redirect_uris-1"] = "https://foo.bar/callback2"
res = res.form.submit(status=200, name="fieldlist_remove", value="redirect_uris-1")
assert not models.Client.query()
assert not backend.query(models.Client)
assert "redirect_uris-1" not in res.form.fields
res = res.form.submit(status=302, name="action", value="edit")
@ -92,8 +92,8 @@ def test_fieldlist_add_invalid_field(testclient, logged_admin):
testclient.post("/admin/client/add", data, status=400)
def test_fieldlist_delete_invalid_field(testclient, logged_admin):
assert not models.Client.query()
def test_fieldlist_delete_invalid_field(testclient, logged_admin, backend):
assert not backend.query(models.Client)
res = testclient.get("/admin/client/add")
data = {

View file

@ -7,8 +7,8 @@ from canaille.app import models
from . import client_credentials
def test_refresh_token(testclient, logged_user, client):
assert not models.Consent.query()
def test_refresh_token(testclient, logged_user, client, backend):
assert not backend.query(models.Consent)
res = testclient.get(
"/oauth/authorize",
@ -27,7 +27,7 @@ def test_refresh_token(testclient, logged_user, client):
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
consents = models.Consent.query(client=client, subject=logged_user)
consents = backend.query(models.Consent, client=client, subject=logged_user)
assert "profile" in consents[0].scope
res = testclient.post(

View file

@ -9,7 +9,9 @@ from canaille.oidc.oauth import setup_oauth
from . import client_credentials
def test_token_default_expiration_date(testclient, logged_user, client, keypair):
def test_token_default_expiration_date(
testclient, logged_user, client, keypair, backend
):
res = testclient.get(
"/oauth/authorize",
params=dict(
@ -52,12 +54,14 @@ def test_token_default_expiration_date(testclient, logged_user, client, keypair)
claims = jwt.decode(id_token, keypair[1])
assert claims["exp"] - claims["iat"] == 3600
consents = models.Consent.query(client=client, subject=logged_user)
consents = backend.query(models.Consent, client=client, subject=logged_user)
for consent in consents:
consent.delete()
def test_token_custom_expiration_date(testclient, logged_user, client, keypair):
def test_token_custom_expiration_date(
testclient, logged_user, client, keypair, backend
):
testclient.app.config["OAUTH2_TOKEN_EXPIRES_IN"] = {
"authorization_code": 1000,
"implicit": 2000,
@ -110,6 +114,6 @@ def test_token_custom_expiration_date(testclient, logged_user, client, keypair):
claims = jwt.decode(id_token, keypair[1])
assert claims["exp"] - claims["iat"] == 6000
consents = models.Consent.query(client=client, subject=logged_user)
consents = backend.query(models.Consent, client=client, subject=logged_user)
for consent in consents:
consent.delete()