refactor: move BackendModel.query to Backend.query

This commit is contained in:
Éloi Rivard 2024-04-10 15:44:11 +02:00
parent 93fa708b1c
commit 8425b2a3b8
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184
29 changed files with 284 additions and 247 deletions

View file

@ -15,6 +15,7 @@ from canaille.app.i18n import DEFAULT_LANGUAGE_CODE
from canaille.app.i18n import gettext as _ from canaille.app.i18n import gettext as _
from canaille.app.i18n import locale_selector from canaille.app.i18n import locale_selector
from canaille.app.i18n import timezone_selector from canaille.app.i18n import timezone_selector
from canaille.backends import BaseBackend
from . import validate_uri from . import validate_uri
from .flask import request_is_htmx from .flask import request_is_htmx
@ -188,7 +189,7 @@ class TableForm(I18NFormMixin, FlaskForm):
if self.query.data: if self.query.data:
self.items = cls.fuzzy(self.query.data, fields, **filter) self.items = cls.fuzzy(self.query.data, fields, **filter)
else: else:
self.items = cls.query(**filter) self.items = BaseBackend.get().query(cls, **filter)
self.page_size = page_size self.page_size = page_size
self.nb_items = len(self.items) self.nb_items = len(self.items)

View file

@ -54,6 +54,25 @@ class BaseBackend:
""" """
raise NotImplementedError() raise NotImplementedError()
def query(self, model, **kwargs):
"""
Perform a query on the database and return a collection of instances.
Parameters can be any valid attribute with the expected value:
>>> backend.query(User, first_name="George")
If several arguments are passed, the methods only returns the model
instances that return matches all the argument values:
>>> backend.query(User, first_name="George", last_name="Abitbol")
If the argument value is a collection, the methods will return the
models that matches any of the values:
>>> backend.query(User, first_name=["George", "Jane"])
"""
raise NotImplementedError()
def check_user_password(self, user, password: str) -> bool: def check_user_password(self, user, password: str) -> bool:
"""Check if the password matches the user password in the database.""" """Check if the password matches the user password in the database."""
raise NotImplementedError() raise NotImplementedError()

View file

@ -16,6 +16,7 @@ from canaille.app.i18n import gettext as _
from canaille.backends import BaseBackend from canaille.backends import BaseBackend
from .utils import listify from .utils import listify
from .utils import python_attrs_to_ldap
@contextmanager @contextmanager
@ -243,13 +244,64 @@ class Backend(BaseBackend):
return result, message return result, message
def set_user_password(self, user, password): def set_user_password(self, user, password):
conn = Backend.get().connection conn = self.connection
conn.passwd_s( conn.passwd_s(
user.dn, user.dn,
None, None,
password.encode("utf-8"), password.encode("utf-8"),
) )
def query(self, model, dn=None, filter=None, **kwargs):
from .ldapobjectquery import LDAPObjectQuery
base = dn
if dn is None:
base = f"{model.base},{model.root_dn}"
elif "=" not in base:
base = ldap.dn.escape_dn_chars(base)
base = f"{model.rdn_attribute}={base},{model.base},{model.root_dn}"
class_filter = (
"".join([f"(objectClass={oc})" for oc in model.ldap_object_class])
if getattr(model, "ldap_object_class")
else ""
)
if class_filter:
class_filter = f"(|{class_filter})"
arg_filter = ""
ldap_args = python_attrs_to_ldap(
{
model.python_attribute_to_ldap(name): values
for name, values in kwargs.items()
if values is not None
},
encode=False,
)
for key, value in ldap_args.items():
if len(value) == 1:
escaped_value = ldap.filter.escape_filter_chars(value[0])
arg_filter += f"({key}={escaped_value})"
else:
values = [ldap.filter.escape_filter_chars(v) for v in value]
arg_filter += (
"(|" + "".join([f"({key}={value})" for value in values]) + ")"
)
if not filter:
filter = ""
ldapfilter = f"(&{class_filter}{arg_filter}{filter})"
base = base or f"{model.base},{model.root_dn}"
try:
result = self.connection.search_s(
base, ldap.SCOPE_SUBTREE, ldapfilter or None, ["+", "*"]
)
except ldap.NO_SUCH_OBJECT:
result = []
return LDAPObjectQuery(model, result)
def setup_ldap_models(config): def setup_ldap_models(config):
from canaille.app import models from canaille.app import models

View file

