forked from Github-Mirrors/canaille
refactor: move BackendModel.query to Backend.query
This commit is contained in:
parent
93fa708b1c
commit
8425b2a3b8
29 changed files with 284 additions and 247 deletions
|
@ -15,6 +15,7 @@ from canaille.app.i18n import DEFAULT_LANGUAGE_CODE
|
|||
from canaille.app.i18n import gettext as _
|
||||
from canaille.app.i18n import locale_selector
|
||||
from canaille.app.i18n import timezone_selector
|
||||
from canaille.backends import BaseBackend
|
||||
|
||||
from . import validate_uri
|
||||
from .flask import request_is_htmx
|
||||
|
@ -188,7 +189,7 @@ class TableForm(I18NFormMixin, FlaskForm):
|
|||
if self.query.data:
|
||||
self.items = cls.fuzzy(self.query.data, fields, **filter)
|
||||
else:
|
||||
self.items = cls.query(**filter)
|
||||
self.items = BaseBackend.get().query(cls, **filter)
|
||||
|
||||
self.page_size = page_size
|
||||
self.nb_items = len(self.items)
|
||||
|
|
|
@ -54,6 +54,25 @@ class BaseBackend:
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def query(self, model, **kwargs):
|
||||
"""
|
||||
Perform a query on the database and return a collection of instances.
|
||||
Parameters can be any valid attribute with the expected value:
|
||||
|
||||
>>> backend.query(User, first_name="George")
|
||||
|
||||
If several arguments are passed, the methods only returns the model
|
||||
instances that return matches all the argument values:
|
||||
|
||||
>>> backend.query(User, first_name="George", last_name="Abitbol")
|
||||
|
||||
If the argument value is a collection, the methods will return the
|
||||
models that matches any of the values:
|
||||
|
||||
>>> backend.query(User, first_name=["George", "Jane"])
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def check_user_password(self, user, password: str) -> bool:
|
||||
"""Check if the password matches the user password in the database."""
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -16,6 +16,7 @@ from canaille.app.i18n import gettext as _
|
|||
from canaille.backends import BaseBackend
|
||||
|
||||
from .utils import listify
|
||||
from .utils import python_attrs_to_ldap
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -243,13 +244,64 @@ class Backend(BaseBackend):
|
|||
return result, message
|
||||
|
||||
def set_user_password(self, user, password):
|
||||
conn = Backend.get().connection
|
||||
conn = self.connection
|
||||
conn.passwd_s(
|
||||
user.dn,
|
||||
None,
|
||||
password.encode("utf-8"),
|
||||
)
|
||||
|
||||
def query(self, model, dn=None, filter=None, **kwargs):
|
||||
from .ldapobjectquery import LDAPObjectQuery
|
||||
|
||||
base = dn
|
||||
if dn is None:
|
||||
base = f"{model.base},{model.root_dn}"
|
||||
elif "=" not in base:
|
||||
base = ldap.dn.escape_dn_chars(base)
|
||||
base = f"{model.rdn_attribute}={base},{model.base},{model.root_dn}"
|
||||
|
||||
class_filter = (
|
||||
"".join([f"(objectClass={oc})" for oc in model.ldap_object_class])
|
||||
if getattr(model, "ldap_object_class")
|
||||
else ""
|
||||
)
|
||||
if class_filter:
|
||||
class_filter = f"(|{class_filter})"
|
||||
|
||||
arg_filter = ""
|
||||
ldap_args = python_attrs_to_ldap(
|
||||
{
|
||||
model.python_attribute_to_ldap(name): values
|
||||
for name, values in kwargs.items()
|
||||
if values is not None
|
||||
},
|
||||
encode=False,
|
||||
)
|
||||
for key, value in ldap_args.items():
|
||||
if len(value) == 1:
|
||||
escaped_value = ldap.filter.escape_filter_chars(value[0])
|
||||
arg_filter += f"({key}={escaped_value})"
|
||||
|
||||
else:
|
||||
values = [ldap.filter.escape_filter_chars(v) for v in value]
|
||||
arg_filter += (
|
||||
"(|" + "".join([f"({key}={value})" for value in values]) + ")"
|
||||
)
|
||||
|
||||
if not filter:
|
||||
filter = ""
|
||||
|
||||
ldapfilter = f"(&{class_filter}{arg_filter}{filter})"
|
||||
base = base or f"{model.base},{model.root_dn}"
|
||||
try:
|
||||
result = self.connection.search_s(
|
||||
base, ldap.SCOPE_SUBTREE, ldapfilter or None, ["+", "*"]
|
||||
)
|
||||
except ldap.NO_SUCH_OBJECT:
|
||||
result = []
|
||||
return LDAPObjectQuery(model, result)
|
||||
|
||||
|
||||
def setup_ldap_models(config):
|
||||
from canaille.app import models
|
||||
|
|
|
@ -4,42 +4,15 @@ import ldap.dn
|
|||
import ldap.filter
|
||||
from ldap.controls.readentry import PostReadControl
|
||||
|
||||
from canaille.backends import BaseBackend
|
||||
from canaille.backends.models import BackendModel
|
||||
|
||||
from .backend import Backend
|
||||
from .ldapobjectquery import LDAPObjectQuery
|
||||
from .utils import attribute_ldap_syntax
|
||||
from .utils import cardinalize_attribute
|
||||
from .utils import ldap_to_python
|
||||
from .utils import listify
|
||||
from .utils import python_to_ldap
|
||||
|
||||
|
||||
def python_attrs_to_ldap(attrs, encode=True, null_allowed=True):
|
||||
formatted_attrs = {
|
||||
name: [
|
||||
python_to_ldap(value, attribute_ldap_syntax(name), encode=encode)
|
||||
for value in listify(values)
|
||||
]
|
||||
for name, values in attrs.items()
|
||||
}
|
||||
if not null_allowed:
|
||||
formatted_attrs = {
|
||||
key: [value for value in values if value]
|
||||
for key, values in formatted_attrs.items()
|
||||
if values
|
||||
}
|
||||
return formatted_attrs
|
||||
|
||||
|
||||
def attribute_ldap_syntax(attribute_name):
|
||||
ldap_attrs = LDAPObject.ldap_object_attributes()
|
||||
if attribute_name not in ldap_attrs:
|
||||
return None
|
||||
|
||||
if ldap_attrs[attribute_name].syntax:
|
||||
return ldap_attrs[attribute_name].syntax
|
||||
|
||||
return attribute_ldap_syntax(ldap_attrs[attribute_name].sup[0])
|
||||
from .utils import python_attrs_to_ldap
|
||||
|
||||
|
||||
class LDAPObjectMetaclass(type):
|
||||
|
@ -256,7 +229,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
|||
@classmethod
|
||||
def get(cls, identifier=None, /, **kwargs):
|
||||
try:
|
||||
return cls.query(identifier, **kwargs)[0]
|
||||
return BaseBackend.get().query(cls, identifier, **kwargs)[0]
|
||||
except (IndexError, ldap.NO_SUCH_OBJECT):
|
||||
if identifier and cls.base:
|
||||
return (
|
||||
|
@ -267,58 +240,6 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
|||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def query(cls, dn=None, filter=None, **kwargs):
|
||||
conn = Backend.get().connection
|
||||
|
||||
base = dn
|
||||
if dn is None:
|
||||
base = f"{cls.base},{cls.root_dn}"
|
||||
elif "=" not in base:
|
||||
base = ldap.dn.escape_dn_chars(base)
|
||||
base = f"{cls.rdn_attribute}={base},{cls.base},{cls.root_dn}"
|
||||
|
||||
class_filter = (
|
||||
"".join([f"(objectClass={oc})" for oc in cls.ldap_object_class])
|
||||
if getattr(cls, "ldap_object_class")
|
||||
else ""
|
||||
)
|
||||
if class_filter:
|
||||
class_filter = f"(|{class_filter})"
|
||||
|
||||
arg_filter = ""
|
||||
ldap_args = python_attrs_to_ldap(
|
||||
{
|
||||
cls.python_attribute_to_ldap(name): values
|
||||
for name, values in kwargs.items()
|
||||
if values is not None
|
||||
},
|
||||
encode=False,
|
||||
)
|
||||
for key, value in ldap_args.items():
|
||||
if len(value) == 1:
|
||||
escaped_value = ldap.filter.escape_filter_chars(value[0])
|
||||
arg_filter += f"({key}={escaped_value})"
|
||||
|
||||
else:
|
||||
values = [ldap.filter.escape_filter_chars(v) for v in value]
|
||||
arg_filter += (
|
||||
"(|" + "".join([f"({key}={value})" for value in values]) + ")"
|
||||
)
|
||||
|
||||
if not filter:
|
||||
filter = ""
|
||||
|
||||
ldapfilter = f"(&{class_filter}{arg_filter}{filter})"
|
||||
base = base or f"{cls.base},{cls.root_dn}"
|
||||
try:
|
||||
result = conn.search_s(
|
||||
base, ldap.SCOPE_SUBTREE, ldapfilter or None, ["+", "*"]
|
||||
)
|
||||
except ldap.NO_SUCH_OBJECT:
|
||||
result = []
|
||||
return LDAPObjectQuery(cls, result)
|
||||
|
||||
@classmethod
|
||||
def fuzzy(cls, query, attributes=None, **kwargs):
|
||||
query = ldap.filter.escape_filter_chars(query)
|
||||
|
@ -327,7 +248,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
|||
filter = (
|
||||
"(|" + "".join(f"({attribute}=*{query}*)" for attribute in attributes) + ")"
|
||||
)
|
||||
return cls.query(filter=filter, **kwargs)
|
||||
return BaseBackend.get().query(cls, filter=filter, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def update_ldap_attributes(cls):
|
||||
|
|
|
@ -95,3 +95,33 @@ def cardinalize_attribute(python_unique, value):
|
|||
return value[0]
|
||||
|
||||
return [v for v in value if v is not None]
|
||||
|
||||
|
||||
def python_attrs_to_ldap(attrs, encode=True, null_allowed=True):
|
||||
formatted_attrs = {
|
||||
name: [
|
||||
python_to_ldap(value, attribute_ldap_syntax(name), encode=encode)
|
||||
for value in listify(values)
|
||||
]
|
||||
for name, values in attrs.items()
|
||||
}
|
||||
if not null_allowed:
|
||||
formatted_attrs = {
|
||||
key: [value for value in values if value]
|
||||
for key, values in formatted_attrs.items()
|
||||
if values
|
||||
}
|
||||
return formatted_attrs
|
||||
|
||||
|
||||
def attribute_ldap_syntax(attribute_name):
|
||||
from .ldapobject import LDAPObject
|
||||
|
||||
ldap_attrs = LDAPObject.ldap_object_attributes()
|
||||
if attribute_name not in ldap_attrs:
|
||||
return None
|
||||
|
||||
if ldap_attrs[attribute_name].syntax:
|
||||
return ldap_attrs[attribute_name].syntax
|
||||
|
||||
return attribute_ldap_syntax(ldap_attrs[attribute_name].sup[0])
|
||||
|
|
|
@ -40,3 +40,28 @@ class Backend(BaseBackend):
|
|||
def set_user_password(self, user, password):
|
||||
user.password = password
|
||||
user.save()
|
||||
|
||||
def query(self, model, **kwargs):
|
||||
# if there is no filter, return all models
|
||||
if not kwargs:
|
||||
states = model.index().values()
|
||||
return [model(**state) for state in states]
|
||||
|
||||
# get the ids from the attribute indexes
|
||||
ids = {
|
||||
id
|
||||
for attribute, values in kwargs.items()
|
||||
for value in model.serialize(model.listify(values))
|
||||
for id in model.attribute_index(attribute).get(value, [])
|
||||
}
|
||||
|
||||
# get the states from the ids
|
||||
states = [model.index()[id] for id in ids]
|
||||
|
||||
# initialize instances from the states
|
||||
instances = [model(**state) for state in states]
|
||||
for instance in instances:
|
||||
# TODO: maybe find a way to not initialize the cache in the first place?
|
||||
instance._cache = {}
|
||||
|
||||
return instances
|
||||
|
|
|
@ -6,6 +6,7 @@ import uuid
|
|||
import canaille.core.models
|
||||
import canaille.oidc.models
|
||||
from canaille.app import models
|
||||
from canaille.backends import BaseBackend
|
||||
from canaille.backends.models import BackendModel
|
||||
|
||||
|
||||
|
@ -25,31 +26,6 @@ class MemoryModel(BackendModel):
|
|||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__} id={self.id}>"
|
||||
|
||||
@classmethod
|
||||
def query(cls, **kwargs):
|
||||
# if there is no filter, return all models
|
||||
if not kwargs:
|
||||
states = cls.index().values()
|
||||
return [cls(**state) for state in states]
|
||||
|
||||
# get the ids from the attribute indexes
|
||||
ids = {
|
||||
id
|
||||
for attribute, values in kwargs.items()
|
||||
for value in cls.serialize(cls.listify(values))
|
||||
for id in cls.attribute_index(attribute).get(value, [])
|
||||
}
|
||||
|
||||
# get the states from the ids
|
||||
states = [cls.index()[id] for id in ids]
|
||||
|
||||
# initialize instances from the states
|
||||
instances = [cls(**state) for state in states]
|
||||
for instance in instances:
|
||||
# TODO: maybe find a way to not initialize the cache in the first place?
|
||||
instance._cache = {}
|
||||
return instances
|
||||
|
||||
@classmethod
|
||||
def index(cls, class_name=None):
|
||||
return MemoryModel.indexes.setdefault(class_name or cls.__name__, {})
|
||||
|
@ -63,7 +39,7 @@ class MemoryModel(BackendModel):
|
|||
@classmethod
|
||||
def fuzzy(cls, query, attributes=None, **kwargs):
|
||||
attributes = attributes or cls.attributes
|
||||
instances = cls.query(**kwargs)
|
||||
instances = BaseBackend.get().query(cls, **kwargs)
|
||||
|
||||
return [
|
||||
instance
|
||||
|
@ -85,7 +61,7 @@ class MemoryModel(BackendModel):
|
|||
or None
|
||||
)
|
||||
|
||||
results = cls.query(**kwargs)
|
||||
results = BaseBackend.get().query(cls, **kwargs)
|
||||
return results[0] if results else None
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -87,37 +87,16 @@ class BackendModel:
|
|||
implemented for every model and for every backend.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def query(cls, **kwargs):
|
||||
"""Perform a query on the database and return a collection of
|
||||
instances.
|
||||
|
||||
Parameters can be any valid attribute with the expected value:
|
||||
|
||||
>>> User.query(first_name="George")
|
||||
|
||||
If several arguments are passed, the methods only returns the model
|
||||
instances that return matches all the argument values:
|
||||
|
||||
>>> User.query(first_name="George", last_name="Abitbol")
|
||||
|
||||
If the argument value is a collection, the methods will return the
|
||||
models that matches any of the values:
|
||||
|
||||
>>> User.query(first_name=["George", "Jane"])
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def fuzzy(cls, query, attributes=None, **kwargs):
|
||||
"""Works like :meth:`~canaille.backends.models.BackendModel.query` but
|
||||
"""Works like :meth:`~canaille.backends.BaseBackend.query` but
|
||||
attribute values loosely be matched."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def get(cls, identifier=None, **kwargs):
|
||||
"""Works like :meth:`~canaille.backends.models.BackendModel.query` but
|
||||
return only one element or :py:data:`None` if no item is matching."""
|
||||
"""Works like :meth:`~canaille.backends.BaseBackend.query` but return
|
||||
only one element or :py:data:`None` if no item is matching."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def save(self):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
|
@ -65,3 +66,15 @@ class Backend(BaseBackend):
|
|||
def set_user_password(self, user, password):
|
||||
user.password = password
|
||||
user.save()
|
||||
|
||||
def query(self, model, **kwargs):
|
||||
filter = [
|
||||
model.attribute_filter(attribute_name, expected_value)
|
||||
for attribute_name, expected_value in kwargs.items()
|
||||
]
|
||||
return (
|
||||
Backend.get()
|
||||
.db_session.execute(select(model).filter(*filter))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
|
|
@ -36,19 +36,6 @@ class SqlAlchemyModel(BackendModel):
|
|||
f"<{self.__class__.__name__} {self.identifier_attribute}={self.identifier}>"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def query(cls, **kwargs):
|
||||
filter = [
|
||||
cls.attribute_filter(attribute_name, expected_value)
|
||||
for attribute_name, expected_value in kwargs.items()
|
||||
]
|
||||
return (
|
||||
Backend.get()
|
||||
.db_session.execute(select(cls).filter(*filter))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def fuzzy(cls, query, attributes=None, **kwargs):
|
||||
attributes = attributes or cls.attributes
|
||||
|
|
|
@ -86,7 +86,7 @@ def join():
|
|||
|
||||
form = JoinForm(request.form or None)
|
||||
if request.form and form.validate():
|
||||
if models.User.query(emails=form.email.data):
|
||||
if BaseBackend.get().query(models.User, emails=form.email.data):
|
||||
flash(
|
||||
_(
|
||||
"You will receive soon an email to continue the registration process."
|
||||
|
@ -295,7 +295,10 @@ def registration(data=None, hash=None):
|
|||
if "groups" not in form and payload and payload.groups:
|
||||
form["groups"] = wtforms.SelectMultipleField(
|
||||
_("Groups"),
|
||||
choices=[(group, group.display_name) for group in models.Group.query()],
|
||||
choices=[
|
||||
(group, group.display_name)
|
||||
for group in BaseBackend.get().query(models.Group)
|
||||
],
|
||||
coerce=IDToModel("Group"),
|
||||
)
|
||||
set_readonly(form["groups"])
|
||||
|
@ -388,7 +391,7 @@ def email_confirmation(data, hash):
|
|||
)
|
||||
return redirect(url_for("core.account.index"))
|
||||
|
||||
if models.User.query(emails=confirmation_obj.email):
|
||||
if BaseBackend.get().query(models.User, emails=confirmation_obj.email):
|
||||
flash(
|
||||
_("This address email is already associated with another account."),
|
||||
"error",
|
||||
|
|
|
@ -312,7 +312,10 @@ PROFILE_FORM_FIELDS = dict(
|
|||
groups=wtforms.SelectMultipleField(
|
||||
_("Groups"),
|
||||
default=[],
|
||||
choices=lambda: [(group, group.display_name) for group in models.Group.query()],
|
||||
choices=lambda: [
|
||||
(group, group.display_name)
|
||||
for group in BaseBackend.get().query(models.Group)
|
||||
],
|
||||
render_kw={"placeholder": _("users, admins …")},
|
||||
coerce=IDToModel("Group"),
|
||||
validators=[non_empty_groups],
|
||||
|
@ -333,7 +336,7 @@ def build_profile_form(write_field_names, readonly_field_names, user=None):
|
|||
if PROFILE_FORM_FIELDS.get(name)
|
||||
}
|
||||
|
||||
if "groups" in fields and not models.Group.query():
|
||||
if "groups" in fields and not BaseBackend.get().query(models.Group):
|
||||
del fields["groups"]
|
||||
|
||||
if current_app.backend.get().has_account_lockability(): # pragma: no branch
|
||||
|
@ -436,7 +439,10 @@ class InvitationForm(Form):
|
|||
)
|
||||
groups = wtforms.SelectMultipleField(
|
||||
_("Groups"),
|
||||
choices=lambda: [(group, group.display_name) for group in models.Group.query()],
|
||||
choices=lambda: [
|
||||
(group, group.display_name)
|
||||
for group in BaseBackend.get().query(models.Group)
|
||||
],
|
||||
render_kw={},
|
||||
coerce=IDToModel("Group"),
|
||||
)
|
||||
|
|
|
@ -5,6 +5,7 @@ from faker.config import AVAILABLE_LOCALES
|
|||
|
||||
from canaille.app import models
|
||||
from canaille.app.i18n import available_language_codes
|
||||
from canaille.backends import BaseBackend
|
||||
|
||||
|
||||
def fake_users(nb=1):
|
||||
|
@ -47,7 +48,7 @@ def fake_users(nb=1):
|
|||
|
||||
|
||||
def fake_groups(nb=1, nb_users_max=1):
|
||||
users = models.User.query()
|
||||
users = BaseBackend.get().query(models.User)
|
||||
groups = list()
|
||||
fake = faker.Faker(["en_US"])
|
||||
for _ in range(nb):
|
||||
|
|
|
@ -3,6 +3,7 @@ from flask.cli import with_appcontext
|
|||
|
||||
from canaille.app import models
|
||||
from canaille.app.commands import with_backendcontext
|
||||
from canaille.backends import BaseBackend
|
||||
|
||||
|
||||
@click.command()
|
||||
|
@ -10,11 +11,11 @@ from canaille.app.commands import with_backendcontext
|
|||
@with_backendcontext
|
||||
def clean():
|
||||
"""Remove expired tokens and authorization codes."""
|
||||
for t in models.Token.query():
|
||||
for t in BaseBackend.get().query(models.Token):
|
||||
if t.is_expired():
|
||||
t.delete()
|
||||
|
||||
for a in models.AuthorizationCode.query():
|
||||
for a in BaseBackend.get().query(models.AuthorizationCode):
|
||||
if a.is_expired():
|
||||
a.delete()
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from canaille.app import models
|
|||
from canaille.app.flask import user_needed
|
||||
from canaille.app.i18n import gettext as _
|
||||
from canaille.app.themes import render_template
|
||||
from canaille.backends import BaseBackend
|
||||
|
||||
from ..utils import SCOPE_DETAILS
|
||||
|
||||
|
@ -19,13 +20,13 @@ bp = Blueprint("consents", __name__, url_prefix="/consent")
|
|||
@bp.route("/")
|
||||
@user_needed()
|
||||
def consents(user):
|
||||
consents = models.Consent.query(subject=user)
|
||||
consents = BaseBackend.get().query(models.Consent, subject=user)
|
||||
clients = {t.client for t in consents}
|
||||
|
||||
nb_consents = len(consents)
|
||||
nb_preconsents = sum(
|
||||
1
|
||||
for client in models.Client.query()
|
||||
for client in BaseBackend.get().query(models.Client)
|
||||
if client.preconsent and client not in clients
|
||||
)
|
||||
|
||||
|
@ -43,11 +44,11 @@ def consents(user):
|
|||
@bp.route("/pre-consents")
|
||||
@user_needed()
|
||||
def pre_consents(user):
|
||||
consents = models.Consent.query(subject=user)
|
||||
consents = BaseBackend.get().query(models.Consent, subject=user)
|
||||
clients = {t.client for t in consents}
|
||||
preconsented = [
|
||||
client
|
||||
for client in models.Client.query()
|
||||
for client in BaseBackend.get().query(models.Client)
|
||||
if client.preconsent and client not in clients
|
||||
]
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ from canaille.app.forms import email_validator
|
|||
from canaille.app.forms import is_uri
|
||||
from canaille.app.forms import unique_values
|
||||
from canaille.app.i18n import lazy_gettext as _
|
||||
from canaille.backends import BaseBackend
|
||||
|
||||
|
||||
class AuthorizeForm(Form):
|
||||
|
@ -18,7 +19,10 @@ class LogoutForm(Form):
|
|||
|
||||
|
||||
def client_audiences():
|
||||
return [(client, client.client_name) for client in models.Client.query()]
|
||||
return [
|
||||
(client, client.client_name)
|
||||
for client in BaseBackend.get().query(models.Client)
|
||||
]
|
||||
|
||||
|
||||
class ClientAddForm(Form):
|
||||
|
|
|
@ -23,6 +23,7 @@ from canaille.app.flask import logout_user
|
|||
from canaille.app.flask import set_parameter_in_url_query
|
||||
from canaille.app.i18n import gettext as _
|
||||
from canaille.app.themes import render_template
|
||||
from canaille.backends import BaseBackend
|
||||
|
||||
from ..oauth import ClientConfigurationEndpoint
|
||||
from ..oauth import ClientRegistrationEndpoint
|
||||
|
@ -109,7 +110,8 @@ def authorize_login(user):
|
|||
def authorize_consent(client, user):
|
||||
requested_scopes = request.args.get("scope", "").split(" ")
|
||||
allowed_scopes = client.get_allowed_scope(requested_scopes).split(" ")
|
||||
consents = models.Consent.query(
|
||||
consents = BaseBackend.get().query(
|
||||
models.Consent,
|
||||
client=client,
|
||||
subject=user,
|
||||
)
|
||||
|
|
|
@ -8,6 +8,7 @@ from authlib.oauth2.rfc6749 import TokenMixin
|
|||
from authlib.oauth2.rfc6749 import util
|
||||
|
||||
from canaille.app import models
|
||||
from canaille.backends import BaseBackend
|
||||
|
||||
from .basemodels import AuthorizationCode as BaseAuthorizationCode
|
||||
from .basemodels import Client as BaseClient
|
||||
|
@ -95,13 +96,13 @@ class Client(BaseClient, ClientMixin):
|
|||
return metadata
|
||||
|
||||
def delete(self):
|
||||
for consent in models.Consent.query(client=self):
|
||||
for consent in BaseBackend.get().query(models.Consent, client=self):
|
||||
consent.delete()
|
||||
|
||||
for code in models.AuthorizationCode.query(client=self):
|
||||
for code in BaseBackend.get().query(models.AuthorizationCode, client=self):
|
||||
code.delete()
|
||||
|
||||
for token in models.Token.query(client=self):
|
||||
for token in BaseBackend.get().query(models.Token, client=self):
|
||||
token.delete()
|
||||
|
||||
super().delete()
|
||||
|
@ -185,7 +186,8 @@ class Consent(BaseConsent):
|
|||
self.revokation_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
self.save()
|
||||
|
||||
tokens = models.Token.query(
|
||||
tokens = BaseBackend.get().query(
|
||||
models.Token,
|
||||
client=self.client,
|
||||
subject=self.subject,
|
||||
)
|
||||
|
|
|
@ -112,7 +112,9 @@ def openid_configuration():
|
|||
|
||||
def exists_nonce(nonce, req):
|
||||
client = models.Client.get(id=req.client_id)
|
||||
exists = models.AuthorizationCode.query(client=client, nonce=nonce)
|
||||
exists = BaseBackend.get().query(
|
||||
models.AuthorizationCode, client=client, nonce=nonce
|
||||
)
|
||||
return bool(exists)
|
||||
|
||||
|
||||
|
@ -237,7 +239,9 @@ class AuthorizationCodeGrant(_AuthorizationCodeGrant):
|
|||
return save_authorization_code(code, request)
|
||||
|
||||
def query_authorization_code(self, code, client):
|
||||
item = models.AuthorizationCode.query(code=code, client=client)
|
||||
item = BaseBackend.get().query(
|
||||
models.AuthorizationCode, code=code, client=client
|
||||
)
|
||||
if item and not item[0].is_expired():
|
||||
return item[0]
|
||||
|
||||
|
@ -283,7 +287,7 @@ class RefreshTokenGrant(_RefreshTokenGrant):
|
|||
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
|
||||
|
||||
def authenticate_refresh_token(self, refresh_token):
|
||||
token = models.Token.query(refresh_token=refresh_token)
|
||||
token = BaseBackend.get().query(models.Token, refresh_token=refresh_token)
|
||||
if token and token[0].is_refresh_token_active():
|
||||
return token[0]
|
||||
|
||||
|
|
|
@ -81,15 +81,15 @@ def test_special_chars_in_rdn(testclient, backend):
|
|||
|
||||
|
||||
def test_filter(backend, foo_group, bar_group):
|
||||
assert models.Group.query(display_name="foo") == [foo_group]
|
||||
assert models.Group.query(display_name="bar") == [bar_group]
|
||||
assert backend.query(models.Group, display_name="foo") == [foo_group]
|
||||
assert backend.query(models.Group, display_name="bar") == [bar_group]
|
||||
|
||||
assert models.Group.query(display_name="foo") != 3
|
||||
assert backend.query(models.Group, display_name="foo") != 3
|
||||
|
||||
assert models.Group.query(display_name=["foo"]) == [foo_group]
|
||||
assert models.Group.query(display_name=["bar"]) == [bar_group]
|
||||
assert backend.query(models.Group, display_name=["foo"]) == [foo_group]
|
||||
assert backend.query(models.Group, display_name=["bar"]) == [bar_group]
|
||||
|
||||
assert set(models.Group.query(display_name=["foo", "bar"])) == {
|
||||
assert set(backend.query(models.Group, display_name=["foo", "bar"])) == {
|
||||
foo_group,
|
||||
bar_group,
|
||||
}
|
||||
|
|
|
@ -36,16 +36,16 @@ def test_model_lifecycle(testclient, backend):
|
|||
)
|
||||
|
||||
assert not user.id
|
||||
assert not models.User.query()
|
||||
assert not models.User.query(id=user.id)
|
||||
assert not models.User.query(id="invalid")
|
||||
assert not backend.query(models.User)
|
||||
assert not backend.query(models.User, id=user.id)
|
||||
assert not backend.query(models.User, id="invalid")
|
||||
assert not models.User.get(id=user.id)
|
||||
|
||||
user.save()
|
||||
|
||||
assert models.User.query() == [user]
|
||||
assert models.User.query(id=user.id) == [user]
|
||||
assert not models.User.query(id="invalid")
|
||||
assert backend.query(models.User) == [user]
|
||||
assert backend.query(models.User, id=user.id) == [user]
|
||||
assert not backend.query(models.User, id="invalid")
|
||||
assert models.User.get(id=user.id) == user
|
||||
|
||||
user.family_name = "new_family_name"
|
||||
|
@ -58,7 +58,7 @@ def test_model_lifecycle(testclient, backend):
|
|||
|
||||
user.delete()
|
||||
|
||||
assert not models.User.query(id=user.id)
|
||||
assert not backend.query(models.User, id=user.id)
|
||||
assert not models.User.get(id=user.id)
|
||||
|
||||
user.delete()
|
||||
|
@ -143,7 +143,7 @@ def test_model_indexation(testclient, backend):
|
|||
|
||||
|
||||
def test_fuzzy_unique_attribute(user, moderator, admin, backend):
|
||||
assert set(models.User.query()) == {user, moderator, admin}
|
||||
assert set(backend.query(models.User)) == {user, moderator, admin}
|
||||
assert set(models.User.fuzzy("Jack")) == {moderator}
|
||||
assert set(models.User.fuzzy("Jack", ["formatted_name"])) == {moderator}
|
||||
assert set(models.User.fuzzy("Jack", ["user_name"])) == set()
|
||||
|
@ -157,7 +157,7 @@ def test_fuzzy_unique_attribute(user, moderator, admin, backend):
|
|||
|
||||
|
||||
def test_fuzzy_multiple_attribute(user, moderator, admin, backend):
|
||||
assert set(models.User.query()) == {user, moderator, admin}
|
||||
assert set(backend.query(models.User)) == {user, moderator, admin}
|
||||
assert set(models.User.fuzzy("jack@doe.com")) == {moderator}
|
||||
assert set(models.User.fuzzy("jack@doe.com", ["emails"])) == {moderator}
|
||||
assert set(models.User.fuzzy("jack@doe.com", ["formatted_name"])) == set()
|
||||
|
@ -171,8 +171,8 @@ def test_fuzzy_multiple_attribute(user, moderator, admin, backend):
|
|||
def test_model_references(testclient, user, foo_group, admin, bar_group, backend):
|
||||
assert foo_group.members == [user]
|
||||
assert user.groups == [foo_group]
|
||||
assert foo_group in models.Group.query(members=user)
|
||||
assert user in models.User.query(groups=foo_group)
|
||||
assert foo_group in backend.query(models.Group, members=user)
|
||||
assert user in backend.query(models.User, groups=foo_group)
|
||||
|
||||
assert user not in bar_group.members
|
||||
assert bar_group not in user.groups
|
||||
|
|
|
@ -6,11 +6,11 @@ from canaille.core.populate import fake_users
|
|||
def test_populate_users(testclient, backend):
|
||||
runner = testclient.app.test_cli_runner()
|
||||
|
||||
assert len(models.User.query()) == 0
|
||||
assert len(backend.query(models.User)) == 0
|
||||
res = runner.invoke(cli, ["populate", "--nb", "10", "users"])
|
||||
assert res.exit_code == 0, res.stdout
|
||||
assert len(models.User.query()) == 10
|
||||
for user in models.User.query():
|
||||
assert len(backend.query(models.User)) == 10
|
||||
for user in backend.query(models.User):
|
||||
user.delete()
|
||||
|
||||
|
||||
|
@ -18,13 +18,13 @@ def test_populate_groups(testclient, backend):
|
|||
fake_users(10)
|
||||
runner = testclient.app.test_cli_runner()
|
||||
|
||||
assert len(models.Group.query()) == 0
|
||||
assert len(backend.query(models.Group)) == 0
|
||||
res = runner.invoke(cli, ["populate", "--nb", "10", "groups"])
|
||||
assert res.exit_code == 0, res.stdout
|
||||
assert len(models.Group.query()) == 10
|
||||
assert len(backend.query(models.Group)) == 10
|
||||
|
||||
for group in models.Group.query():
|
||||
for group in backend.query(models.Group):
|
||||
group.delete()
|
||||
|
||||
for user in models.User.query():
|
||||
for user in backend.query(models.User):
|
||||
user.delete()
|
||||
|
|
|
@ -4,7 +4,7 @@ from canaille.core.populate import fake_users
|
|||
|
||||
|
||||
def test_no_group(app, backend):
|
||||
assert models.Group.query() == []
|
||||
assert backend.query(models.Group) == []
|
||||
|
||||
|
||||
def test_group_list_pagination(testclient, logged_admin, foo_group):
|
||||
|
|
|
@ -12,7 +12,7 @@ def test_registration_without_email_validation(testclient, backend, foo_group):
|
|||
testclient.app.config["CANAILLE"]["ENABLE_REGISTRATION"] = True
|
||||
testclient.app.config["CANAILLE"]["EMAIL_CONFIRMATION"] = False
|
||||
|
||||
assert not models.User.query(user_name="newuser")
|
||||
assert not backend.query(models.User, user_name="newuser")
|
||||
res = testclient.get(url_for("core.account.registration"), status=200)
|
||||
res.form["user_name"] = "newuser"
|
||||
res.form["password1"] = "password"
|
||||
|
@ -60,7 +60,7 @@ def test_registration_with_email_validation(testclient, backend, smtpd, foo_grou
|
|||
text_mail = smtpd.messages[0].get_payload()[0].get_payload(decode=True).decode()
|
||||
assert registration_url in text_mail
|
||||
|
||||
assert not models.User.query(user_name="newuser")
|
||||
assert not backend.query(models.User, user_name="newuser")
|
||||
with time_machine.travel("2020-01-01 02:01:00+00:00", tick=False):
|
||||
res = testclient.get(registration_url, status=200)
|
||||
res.form["user_name"] = "newuser"
|
||||
|
|
|
@ -13,8 +13,10 @@ from canaille.app import models
|
|||
from . import client_credentials
|
||||
|
||||
|
||||
def test_nominal_case(testclient, logged_user, client, keypair, trusted_client):
|
||||
assert not models.Consent.query()
|
||||
def test_nominal_case(
|
||||
testclient, logged_user, client, keypair, trusted_client, backend
|
||||
):
|
||||
assert not backend.query(models.Consent)
|
||||
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
|
@ -43,7 +45,7 @@ def test_nominal_case(testclient, logged_user, client, keypair, trusted_client):
|
|||
"phone",
|
||||
}
|
||||
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
assert set(consents[0].scope) == {
|
||||
"openid",
|
||||
"profile",
|
||||
|
@ -112,8 +114,10 @@ def test_invalid_client(testclient, logged_user, keypair):
|
|||
)
|
||||
|
||||
|
||||
def test_redirect_uri(testclient, logged_user, client, keypair, trusted_client):
|
||||
assert not models.Consent.query()
|
||||
def test_redirect_uri(
|
||||
testclient, logged_user, client, keypair, trusted_client, backend
|
||||
):
|
||||
assert not backend.query(models.Consent)
|
||||
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
|
@ -134,7 +138,7 @@ def test_redirect_uri(testclient, logged_user, client, keypair, trusted_client):
|
|||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
|
||||
res = testclient.post(
|
||||
"/oauth/token",
|
||||
|
@ -157,8 +161,10 @@ def test_redirect_uri(testclient, logged_user, client, keypair, trusted_client):
|
|||
consent.delete()
|
||||
|
||||
|
||||
def test_preconsented_client(testclient, logged_user, client, keypair, trusted_client):
|
||||
assert not models.Consent.query()
|
||||
def test_preconsented_client(
|
||||
testclient, logged_user, client, keypair, trusted_client, backend
|
||||
):
|
||||
assert not backend.query(models.Consent)
|
||||
|
||||
client.preconsent = True
|
||||
client.save()
|
||||
|
@ -180,7 +186,7 @@ def test_preconsented_client(testclient, logged_user, client, keypair, trusted_c
|
|||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
assert not consents
|
||||
|
||||
res = testclient.post(
|
||||
|
@ -214,8 +220,8 @@ def test_preconsented_client(testclient, logged_user, client, keypair, trusted_c
|
|||
assert res.json["name"] == "John (johnny) Doe"
|
||||
|
||||
|
||||
def test_logout_login(testclient, logged_user, client):
|
||||
assert not models.Consent.query()
|
||||
def test_logout_login(testclient, logged_user, client, backend):
|
||||
assert not backend.query(models.Consent)
|
||||
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
|
@ -254,7 +260,7 @@ def test_logout_login(testclient, logged_user, client):
|
|||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
assert "profile" in consents[0].scope
|
||||
|
||||
res = testclient.post(
|
||||
|
@ -285,8 +291,8 @@ def test_logout_login(testclient, logged_user, client):
|
|||
consent.delete()
|
||||
|
||||
|
||||
def test_deny(testclient, logged_user, client):
|
||||
assert not models.Consent.query()
|
||||
def test_deny(testclient, logged_user, client, backend):
|
||||
assert not backend.query(models.Consent)
|
||||
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
|
@ -305,11 +311,11 @@ def test_deny(testclient, logged_user, client):
|
|||
error = params["error"][0]
|
||||
assert error == "access_denied"
|
||||
|
||||
assert not models.Consent.query()
|
||||
assert not backend.query(models.Consent)
|
||||
|
||||
|
||||
def test_code_challenge(testclient, logged_user, client):
|
||||
assert not models.Consent.query()
|
||||
def test_code_challenge(testclient, logged_user, client, backend):
|
||||
assert not backend.query(models.Consent)
|
||||
|
||||
client.token_endpoint_auth_method = "none"
|
||||
client.save()
|
||||
|
@ -338,7 +344,7 @@ def test_code_challenge(testclient, logged_user, client):
|
|||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
assert "profile" in consents[0].scope
|
||||
|
||||
res = testclient.post(
|
||||
|
@ -373,8 +379,8 @@ def test_code_challenge(testclient, logged_user, client):
|
|||
consent.delete()
|
||||
|
||||
|
||||
def test_consent_already_given(testclient, logged_user, client):
|
||||
assert not models.Consent.query()
|
||||
def test_consent_already_given(testclient, logged_user, client, backend):
|
||||
assert not backend.query(models.Consent)
|
||||
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
|
@ -395,7 +401,7 @@ def test_consent_already_given(testclient, logged_user, client):
|
|||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
assert "profile" in consents[0].scope
|
||||
|
||||
res = testclient.post(
|
||||
|
@ -430,9 +436,9 @@ def test_consent_already_given(testclient, logged_user, client):
|
|||
|
||||
|
||||
def test_when_consent_already_given_but_for_a_smaller_scope(
|
||||
testclient, logged_user, client
|
||||
testclient, logged_user, client, backend
|
||||
):
|
||||
assert not models.Consent.query()
|
||||
assert not backend.query(models.Consent)
|
||||
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
|
@ -453,7 +459,7 @@ def test_when_consent_already_given_but_for_a_smaller_scope(
|
|||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
assert "profile" in consents[0].scope
|
||||
assert "groups" not in consents[0].scope
|
||||
|
||||
|
@ -489,7 +495,7 @@ def test_when_consent_already_given_but_for_a_smaller_scope(
|
|||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
assert "profile" in consents[0].scope
|
||||
assert "groups" in consents[0].scope
|
||||
|
||||
|
@ -535,8 +541,8 @@ def test_nonce_required_in_oidc_requests(testclient, logged_user, client):
|
|||
assert res.json.get("error") == "invalid_request"
|
||||
|
||||
|
||||
def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client):
|
||||
assert not models.Consent.query()
|
||||
def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client, backend):
|
||||
assert not backend.query(models.Consent)
|
||||
testclient.app.config["CANAILLE_OIDC"]["REQUIRE_NONCE"] = False
|
||||
|
||||
res = testclient.get(
|
||||
|
@ -552,12 +558,12 @@ def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client):
|
|||
res = res.form.submit(name="answer", value="accept", status=302)
|
||||
|
||||
assert res.location.startswith(client.redirect_uris[0])
|
||||
for consent in models.Consent.query():
|
||||
for consent in backend.query(models.Consent):
|
||||
consent.delete()
|
||||
|
||||
|
||||
def test_request_scope_too_large(testclient, logged_user, keypair, client):
|
||||
assert not models.Consent.query()
|
||||
def test_request_scope_too_large(testclient, logged_user, keypair, client, backend):
|
||||
assert not backend.query(models.Consent)
|
||||
client.scope = ["openid", "profile", "groups"]
|
||||
client.save()
|
||||
|
||||
|
@ -582,7 +588,7 @@ def test_request_scope_too_large(testclient, logged_user, keypair, client):
|
|||
"profile",
|
||||
}
|
||||
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
assert set(consents[0].scope) == {
|
||||
"openid",
|
||||
"profile",
|
||||
|
|
|
@ -95,7 +95,7 @@ def test_someone_else_consent_restoration(
|
|||
|
||||
|
||||
def test_oidc_authorization_after_revokation(
|
||||
testclient, logged_user, client, keypair, consent
|
||||
testclient, logged_user, client, keypair, consent, backend
|
||||
):
|
||||
consent.revoke()
|
||||
|
||||
|
@ -114,7 +114,7 @@ def test_oidc_authorization_after_revokation(
|
|||
|
||||
res = res.form.submit(name="answer", value="accept", status=302)
|
||||
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
consent.reload()
|
||||
assert consents[0] == consent
|
||||
assert not consent.revoked
|
||||
|
|
|
@ -5,8 +5,8 @@ from canaille.app import models
|
|||
# forms.
|
||||
|
||||
|
||||
def test_fieldlist_add(testclient, logged_admin):
|
||||
assert not models.Client.query()
|
||||
def test_fieldlist_add(testclient, logged_admin, backend):
|
||||
assert not backend.query(models.Client)
|
||||
|
||||
res = testclient.get("/admin/client/add")
|
||||
assert "redirect_uris-1" not in res.form.fields
|
||||
|
@ -23,7 +23,7 @@ def test_fieldlist_add(testclient, logged_admin):
|
|||
res.form[k].force_value(v)
|
||||
|
||||
res = res.form.submit(status=200, name="fieldlist_add", value="redirect_uris-0")
|
||||
assert not models.Client.query()
|
||||
assert not backend.query(models.Client)
|
||||
|
||||
data["redirect_uris-1"] = "https://foo.bar/callback2"
|
||||
for k, v in data.items():
|
||||
|
@ -43,8 +43,8 @@ def test_fieldlist_add(testclient, logged_admin):
|
|||
client.delete()
|
||||
|
||||
|
||||
def test_fieldlist_delete(testclient, logged_admin):
|
||||
assert not models.Client.query()
|
||||
def test_fieldlist_delete(testclient, logged_admin, backend):
|
||||
assert not backend.query(models.Client)
|
||||
res = testclient.get("/admin/client/add")
|
||||
|
||||
data = {
|
||||
|
@ -61,7 +61,7 @@ def test_fieldlist_delete(testclient, logged_admin):
|
|||
|
||||
res.form["redirect_uris-1"] = "https://foo.bar/callback2"
|
||||
res = res.form.submit(status=200, name="fieldlist_remove", value="redirect_uris-1")
|
||||
assert not models.Client.query()
|
||||
assert not backend.query(models.Client)
|
||||
assert "redirect_uris-1" not in res.form.fields
|
||||
|
||||
res = res.form.submit(status=302, name="action", value="edit")
|
||||
|
@ -92,8 +92,8 @@ def test_fieldlist_add_invalid_field(testclient, logged_admin):
|
|||
testclient.post("/admin/client/add", data, status=400)
|
||||
|
||||
|
||||
def test_fieldlist_delete_invalid_field(testclient, logged_admin):
|
||||
assert not models.Client.query()
|
||||
def test_fieldlist_delete_invalid_field(testclient, logged_admin, backend):
|
||||
assert not backend.query(models.Client)
|
||||
res = testclient.get("/admin/client/add")
|
||||
|
||||
data = {
|
||||
|
|
|
@ -7,8 +7,8 @@ from canaille.app import models
|
|||
from . import client_credentials
|
||||
|
||||
|
||||
def test_refresh_token(testclient, logged_user, client):
|
||||
assert not models.Consent.query()
|
||||
def test_refresh_token(testclient, logged_user, client, backend):
|
||||
assert not backend.query(models.Consent)
|
||||
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
|
@ -27,7 +27,7 @@ def test_refresh_token(testclient, logged_user, client):
|
|||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
assert "profile" in consents[0].scope
|
||||
|
||||
res = testclient.post(
|
||||
|
|
|
@ -9,7 +9,9 @@ from canaille.oidc.oauth import setup_oauth
|
|||
from . import client_credentials
|
||||
|
||||
|
||||
def test_token_default_expiration_date(testclient, logged_user, client, keypair):
|
||||
def test_token_default_expiration_date(
|
||||
testclient, logged_user, client, keypair, backend
|
||||
):
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
params=dict(
|
||||
|
@ -52,12 +54,14 @@ def test_token_default_expiration_date(testclient, logged_user, client, keypair)
|
|||
claims = jwt.decode(id_token, keypair[1])
|
||||
assert claims["exp"] - claims["iat"] == 3600
|
||||
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
for consent in consents:
|
||||
consent.delete()
|
||||
|
||||
|
||||
def test_token_custom_expiration_date(testclient, logged_user, client, keypair):
|
||||
def test_token_custom_expiration_date(
|
||||
testclient, logged_user, client, keypair, backend
|
||||
):
|
||||
testclient.app.config["OAUTH2_TOKEN_EXPIRES_IN"] = {
|
||||
"authorization_code": 1000,
|
||||
"implicit": 2000,
|
||||
|
@ -110,6 +114,6 @@ def test_token_custom_expiration_date(testclient, logged_user, client, keypair):
|
|||
claims = jwt.decode(id_token, keypair[1])
|
||||
assert claims["exp"] - claims["iat"] == 6000
|
||||
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
for consent in consents:
|
||||
consent.delete()
|
||||
|
|
Loading…
Reference in a new issue