forked from Github-Mirrors/canaille
refactor: move BackendModel.query to Backend.query
This commit is contained in:
parent
93fa708b1c
commit
8425b2a3b8
29 changed files with 284 additions and 247 deletions
|
@ -15,6 +15,7 @@ from canaille.app.i18n import DEFAULT_LANGUAGE_CODE
|
||||||
from canaille.app.i18n import gettext as _
|
from canaille.app.i18n import gettext as _
|
||||||
from canaille.app.i18n import locale_selector
|
from canaille.app.i18n import locale_selector
|
||||||
from canaille.app.i18n import timezone_selector
|
from canaille.app.i18n import timezone_selector
|
||||||
|
from canaille.backends import BaseBackend
|
||||||
|
|
||||||
from . import validate_uri
|
from . import validate_uri
|
||||||
from .flask import request_is_htmx
|
from .flask import request_is_htmx
|
||||||
|
@ -188,7 +189,7 @@ class TableForm(I18NFormMixin, FlaskForm):
|
||||||
if self.query.data:
|
if self.query.data:
|
||||||
self.items = cls.fuzzy(self.query.data, fields, **filter)
|
self.items = cls.fuzzy(self.query.data, fields, **filter)
|
||||||
else:
|
else:
|
||||||
self.items = cls.query(**filter)
|
self.items = BaseBackend.get().query(cls, **filter)
|
||||||
|
|
||||||
self.page_size = page_size
|
self.page_size = page_size
|
||||||
self.nb_items = len(self.items)
|
self.nb_items = len(self.items)
|
||||||
|
|
|
@ -54,6 +54,25 @@ class BaseBackend:
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def query(self, model, **kwargs):
|
||||||
|
"""
|
||||||
|
Perform a query on the database and return a collection of instances.
|
||||||
|
Parameters can be any valid attribute with the expected value:
|
||||||
|
|
||||||
|
>>> backend.query(User, first_name="George")
|
||||||
|
|
||||||
|
If several arguments are passed, the methods only returns the model
|
||||||
|
instances that return matches all the argument values:
|
||||||
|
|
||||||
|
>>> backend.query(User, first_name="George", last_name="Abitbol")
|
||||||
|
|
||||||
|
If the argument value is a collection, the methods will return the
|
||||||
|
models that matches any of the values:
|
||||||
|
|
||||||
|
>>> backend.query(User, first_name=["George", "Jane"])
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def check_user_password(self, user, password: str) -> bool:
|
def check_user_password(self, user, password: str) -> bool:
|
||||||
"""Check if the password matches the user password in the database."""
|
"""Check if the password matches the user password in the database."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -16,6 +16,7 @@ from canaille.app.i18n import gettext as _
|
||||||
from canaille.backends import BaseBackend
|
from canaille.backends import BaseBackend
|
||||||
|
|
||||||
from .utils import listify
|
from .utils import listify
|
||||||
|
from .utils import python_attrs_to_ldap
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -243,13 +244,64 @@ class Backend(BaseBackend):
|
||||||
return result, message
|
return result, message
|
||||||
|
|
||||||
def set_user_password(self, user, password):
|
def set_user_password(self, user, password):
|
||||||
conn = Backend.get().connection
|
conn = self.connection
|
||||||
conn.passwd_s(
|
conn.passwd_s(
|
||||||
user.dn,
|
user.dn,
|
||||||
None,
|
None,
|
||||||
password.encode("utf-8"),
|
password.encode("utf-8"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def query(self, model, dn=None, filter=None, **kwargs):
|
||||||
|
from .ldapobjectquery import LDAPObjectQuery
|
||||||
|
|
||||||
|
base = dn
|
||||||
|
if dn is None:
|
||||||
|
base = f"{model.base},{model.root_dn}"
|
||||||
|
elif "=" not in base:
|
||||||
|
base = ldap.dn.escape_dn_chars(base)
|
||||||
|
base = f"{model.rdn_attribute}={base},{model.base},{model.root_dn}"
|
||||||
|
|
||||||
|
class_filter = (
|
||||||
|
"".join([f"(objectClass={oc})" for oc in model.ldap_object_class])
|
||||||
|
if getattr(model, "ldap_object_class")
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
if class_filter:
|
||||||
|
class_filter = f"(|{class_filter})"
|
||||||
|
|
||||||
|
arg_filter = ""
|
||||||
|
ldap_args = python_attrs_to_ldap(
|
||||||
|
{
|
||||||
|
model.python_attribute_to_ldap(name): values
|
||||||
|
for name, values in kwargs.items()
|
||||||
|
if values is not None
|
||||||
|
},
|
||||||
|
encode=False,
|
||||||
|
)
|
||||||
|
for key, value in ldap_args.items():
|
||||||
|
if len(value) == 1:
|
||||||
|
escaped_value = ldap.filter.escape_filter_chars(value[0])
|
||||||
|
arg_filter += f"({key}={escaped_value})"
|
||||||
|
|
||||||
|
else:
|
||||||
|
values = [ldap.filter.escape_filter_chars(v) for v in value]
|
||||||
|
arg_filter += (
|
||||||
|
"(|" + "".join([f"({key}={value})" for value in values]) + ")"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not filter:
|
||||||
|
filter = ""
|
||||||
|
|
||||||
|
ldapfilter = f"(&{class_filter}{arg_filter}{filter})"
|
||||||
|
base = base or f"{model.base},{model.root_dn}"
|
||||||
|
try:
|
||||||
|
result = self.connection.search_s(
|
||||||
|
base, ldap.SCOPE_SUBTREE, ldapfilter or None, ["+", "*"]
|
||||||
|
)
|
||||||
|
except ldap.NO_SUCH_OBJECT:
|
||||||
|
result = []
|
||||||
|
return LDAPObjectQuery(model, result)
|
||||||
|
|
||||||
|
|
||||||
def setup_ldap_models(config):
|
def setup_ldap_models(config):
|
||||||
from canaille.app import models
|
from canaille.app import models
|
||||||
|
|
|
@ -4,42 +4,15 @@ import ldap.dn
|
||||||
import ldap.filter
|
import ldap.filter
|
||||||
from ldap.controls.readentry import PostReadControl
|
from ldap.controls.readentry import PostReadControl
|
||||||
|
|
||||||
|
from canaille.backends import BaseBackend
|
||||||
from canaille.backends.models import BackendModel
|
from canaille.backends.models import BackendModel
|
||||||
|
|
||||||
from .backend import Backend
|
from .backend import Backend
|
||||||
from .ldapobjectquery import LDAPObjectQuery
|
from .utils import attribute_ldap_syntax
|
||||||
from .utils import cardinalize_attribute
|
from .utils import cardinalize_attribute
|
||||||
from .utils import ldap_to_python
|
from .utils import ldap_to_python
|
||||||
from .utils import listify
|
from .utils import listify
|
||||||
from .utils import python_to_ldap
|
from .utils import python_attrs_to_ldap
|
||||||
|
|
||||||
|
|
||||||
def python_attrs_to_ldap(attrs, encode=True, null_allowed=True):
|
|
||||||
formatted_attrs = {
|
|
||||||
name: [
|
|
||||||
python_to_ldap(value, attribute_ldap_syntax(name), encode=encode)
|
|
||||||
for value in listify(values)
|
|
||||||
]
|
|
||||||
for name, values in attrs.items()
|
|
||||||
}
|
|
||||||
if not null_allowed:
|
|
||||||
formatted_attrs = {
|
|
||||||
key: [value for value in values if value]
|
|
||||||
for key, values in formatted_attrs.items()
|
|
||||||
if values
|
|
||||||
}
|
|
||||||
return formatted_attrs
|
|
||||||
|
|
||||||
|
|
||||||
def attribute_ldap_syntax(attribute_name):
|
|
||||||
ldap_attrs = LDAPObject.ldap_object_attributes()
|
|
||||||
if attribute_name not in ldap_attrs:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if ldap_attrs[attribute_name].syntax:
|
|
||||||
return ldap_attrs[attribute_name].syntax
|
|
||||||
|
|
||||||
return attribute_ldap_syntax(ldap_attrs[attribute_name].sup[0])
|
|
||||||
|
|
||||||
|
|
||||||
class LDAPObjectMetaclass(type):
|
class LDAPObjectMetaclass(type):
|
||||||
|
@ -256,7 +229,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, identifier=None, /, **kwargs):
|
def get(cls, identifier=None, /, **kwargs):
|
||||||
try:
|
try:
|
||||||
return cls.query(identifier, **kwargs)[0]
|
return BaseBackend.get().query(cls, identifier, **kwargs)[0]
|
||||||
except (IndexError, ldap.NO_SUCH_OBJECT):
|
except (IndexError, ldap.NO_SUCH_OBJECT):
|
||||||
if identifier and cls.base:
|
if identifier and cls.base:
|
||||||
return (
|
return (
|
||||||
|
@ -267,58 +240,6 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def query(cls, dn=None, filter=None, **kwargs):
|
|
||||||
conn = Backend.get().connection
|
|
||||||
|
|
||||||
base = dn
|
|
||||||
if dn is None:
|
|
||||||
base = f"{cls.base},{cls.root_dn}"
|
|
||||||
elif "=" not in base:
|
|
||||||
base = ldap.dn.escape_dn_chars(base)
|
|
||||||
base = f"{cls.rdn_attribute}={base},{cls.base},{cls.root_dn}"
|
|
||||||
|
|
||||||
class_filter = (
|
|
||||||
"".join([f"(objectClass={oc})" for oc in cls.ldap_object_class])
|
|
||||||
if getattr(cls, "ldap_object_class")
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
if class_filter:
|
|
||||||
class_filter = f"(|{class_filter})"
|
|
||||||
|
|
||||||
arg_filter = ""
|
|
||||||
ldap_args = python_attrs_to_ldap(
|
|
||||||
{
|
|
||||||
cls.python_attribute_to_ldap(name): values
|
|
||||||
for name, values in kwargs.items()
|
|
||||||
if values is not None
|
|
||||||
},
|
|
||||||
encode=False,
|
|
||||||
)
|
|
||||||
for key, value in ldap_args.items():
|
|
||||||
if len(value) == 1:
|
|
||||||
escaped_value = ldap.filter.escape_filter_chars(value[0])
|
|
||||||
arg_filter += f"({key}={escaped_value})"
|
|
||||||
|
|
||||||
else:
|
|
||||||
values = [ldap.filter.escape_filter_chars(v) for v in value]
|
|
||||||
arg_filter += (
|
|
||||||
"(|" + "".join([f"({key}={value})" for value in values]) + ")"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not filter:
|
|
||||||
filter = ""
|
|
||||||
|
|
||||||
ldapfilter = f"(&{class_filter}{arg_filter}{filter})"
|
|
||||||
base = base or f"{cls.base},{cls.root_dn}"
|
|
||||||
try:
|
|
||||||
result = conn.search_s(
|
|
||||||
base, ldap.SCOPE_SUBTREE, ldapfilter or None, ["+", "*"]
|
|
||||||
)
|
|
||||||
except ldap.NO_SUCH_OBJECT:
|
|
||||||
result = []
|
|
||||||
return LDAPObjectQuery(cls, result)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fuzzy(cls, query, attributes=None, **kwargs):
|
def fuzzy(cls, query, attributes=None, **kwargs):
|
||||||
query = ldap.filter.escape_filter_chars(query)
|
query = ldap.filter.escape_filter_chars(query)
|
||||||
|
@ -327,7 +248,7 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
||||||
filter = (
|
filter = (
|
||||||
"(|" + "".join(f"({attribute}=*{query}*)" for attribute in attributes) + ")"
|
"(|" + "".join(f"({attribute}=*{query}*)" for attribute in attributes) + ")"
|
||||||
)
|
)
|
||||||
return cls.query(filter=filter, **kwargs)
|
return BaseBackend.get().query(cls, filter=filter, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_ldap_attributes(cls):
|
def update_ldap_attributes(cls):
|
||||||
|
|
|
@ -95,3 +95,33 @@ def cardinalize_attribute(python_unique, value):
|
||||||
return value[0]
|
return value[0]
|
||||||
|
|
||||||
return [v for v in value if v is not None]
|
return [v for v in value if v is not None]
|
||||||
|
|
||||||
|
|
||||||
|
def python_attrs_to_ldap(attrs, encode=True, null_allowed=True):
|
||||||
|
formatted_attrs = {
|
||||||
|
name: [
|
||||||
|
python_to_ldap(value, attribute_ldap_syntax(name), encode=encode)
|
||||||
|
for value in listify(values)
|
||||||
|
]
|
||||||
|
for name, values in attrs.items()
|
||||||
|
}
|
||||||
|
if not null_allowed:
|
||||||
|
formatted_attrs = {
|
||||||
|
key: [value for value in values if value]
|
||||||
|
for key, values in formatted_attrs.items()
|
||||||
|
if values
|
||||||
|
}
|
||||||
|
return formatted_attrs
|
||||||
|
|
||||||
|
|
||||||
|
def attribute_ldap_syntax(attribute_name):
|
||||||
|
from .ldapobject import LDAPObject
|
||||||
|
|
||||||
|
ldap_attrs = LDAPObject.ldap_object_attributes()
|
||||||
|
if attribute_name not in ldap_attrs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if ldap_attrs[attribute_name].syntax:
|
||||||
|
return ldap_attrs[attribute_name].syntax
|
||||||
|
|
||||||
|
return attribute_ldap_syntax(ldap_attrs[attribute_name].sup[0])
|
||||||
|
|
|
@ -40,3 +40,28 @@ class Backend(BaseBackend):
|
||||||
def set_user_password(self, user, password):
|
def set_user_password(self, user, password):
|
||||||
user.password = password
|
user.password = password
|
||||||
user.save()
|
user.save()
|
||||||
|
|
||||||
|
def query(self, model, **kwargs):
|
||||||
|
# if there is no filter, return all models
|
||||||
|
if not kwargs:
|
||||||
|
states = model.index().values()
|
||||||
|
return [model(**state) for state in states]
|
||||||
|
|
||||||
|
# get the ids from the attribute indexes
|
||||||
|
ids = {
|
||||||
|
id
|
||||||
|
for attribute, values in kwargs.items()
|
||||||
|
for value in model.serialize(model.listify(values))
|
||||||
|
for id in model.attribute_index(attribute).get(value, [])
|
||||||
|
}
|
||||||
|
|
||||||
|
# get the states from the ids
|
||||||
|
states = [model.index()[id] for id in ids]
|
||||||
|
|
||||||
|
# initialize instances from the states
|
||||||
|
instances = [model(**state) for state in states]
|
||||||
|
for instance in instances:
|
||||||
|
# TODO: maybe find a way to not initialize the cache in the first place?
|
||||||
|
instance._cache = {}
|
||||||
|
|
||||||
|
return instances
|
||||||
|
|
|
@ -6,6 +6,7 @@ import uuid
|
||||||
import canaille.core.models
|
import canaille.core.models
|
||||||
import canaille.oidc.models
|
import canaille.oidc.models
|
||||||
from canaille.app import models
|
from canaille.app import models
|
||||||
|
from canaille.backends import BaseBackend
|
||||||
from canaille.backends.models import BackendModel
|
from canaille.backends.models import BackendModel
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,31 +26,6 @@ class MemoryModel(BackendModel):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<{self.__class__.__name__} id={self.id}>"
|
return f"<{self.__class__.__name__} id={self.id}>"
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def query(cls, **kwargs):
|
|
||||||
# if there is no filter, return all models
|
|
||||||
if not kwargs:
|
|
||||||
states = cls.index().values()
|
|
||||||
return [cls(**state) for state in states]
|
|
||||||
|
|
||||||
# get the ids from the attribute indexes
|
|
||||||
ids = {
|
|
||||||
id
|
|
||||||
for attribute, values in kwargs.items()
|
|
||||||
for value in cls.serialize(cls.listify(values))
|
|
||||||
for id in cls.attribute_index(attribute).get(value, [])
|
|
||||||
}
|
|
||||||
|
|
||||||
# get the states from the ids
|
|
||||||
states = [cls.index()[id] for id in ids]
|
|
||||||
|
|
||||||
# initialize instances from the states
|
|
||||||
instances = [cls(**state) for state in states]
|
|
||||||
for instance in instances:
|
|
||||||
# TODO: maybe find a way to not initialize the cache in the first place?
|
|
||||||
instance._cache = {}
|
|
||||||
return instances
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def index(cls, class_name=None):
|
def index(cls, class_name=None):
|
||||||
return MemoryModel.indexes.setdefault(class_name or cls.__name__, {})
|
return MemoryModel.indexes.setdefault(class_name or cls.__name__, {})
|
||||||
|
@ -63,7 +39,7 @@ class MemoryModel(BackendModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def fuzzy(cls, query, attributes=None, **kwargs):
|
def fuzzy(cls, query, attributes=None, **kwargs):
|
||||||
attributes = attributes or cls.attributes
|
attributes = attributes or cls.attributes
|
||||||
instances = cls.query(**kwargs)
|
instances = BaseBackend.get().query(cls, **kwargs)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
instance
|
instance
|
||||||
|
@ -85,7 +61,7 @@ class MemoryModel(BackendModel):
|
||||||
or None
|
or None
|
||||||
)
|
)
|
||||||
|
|
||||||
results = cls.query(**kwargs)
|
results = BaseBackend.get().query(cls, **kwargs)
|
||||||
return results[0] if results else None
|
return results[0] if results else None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -87,37 +87,16 @@ class BackendModel:
|
||||||
implemented for every model and for every backend.
|
implemented for every model and for every backend.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def query(cls, **kwargs):
|
|
||||||
"""Perform a query on the database and return a collection of
|
|
||||||
instances.
|
|
||||||
|
|
||||||
Parameters can be any valid attribute with the expected value:
|
|
||||||
|
|
||||||
>>> User.query(first_name="George")
|
|
||||||
|
|
||||||
If several arguments are passed, the methods only returns the model
|
|
||||||
instances that return matches all the argument values:
|
|
||||||
|
|
||||||
>>> User.query(first_name="George", last_name="Abitbol")
|
|
||||||
|
|
||||||
If the argument value is a collection, the methods will return the
|
|
||||||
models that matches any of the values:
|
|
||||||
|
|
||||||
>>> User.query(first_name=["George", "Jane"])
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fuzzy(cls, query, attributes=None, **kwargs):
|
def fuzzy(cls, query, attributes=None, **kwargs):
|
||||||
"""Works like :meth:`~canaille.backends.models.BackendModel.query` but
|
"""Works like :meth:`~canaille.backends.BaseBackend.query` but
|
||||||
attribute values loosely be matched."""
|
attribute values loosely be matched."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, identifier=None, **kwargs):
|
def get(cls, identifier=None, **kwargs):
|
||||||
"""Works like :meth:`~canaille.backends.models.BackendModel.query` but
|
"""Works like :meth:`~canaille.backends.BaseBackend.query` but return
|
||||||
return only one element or :py:data:`None` if no item is matching."""
|
only one element or :py:data:`None` if no item is matching."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy.orm import declarative_base
|
from sqlalchemy.orm import declarative_base
|
||||||
|
|
||||||
|
@ -65,3 +66,15 @@ class Backend(BaseBackend):
|
||||||
def set_user_password(self, user, password):
|
def set_user_password(self, user, password):
|
||||||
user.password = password
|
user.password = password
|
||||||
user.save()
|
user.save()
|
||||||
|
|
||||||
|
def query(self, model, **kwargs):
|
||||||
|
filter = [
|
||||||
|
model.attribute_filter(attribute_name, expected_value)
|
||||||
|
for attribute_name, expected_value in kwargs.items()
|
||||||
|
]
|
||||||
|
return (
|
||||||
|
Backend.get()
|
||||||
|
.db_session.execute(select(model).filter(*filter))
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
|
@ -36,19 +36,6 @@ class SqlAlchemyModel(BackendModel):
|
||||||
f"<{self.__class__.__name__} {self.identifier_attribute}={self.identifier}>"
|
f"<{self.__class__.__name__} {self.identifier_attribute}={self.identifier}>"
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def query(cls, **kwargs):
|
|
||||||
filter = [
|
|
||||||
cls.attribute_filter(attribute_name, expected_value)
|
|
||||||
for attribute_name, expected_value in kwargs.items()
|
|
||||||
]
|
|
||||||
return (
|
|
||||||
Backend.get()
|
|
||||||
.db_session.execute(select(cls).filter(*filter))
|
|
||||||
.scalars()
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fuzzy(cls, query, attributes=None, **kwargs):
|
def fuzzy(cls, query, attributes=None, **kwargs):
|
||||||
attributes = attributes or cls.attributes
|
attributes = attributes or cls.attributes
|
||||||
|
|
|
@ -86,7 +86,7 @@ def join():
|
||||||
|
|
||||||
form = JoinForm(request.form or None)
|
form = JoinForm(request.form or None)
|
||||||
if request.form and form.validate():
|
if request.form and form.validate():
|
||||||
if models.User.query(emails=form.email.data):
|
if BaseBackend.get().query(models.User, emails=form.email.data):
|
||||||
flash(
|
flash(
|
||||||
_(
|
_(
|
||||||
"You will receive soon an email to continue the registration process."
|
"You will receive soon an email to continue the registration process."
|
||||||
|
@ -295,7 +295,10 @@ def registration(data=None, hash=None):
|
||||||
if "groups" not in form and payload and payload.groups:
|
if "groups" not in form and payload and payload.groups:
|
||||||
form["groups"] = wtforms.SelectMultipleField(
|
form["groups"] = wtforms.SelectMultipleField(
|
||||||
_("Groups"),
|
_("Groups"),
|
||||||
choices=[(group, group.display_name) for group in models.Group.query()],
|
choices=[
|
||||||
|
(group, group.display_name)
|
||||||
|
for group in BaseBackend.get().query(models.Group)
|
||||||
|
],
|
||||||
coerce=IDToModel("Group"),
|
coerce=IDToModel("Group"),
|
||||||
)
|
)
|
||||||
set_readonly(form["groups"])
|
set_readonly(form["groups"])
|
||||||
|
@ -388,7 +391,7 @@ def email_confirmation(data, hash):
|
||||||
)
|
)
|
||||||
return redirect(url_for("core.account.index"))
|
return redirect(url_for("core.account.index"))
|
||||||
|
|
||||||
if models.User.query(emails=confirmation_obj.email):
|
if BaseBackend.get().query(models.User, emails=confirmation_obj.email):
|
||||||
flash(
|
flash(
|
||||||
_("This address email is already associated with another account."),
|
_("This address email is already associated with another account."),
|
||||||
"error",
|
"error",
|
||||||
|
|
|
@ -312,7 +312,10 @@ PROFILE_FORM_FIELDS = dict(
|
||||||
groups=wtforms.SelectMultipleField(
|
groups=wtforms.SelectMultipleField(
|
||||||
_("Groups"),
|
_("Groups"),
|
||||||
default=[],
|
default=[],
|
||||||
choices=lambda: [(group, group.display_name) for group in models.Group.query()],
|
choices=lambda: [
|
||||||
|
(group, group.display_name)
|
||||||
|
for group in BaseBackend.get().query(models.Group)
|
||||||
|
],
|
||||||
render_kw={"placeholder": _("users, admins …")},
|
render_kw={"placeholder": _("users, admins …")},
|
||||||
coerce=IDToModel("Group"),
|
coerce=IDToModel("Group"),
|
||||||
validators=[non_empty_groups],
|
validators=[non_empty_groups],
|
||||||
|
@ -333,7 +336,7 @@ def build_profile_form(write_field_names, readonly_field_names, user=None):
|
||||||
if PROFILE_FORM_FIELDS.get(name)
|
if PROFILE_FORM_FIELDS.get(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
if "groups" in fields and not models.Group.query():
|
if "groups" in fields and not BaseBackend.get().query(models.Group):
|
||||||
del fields["groups"]
|
del fields["groups"]
|
||||||
|
|
||||||
if current_app.backend.get().has_account_lockability(): # pragma: no branch
|
if current_app.backend.get().has_account_lockability(): # pragma: no branch
|
||||||
|
@ -436,7 +439,10 @@ class InvitationForm(Form):
|
||||||
)
|
)
|
||||||
groups = wtforms.SelectMultipleField(
|
groups = wtforms.SelectMultipleField(
|
||||||
_("Groups"),
|
_("Groups"),
|
||||||
choices=lambda: [(group, group.display_name) for group in models.Group.query()],
|
choices=lambda: [
|
||||||
|
(group, group.display_name)
|
||||||
|
for group in BaseBackend.get().query(models.Group)
|
||||||
|
],
|
||||||
render_kw={},
|
render_kw={},
|
||||||
coerce=IDToModel("Group"),
|
coerce=IDToModel("Group"),
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,6 +5,7 @@ from faker.config import AVAILABLE_LOCALES
|
||||||
|
|
||||||
from canaille.app import models
|
from canaille.app import models
|
||||||
from canaille.app.i18n import available_language_codes
|
from canaille.app.i18n import available_language_codes
|
||||||
|
from canaille.backends import BaseBackend
|
||||||
|
|
||||||
|
|
||||||
def fake_users(nb=1):
|
def fake_users(nb=1):
|
||||||
|
@ -47,7 +48,7 @@ def fake_users(nb=1):
|
||||||
|
|
||||||
|
|
||||||
def fake_groups(nb=1, nb_users_max=1):
|
def fake_groups(nb=1, nb_users_max=1):
|
||||||
users = models.User.query()
|
users = BaseBackend.get().query(models.User)
|
||||||
groups = list()
|
groups = list()
|
||||||
fake = faker.Faker(["en_US"])
|
fake = faker.Faker(["en_US"])
|
||||||
for _ in range(nb):
|
for _ in range(nb):
|
||||||
|
|
|
@ -3,6 +3,7 @@ from flask.cli import with_appcontext
|
||||||
|
|
||||||
from canaille.app import models
|
from canaille.app import models
|
||||||
from canaille.app.commands import with_backendcontext
|
from canaille.app.commands import with_backendcontext
|
||||||
|
from canaille.backends import BaseBackend
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
|
@ -10,11 +11,11 @@ from canaille.app.commands import with_backendcontext
|
||||||
@with_backendcontext
|
@with_backendcontext
|
||||||
def clean():
|
def clean():
|
||||||
"""Remove expired tokens and authorization codes."""
|
"""Remove expired tokens and authorization codes."""
|
||||||
for t in models.Token.query():
|
for t in BaseBackend.get().query(models.Token):
|
||||||
if t.is_expired():
|
if t.is_expired():
|
||||||
t.delete()
|
t.delete()
|
||||||
|
|
||||||
for a in models.AuthorizationCode.query():
|
for a in BaseBackend.get().query(models.AuthorizationCode):
|
||||||
if a.is_expired():
|
if a.is_expired():
|
||||||
a.delete()
|
a.delete()
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ from canaille.app import models
|
||||||
from canaille.app.flask import user_needed
|
from canaille.app.flask import user_needed
|
||||||
from canaille.app.i18n import gettext as _
|
from canaille.app.i18n import gettext as _
|
||||||
from canaille.app.themes import render_template
|
from canaille.app.themes import render_template
|
||||||
|
from canaille.backends import BaseBackend
|
||||||
|
|
||||||
from ..utils import SCOPE_DETAILS
|
from ..utils import SCOPE_DETAILS
|
||||||
|
|
||||||
|
@ -19,13 +20,13 @@ bp = Blueprint("consents", __name__, url_prefix="/consent")
|
||||||
@bp.route("/")
|
@bp.route("/")
|
||||||
@user_needed()
|
@user_needed()
|
||||||
def consents(user):
|
def consents(user):
|
||||||
consents = models.Consent.query(subject=user)
|
consents = BaseBackend.get().query(models.Consent, subject=user)
|
||||||
clients = {t.client for t in consents}
|
clients = {t.client for t in consents}
|
||||||
|
|
||||||
nb_consents = len(consents)
|
nb_consents = len(consents)
|
||||||
nb_preconsents = sum(
|
nb_preconsents = sum(
|
||||||
1
|
1
|
||||||
for client in models.Client.query()
|
for client in BaseBackend.get().query(models.Client)
|
||||||
if client.preconsent and client not in clients
|
if client.preconsent and client not in clients
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -43,11 +44,11 @@ def consents(user):
|
||||||
@bp.route("/pre-consents")
|
@bp.route("/pre-consents")
|
||||||
@user_needed()
|
@user_needed()
|
||||||
def pre_consents(user):
|
def pre_consents(user):
|
||||||
consents = models.Consent.query(subject=user)
|
consents = BaseBackend.get().query(models.Consent, subject=user)
|
||||||
clients = {t.client for t in consents}
|
clients = {t.client for t in consents}
|
||||||
preconsented = [
|
preconsented = [
|
||||||
client
|
client
|
||||||
for client in models.Client.query()
|
for client in BaseBackend.get().query(models.Client)
|
||||||
if client.preconsent and client not in clients
|
if client.preconsent and client not in clients
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ from canaille.app.forms import email_validator
|
||||||
from canaille.app.forms import is_uri
|
from canaille.app.forms import is_uri
|
||||||
from canaille.app.forms import unique_values
|
from canaille.app.forms import unique_values
|
||||||
from canaille.app.i18n import lazy_gettext as _
|
from canaille.app.i18n import lazy_gettext as _
|
||||||
|
from canaille.backends import BaseBackend
|
||||||
|
|
||||||
|
|
||||||
class AuthorizeForm(Form):
|
class AuthorizeForm(Form):
|
||||||
|
@ -18,7 +19,10 @@ class LogoutForm(Form):
|
||||||
|
|
||||||
|
|
||||||
def client_audiences():
|
def client_audiences():
|
||||||
return [(client, client.client_name) for client in models.Client.query()]
|
return [
|
||||||
|
(client, client.client_name)
|
||||||
|
for client in BaseBackend.get().query(models.Client)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ClientAddForm(Form):
|
class ClientAddForm(Form):
|
||||||
|
|
|
@ -23,6 +23,7 @@ from canaille.app.flask import logout_user
|
||||||
from canaille.app.flask import set_parameter_in_url_query
|
from canaille.app.flask import set_parameter_in_url_query
|
||||||
from canaille.app.i18n import gettext as _
|
from canaille.app.i18n import gettext as _
|
||||||
from canaille.app.themes import render_template
|
from canaille.app.themes import render_template
|
||||||
|
from canaille.backends import BaseBackend
|
||||||
|
|
||||||
from ..oauth import ClientConfigurationEndpoint
|
from ..oauth import ClientConfigurationEndpoint
|
||||||
from ..oauth import ClientRegistrationEndpoint
|
from ..oauth import ClientRegistrationEndpoint
|
||||||
|
@ -109,7 +110,8 @@ def authorize_login(user):
|
||||||
def authorize_consent(client, user):
|
def authorize_consent(client, user):
|
||||||
requested_scopes = request.args.get("scope", "").split(" ")
|
requested_scopes = request.args.get("scope", "").split(" ")
|
||||||
allowed_scopes = client.get_allowed_scope(requested_scopes).split(" ")
|
allowed_scopes = client.get_allowed_scope(requested_scopes).split(" ")
|
||||||
consents = models.Consent.query(
|
consents = BaseBackend.get().query(
|
||||||
|
models.Consent,
|
||||||
client=client,
|
client=client,
|
||||||
subject=user,
|
subject=user,
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,6 +8,7 @@ from authlib.oauth2.rfc6749 import TokenMixin
|
||||||
from authlib.oauth2.rfc6749 import util
|
from authlib.oauth2.rfc6749 import util
|
||||||
|
|
||||||
from canaille.app import models
|
from canaille.app import models
|
||||||
|
from canaille.backends import BaseBackend
|
||||||
|
|
||||||
from .basemodels import AuthorizationCode as BaseAuthorizationCode
|
from .basemodels import AuthorizationCode as BaseAuthorizationCode
|
||||||
from .basemodels import Client as BaseClient
|
from .basemodels import Client as BaseClient
|
||||||
|
@ -95,13 +96,13 @@ class Client(BaseClient, ClientMixin):
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def delete(self):
|
def delete(self):
|
||||||
for consent in models.Consent.query(client=self):
|
for consent in BaseBackend.get().query(models.Consent, client=self):
|
||||||
consent.delete()
|
consent.delete()
|
||||||
|
|
||||||
for code in models.AuthorizationCode.query(client=self):
|
for code in BaseBackend.get().query(models.AuthorizationCode, client=self):
|
||||||
code.delete()
|
code.delete()
|
||||||
|
|
||||||
for token in models.Token.query(client=self):
|
for token in BaseBackend.get().query(models.Token, client=self):
|
||||||
token.delete()
|
token.delete()
|
||||||
|
|
||||||
super().delete()
|
super().delete()
|
||||||
|
@ -185,7 +186,8 @@ class Consent(BaseConsent):
|
||||||
self.revokation_date = datetime.datetime.now(datetime.timezone.utc)
|
self.revokation_date = datetime.datetime.now(datetime.timezone.utc)
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
tokens = models.Token.query(
|
tokens = BaseBackend.get().query(
|
||||||
|
models.Token,
|
||||||
client=self.client,
|
client=self.client,
|
||||||
subject=self.subject,
|
subject=self.subject,
|
||||||
)
|
)
|
||||||
|
|
|
@ -112,7 +112,9 @@ def openid_configuration():
|
||||||
|
|
||||||
def exists_nonce(nonce, req):
|
def exists_nonce(nonce, req):
|
||||||
client = models.Client.get(id=req.client_id)
|
client = models.Client.get(id=req.client_id)
|
||||||
exists = models.AuthorizationCode.query(client=client, nonce=nonce)
|
exists = BaseBackend.get().query(
|
||||||
|
models.AuthorizationCode, client=client, nonce=nonce
|
||||||
|
)
|
||||||
return bool(exists)
|
return bool(exists)
|
||||||
|
|
||||||
|
|
||||||
|
@ -237,7 +239,9 @@ class AuthorizationCodeGrant(_AuthorizationCodeGrant):
|
||||||
return save_authorization_code(code, request)
|
return save_authorization_code(code, request)
|
||||||
|
|
||||||
def query_authorization_code(self, code, client):
|
def query_authorization_code(self, code, client):
|
||||||
item = models.AuthorizationCode.query(code=code, client=client)
|
item = BaseBackend.get().query(
|
||||||
|
models.AuthorizationCode, code=code, client=client
|
||||||
|
)
|
||||||
if item and not item[0].is_expired():
|
if item and not item[0].is_expired():
|
||||||
return item[0]
|
return item[0]
|
||||||
|
|
||||||
|
@ -283,7 +287,7 @@ class RefreshTokenGrant(_RefreshTokenGrant):
|
||||||
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
|
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
|
||||||
|
|
||||||
def authenticate_refresh_token(self, refresh_token):
|
def authenticate_refresh_token(self, refresh_token):
|
||||||
token = models.Token.query(refresh_token=refresh_token)
|
token = BaseBackend.get().query(models.Token, refresh_token=refresh_token)
|
||||||
if token and token[0].is_refresh_token_active():
|
if token and token[0].is_refresh_token_active():
|
||||||
return token[0]
|
return token[0]
|
||||||
|
|
||||||
|
|
|
@ -81,15 +81,15 @@ def test_special_chars_in_rdn(testclient, backend):
|
||||||
|
|
||||||
|
|
||||||
def test_filter(backend, foo_group, bar_group):
|
def test_filter(backend, foo_group, bar_group):
|
||||||
assert models.Group.query(display_name="foo") == [foo_group]
|
assert backend.query(models.Group, display_name="foo") == [foo_group]
|
||||||
assert models.Group.query(display_name="bar") == [bar_group]
|
assert backend.query(models.Group, display_name="bar") == [bar_group]
|
||||||
|
|
||||||
assert models.Group.query(display_name="foo") != 3
|
assert backend.query(models.Group, display_name="foo") != 3
|
||||||
|
|
||||||
assert models.Group.query(display_name=["foo"]) == [foo_group]
|
assert backend.query(models.Group, display_name=["foo"]) == [foo_group]
|
||||||
assert models.Group.query(display_name=["bar"]) == [bar_group]
|
assert backend.query(models.Group, display_name=["bar"]) == [bar_group]
|
||||||
|
|
||||||
assert set(models.Group.query(display_name=["foo", "bar"])) == {
|
assert set(backend.query(models.Group, display_name=["foo", "bar"])) == {
|
||||||
foo_group,
|
foo_group,
|
||||||
bar_group,
|
bar_group,
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,16 +36,16 @@ def test_model_lifecycle(testclient, backend):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert not user.id
|
assert not user.id
|
||||||
assert not models.User.query()
|
assert not backend.query(models.User)
|
||||||
assert not models.User.query(id=user.id)
|
assert not backend.query(models.User, id=user.id)
|
||||||
assert not models.User.query(id="invalid")
|
assert not backend.query(models.User, id="invalid")
|
||||||
assert not models.User.get(id=user.id)
|
assert not models.User.get(id=user.id)
|
||||||
|
|
||||||
user.save()
|
user.save()
|
||||||
|
|
||||||
assert models.User.query() == [user]
|
assert backend.query(models.User) == [user]
|
||||||
assert models.User.query(id=user.id) == [user]
|
assert backend.query(models.User, id=user.id) == [user]
|
||||||
assert not models.User.query(id="invalid")
|
assert not backend.query(models.User, id="invalid")
|
||||||
assert models.User.get(id=user.id) == user
|
assert models.User.get(id=user.id) == user
|
||||||
|
|
||||||
user.family_name = "new_family_name"
|
user.family_name = "new_family_name"
|
||||||
|
@ -58,7 +58,7 @@ def test_model_lifecycle(testclient, backend):
|
||||||
|
|
||||||
user.delete()
|
user.delete()
|
||||||
|
|
||||||
assert not models.User.query(id=user.id)
|
assert not backend.query(models.User, id=user.id)
|
||||||
assert not models.User.get(id=user.id)
|
assert not models.User.get(id=user.id)
|
||||||
|
|
||||||
user.delete()
|
user.delete()
|
||||||
|
@ -143,7 +143,7 @@ def test_model_indexation(testclient, backend):
|
||||||
|
|
||||||
|
|
||||||
def test_fuzzy_unique_attribute(user, moderator, admin, backend):
|
def test_fuzzy_unique_attribute(user, moderator, admin, backend):
|
||||||
assert set(models.User.query()) == {user, moderator, admin}
|
assert set(backend.query(models.User)) == {user, moderator, admin}
|
||||||
assert set(models.User.fuzzy("Jack")) == {moderator}
|
assert set(models.User.fuzzy("Jack")) == {moderator}
|
||||||
assert set(models.User.fuzzy("Jack", ["formatted_name"])) == {moderator}
|
assert set(models.User.fuzzy("Jack", ["formatted_name"])) == {moderator}
|
||||||
assert set(models.User.fuzzy("Jack", ["user_name"])) == set()
|
assert set(models.User.fuzzy("Jack", ["user_name"])) == set()
|
||||||
|
@ -157,7 +157,7 @@ def test_fuzzy_unique_attribute(user, moderator, admin, backend):
|
||||||
|
|
||||||
|
|
||||||
def test_fuzzy_multiple_attribute(user, moderator, admin, backend):
|
def test_fuzzy_multiple_attribute(user, moderator, admin, backend):
|
||||||
assert set(models.User.query()) == {user, moderator, admin}
|
assert set(backend.query(models.User)) == {user, moderator, admin}
|
||||||
assert set(models.User.fuzzy("jack@doe.com")) == {moderator}
|
assert set(models.User.fuzzy("jack@doe.com")) == {moderator}
|
||||||
assert set(models.User.fuzzy("jack@doe.com", ["emails"])) == {moderator}
|
assert set(models.User.fuzzy("jack@doe.com", ["emails"])) == {moderator}
|
||||||
assert set(models.User.fuzzy("jack@doe.com", ["formatted_name"])) == set()
|
assert set(models.User.fuzzy("jack@doe.com", ["formatted_name"])) == set()
|
||||||
|
@ -171,8 +171,8 @@ def test_fuzzy_multiple_attribute(user, moderator, admin, backend):
|
||||||
def test_model_references(testclient, user, foo_group, admin, bar_group, backend):
|
def test_model_references(testclient, user, foo_group, admin, bar_group, backend):
|
||||||
assert foo_group.members == [user]
|
assert foo_group.members == [user]
|
||||||
assert user.groups == [foo_group]
|
assert user.groups == [foo_group]
|
||||||
assert foo_group in models.Group.query(members=user)
|
assert foo_group in backend.query(models.Group, members=user)
|
||||||
assert user in models.User.query(groups=foo_group)
|
assert user in backend.query(models.User, groups=foo_group)
|
||||||
|
|
||||||
assert user not in bar_group.members
|
assert user not in bar_group.members
|
||||||
assert bar_group not in user.groups
|
assert bar_group not in user.groups
|
||||||
|
|
|
@ -6,11 +6,11 @@ from canaille.core.populate import fake_users
|
||||||
def test_populate_users(testclient, backend):
|
def test_populate_users(testclient, backend):
|
||||||
runner = testclient.app.test_cli_runner()
|
runner = testclient.app.test_cli_runner()
|
||||||
|
|
||||||
assert len(models.User.query()) == 0
|
assert len(backend.query(models.User)) == 0
|
||||||
res = runner.invoke(cli, ["populate", "--nb", "10", "users"])
|
res = runner.invoke(cli, ["populate", "--nb", "10", "users"])
|
||||||
assert res.exit_code == 0, res.stdout
|
assert res.exit_code == 0, res.stdout
|
||||||
assert len(models.User.query()) == 10
|
assert len(backend.query(models.User)) == 10
|
||||||
for user in models.User.query():
|
for user in backend.query(models.User):
|
||||||
user.delete()
|
user.delete()
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,13 +18,13 @@ def test_populate_groups(testclient, backend):
|
||||||
fake_users(10)
|
fake_users(10)
|
||||||
runner = testclient.app.test_cli_runner()
|
runner = testclient.app.test_cli_runner()
|
||||||
|
|
||||||
assert len(models.Group.query()) == 0
|
assert len(backend.query(models.Group)) == 0
|
||||||
res = runner.invoke(cli, ["populate", "--nb", "10", "groups"])
|
res = runner.invoke(cli, ["populate", "--nb", "10", "groups"])
|
||||||
assert res.exit_code == 0, res.stdout
|
assert res.exit_code == 0, res.stdout
|
||||||
assert len(models.Group.query()) == 10
|
assert len(backend.query(models.Group)) == 10
|
||||||
|
|
||||||
for group in models.Group.query():
|
for group in backend.query(models.Group):
|
||||||
group.delete()
|
group.delete()
|
||||||
|
|
||||||
for user in models.User.query():
|
for user in backend.query(models.User):
|
||||||
user.delete()
|
user.delete()
|
||||||
|
|
|
@ -4,7 +4,7 @@ from canaille.core.populate import fake_users
|
||||||
|
|
||||||
|
|
||||||
def test_no_group(app, backend):
|
def test_no_group(app, backend):
|
||||||
assert models.Group.query() == []
|
assert backend.query(models.Group) == []
|
||||||
|
|
||||||
|
|
||||||
def test_group_list_pagination(testclient, logged_admin, foo_group):
|
def test_group_list_pagination(testclient, logged_admin, foo_group):
|
||||||
|
|
|
@ -12,7 +12,7 @@ def test_registration_without_email_validation(testclient, backend, foo_group):
|
||||||
testclient.app.config["CANAILLE"]["ENABLE_REGISTRATION"] = True
|
testclient.app.config["CANAILLE"]["ENABLE_REGISTRATION"] = True
|
||||||
testclient.app.config["CANAILLE"]["EMAIL_CONFIRMATION"] = False
|
testclient.app.config["CANAILLE"]["EMAIL_CONFIRMATION"] = False
|
||||||
|
|
||||||
assert not models.User.query(user_name="newuser")
|
assert not backend.query(models.User, user_name="newuser")
|
||||||
res = testclient.get(url_for("core.account.registration"), status=200)
|
res = testclient.get(url_for("core.account.registration"), status=200)
|
||||||
res.form["user_name"] = "newuser"
|
res.form["user_name"] = "newuser"
|
||||||
res.form["password1"] = "password"
|
res.form["password1"] = "password"
|
||||||
|
@ -60,7 +60,7 @@ def test_registration_with_email_validation(testclient, backend, smtpd, foo_grou
|
||||||
text_mail = smtpd.messages[0].get_payload()[0].get_payload(decode=True).decode()
|
text_mail = smtpd.messages[0].get_payload()[0].get_payload(decode=True).decode()
|
||||||
assert registration_url in text_mail
|
assert registration_url in text_mail
|
||||||
|
|
||||||
assert not models.User.query(user_name="newuser")
|
assert not backend.query(models.User, user_name="newuser")
|
||||||
with time_machine.travel("2020-01-01 02:01:00+00:00", tick=False):
|
with time_machine.travel("2020-01-01 02:01:00+00:00", tick=False):
|
||||||
res = testclient.get(registration_url, status=200)
|
res = testclient.get(registration_url, status=200)
|
||||||
res.form["user_name"] = "newuser"
|
res.form["user_name"] = "newuser"
|
||||||
|
|
|
@ -13,8 +13,10 @@ from canaille.app import models
|
||||||
from . import client_credentials
|
from . import client_credentials
|
||||||
|
|
||||||
|
|
||||||
def test_nominal_case(testclient, logged_user, client, keypair, trusted_client):
|
def test_nominal_case(
|
||||||
assert not models.Consent.query()
|
testclient, logged_user, client, keypair, trusted_client, backend
|
||||||
|
):
|
||||||
|
assert not backend.query(models.Consent)
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
"/oauth/authorize",
|
"/oauth/authorize",
|
||||||
|
@ -43,7 +45,7 @@ def test_nominal_case(testclient, logged_user, client, keypair, trusted_client):
|
||||||
"phone",
|
"phone",
|
||||||
}
|
}
|
||||||
|
|
||||||
consents = models.Consent.query(client=client, subject=logged_user)
|
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||||
assert set(consents[0].scope) == {
|
assert set(consents[0].scope) == {
|
||||||
"openid",
|
"openid",
|
||||||
"profile",
|
"profile",
|
||||||
|
@ -112,8 +114,10 @@ def test_invalid_client(testclient, logged_user, keypair):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_redirect_uri(testclient, logged_user, client, keypair, trusted_client):
|
def test_redirect_uri(
|
||||||
assert not models.Consent.query()
|
testclient, logged_user, client, keypair, trusted_client, backend
|
||||||
|
):
|
||||||
|
assert not backend.query(models.Consent)
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
"/oauth/authorize",
|
"/oauth/authorize",
|
||||||
|
@ -134,7 +138,7 @@ def test_redirect_uri(testclient, logged_user, client, keypair, trusted_client):
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = models.AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
consents = models.Consent.query(client=client, subject=logged_user)
|
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||||
|
|
||||||
res = testclient.post(
|
res = testclient.post(
|
||||||
"/oauth/token",
|
"/oauth/token",
|
||||||
|
@ -157,8 +161,10 @@ def test_redirect_uri(testclient, logged_user, client, keypair, trusted_client):
|
||||||
consent.delete()
|
consent.delete()
|
||||||
|
|
||||||
|
|
||||||
def test_preconsented_client(testclient, logged_user, client, keypair, trusted_client):
|
def test_preconsented_client(
|
||||||
assert not models.Consent.query()
|
testclient, logged_user, client, keypair, trusted_client, backend
|
||||||
|
):
|
||||||
|
assert not backend.query(models.Consent)
|
||||||
|
|
||||||
client.preconsent = True
|
client.preconsent = True
|
||||||
client.save()
|
client.save()
|
||||||
|
@ -180,7 +186,7 @@ def test_preconsented_client(testclient, logged_user, client, keypair, trusted_c
|
||||||
authcode = models.AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
|
|
||||||
consents = models.Consent.query(client=client, subject=logged_user)
|
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||||
assert not consents
|
assert not consents
|
||||||
|
|
||||||
res = testclient.post(
|
res = testclient.post(
|
||||||
|
@ -214,8 +220,8 @@ def test_preconsented_client(testclient, logged_user, client, keypair, trusted_c
|
||||||
assert res.json["name"] == "John (johnny) Doe"
|
assert res.json["name"] == "John (johnny) Doe"
|
||||||
|
|
||||||
|
|
||||||
def test_logout_login(testclient, logged_user, client):
|
def test_logout_login(testclient, logged_user, client, backend):
|
||||||
assert not models.Consent.query()
|
assert not backend.query(models.Consent)
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
"/oauth/authorize",
|
"/oauth/authorize",
|
||||||
|
@ -254,7 +260,7 @@ def test_logout_login(testclient, logged_user, client):
|
||||||
authcode = models.AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
|
|
||||||
consents = models.Consent.query(client=client, subject=logged_user)
|
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||||
assert "profile" in consents[0].scope
|
assert "profile" in consents[0].scope
|
||||||
|
|
||||||
res = testclient.post(
|
res = testclient.post(
|
||||||
|
@ -285,8 +291,8 @@ def test_logout_login(testclient, logged_user, client):
|
||||||
consent.delete()
|
consent.delete()
|
||||||
|
|
||||||
|
|
||||||
def test_deny(testclient, logged_user, client):
|
def test_deny(testclient, logged_user, client, backend):
|
||||||
assert not models.Consent.query()
|
assert not backend.query(models.Consent)
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
"/oauth/authorize",
|
"/oauth/authorize",
|
||||||
|
@ -305,11 +311,11 @@ def test_deny(testclient, logged_user, client):
|
||||||
error = params["error"][0]
|
error = params["error"][0]
|
||||||
assert error == "access_denied"
|
assert error == "access_denied"
|
||||||
|
|
||||||
assert not models.Consent.query()
|
assert not backend.query(models.Consent)
|
||||||
|
|
||||||
|
|
||||||
def test_code_challenge(testclient, logged_user, client):
|
def test_code_challenge(testclient, logged_user, client, backend):
|
||||||
assert not models.Consent.query()
|
assert not backend.query(models.Consent)
|
||||||
|
|
||||||
client.token_endpoint_auth_method = "none"
|
client.token_endpoint_auth_method = "none"
|
||||||
client.save()
|
client.save()
|
||||||
|
@ -338,7 +344,7 @@ def test_code_challenge(testclient, logged_user, client):
|
||||||
authcode = models.AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
|
|
||||||
consents = models.Consent.query(client=client, subject=logged_user)
|
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||||
assert "profile" in consents[0].scope
|
assert "profile" in consents[0].scope
|
||||||
|
|
||||||
res = testclient.post(
|
res = testclient.post(
|
||||||
|
@ -373,8 +379,8 @@ def test_code_challenge(testclient, logged_user, client):
|
||||||
consent.delete()
|
consent.delete()
|
||||||
|
|
||||||
|
|
||||||
def test_consent_already_given(testclient, logged_user, client):
|
def test_consent_already_given(testclient, logged_user, client, backend):
|
||||||
assert not models.Consent.query()
|
assert not backend.query(models.Consent)
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
"/oauth/authorize",
|
"/oauth/authorize",
|
||||||
|
@ -395,7 +401,7 @@ def test_consent_already_given(testclient, logged_user, client):
|
||||||
authcode = models.AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
|
|
||||||
consents = models.Consent.query(client=client, subject=logged_user)
|
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||||
assert "profile" in consents[0].scope
|
assert "profile" in consents[0].scope
|
||||||
|
|
||||||
res = testclient.post(
|
res = testclient.post(
|
||||||
|
@ -430,9 +436,9 @@ def test_consent_already_given(testclient, logged_user, client):
|
||||||
|
|
||||||
|
|
||||||
def test_when_consent_already_given_but_for_a_smaller_scope(
|
def test_when_consent_already_given_but_for_a_smaller_scope(
|
||||||
testclient, logged_user, client
|
testclient, logged_user, client, backend
|
||||||
):
|
):
|
||||||
assert not models.Consent.query()
|
assert not backend.query(models.Consent)
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
"/oauth/authorize",
|
"/oauth/authorize",
|
||||||
|
@ -453,7 +459,7 @@ def test_when_consent_already_given_but_for_a_smaller_scope(
|
||||||
authcode = models.AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
|
|
||||||
consents = models.Consent.query(client=client, subject=logged_user)
|
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||||
assert "profile" in consents[0].scope
|
assert "profile" in consents[0].scope
|
||||||
assert "groups" not in consents[0].scope
|
assert "groups" not in consents[0].scope
|
||||||
|
|
||||||
|
@ -489,7 +495,7 @@ def test_when_consent_already_given_but_for_a_smaller_scope(
|
||||||
authcode = models.AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
|
|
||||||
consents = models.Consent.query(client=client, subject=logged_user)
|
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||||
assert "profile" in consents[0].scope
|
assert "profile" in consents[0].scope
|
||||||
assert "groups" in consents[0].scope
|
assert "groups" in consents[0].scope
|
||||||
|
|
||||||
|
@ -535,8 +541,8 @@ def test_nonce_required_in_oidc_requests(testclient, logged_user, client):
|
||||||
assert res.json.get("error") == "invalid_request"
|
assert res.json.get("error") == "invalid_request"
|
||||||
|
|
||||||
|
|
||||||
def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client):
|
def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client, backend):
|
||||||
assert not models.Consent.query()
|
assert not backend.query(models.Consent)
|
||||||
testclient.app.config["CANAILLE_OIDC"]["REQUIRE_NONCE"] = False
|
testclient.app.config["CANAILLE_OIDC"]["REQUIRE_NONCE"] = False
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
|
@ -552,12 +558,12 @@ def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client):
|
||||||
res = res.form.submit(name="answer", value="accept", status=302)
|
res = res.form.submit(name="answer", value="accept", status=302)
|
||||||
|
|
||||||
assert res.location.startswith(client.redirect_uris[0])
|
assert res.location.startswith(client.redirect_uris[0])
|
||||||
for consent in models.Consent.query():
|
for consent in backend.query(models.Consent):
|
||||||
consent.delete()
|
consent.delete()
|
||||||
|
|
||||||
|
|
||||||
def test_request_scope_too_large(testclient, logged_user, keypair, client):
|
def test_request_scope_too_large(testclient, logged_user, keypair, client, backend):
|
||||||
assert not models.Consent.query()
|
assert not backend.query(models.Consent)
|
||||||
client.scope = ["openid", "profile", "groups"]
|
client.scope = ["openid", "profile", "groups"]
|
||||||
client.save()
|
client.save()
|
||||||
|
|
||||||
|
@ -582,7 +588,7 @@ def test_request_scope_too_large(testclient, logged_user, keypair, client):
|
||||||
"profile",
|
"profile",
|
||||||
}
|
}
|
||||||
|
|
||||||
consents = models.Consent.query(client=client, subject=logged_user)
|
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||||
assert set(consents[0].scope) == {
|
assert set(consents[0].scope) == {
|
||||||
"openid",
|
"openid",
|
||||||
"profile",
|
"profile",
|
||||||
|
|
|
@ -95,7 +95,7 @@ def test_someone_else_consent_restoration(
|
||||||
|
|
||||||
|
|
||||||
def test_oidc_authorization_after_revokation(
|
def test_oidc_authorization_after_revokation(
|
||||||
testclient, logged_user, client, keypair, consent
|
testclient, logged_user, client, keypair, consent, backend
|
||||||
):
|
):
|
||||||
consent.revoke()
|
consent.revoke()
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ def test_oidc_authorization_after_revokation(
|
||||||
|
|
||||||
res = res.form.submit(name="answer", value="accept", status=302)
|
res = res.form.submit(name="answer", value="accept", status=302)
|
||||||
|
|
||||||
consents = models.Consent.query(client=client, subject=logged_user)
|
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||||
consent.reload()
|
consent.reload()
|
||||||
assert consents[0] == consent
|
assert consents[0] == consent
|
||||||
assert not consent.revoked
|
assert not consent.revoked
|
||||||
|
|
|
@ -5,8 +5,8 @@ from canaille.app import models
|
||||||
# forms.
|
# forms.
|
||||||
|
|
||||||
|
|
||||||
def test_fieldlist_add(testclient, logged_admin):
|
def test_fieldlist_add(testclient, logged_admin, backend):
|
||||||
assert not models.Client.query()
|
assert not backend.query(models.Client)
|
||||||
|
|
||||||
res = testclient.get("/admin/client/add")
|
res = testclient.get("/admin/client/add")
|
||||||
assert "redirect_uris-1" not in res.form.fields
|
assert "redirect_uris-1" not in res.form.fields
|
||||||
|
@ -23,7 +23,7 @@ def test_fieldlist_add(testclient, logged_admin):
|
||||||
res.form[k].force_value(v)
|
res.form[k].force_value(v)
|
||||||
|
|
||||||
res = res.form.submit(status=200, name="fieldlist_add", value="redirect_uris-0")
|
res = res.form.submit(status=200, name="fieldlist_add", value="redirect_uris-0")
|
||||||
assert not models.Client.query()
|
assert not backend.query(models.Client)
|
||||||
|
|
||||||
data["redirect_uris-1"] = "https://foo.bar/callback2"
|
data["redirect_uris-1"] = "https://foo.bar/callback2"
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
|
@ -43,8 +43,8 @@ def test_fieldlist_add(testclient, logged_admin):
|
||||||
client.delete()
|
client.delete()
|
||||||
|
|
||||||
|
|
||||||
def test_fieldlist_delete(testclient, logged_admin):
|
def test_fieldlist_delete(testclient, logged_admin, backend):
|
||||||
assert not models.Client.query()
|
assert not backend.query(models.Client)
|
||||||
res = testclient.get("/admin/client/add")
|
res = testclient.get("/admin/client/add")
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
|
@ -61,7 +61,7 @@ def test_fieldlist_delete(testclient, logged_admin):
|
||||||
|
|
||||||
res.form["redirect_uris-1"] = "https://foo.bar/callback2"
|
res.form["redirect_uris-1"] = "https://foo.bar/callback2"
|
||||||
res = res.form.submit(status=200, name="fieldlist_remove", value="redirect_uris-1")
|
res = res.form.submit(status=200, name="fieldlist_remove", value="redirect_uris-1")
|
||||||
assert not models.Client.query()
|
assert not backend.query(models.Client)
|
||||||
assert "redirect_uris-1" not in res.form.fields
|
assert "redirect_uris-1" not in res.form.fields
|
||||||
|
|
||||||
res = res.form.submit(status=302, name="action", value="edit")
|
res = res.form.submit(status=302, name="action", value="edit")
|
||||||
|
@ -92,8 +92,8 @@ def test_fieldlist_add_invalid_field(testclient, logged_admin):
|
||||||
testclient.post("/admin/client/add", data, status=400)
|
testclient.post("/admin/client/add", data, status=400)
|
||||||
|
|
||||||
|
|
||||||
def test_fieldlist_delete_invalid_field(testclient, logged_admin):
|
def test_fieldlist_delete_invalid_field(testclient, logged_admin, backend):
|
||||||
assert not models.Client.query()
|
assert not backend.query(models.Client)
|
||||||
res = testclient.get("/admin/client/add")
|
res = testclient.get("/admin/client/add")
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
|
|
|
@ -7,8 +7,8 @@ from canaille.app import models
|
||||||
from . import client_credentials
|
from . import client_credentials
|
||||||
|
|
||||||
|
|
||||||
def test_refresh_token(testclient, logged_user, client):
|
def test_refresh_token(testclient, logged_user, client, backend):
|
||||||
assert not models.Consent.query()
|
assert not backend.query(models.Consent)
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
"/oauth/authorize",
|
"/oauth/authorize",
|
||||||
|
@ -27,7 +27,7 @@ def test_refresh_token(testclient, logged_user, client):
|
||||||
authcode = models.AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
|
|
||||||
consents = models.Consent.query(client=client, subject=logged_user)
|
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||||
assert "profile" in consents[0].scope
|
assert "profile" in consents[0].scope
|
||||||
|
|
||||||
res = testclient.post(
|
res = testclient.post(
|
||||||
|
|
|
@ -9,7 +9,9 @@ from canaille.oidc.oauth import setup_oauth
|
||||||
from . import client_credentials
|
from . import client_credentials
|
||||||
|
|
||||||
|
|
||||||
def test_token_default_expiration_date(testclient, logged_user, client, keypair):
|
def test_token_default_expiration_date(
|
||||||
|
testclient, logged_user, client, keypair, backend
|
||||||
|
):
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
"/oauth/authorize",
|
"/oauth/authorize",
|
||||||
params=dict(
|
params=dict(
|
||||||
|
@ -52,12 +54,14 @@ def test_token_default_expiration_date(testclient, logged_user, client, keypair)
|
||||||
claims = jwt.decode(id_token, keypair[1])
|
claims = jwt.decode(id_token, keypair[1])
|
||||||
assert claims["exp"] - claims["iat"] == 3600
|
assert claims["exp"] - claims["iat"] == 3600
|
||||||
|
|
||||||
consents = models.Consent.query(client=client, subject=logged_user)
|
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||||
for consent in consents:
|
for consent in consents:
|
||||||
consent.delete()
|
consent.delete()
|
||||||
|
|
||||||
|
|
||||||
def test_token_custom_expiration_date(testclient, logged_user, client, keypair):
|
def test_token_custom_expiration_date(
|
||||||
|
testclient, logged_user, client, keypair, backend
|
||||||
|
):
|
||||||
testclient.app.config["OAUTH2_TOKEN_EXPIRES_IN"] = {
|
testclient.app.config["OAUTH2_TOKEN_EXPIRES_IN"] = {
|
||||||
"authorization_code": 1000,
|
"authorization_code": 1000,
|
||||||
"implicit": 2000,
|
"implicit": 2000,
|
||||||
|
@ -110,6 +114,6 @@ def test_token_custom_expiration_date(testclient, logged_user, client, keypair):
|
||||||
claims = jwt.decode(id_token, keypair[1])
|
claims = jwt.decode(id_token, keypair[1])
|
||||||
assert claims["exp"] - claims["iat"] == 6000
|
assert claims["exp"] - claims["iat"] == 6000
|
||||||
|
|
||||||
consents = models.Consent.query(client=client, subject=logged_user)
|
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||||
for consent in consents:
|
for consent in consents:
|
||||||
consent.delete()
|
consent.delete()
|
||||||
|
|
Loading…
Reference in a new issue