@ -4,42 +4,15 @@ import ldap.dn
import ldap.filter import ldap.filter
from ldap.controls.readentry import PostReadControl from ldap.controls.readentry import PostReadControl
from canaille.backends import BaseBackend
from canaille.backends.models import BackendModel from canaille.backends.models import BackendModel
from .backend import Backend from .backend import Backend
from .ldapobjectquery import LDAPObjectQuery from .utils import attribute_ldap_syntax
from .utils import cardinalize_attribute from .utils import cardinalize_attribute
from .utils import ldap_to_python from .utils import ldap_to_python
from .utils import listify from .utils import listify
from .utils import python_to_ldap from .utils import python_attrs_to_ldap
def python_attrs_to_ldap(attrs, encode=True, null_allowed=True):
formatted_attrs = {
name: [
python_to_ldap(value, attribute_ldap_syntax(name), encode=encode)
for value in listify(values)
]
for name, values in attrs.items()
}
if not null_allowed:
formatted_attrs = {
key: [value for value in values if value]
for key, values in formatted_attrs.items()
if values
}
return formatted_attrs
def attribute_ldap_syntax(attribute_name):
ldap_attrs = LDAPObject.ldap_object_attributes()
if attribute_name not in ldap_attrs:
return None
if ldap_attrs[attribute_name].syntax:
return ldap_attrs[attribute_name].syntax
return attribute_ldap_syntax(ldap_attrs[attribute_name].sup[0])
class LDAPObjectMetaclass(type): class LDAPObjectMetaclass(type):
@ -256,7 +229,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
@classmethod @classmethod
def get(cls, identifier=None, /, **kwargs): def get(cls, identifier=None, /, **kwargs):
try: try:
return cls.query(identifier, **kwargs)[0] return BaseBackend.get().query(cls, identifier, **kwargs)[0]
except (IndexError, ldap.NO_SUCH_OBJECT): except (IndexError, ldap.NO_SUCH_OBJECT):
if identifier and cls.base: if identifier and cls.base:
return ( return (
@ -267,58 +240,6 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
return None return None
@classmethod
def query(cls, dn=None, filter=None, **kwargs):
conn = Backend.get().connection
base = dn
if dn is None:
base = f"{cls.base},{cls.root_dn}"
elif "=" not in base:
base = ldap.dn.escape_dn_chars(base)
base = f"{cls.rdn_attribute}={base},{cls.base},{cls.root_dn}"
class_filter = (
"".join([f"(objectClass={oc})" for oc in cls.ldap_object_class])
if getattr(cls, "ldap_object_class")
else ""
)
if class_filter:
class_filter = f"(|{class_filter})"
arg_filter = ""
ldap_args = python_attrs_to_ldap(
{
cls.python_attribute_to_ldap(name): values
for name, values in kwargs.items()
if values is not None
},
encode=False,
)
for key, value in ldap_args.items():
if len(value) == 1:
escaped_value = ldap.filter.escape_filter_chars(value[0])
arg_filter += f"({key}={escaped_value})"
else:
values = [ldap.filter.escape_filter_chars(v) for v in value]
arg_filter += (
"(|" + "".join([f"({key}={value})" for value in values]) + ")"
)
if not filter:
filter = ""
ldapfilter = f"(&{class_filter}{arg_filter}{filter})"
base = base or f"{cls.base},{cls.root_dn}"
try:
result = conn.search_s(
base, ldap.SCOPE_SUBTREE, ldapfilter or None, ["+", "*"]
)
except ldap.NO_SUCH_OBJECT:
result = []
return LDAPObjectQuery(cls, result)
@classmethod @classmethod
def fuzzy(cls, query, attributes=None, **kwargs): def fuzzy(cls, query, attributes=None, **kwargs):
query = ldap.filter.escape_filter_chars(query) query = ldap.filter.escape_filter_chars(query)
@ -327,7 +248,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
filter = ( filter = (
"(|" + "".join(f"({attribute}=*{query}*)" for attribute in attributes) + ")" "(|" + "".join(f"({attribute}=*{query}*)" for attribute in attributes) + ")"
) )
return cls.query(filter=filter, **kwargs) return BaseBackend.get().query(cls, filter=filter, **kwargs)
@classmethod @classmethod
def update_ldap_attributes(cls): def update_ldap_attributes(cls):

View file

@ -95,3 +95,33 @@ def cardinalize_attribute(python_unique, value):
return value[0] return value[0]
return [v for v in value if v is not None] return [v for v in value if v is not None]
def python_attrs_to_ldap(attrs, encode=True, null_allowed=True):
formatted_attrs = {
name: [
python_to_ldap(value, attribute_ldap_syntax(name), encode=encode)
for value in listify(values)
]
for name, values in attrs.items()
}
if not null_allowed:
formatted_attrs = {
key: [value for value in values if value]
for key, values in formatted_attrs.items()
if values
}
return formatted_attrs
def attribute_ldap_syntax(attribute_name):
from .ldapobject import LDAPObject
ldap_attrs = LDAPObject.ldap_object_attributes()
if attribute_name not in ldap_attrs:
return None
if ldap_attrs[attribute_name].syntax:
return ldap_attrs[attribute_name].syntax
return attribute_ldap_syntax(ldap_attrs[attribute_name].sup[0])

View file

@ -40,3 +40,28 @@ class Backend(BaseBackend):
def set_user_password(self, user, password): def set_user_password(self, user, password):
user.password = password user.password = password
user.save() user.save()
def query(self, model, **kwargs):
# if there is no filter, return all models
if not kwargs:
states = model.index().values()
return [model(**state) for state in states]
# get the ids from the attribute indexes
ids = {
id
for attribute, values in kwargs.items()
for value in model.serialize(model.listify(values))
for id in model.attribute_index(attribute).get(value, [])
}
# get the states from the ids
states = [model.index()[id] for id in ids]
# initialize instances from the states
instances = [model(**state) for state in states]
for instance in instances:
# TODO: maybe find a way to not initialize the cache in the first place?
instance._cache = {}
return instances

View file

@ -6,6 +6,7 @@ import uuid
import canaille.core.models import canaille.core.models
import canaille.oidc.models import canaille.oidc.models
from canaille.app import models from canaille.app import models
from canaille.backends import BaseBackend
from canaille.backends.models import BackendModel from canaille.backends.models import BackendModel
@ -25,31 +26,6 @@ class MemoryModel(BackendModel):
def __repr__(self): def __repr__(self):
return f"<{self.__class__.__name__} id={self.id}>" return f"<{self.__class__.__name__} id={self.id}>"
@classmethod
def query(cls, **kwargs):
# if there is no filter, return all models
if not kwargs:
states = cls.index().values()
return [cls(**state) for state in states]
# get the ids from the attribute indexes
ids = {
id
for attribute, values in kwargs.items()
for value in cls.serialize(cls.listify(values))
for id in cls.attribute_index(attribute).get(value, [])
}
# get the states from the ids
states = [cls.index()[id] for id in ids]
# initialize instances from the states
instances = [cls(**state) for state in states]
for instance in instances:
# TODO: maybe find a way to not initialize the cache in the first place?
instance._cache = {}
return instances
@classmethod @classmethod
def index(cls, class_name=None): def index(cls, class_name=None):
return MemoryModel.indexes.setdefault(class_name or cls.__name__, {}) return MemoryModel.indexes.setdefault(class_name or cls.__name__, {})
@ -63,7 +39,7 @@ class MemoryModel(BackendModel):
@classmethod @classmethod
def fuzzy(cls, query, attributes=None, **kwargs): def fuzzy(cls, query, attributes=None, **kwargs):
attributes = attributes or cls.attributes attributes = attributes or cls.attributes
instances = cls.query(**kwargs) instances = BaseBackend.get().query(cls, **kwargs)
return [ return [
instance instance
@ -85,7 +61,7 @@ class MemoryModel(BackendModel):
or None or None
) )
results = cls.query(**kwargs) results = BaseBackend.get().query(cls, **kwargs)
return results[0] if results else None return results[0] if results else None
@classmethod @classmethod

View file

@ -87,37 +87,16 @@ class BackendModel:
implemented for every model and for every backend. implemented for every model and for every backend.
""" """
@classmethod
def query(cls, **kwargs):
"""Perform a query on the database and return a collection of
instances.
Parameters can be any valid attribute with the expected value:
>>> User.query(first_name="George")
If several arguments are passed, the methods only returns the model
instances that return matches all the argument values:
>>> User.query(first_name="George", last_name="Abitbol")
If the argument value is a collection, the methods will return the
models that matches any of the values:
>>> User.query(first_name=["George", "Jane"])
"""
raise NotImplementedError()
@classmethod @classmethod
def fuzzy(cls, query, attributes=None, **kwargs): def fuzzy(cls, query, attributes=None, **kwargs):
"""Works like :meth:`~canaille.backends.models.BackendModel.query` but """Works like :meth:`~canaille.backends.BaseBackend.query` but
attribute values loosely be matched.""" attribute values loosely be matched."""
raise NotImplementedError() raise NotImplementedError()
@classmethod @classmethod
def get(cls, identifier=None, **kwargs): def get(cls, identifier=None, **kwargs):
"""Works like :meth:`~canaille.backends.models.BackendModel.query` but """Works like :meth:`~canaille.backends.BaseBackend.query` but return
return only one element or :py:data:`None` if no item is matching.""" only one element or :py:data:`None` if no item is matching."""
raise NotImplementedError() raise NotImplementedError()
def save(self): def save(self):

View file

@ -1,4 +1,5 @@
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_base
@ -65,3 +66,15 @@ class Backend(BaseBackend):
def set_user_password(self, user, password): def set_user_password(self, user, password):
user.password = password user.password = password
user.save() user.save()
def query(self, model, **kwargs):
filter = [
model.attribute_filter(attribute_name, expected_value)
for attribute_name, expected_value in kwargs.items()
]
return (
Backend.get()
.db_session.execute(select(model).filter(*filter))
.scalars()
.all()
)

View file

@ -36,19 +36,6 @@ class SqlAlchemyModel(BackendModel):
f"<{self.__class__.__name__} {self.identifier_attribute}={self.identifier}>" f"<{self.__class__.__name__} {self.identifier_attribute}={self.identifier}>"
) )
@classmethod
def query(cls, **kwargs):
filter = [
cls.attribute_filter(attribute_name, expected_value)
for attribute_name, expected_value in kwargs.items()
]
return (
Backend.get()
.db_session.execute(select(cls).filter(*filter))
.scalars()
.all()
)
@classmethod @classmethod
def fuzzy(cls, query, attributes=None, **kwargs): def fuzzy(cls, query, attributes=None, **kwargs):
attributes = attributes or cls.attributes attributes = attributes or cls.attributes

View file

@ -86,7 +86,7 @@ def join():
form = JoinForm(request.form or None) form = JoinForm(request.form or None)
if request.form and form.validate(): if request.form and form.validate():
if models.User.query(emails=form.email.data): if BaseBackend.get().query(models.User, emails=form.email.data):
flash( flash(
_( _(
"You will receive soon an email to continue the registration process." "You will receive soon an email to continue the registration process."
@ -295,7 +295,10 @@ def registration(data=None, hash=None):
if "groups" not in form and payload and payload.groups: if "groups" not in form and payload and payload.groups:
form["groups"] = wtforms.SelectMultipleField( form["groups"] = wtforms.SelectMultipleField(
_("Groups"), _("Groups"),
choices=[(group, group.display_name) for group in models.Group.query()], choices=[
(group, group.display_name)
for group in BaseBackend.get().query(models.Group)
],
coerce=IDToModel("Group"), coerce=IDToModel("Group"),
) )
set_readonly(form["groups"]) set_readonly(form["groups"])
@ -388,7 +391,7 @@ def email_confirmation(data, hash):
) )
return redirect(url_for("core.account.index")) return redirect(url_for("core.account.index"))
if models.User.query(emails=confirmation_obj.email): if BaseBackend.get().query(models.User, emails=confirmation_obj.email):
flash( flash(
_("This address email is already associated with another account."), _("This address email is already associated with another account."),
"error", "error",

View file

@ -312,7 +312,10 @@ PROFILE_FORM_FIELDS = dict(
groups=wtforms.SelectMultipleField( groups=wtforms.SelectMultipleField(
_("Groups"), _("Groups"),
default=[], default=[],
choices=lambda: [(group, group.display_name) for group in models.Group.query()], choices=lambda: [
(group, group.display_name)
for group in BaseBackend.get().query(models.Group)
],
render_kw={"placeholder": _("users, admins …")}, render_kw={"placeholder": _("users, admins …")},
coerce=IDToModel("Group"), coerce=IDToModel("Group"),
validators=[non_empty_groups], validators=[non_empty_groups],
@ -333,7 +336,7 @@ def build_profile_form(write_field_names, readonly_field_names, user=None):
if PROFILE_FORM_FIELDS.get(name) if PROFILE_FORM_FIELDS.get(name)
} }
if "groups" in fields and not models.Group.query(): if "groups" in fields and not BaseBackend.get().query(models.Group):
del fields["groups"] del fields["groups"]
if current_app.backend.get().has_account_lockability(): # pragma: no branch if current_app.backend.get().has_account_lockability(): # pragma: no branch
@ -436,7 +439,10 @@ class InvitationForm(Form):
) )
groups = wtforms.SelectMultipleField( groups = wtforms.SelectMultipleField(
_("Groups"), _("Groups"),
choices=lambda: [(group, group.display_name) for group in models.Group.query()], choices=lambda: [
(group, group.display_name)
for group in BaseBackend.get().query(models.Group)
],
render_kw={}, render_kw={},
coerce=IDToModel("Group"), coerce=IDToModel("Group"),
) )

View file

@ -5,6 +5,7 @@ from faker.config import AVAILABLE_LOCALES
from canaille.app import models from canaille.app import models
from canaille.app.i18n import available_language_codes from canaille.app.i18n import available_language_codes
from canaille.backends import BaseBackend
def fake_users(nb=1): def fake_users(nb=1):
@ -47,7 +48,7 @@ def fake_users(nb=1):
def fake_groups(nb=1, nb_users_max=1): def fake_groups(nb=1, nb_users_max=1):
users = models.User.query() users = BaseBackend.get().query(models.User)
groups = list() groups = list()
fake = faker.Faker(["en_US"]) fake = faker.Faker(["en_US"])
for _ in range(nb): for _ in range(nb):

View file

@ -3,6 +3,7 @@ from flask.cli import with_appcontext
from canaille.app import models from canaille.app import models
from canaille.app.commands import with_backendcontext from canaille.app.commands import with_backendcontext
from canaille.backends import BaseBackend
@click.command() @click.command()
@ -10,11 +11,11 @@ from canaille.app.commands import with_backendcontext
@with_backendcontext @with_backendcontext
def clean(): def clean():
"""Remove expired tokens and authorization codes.""" """Remove expired tokens and authorization codes."""
for t in models.Token.query(): for t in BaseBackend.get().query(models.Token):
if t.is_expired(): if t.is_expired():
t.delete() t.delete()
for a in models.AuthorizationCode.query(): for a in BaseBackend.get().query(models.AuthorizationCode):
if a.is_expired(): if a.is_expired():
a.delete() a.delete()

View file

@ -10,6 +10,7 @@ from canaille.app import models
from canaille.app.flask import user_needed from canaille.app.flask import user_needed
from canaille.app.i18n import gettext as _ from canaille.app.i18n import gettext as _
from canaille.app.themes import render_template from canaille.app.themes import render_template
from canaille.backends import BaseBackend
from ..utils import SCOPE_DETAILS from ..utils import SCOPE_DETAILS
@ -19,13 +20,13 @@ bp = Blueprint("consents", __name__, url_prefix="/consent")
@bp.route("/") @bp.route("/")
@user_needed() @user_needed()
def consents(user): def consents(user):
consents = models.Consent.query(subject=user) consents = BaseBackend.get().query(models.Consent, subject=user)
clients = {t.client for t in consents} clients = {t.client for t in consents}
nb_consents = len(consents) nb_consents = len(consents)
nb_preconsents = sum( nb_preconsents = sum(
1 1
for client in models.Client.query() for client in BaseBackend.get().query(models.Client)
if client.preconsent and client not in clients if client.preconsent and client not in clients
) )
@ -43,11 +44,11 @@ def consents(user):
@bp.route("/pre-consents") @bp.route("/pre-consents")
@user_needed() @user_needed()
def pre_consents(user): def pre_consents(user):
consents = models.Consent.query(subject=user) consents = BaseBackend.get().query(models.Consent, subject=user)
clients = {t.client for t in consents} clients = {t.client for t in consents}
preconsented = [ preconsented = [
client client
for client in models.Client.query() for client in BaseBackend.get().query(models.Client)
if client.preconsent and client not in clients if client.preconsent and client not in clients
] ]

View file

@ -7,6 +7,7 @@ from canaille.app.forms import email_validator
from canaille.app.forms import is_uri from canaille.app.forms import is_uri
from canaille.app.forms import unique_values from canaille.app.forms import unique_values
from canaille.app.i18n import lazy_gettext as _ from canaille.app.i18n import lazy_gettext as _
from canaille.backends import BaseBackend
class AuthorizeForm(Form): class AuthorizeForm(Form):
@ -18,7 +19,10 @@ class LogoutForm(Form):
def client_audiences(): def client_audiences():
return [(client, client.client_name) for client in models.Client.query()] return [
(client, client.client_name)
for client in BaseBackend.get().query(models.Client)
]
class ClientAddForm(Form): class ClientAddForm(Form):

View file

@ -23,6 +23,7 @@ from canaille.app.flask import logout_user
from canaille.app.flask import set_parameter_in_url_query from canaille.app.flask import set_parameter_in_url_query
from canaille.app.i18n import gettext as _ from canaille.app.i18n import gettext as _
from canaille.app.themes import render_template from canaille.app.themes import render_template
from canaille.backends import BaseBackend
from ..oauth import ClientConfigurationEndpoint from ..oauth import ClientConfigurationEndpoint
from ..oauth import ClientRegistrationEndpoint from ..oauth import ClientRegistrationEndpoint
@ -109,7 +110,8 @@ def authorize_login(user):
def authorize_consent(client, user): def authorize_consent(client, user):
requested_scopes = request.args.get("scope", "").split(" ") requested_scopes = request.args.get("scope", "").split(" ")
allowed_scopes = client.get_allowed_scope(requested_scopes).split(" ") allowed_scopes = client.get_allowed_scope(requested_scopes).split(" ")
consents = models.Consent.query( consents = BaseBackend.get().query(
models.Consent,
client=client, client=client,
subject=user, subject=user,
) )

View file

@ -8,6 +8,7 @@ from authlib.oauth2.rfc6749 import TokenMixin
from authlib.oauth2.rfc6749 import util from authlib.oauth2.rfc6749 import util
from canaille.app import models from canaille.app import models
from canaille.backends import BaseBackend
from .basemodels import AuthorizationCode as BaseAuthorizationCode from .basemodels import AuthorizationCode as BaseAuthorizationCode
from .basemodels import Client as BaseClient from .basemodels import Client as BaseClient
@ -95,13 +96,13 @@ class Client(BaseClient, ClientMixin):
return metadata return metadata
def delete(self): def delete(self):
for consent in models.Consent.query(client=self): for consent in BaseBackend.get().query(models.Consent, client=self):
consent.delete() consent.delete()
for code in models.AuthorizationCode.query(client=self): for code in BaseBackend.get().query(models.AuthorizationCode, client=self):
code.delete() code.delete()
for token in models.Token.query(client=self): for token in BaseBackend.get().query(models.Token, client=self):
token.delete() token.delete()
super().delete() super().delete()
@ -185,7 +186,8 @@ class Consent(BaseConsent):
self.revokation_date = datetime.datetime.now(datetime.timezone.utc) self.revokation_date = datetime.datetime.now(datetime.timezone.utc)
self.save() self.save()
tokens = models.Token.query( tokens = BaseBackend.get().query(
models.Token,
client=self.client, client=self.client,
subject=self.subject, subject=self.subject,
) )

View file

@ -112,7 +112,9 @@ def openid_configuration():
def exists_nonce(nonce, req): def exists_nonce(nonce, req):
client = models.Client.get(id=req.client_id) client = models.Client.get(id=req.client_id)
exists = models.AuthorizationCode.query(client=client, nonce=nonce) exists = BaseBackend.get().query(
models.AuthorizationCode, client=client, nonce=nonce
)
return bool(exists) return bool(exists)
@ -237,7 +239,9 @@ class AuthorizationCodeGrant(_AuthorizationCodeGrant):
return save_authorization_code(code, request) return save_authorization_code(code, request)
def query_authorization_code(self, code, client): def query_authorization_code(self, code, client):
item = models.AuthorizationCode.query(code=code, client=client) item = BaseBackend.get().query(
models.AuthorizationCode, code=code, client=client
)
if item and not item[0].is_expired(): if item and not item[0].is_expired():
return item[0] return item[0]
@ -283,7 +287,7 @@ class RefreshTokenGrant(_RefreshTokenGrant):
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"] TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
def authenticate_refresh_token(self, refresh_token): def authenticate_refresh_token(self, refresh_token):
token = models.Token.query(refresh_token=refresh_token) token = BaseBackend.get().query(models.Token, refresh_token=refresh_token)
if token and token[0].is_refresh_token_active(): if token and token[0].is_refresh_token_active():
return token[0] return token[0]

View file

@ -81,15 +81,15 @@ def test_special_chars_in_rdn(testclient, backend):
def test_filter(backend, foo_group, bar_group): def test_filter(backend, foo_group, bar_group):
assert models.Group.query(display_name="foo") == [foo_group] assert backend.query(models.Group, display_name="foo") == [foo_group]
assert models.Group.query(display_name="bar") == [bar_group] assert backend.query(models.Group, display_name="bar") == [bar_group]
assert models.Group.query(display_name="foo") != 3 assert backend.query(models.Group, display_name="foo") != 3
assert models.Group.query(display_name=["foo"]) == [foo_group] assert backend.query(models.Group, display_name=["foo"]) == [foo_group]
assert models.Group.query(display_name=["bar"]) == [bar_group] assert backend.query(models.Group, display_name=["bar"]) == [bar_group]
assert set(models.Group.query(display_name=["foo", "bar"])) == { assert set(backend.query(models.Group, display_name=["foo", "bar"])) == {
foo_group, foo_group,
bar_group, bar_group,
} }

View file

@ -36,16 +36,16 @@ def test_model_lifecycle(testclient, backend):
) )
assert not user.id assert not user.id
assert not models.User.query() assert not backend.query(models.User)
assert not models.User.query(id=user.id) assert not backend.query(models.User, id=user.id)
assert not models.User.query(id="invalid") assert not backend.query(models.User, id="invalid")
assert not models.User.get(id=user.id) assert not models.User.get(id=user.id)
user.save() user.save()
assert models.User.query() == [user] assert backend.query(models.User) == [user]
assert models.User.query(id=user.id) == [user] assert backend.query(models.User, id=user.id) == [user]
assert not models.User.query(id="invalid") assert not backend.query(models.User, id="invalid")
assert models.User.get(id=user.id) == user assert models.User.get(id=user.id) == user
user.family_name = "new_family_name" user.family_name = "new_family_name"
@ -58,7 +58,7 @@ def test_model_lifecycle(testclient, backend):
user.delete() user.delete()
assert not models.User.query(id=user.id) assert not backend.query(models.User, id=user.id)
assert not models.User.get(id=user.id) assert not models.User.get(id=user.id)
user.delete() user.delete()
@ -143,7 +143,7 @@ def test_model_indexation(testclient, backend):
def test_fuzzy_unique_attribute(user, moderator, admin, backend): def test_fuzzy_unique_attribute(user, moderator, admin, backend):
assert set(models.User.query()) == {user, moderator, admin} assert set(backend.query(models.User)) == {user, moderator, admin}
assert set(models.User.fuzzy("Jack")) == {moderator} assert set(models.User.fuzzy("Jack")) == {moderator}
assert set(models.User.fuzzy("Jack", ["formatted_name"])) == {moderator} assert set(models.User.fuzzy("Jack", ["formatted_name"])) == {moderator}
assert set(models.User.fuzzy("Jack", ["user_name"])) == set() assert set(models.User.fuzzy("Jack", ["user_name"])) == set()
@ -157,7 +157,7 @@ def test_fuzzy_unique_attribute(user, moderator, admin, backend):
def test_fuzzy_multiple_attribute(user, moderator, admin, backend): def test_fuzzy_multiple_attribute(user, moderator, admin, backend):
assert set(models.User.query()) == {user, moderator, admin} assert set(backend.query(models.User)) == {user, moderator, admin}
assert set(models.User.fuzzy("jack@doe.com")) == {moderator} assert set(models.User.fuzzy("jack@doe.com")) == {moderator}
assert set(models.User.fuzzy("jack@doe.com", ["emails"])) == {moderator} assert set(models.User.fuzzy("jack@doe.com", ["emails"])) == {moderator}
assert set(models.User.fuzzy("jack@doe.com", ["formatted_name"])) == set() assert set(models.User.fuzzy("jack@doe.com", ["formatted_name"])) == set()
@ -171,8 +171,8 @@ def test_fuzzy_multiple_attribute(user, moderator, admin, backend):
def test_model_references(testclient, user, foo_group, admin, bar_group, backend): def test_model_references(testclient, user, foo_group, admin, bar_group, backend):
assert foo_group.members == [user] assert foo_group.members == [user]
assert user.groups == [foo_group] assert user.groups == [foo_group]
assert foo_group in models.Group.query(members=user) assert foo_group in backend.query(models.Group, members=user)
assert user in models.User.query(groups=foo_group) assert user in backend.query(models.User, groups=foo_group)
assert user not in bar_group.members assert user not in bar_group.members
assert bar_group not in user.groups assert bar_group not in user.groups

View file

@ -6,11 +6,11 @@ from canaille.core.populate import fake_users
def test_populate_users(testclient, backend): def test_populate_users(testclient, backend):
runner = testclient.app.test_cli_runner() runner = testclient.app.test_cli_runner()
assert len(models.User.query()) == 0 assert len(backend.query(models.User)) == 0
res = runner.invoke(cli, ["populate", "--nb", "10", "users"]) res = runner.invoke(cli, ["populate", "--nb", "10", "users"])
assert res.exit_code == 0, res.stdout assert res.exit_code == 0, res.stdout
assert len(models.User.query()) == 10 assert len(backend.query(models.User)) == 10
for user in models.User.query(): for user in backend.query(models.User):
user.delete() user.delete()
@ -18,13 +18,13 @@ def test_populate_groups(testclient, backend):
fake_users(10) fake_users(10)
runner = testclient.app.test_cli_runner() runner = testclient.app.test_cli_runner()
assert len(models.Group.query()) == 0 assert len(backend.query(models.Group)) == 0
res = runner.invoke(cli, ["populate", "--nb", "10", "groups"]) res = runner.invoke(cli, ["populate", "--nb", "10", "groups"])
assert res.exit_code == 0, res.stdout assert res.exit_code == 0, res.stdout
assert len(models.Group.query()) == 10 assert len(backend.query(models.Group)) == 10
for group in models.Group.query(): for group in backend.query(models.Group):
group.delete() group.delete()
for user in models.User.query(): for user in backend.query(models.User):
user.delete() user.delete()

View file

@ -4,7 +4,7 @@ from canaille.core.populate import fake_users
def test_no_group(app, backend): def test_no_group(app, backend):
assert models.Group.query() == [] assert backend.query(models.Group) == []
def test_group_list_pagination(testclient, logged_admin, foo_group): def test_group_list_pagination(testclient, logged_admin, foo_group):

View file

@ -12,7 +12,7 @@ def test_registration_without_email_validation(testclient, backend, foo_group):
testclient.app.config["CANAILLE"]["ENABLE_REGISTRATION"] = True testclient.app.config["CANAILLE"]["ENABLE_REGISTRATION"] = True
testclient.app.config["CANAILLE"]["EMAIL_CONFIRMATION"] = False testclient.app.config["CANAILLE"]["EMAIL_CONFIRMATION"] = False
assert not models.User.query(user_name="newuser") assert not backend.query(models.User, user_name="newuser")
res = testclient.get(url_for("core.account.registration"), status=200) res = testclient.get(url_for("core.account.registration"), status=200)
res.form["user_name"] = "newuser" res.form["user_name"] = "newuser"
res.form["password1"] = "password" res.form["password1"] = "password"
@ -60,7 +60,7 @@ def test_registration_with_email_validation(testclient, backend, smtpd, foo_grou
text_mail = smtpd.messages[0].get_payload()[0].get_payload(decode=True).decode() text_mail = smtpd.messages[0].get_payload()[0].get_payload(decode=True).decode()
assert registration_url in text_mail assert registration_url in text_mail
assert not models.User.query(user_name="newuser") assert not backend.query(models.User, user_name="newuser")
with time_machine.travel("2020-01-01 02:01:00+00:00", tick=False): with time_machine.travel("2020-01-01 02:01:00+00:00", tick=False):
res = testclient.get(registration_url, status=200) res = testclient.get(registration_url, status=200)
res.form["user_name"] = "newuser" res.form["user_name"] = "newuser"

View file

@ -13,8 +13,10 @@ from canaille.app import models
from . import client_credentials from . import client_credentials
def test_nominal_case(testclient, logged_user, client, keypair, trusted_client): def test_nominal_case(
assert not models.Consent.query() testclient, logged_user, client, keypair, trusted_client, backend
):
assert not backend.query(models.Consent)
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
@ -43,7 +45,7 @@ def test_nominal_case(testclient, logged_user, client, keypair, trusted_client):
"phone", "phone",
} }
consents = models.Consent.query(client=client, subject=logged_user) consents = backend.query(models.Consent, client=client, subject=logged_user)
assert set(consents[0].scope) == { assert set(consents[0].scope) == {
"openid", "openid",
"profile", "profile",
@ -112,8 +114,10 @@ def test_invalid_client(testclient, logged_user, keypair):
) )
def test_redirect_uri(testclient, logged_user, client, keypair, trusted_client): def test_redirect_uri(
assert not models.Consent.query() testclient, logged_user, client, keypair, trusted_client, backend
):
assert not backend.query(models.Consent)
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
@ -134,7 +138,7 @@ def test_redirect_uri(testclient, logged_user, client, keypair, trusted_client):
code = params["code"][0] code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code) authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None assert authcode is not None
consents = models.Consent.query(client=client, subject=logged_user) consents = backend.query(models.Consent, client=client, subject=logged_user)
res = testclient.post( res = testclient.post(
"/oauth/token", "/oauth/token",
@ -157,8 +161,10 @@ def test_redirect_uri(testclient, logged_user, client, keypair, trusted_client):
consent.delete() consent.delete()
def test_preconsented_client(testclient, logged_user, client, keypair, trusted_client): def test_preconsented_client(
assert not models.Consent.query() testclient, logged_user, client, keypair, trusted_client, backend
):
assert not backend.query(models.Consent)
client.preconsent = True client.preconsent = True
client.save() client.save()
@ -180,7 +186,7 @@ def test_preconsented_client(testclient, logged_user, client, keypair, trusted_c
authcode = models.AuthorizationCode.get(code=code) authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None assert authcode is not None
consents = models.Consent.query(client=client, subject=logged_user) consents = backend.query(models.Consent, client=client, subject=logged_user)
assert not consents assert not consents
res = testclient.post( res = testclient.post(
@ -214,8 +220,8 @@ def test_preconsented_client(testclient, logged_user, client, keypair, trusted_c
assert res.json["name"] == "John (johnny) Doe" assert res.json["name"] == "John (johnny) Doe"
def test_logout_login(testclient, logged_user, client): def test_logout_login(testclient, logged_user, client, backend):
assert not models.Consent.query() assert not backend.query(models.Consent)
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
@ -254,7 +260,7 @@ def test_logout_login(testclient, logged_user, client):
authcode = models.AuthorizationCode.get(code=code) authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None assert authcode is not None
consents = models.Consent.query(client=client, subject=logged_user) consents = backend.query(models.Consent, client=client, subject=logged_user)
assert "profile" in consents[0].scope assert "profile" in consents[0].scope
res = testclient.post( res = testclient.post(
@ -285,8 +291,8 @@ def test_logout_login(testclient, logged_user, client):
consent.delete() consent.delete()
def test_deny(testclient, logged_user, client): def test_deny(testclient, logged_user, client, backend):
assert not models.Consent.query() assert not backend.query(models.Consent)
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
@ -305,11 +311,11 @@ def test_deny(testclient, logged_user, client):
error = params["error"][0] error = params["error"][0]
assert error == "access_denied" assert error == "access_denied"
assert not models.Consent.query() assert not backend.query(models.Consent)
def test_code_challenge(testclient, logged_user, client): def test_code_challenge(testclient, logged_user, client, backend):
assert not models.Consent.query() assert not backend.query(models.Consent)
client.token_endpoint_auth_method = "none" client.token_endpoint_auth_method = "none"
client.save() client.save()
@ -338,7 +344,7 @@ def test_code_challenge(testclient, logged_user, client):
authcode = models.AuthorizationCode.get(code=code) authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None assert authcode is not None
consents = models.Consent.query(client=client, subject=logged_user) consents = backend.query(models.Consent, client=client, subject=logged_user)
assert "profile" in consents[0].scope assert "profile" in consents[0].scope
res = testclient.post( res = testclient.post(
@ -373,8 +379,8 @@ def test_code_challenge(testclient, logged_user, client):
consent.delete() consent.delete()
def test_consent_already_given(testclient, logged_user, client): def test_consent_already_given(testclient, logged_user, client, backend):
assert not models.Consent.query() assert not backend.query(models.Consent)
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
@ -395,7 +401,7 @@ def test_consent_already_given(testclient, logged_user, client):
authcode = models.AuthorizationCode.get(code=code) authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None assert authcode is not None
consents = models.Consent.query(client=client, subject=logged_user) consents = backend.query(models.Consent, client=client, subject=logged_user)
assert "profile" in consents[0].scope assert "profile" in consents[0].scope
res = testclient.post( res = testclient.post(
@ -430,9 +436,9 @@ def test_consent_already_given(testclient, logged_user, client):
def test_when_consent_already_given_but_for_a_smaller_scope( def test_when_consent_already_given_but_for_a_smaller_scope(
testclient, logged_user, client testclient, logged_user, client, backend
): ):
assert not models.Consent.query() assert not backend.query(models.Consent)
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
@ -453,7 +459,7 @@ def test_when_consent_already_given_but_for_a_smaller_scope(
authcode = models.AuthorizationCode.get(code=code) authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None assert authcode is not None
consents = models.Consent.query(client=client, subject=logged_user) consents = backend.query(models.Consent, client=client, subject=logged_user)
assert "profile" in consents[0].scope assert "profile" in consents[0].scope
assert "groups" not in consents[0].scope assert "groups" not in consents[0].scope
@ -489,7 +495,7 @@ def test_when_consent_already_given_but_for_a_smaller_scope(
authcode = models.AuthorizationCode.get(code=code) authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None assert authcode is not None
consents = models.Consent.query(client=client, subject=logged_user) consents = backend.query(models.Consent, client=client, subject=logged_user)
assert "profile" in consents[0].scope assert "profile" in consents[0].scope
assert "groups" in consents[0].scope assert "groups" in consents[0].scope
@ -535,8 +541,8 @@ def test_nonce_required_in_oidc_requests(testclient, logged_user, client):
assert res.json.get("error") == "invalid_request" assert res.json.get("error") == "invalid_request"
def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client): def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client, backend):
assert not models.Consent.query() assert not backend.query(models.Consent)
testclient.app.config["CANAILLE_OIDC"]["REQUIRE_NONCE"] = False testclient.app.config["CANAILLE_OIDC"]["REQUIRE_NONCE"] = False
res = testclient.get( res = testclient.get(
@ -552,12 +558,12 @@ def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client):
res = res.form.submit(name="answer", value="accept", status=302) res = res.form.submit(name="answer", value="accept", status=302)
assert res.location.startswith(client.redirect_uris[0]) assert res.location.startswith(client.redirect_uris[0])
for consent in models.Consent.query(): for consent in backend.query(models.Consent):
consent.delete() consent.delete()
def test_request_scope_too_large(testclient, logged_user, keypair, client): def test_request_scope_too_large(testclient, logged_user, keypair, client, backend):
assert not models.Consent.query() assert not backend.query(models.Consent)
client.scope = ["openid", "profile", "groups"] client.scope = ["openid", "profile", "groups"]
client.save() client.save()
@ -582,7 +588,7 @@ def test_request_scope_too_large(testclient, logged_user, keypair, client):
"profile", "profile",
} }
consents = models.Consent.query(client=client, subject=logged_user) consents = backend.query(models.Consent, client=client, subject=logged_user)
assert set(consents[0].scope) == { assert set(consents[0].scope) == {
"openid", "openid",
"profile", "profile",

View file

@ -95,7 +95,7 @@ def test_someone_else_consent_restoration(
def test_oidc_authorization_after_revokation( def test_oidc_authorization_after_revokation(
testclient, logged_user, client, keypair, consent testclient, logged_user, client, keypair, consent, backend
): ):
consent.revoke() consent.revoke()
@ -114,7 +114,7 @@ def test_oidc_authorization_after_revokation(
res = res.form.submit(name="answer", value="accept", status=302) res = res.form.submit(name="answer", value="accept", status=302)
consents = models.Consent.query(client=client, subject=logged_user) consents = backend.query(models.Consent, client=client, subject=logged_user)
consent.reload() consent.reload()
assert consents[0] == consent assert consents[0] == consent
assert not consent.revoked assert not consent.revoked

View file

@ -5,8 +5,8 @@ from canaille.app import models
# forms. # forms.
def test_fieldlist_add(testclient, logged_admin): def test_fieldlist_add(testclient, logged_admin, backend):
assert not models.Client.query() assert not backend.query(models.Client)
res = testclient.get("/admin/client/add") res = testclient.get("/admin/client/add")
assert "redirect_uris-1" not in res.form.fields assert "redirect_uris-1" not in res.form.fields
@ -23,7 +23,7 @@ def test_fieldlist_add(testclient, logged_admin):
res.form[k].force_value(v) res.form[k].force_value(v)
res = res.form.submit(status=200, name="fieldlist_add", value="redirect_uris-0") res = res.form.submit(status=200, name="fieldlist_add", value="redirect_uris-0")
assert not models.Client.query() assert not backend.query(models.Client)
data["redirect_uris-1"] = "https://foo.bar/callback2" data["redirect_uris-1"] = "https://foo.bar/callback2"
for k, v in data.items(): for k, v in data.items():
@ -43,8 +43,8 @@ def test_fieldlist_add(testclient, logged_admin):
client.delete() client.delete()
def test_fieldlist_delete(testclient, logged_admin): def test_fieldlist_delete(testclient, logged_admin, backend):
assert not models.Client.query() assert not backend.query(models.Client)
res = testclient.get("/admin/client/add") res = testclient.get("/admin/client/add")
data = { data = {
@ -61,7 +61,7 @@ def test_fieldlist_delete(testclient, logged_admin):
res.form["redirect_uris-1"] = "https://foo.bar/callback2" res.form["redirect_uris-1"] = "https://foo.bar/callback2"
res = res.form.submit(status=200, name="fieldlist_remove", value="redirect_uris-1") res = res.form.submit(status=200, name="fieldlist_remove", value="redirect_uris-1")
assert not models.Client.query() assert not backend.query(models.Client)
assert "redirect_uris-1" not in res.form.fields assert "redirect_uris-1" not in res.form.fields
res = res.form.submit(status=302, name="action", value="edit") res = res.form.submit(status=302, name="action", value="edit")
@ -92,8 +92,8 @@ def test_fieldlist_add_invalid_field(testclient, logged_admin):
testclient.post("/admin/client/add", data, status=400) testclient.post("/admin/client/add", data, status=400)
def test_fieldlist_delete_invalid_field(testclient, logged_admin): def test_fieldlist_delete_invalid_field(testclient, logged_admin, backend):
assert not models.Client.query() assert not backend.query(models.Client)
res = testclient.get("/admin/client/add") res = testclient.get("/admin/client/add")
data = { data = {

View file

@ -7,8 +7,8 @@ from canaille.app import models
from . import client_credentials from . import client_credentials
def test_refresh_token(testclient, logged_user, client): def test_refresh_token(testclient, logged_user, client, backend):
assert not models.Consent.query() assert not backend.query(models.Consent)
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
@ -27,7 +27,7 @@ def test_refresh_token(testclient, logged_user, client):
authcode = models.AuthorizationCode.get(code=code) authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None assert authcode is not None
consents = models.Consent.query(client=client, subject=logged_user) consents = backend.query(models.Consent, client=client, subject=logged_user)
assert "profile" in consents[0].scope assert "profile" in consents[0].scope
res = testclient.post( res = testclient.post(

View file

@ -9,7 +9,9 @@ from canaille.oidc.oauth import setup_oauth
from . import client_credentials from . import client_credentials
def test_token_default_expiration_date(testclient, logged_user, client, keypair): def test_token_default_expiration_date(
testclient, logged_user, client, keypair, backend
):
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
params=dict( params=dict(
@ -52,12 +54,14 @@ def test_token_default_expiration_date(testclient, logged_user, client, keypair)
claims = jwt.decode(id_token, keypair[1]) claims = jwt.decode(id_token, keypair[1])
assert claims["exp"] - claims["iat"] == 3600 assert claims["exp"] - claims["iat"] == 3600
consents = models.Consent.query(client=client, subject=logged_user) consents = backend.query(models.Consent, client=client, subject=logged_user)
for consent in consents: for consent in consents:
consent.delete() consent.delete()
def test_token_custom_expiration_date(testclient, logged_user, client, keypair): def test_token_custom_expiration_date(
testclient, logged_user, client, keypair, backend
):
testclient.app.config["OAUTH2_TOKEN_EXPIRES_IN"] = { testclient.app.config["OAUTH2_TOKEN_EXPIRES_IN"] = {
"authorization_code": 1000, "authorization_code": 1000,
"implicit": 2000, "implicit": 2000,
@ -110,6 +114,6 @@ def test_token_custom_expiration_date(testclient, logged_user, client, keypair):
claims = jwt.decode(id_token, keypair[1]) claims = jwt.decode(id_token, keypair[1])
assert claims["exp"] - claims["iat"] == 6000 assert claims["exp"] - claims["iat"] == 6000
consents = models.Consent.query(client=client, subject=logged_user) consents = backend.query(models.Consent, client=client, subject=logged_user)
for consent in consents: for consent in consents:
consent.delete() consent.delete()