refactor: move BackendModel.get to Backend.get

This commit is contained in:
Éloi Rivard 2024-04-14 17:30:59 +02:00
parent ccde88b1bf
commit 44573713ed
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184
40 changed files with 255 additions and 241 deletions

View file

@ -20,7 +20,7 @@ def current_user():
return g.user
for user_id in session.get("user_id", [])[::-1]:
user = models.User.get(user_id)
user = current_app.backend.instance.get(models.User, user_id)
if user and (
not current_app.backend.has_account_lockability() or not user.locked
):
@ -147,7 +147,7 @@ def model_converter(model):
def to_python(self, identifier):
current_app.backend.setup()
instance = model.get(identifier)
instance = current_app.backend.get(model, identifier)
if self.required and not instance:
abort(404)

View file

@ -260,7 +260,9 @@ class IDToModel:
def __call__(self, data):
model = getattr(models, self.model_name)
instance = data if isinstance(data, model) else model.get(data)
instance = (
data if isinstance(data, model) else BaseBackend.instance.get(model, data)
)
if instance:
return instance

View file

@ -80,6 +80,11 @@ class BaseBackend:
attribute values loosely be matched."""
raise NotImplementedError()
def get(self, model, identifier=None, **kwargs):
"""Works like :meth:`~canaille.backends.BaseBackend.query` but return
only one element or :py:data:`None` if no item is matching."""
raise NotImplementedError()
def check_user_password(self, user, password: str) -> bool:
"""Check if the password matches the user password in the database."""
raise NotImplementedError()

View file

@ -202,7 +202,7 @@ class Backend(BaseBackend):
if login
else None
)
return User.get(filter=filter)
return self.get(User, filter=filter)
def check_user_password(self, user, password):
conn = ldap.initialize(current_app.config["CANAILLE_LDAP"]["URI"])
@ -311,6 +311,19 @@ class Backend(BaseBackend):
)
return self.query(model, filter=filter, **kwargs)
def get(self, model, identifier=None, /, **kwargs):
try:
return self.query(model, identifier, **kwargs)[0]
except (IndexError, ldap.NO_SUCH_OBJECT):
if identifier and model.base:
return (
self.get(model, **{model.identifier_attribute: identifier})
or self.get(model, id=identifier)
or None
)
return None
def setup_ldap_models(config):
from canaille.app import models

View file

@ -4,7 +4,6 @@ import ldap.dn
import ldap.filter
from ldap.controls.readentry import PostReadControl
from canaille.backends import BaseBackend
from canaille.backends.models import BackendModel
from .backend import Backend
@ -226,20 +225,6 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
return cls._attribute_type_by_name
@classmethod
def get(cls, identifier=None, /, **kwargs):
try:
return BaseBackend.instance.query(cls, identifier, **kwargs)[0]
except (IndexError, ldap.NO_SUCH_OBJECT):
if identifier and cls.base:
return (
cls.get(**{cls.identifier_attribute: identifier})
or cls.get(id=identifier)
or None
)
return None
@classmethod
def update_ldap_attributes(cls):
all_object_classes = cls.ldap_object_classes()

View file

@ -1,6 +1,8 @@
import datetime
from enum import Enum
from canaille.backends import BaseBackend
LDAP_NULL_DATE = "000001010000Z"
@ -50,7 +52,7 @@ def ldap_to_python(value, syntax):
return value.decode("utf-8").upper() == "TRUE"
if syntax == Syntax.DISTINGUISHED_NAME:
return LDAPObject.get(value.decode("utf-8"))
return BaseBackend.instance.get(LDAPObject, value.decode("utf-8"))
return value.decode("utf-8")

View file

@ -26,7 +26,7 @@ class Backend(BaseBackend):
def get_user_from_login(self, login):
from .models import User
return User.get(user_name=login)
return self.get(User, user_name=login)
def check_user_password(self, user, password):
if password != user.password:
@ -80,3 +80,14 @@ class Backend(BaseBackend):
if isinstance(value, str)
)
]
def get(self, model, identifier=None, /, **kwargs):
if identifier:
return (
self.get(model, **{model.identifier_attribute: identifier})
or self.get(model, id=identifier)
or None
)
results = self.query(model, **kwargs)
return results[0] if results else None

View file

@ -36,18 +36,6 @@ class MemoryModel(BackendModel):
class_name or cls.__name__, {}
).setdefault(attribute, {})
@classmethod
def get(cls, identifier=None, /, **kwargs):
if identifier:
return (
cls.get(**{cls.identifier_attribute: identifier})
or cls.get(id=identifier)
or None
)
results = BaseBackend.instance.query(cls, **kwargs)
return results[0] if results else None
@classmethod
def listify(cls, value):
return value if isinstance(value, list) else [value]
@ -75,7 +63,7 @@ class MemoryModel(BackendModel):
model, _ = cls.get_model_annotations(attribute_name)
if model and not isinstance(value, model):
backend_model = getattr(models, model.__name__)
return backend_model.get(id=value)
return BaseBackend.instance.get(backend_model, id=value)
return value
@ -166,7 +154,7 @@ class MemoryModel(BackendModel):
del self.index()[self.id]
def reload(self):
self._state = self.__class__.get(id=self.id)._state
self._state = BaseBackend.instance.get(self.__class__, id=self.id)._state
self._cache = {}
def __eq__(self, other):
@ -174,7 +162,7 @@ class MemoryModel(BackendModel):
return False
if not isinstance(other, MemoryModel):
return self == self.__class__.get(id=other)
return self == BaseBackend.instance.get(self.__class__, id=other)
return self._state == other._state

View file

@ -11,6 +11,7 @@ from typing import get_type_hints
from canaille.app import classproperty
from canaille.app import models
from canaille.backends import BaseBackend
class Model:
@ -87,12 +88,6 @@ class BackendModel:
implemented for every model and for every backend.
"""
@classmethod
def get(cls, identifier=None, **kwargs):
"""Works like :meth:`~canaille.backends.BaseBackend.query` but return
only one element or :py:data:`None` if no item is matching."""
raise NotImplementedError()
def save(self):
"""Validate the current modifications in the database."""
raise NotImplementedError()
@ -175,7 +170,7 @@ class BackendModel:
backend_model = getattr(models, model.__name__)
if instance := backend_model.get(value):
if instance := BaseBackend.instance.get(backend_model, value):
filter[attribute] = instance
return all(

View file

@ -53,7 +53,7 @@ class Backend(BaseBackend):
def get_user_from_login(self, login):
from .models import User
return User.get(user_name=login)
return self.get(User, user_name=login)
def check_user_password(self, user, password):
if password != user.password:
@ -90,3 +90,19 @@ class Backend(BaseBackend):
)
return self.db_session.execute(select(model).filter(filter)).scalars().all()
def get(self, model, identifier=None, /, **kwargs):
if identifier:
return (
self.get(model, **{model.identifier_attribute: identifier})
or self.get(model, id=identifier)
or None
)
filter = [
model.attribute_filter(attribute_name, expected_value)
for attribute_name, expected_value in kwargs.items()
]
return Backend.instance.db_session.execute(
select(model).filter(*filter)
).scalar_one_or_none()

View file

@ -11,7 +11,6 @@ from sqlalchemy import LargeBinary
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
@ -48,23 +47,6 @@ class SqlAlchemyModel(BackendModel):
return getattr(cls, name) == value
@classmethod
def get(cls, identifier=None, /, **kwargs):
if identifier:
return (
cls.get(**{cls.identifier_attribute: identifier})
or cls.get(id=identifier)
or None
)
filter = [
cls.attribute_filter(attribute_name, expected_value)
for attribute_name, expected_value in kwargs.items()
]
return Backend.instance.db_session.execute(
select(cls).filter(*filter)
).scalar_one_or_none()
def save(self):
self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0

View file

@ -257,7 +257,9 @@ def registration(data=None, hash=None):
)
return redirect(url_for("core.account.index"))
if payload.user_name and models.User.get(user_name=payload.user_name):
if payload.user_name and BaseBackend.instance.get(
models.User, user_name=payload.user_name
):
flash(
_("Your account has already been created."),
"error",
@ -282,7 +284,10 @@ def registration(data=None, hash=None):
data = {
"user_name": payload.user_name,
"emails": [payload.email],
"groups": [models.Group.get(id=group_id) for group_id in payload.groups],
"groups": [
BaseBackend.instance.get(models.Group, id=group_id)
for group_id in payload.groups
],
}
has_smtp = "SMTP" in current_app.config["CANAILLE"]
@ -376,7 +381,7 @@ def email_confirmation(data, hash):
)
return redirect(url_for("core.account.index"))
user = models.User.get(confirmation_obj.identifier)
user = BaseBackend.instance.get(models.User, confirmation_obj.identifier)
if not user:
flash(
_("The email confirmation link that brought you here is invalid."),

View file

@ -24,7 +24,7 @@ MINIMUM_PASSWORD_LENGTH = 8
def unique_user_name(form, field):
if models.User.get(user_name=field.data) and (
if BaseBackend.instance.get(models.User, user_name=field.data) and (
not getattr(form, "user", None) or form.user.user_name != field.data
):
raise wtforms.ValidationError(
@ -33,7 +33,7 @@ def unique_user_name(form, field):
def unique_email(form, field):
if models.User.get(emails=field.data) and (
if BaseBackend.instance.get(models.User, emails=field.data) and (
not getattr(form, "user", None) or field.data not in form.user.emails
):
raise wtforms.ValidationError(
@ -42,7 +42,7 @@ def unique_email(form, field):
def unique_group(form, field):
if models.Group.get(display_name=field.data):
if BaseBackend.instance.get(models.Group, display_name=field.data):
raise wtforms.ValidationError(
_("The group '{group}' already exists").format(group=field.data)
)

View file

@ -258,12 +258,12 @@ class User(Model):
def preferred_email(self):
return self.emails[0] if self.emails else None
def __getattr__(self, name):
def __getattribute__(self, name):
prefix = "can_"
if name.startswith(prefix) and name != "can_read":
return self.can(name[len(prefix) :])
return super().__getattr__(name)
return super().__getattribute__(name)
def can(self, *permissions: Permission):
"""Wether or not the user has the

View file

@ -108,7 +108,7 @@ def revoke_preconsent(user, client):
flash(_("Could not revoke this access"), "error")
return redirect(url_for("oidc.consents.consents"))
consent = models.Consent.get(client=client, subject=user)
consent = BaseBackend.instance.get(models.Consent, client=client, subject=user)
if consent:
return redirect(url_for("oidc.consents.revoke", consent=consent))

View file

@ -50,7 +50,9 @@ def authorize():
request.form.to_dict(flat=False),
)
client = models.Client.get(client_id=request.args["client_id"])
client = BaseBackend.instance.get(
models.Client, client_id=request.args["client_id"]
)
user = current_user()
if response := authorize_guards(client):
@ -276,7 +278,7 @@ def end_session():
valid_uris = []
if "client_id" in data:
client = models.Client.get(client_id=data["client_id"])
client = BaseBackend.instance.get(models.Client, client_id=data["client_id"])
if client:
valid_uris = client.post_logout_redirect_uris
@ -328,7 +330,7 @@ def end_session():
else [id_token["aud"]]
)
for client_id in client_ids:
client = models.Client.get(client_id=client_id)
client = BaseBackend.instance.get(models.Client, client_id=client_id)
if client:
valid_uris.extend(client.post_logout_redirect_uris or [])

View file

@ -111,7 +111,7 @@ def openid_configuration():
def exists_nonce(nonce, req):
client = models.Client.get(id=req.client_id)
client = BaseBackend.instance.get(models.Client, id=req.client_id)
exists = BaseBackend.instance.query(
models.AuthorizationCode, client=client, nonce=nonce
)
@ -334,7 +334,7 @@ class OpenIDHybridGrant(_OpenIDHybridGrant):
def query_client(client_id):
return models.Client.get(client_id=client_id)
return BaseBackend.instance.get(models.Client, client_id=client_id)
def save_token(token, request):
@ -356,20 +356,20 @@ def save_token(token, request):
class BearerTokenValidator(_BearerTokenValidator):
def authenticate_token(self, token_string):
return models.Token.get(access_token=token_string)
return BaseBackend.instance.get(models.Token, access_token=token_string)
def query_token(token, token_type_hint):
if token_type_hint == "access_token":
return models.Token.get(access_token=token)
return BaseBackend.instance.get(models.Token, access_token=token)
elif token_type_hint == "refresh_token":
return models.Token.get(refresh_token=token)
return BaseBackend.instance.get(models.Token, refresh_token=token)
item = models.Token.get(access_token=token)
item = BaseBackend.instance.get(models.Token, access_token=token)
if item:
return item
item = models.Token.get(refresh_token=token)
item = BaseBackend.instance.get(models.Token, refresh_token=token)
if item:
return item
@ -472,7 +472,7 @@ class ClientRegistrationEndpoint(ClientManagementMixin, _ClientRegistrationEndpo
class ClientConfigurationEndpoint(ClientManagementMixin, _ClientConfigurationEndpoint):
def authenticate_client(self, request):
client_id = request.uri.split("/")[-1]
return models.Client.get(client_id=client_id)
return BaseBackend.instance.get(models.Client, client_id=client_id)
def revoke_access_token(self, request, token):
pass

View file

@ -7,7 +7,7 @@ def test_guess_object_from_dn(backend, testclient, foo_group):
foo_group.members = [foo_group]
foo_group.save()
dn = foo_group.dn
g = LDAPObject.get(dn)
g = backend.get(LDAPObject, dn)
assert isinstance(g, models.Group)
assert g == foo_group
assert g.display_name == foo_group.display_name
@ -21,9 +21,9 @@ def test_object_class_update(backend, testclient):
user1.save()
assert set(user1.get_ldap_attribute("objectClass")) == {"inetOrgPerson"}
assert set(models.User.get(id=user1.id).get_ldap_attribute("objectClass")) == {
"inetOrgPerson"
}
assert set(
backend.get(models.User, id=user1.id).get_ldap_attribute("objectClass")
) == {"inetOrgPerson"}
testclient.app.config["CANAILLE_LDAP"]["USER_CLASS"] = [
"inetOrgPerson",
@ -38,12 +38,14 @@ def test_object_class_update(backend, testclient):
"inetOrgPerson",
"extensibleObject",
}
assert set(models.User.get(id=user2.id).get_ldap_attribute("objectClass")) == {
assert set(
backend.get(models.User, id=user2.id).get_ldap_attribute("objectClass")
) == {
"inetOrgPerson",
"extensibleObject",
}
user1 = models.User.get(id=user1.id)
user1 = backend.get(models.User, id=user1.id)
assert user1.get_ldap_attribute("objectClass") == ["inetOrgPerson"]
user1.save()
@ -51,7 +53,9 @@ def test_object_class_update(backend, testclient):
"inetOrgPerson",
"extensibleObject",
}
assert set(models.User.get(id=user1.id).get_ldap_attribute("objectClass")) == {
assert set(
backend.get(models.User, id=user1.id).get_ldap_attribute("objectClass")
) == {
"inetOrgPerson",
"extensibleObject",
}

View file

@ -27,7 +27,7 @@ def test_object_creation(app, backend):
user.save()
assert user.exists
user = models.User.get(id=user.id)
user = backend.get(models.User, id=user.id)
assert user.exists
user.delete()
@ -52,9 +52,9 @@ def test_dn_when_leading_space_in_id_attribute(testclient, backend):
assert ldap.dn.is_dn(dn)
assert ldap.dn.dn2str(ldap.dn.str2dn(dn)) == dn
assert user == models.User.get(user.user_name)
assert user == models.User.get(user_name=user.user_name)
assert user == models.User.get(dn)
assert user == backend.get(models.User, user.user_name)
assert user == backend.get(models.User, user_name=user.user_name)
assert user == backend.get(models.User, dn)
user.delete()
@ -73,9 +73,9 @@ def test_special_chars_in_rdn(testclient, backend):
assert ldap.dn.dn2str(ldap.dn.str2dn(dn)) == dn
assert dn == "uid=\\#user,ou=users,dc=mydomain,dc=tld"
assert user == models.User.get(user.user_name)
assert user == models.User.get(user_name=user.user_name)
assert user == models.User.get(dn)
assert user == backend.get(models.User, user.user_name)
assert user == backend.get(models.User, user_name=user.user_name)
assert user == backend.get(models.User, dn)
user.delete()

View file

@ -19,7 +19,7 @@ def test_model_comparison(testclient, backend):
formatted_name="bar",
)
bar.save()
foo2 = models.User.get(id=foo1.id)
foo2 = backend.get(models.User, id=foo1.id)
assert foo1 == foo2
assert foo1 != bar
@ -39,14 +39,14 @@ def test_model_lifecycle(testclient, backend):
assert not backend.query(models.User)
assert not backend.query(models.User, id=user.id)
assert not backend.query(models.User, id="invalid")
assert not models.User.get(id=user.id)
assert not backend.get(models.User, id=user.id)
user.save()
assert backend.query(models.User) == [user]
assert backend.query(models.User, id=user.id) == [user]
assert not backend.query(models.User, id="invalid")
assert models.User.get(id=user.id) == user
assert backend.get(models.User, id=user.id) == user
user.family_name = "new_family_name"
@ -59,7 +59,7 @@ def test_model_lifecycle(testclient, backend):
user.delete()
assert not backend.query(models.User, id=user.id)
assert not models.User.get(id=user.id)
assert not backend.get(models.User, id=user.id)
user.delete()
@ -78,7 +78,7 @@ def test_model_attribute_edition(testclient, backend):
assert user.family_name == "family_name"
assert user.emails == ["email1@user.com", "email2@user.com"]
user = models.User.get(id=user.id)
user = backend.get(models.User, id=user.id)
assert user.user_name == "user_name"
assert user.family_name == "family_name"
assert user.emails == ["email1@user.com", "email2@user.com"]
@ -90,7 +90,7 @@ def test_model_attribute_edition(testclient, backend):
assert user.family_name == "new_family_name"
assert user.emails == ["email1@user.com"]
user = models.User.get(id=user.id)
user = backend.get(models.User, id=user.id)
assert user.family_name == "new_family_name"
assert user.emails == ["email1@user.com"]
@ -112,34 +112,34 @@ def test_model_indexation(testclient, backend):
)
user.save()
assert models.User.get(family_name="family_name") == user
assert not models.User.get(family_name="new_family_name")
assert models.User.get(emails=["email1@user.com"]) == user
assert models.User.get(emails=["email2@user.com"]) == user
assert not models.User.get(emails=["email3@user.com"])
assert backend.get(models.User, family_name="family_name") == user
assert not backend.get(models.User, family_name="new_family_name")
assert backend.get(models.User, emails=["email1@user.com"]) == user
assert backend.get(models.User, emails=["email2@user.com"]) == user
assert not backend.get(models.User, emails=["email3@user.com"])
user.family_name = "new_family_name"
user.emails = ["email2@user.com"]
assert models.User.get(family_name="family_name") != user
assert models.User.get(emails=["email1@user.com"]) != user
assert not models.User.get(emails=["email3@user.com"])
assert backend.get(models.User, family_name="family_name") != user
assert backend.get(models.User, emails=["email1@user.com"]) != user
assert not backend.get(models.User, emails=["email3@user.com"])
user.save()
assert not models.User.get(family_name="family_name")
assert models.User.get(family_name="new_family_name") == user
assert not models.User.get(emails=["email1@user.com"])
assert models.User.get(emails=["email2@user.com"]) == user
assert not models.User.get(emails=["email3@user.com"])
assert not backend.get(models.User, family_name="family_name")
assert backend.get(models.User, family_name="new_family_name") == user
assert not backend.get(models.User, emails=["email1@user.com"])
assert backend.get(models.User, emails=["email2@user.com"]) == user
assert not backend.get(models.User, emails=["email3@user.com"])
user.delete()
assert not models.User.get(family_name="family_name")
assert not models.User.get(family_name="new_family_name")
assert not models.User.get(emails=["email1@user.com"])
assert not models.User.get(emails=["email2@user.com"])
assert not models.User.get(emails=["email3@user.com"])
assert not backend.get(models.User, family_name="family_name")
assert not backend.get(models.User, family_name="new_family_name")
assert not backend.get(models.User, emails=["email1@user.com"])
assert not backend.get(models.User, emails=["email2@user.com"])
assert not backend.get(models.User, emails=["email3@user.com"])
def test_fuzzy_unique_attribute(user, moderator, admin, backend):

View file

@ -78,7 +78,7 @@ def test_admin_self_deletion(testclient, backend):
.follow(status=200)
)
assert models.User.get(user_name="temp") is None
assert backend.get(models.User, user_name="temp") is None
with testclient.session_transaction() as sess:
assert not sess.get("user_id")
@ -116,7 +116,7 @@ def test_user_self_deletion(testclient, backend):
.follow(status=200)
)
assert models.User.get(user_name="temp") is None
assert backend.get(models.User, user_name="temp") is None
with testclient.session_transaction() as sess:
assert not sess.get("user_id")
@ -136,7 +136,7 @@ def test_account_locking(user, backend):
assert user.locked
user.save()
assert user.locked
assert models.User.get(id=user.id).locked
assert backend.get(models.User, id=user.id).locked
assert backend.check_user_password(user, "correct horse battery staple") == (
False,
"Your account has been locked.",
@ -145,7 +145,7 @@ def test_account_locking(user, backend):
user.lock_date = None
user.save()
assert not user.locked
assert not models.User.get(id=user.id).locked
assert not backend.get(models.User, id=user.id).locked
assert backend.check_user_password(user, "correct horse battery staple") == (
True,
None,
@ -165,7 +165,7 @@ def test_account_locking_past_date(user, backend):
) - datetime.timedelta(days=30)
user.save()
assert user.locked
assert models.User.get(id=user.id).locked
assert backend.get(models.User, id=user.id).locked
assert backend.check_user_password(user, "correct horse battery staple") == (
False,
"Your account has been locked.",
@ -185,7 +185,7 @@ def test_account_locking_future_date(user, backend):
) + datetime.timedelta(days=365 * 4)
user.save()
assert not user.locked
assert not models.User.get(id=user.id).locked
assert not backend.get(models.User, id=user.id).locked
assert backend.check_user_password(user, "correct horse battery staple") == (
True,
None,

View file

@ -131,12 +131,12 @@ def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group):
def test_moderator_can_create_edit_and_delete_group(
testclient, logged_moderator, foo_group
testclient, logged_moderator, foo_group, backend
):
# The group does not exist
res = testclient.get("/groups", status=200)
assert models.Group.get(display_name="bar") is None
assert models.Group.get(display_name="foo") == foo_group
assert backend.get(models.Group, display_name="bar") is None
assert backend.get(models.Group, display_name="foo") == foo_group
res.mustcontain(no="bar")
res.mustcontain("foo")
@ -150,7 +150,7 @@ def test_moderator_can_create_edit_and_delete_group(
res = form.submit(status=302).follow(status=200)
logged_moderator.reload()
bar_group = models.Group.get(display_name="bar")
bar_group = backend.get(models.Group, display_name="bar")
assert bar_group.display_name == "bar"
assert bar_group.description == "yolo"
assert bar_group.members == [
@ -168,10 +168,10 @@ def test_moderator_can_create_edit_and_delete_group(
assert res.flashes == [("error", "Group edition failed.")]
res.mustcontain("This field cannot be edited")
bar_group = models.Group.get(display_name="bar")
bar_group = backend.get(models.Group, display_name="bar")
assert bar_group.display_name == "bar"
assert bar_group.description == "yolo"
assert models.Group.get(display_name="bar2") is None
assert backend.get(models.Group, display_name="bar2") is None
# Group description can be edited
res = testclient.get("/groups/bar", status=200)
@ -182,14 +182,14 @@ def test_moderator_can_create_edit_and_delete_group(
assert res.flashes == [("success", "The group bar has been sucessfully edited.")]
res = res.follow()
bar_group = models.Group.get(display_name="bar")
bar_group = backend.get(models.Group, display_name="bar")
assert bar_group.display_name == "bar"
assert bar_group.description == "yolo2"
# Group is deleted
res = res.forms["editgroupform"].submit(name="action", value="confirm-delete")
res = res.form.submit(name="action", value="delete", status=302)
assert models.Group.get(display_name="bar") is None
assert backend.get(models.Group, display_name="bar") is None
assert ("success", "The group bar has been sucessfully deleted") in res.flashes

View file

@ -7,7 +7,7 @@ from canaille.core.endpoints.account import RegistrationPayload
def test_invitation(testclient, logged_admin, foo_group, smtpd, backend):
assert models.User.get(user_name="someone") is None
assert backend.get(models.User, user_name="someone") is None
res = testclient.get("/invite", status=200)
@ -46,7 +46,7 @@ def test_invitation(testclient, logged_admin, foo_group, smtpd, backend):
assert ("success", "Your account has been created successfully.") in res.flashes
res = res.follow(status=200)
user = models.User.get(user_name="someone")
user = backend.get(models.User, user_name="someone")
foo_group.reload()
assert backend.check_user_password(user, "whatever")[0]
assert user.groups == [foo_group]
@ -62,8 +62,8 @@ def test_invitation(testclient, logged_admin, foo_group, smtpd, backend):
def test_invitation_editable_user_name(
testclient, logged_admin, foo_group, smtpd, backend
):
assert models.User.get(user_name="jackyjack") is None
assert models.User.get(user_name="djorje") is None
assert backend.get(models.User, user_name="jackyjack") is None
assert backend.get(models.User, user_name="djorje") is None
res = testclient.get("/invite", status=200)
@ -102,7 +102,7 @@ def test_invitation_editable_user_name(
assert ("success", "Your account has been created successfully.") in res.flashes
res = res.follow(status=200)
user = models.User.get(user_name="djorje")
user = backend.get(models.User, user_name="djorje")
foo_group.reload()
assert backend.check_user_password(user, "whatever")[0]
assert user.groups == [foo_group]
@ -114,7 +114,7 @@ def test_invitation_editable_user_name(
def test_generate_link(testclient, logged_admin, foo_group, smtpd, backend):
assert models.User.get(user_name="sometwo") is None
assert backend.get(models.User, user_name="sometwo") is None
res = testclient.get("/invite", status=200)
@ -149,7 +149,7 @@ def test_generate_link(testclient, logged_admin, foo_group, smtpd, backend):
res = res.form.submit(status=302)
res = res.follow(status=200)
user = models.User.get(user_name="sometwo")
user = backend.get(models.User, user_name="sometwo")
foo_group.reload()
assert backend.check_user_password(user, "whatever")[0]
assert user.groups == [foo_group]
@ -245,7 +245,7 @@ def test_registration_more_than_48_hours_after_invitation(testclient, foo_group)
testclient.get(f"/register/{b64}/{hash}", status=302)
def test_registration_no_password(testclient, foo_group):
def test_registration_no_password(testclient, foo_group, backend):
payload = RegistrationPayload(
datetime.datetime.now(datetime.timezone.utc).isoformat(),
"someoneelse",
@ -264,7 +264,7 @@ def test_registration_no_password(testclient, foo_group):
res = res.form.submit(status=200)
res.mustcontain("This field is required.")
assert not models.User.get(user_name="someoneelse")
assert not backend.get(models.User, user_name="someoneelse")
with testclient.session_transaction() as sess:
assert "user_id" not in sess
@ -302,7 +302,7 @@ def test_unavailable_if_no_smtp(testclient, logged_admin):
def test_groups_are_saved_even_when_user_does_not_have_read_permission(
testclient, foo_group
testclient, foo_group, backend
):
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["READ"] = [
"user_name"
@ -331,7 +331,7 @@ def test_groups_are_saved_even_when_user_does_not_have_read_permission(
res = res.form.submit(status=302)
res = res.follow(status=200)
user = models.User.get(user_name="someoneelse")
user = backend.get(models.User, user_name="someoneelse")
foo_group.reload()
assert user.groups == [foo_group]
user.delete()

View file

@ -6,7 +6,7 @@ def test_user_creation_edition_and_deletion(
):
# The user does not exist.
res = testclient.get("/users", status=200)
assert models.User.get(user_name="george") is None
assert backend.get(models.User, user_name="george") is None
res.mustcontain(no="george")
# Fill the profile for a new user.
@ -24,7 +24,7 @@ def test_user_creation_edition_and_deletion(
res = res.form.submit(name="action", value="create-profile", status=302)
assert ("success", "User account creation succeed.") in res.flashes
res = res.follow(status=200)
george = models.User.get(user_name="george")
george = backend.get(models.User, user_name="george")
foo_group.reload()
assert "George" == george.given_name
assert george.groups == [foo_group]
@ -45,7 +45,7 @@ def test_user_creation_edition_and_deletion(
res.form["groups"] = [foo_group.id, bar_group.id]
res = res.form.submit(name="action", value="edit-settings").follow()
george = models.User.get(user_name="george")
george = backend.get(models.User, user_name="george")
assert "Georgio" == george.given_name
assert backend.check_user_password(george, "totoyolo")[0]
@ -62,7 +62,7 @@ def test_user_creation_edition_and_deletion(
res = res.form.submit(name="action", value="confirm-delete", status=200)
res = res.form.submit(name="action", value="delete", status=302)
res = res.follow(status=200)
assert models.User.get(user_name="george") is None
assert backend.get(models.User, user_name="george") is None
res.mustcontain(no="george")
@ -82,7 +82,7 @@ def test_profile_creation_dynamic_validation(testclient, logged_admin, user):
res.mustcontain("The email 'john@doe.com' is already used")
def test_user_creation_without_password(testclient, logged_moderator):
def test_user_creation_without_password(testclient, logged_moderator, backend):
res = testclient.get("/profile", status=200)
res.form["user_name"] = "george"
res.form["family_name"] = "Abitbol"
@ -91,7 +91,7 @@ def test_user_creation_without_password(testclient, logged_moderator):
res = res.form.submit(name="action", value="create-profile", status=302)
assert ("success", "User account creation succeed.") in res.flashes
res = res.follow(status=200)
george = models.User.get(user_name="george")
george = backend.get(models.User, user_name="george")
assert george.user_name == "george"
assert not george.has_password()
@ -99,16 +99,16 @@ def test_user_creation_without_password(testclient, logged_moderator):
def test_user_creation_form_validation_failed(
testclient, logged_moderator, foo_group, bar_group
testclient, logged_moderator, foo_group, bar_group, backend
):
res = testclient.get("/users", status=200)
assert models.User.get(user_name="george") is None
assert backend.get(models.User, user_name="george") is None
res.mustcontain(no="george")
res = testclient.get("/profile", status=200)
res = res.form.submit(name="action", value="create-profile")
assert ("error", "User account creation failed.") in res.flashes
assert models.User.get(user_name="george") is None
assert backend.get(models.User, user_name="george") is None
def test_username_already_taken(
@ -133,7 +133,7 @@ def test_email_already_taken(testclient, logged_moderator, user, foo_group, bar_
res.mustcontain("The email 'john@doe.com' is already used")
def test_cn_setting_with_given_name_and_surname(testclient, logged_moderator):
def test_cn_setting_with_given_name_and_surname(testclient, logged_moderator, backend):
res = testclient.get("/profile", status=200)
res.form["user_name"] = "george"
res.form["given_name"] = "George"
@ -144,12 +144,12 @@ def test_cn_setting_with_given_name_and_surname(testclient, logged_moderator):
status=200
)
george = models.User.get(user_name="george")
george = backend.get(models.User, user_name="george")
assert george.formatted_name == "George Abitbol"
george.delete()
def test_cn_setting_with_surname_only(testclient, logged_moderator):
def test_cn_setting_with_surname_only(testclient, logged_moderator, backend):
res = testclient.get("/profile", status=200)
res.form["user_name"] = "george"
res.form["family_name"] = "Abitbol"
@ -159,7 +159,7 @@ def test_cn_setting_with_surname_only(testclient, logged_moderator):
status=200
)
george = models.User.get(user_name="george")
george = backend.get(models.User, user_name="george")
assert george.formatted_name == "Abitbol"
george.delete()

View file

@ -104,9 +104,9 @@ def test_photo_on_profile_edition(
assert logged_user.photo is None
def test_photo_on_profile_creation(testclient, jpeg_photo, logged_admin):
def test_photo_on_profile_creation(testclient, jpeg_photo, logged_admin, backend):
res = testclient.get("/users", status=200)
assert models.User.get(user_name="foobar") is None
assert backend.get(models.User, user_name="foobar") is None
res.mustcontain(no="foobar")
res = testclient.get("/profile", status=200)
@ -119,14 +119,16 @@ def test_photo_on_profile_creation(testclient, jpeg_photo, logged_admin):
status=200
)
user = models.User.get(user_name="foobar")
user = backend.get(models.User, user_name="foobar")
assert user.photo == jpeg_photo
user.delete()
def test_photo_deleted_on_profile_creation(testclient, jpeg_photo, logged_admin):
def test_photo_deleted_on_profile_creation(
testclient, jpeg_photo, logged_admin, backend
):
res = testclient.get("/users", status=200)
assert models.User.get(user_name="foobar") is None
assert backend.get(models.User, user_name="foobar") is None
res.mustcontain(no="foobar")
res = testclient.get("/profile", status=200)
@ -140,6 +142,6 @@ def test_photo_deleted_on_profile_creation(testclient, jpeg_photo, logged_admin)
status=200
)
user = models.User.get(user_name="foobar")
user = backend.get(models.User, user_name="foobar")
assert user.photo is None
user.delete()

View file

@ -381,7 +381,7 @@ def test_account_locking(
res = res.form.submit(name="action", value="confirm-lock")
res = res.form.submit(name="action", value="lock")
user = models.User.get(id=user.id)
user = backend.get(models.User, id=user.id)
assert user.lock_date <= datetime.datetime.now(datetime.timezone.utc)
assert user.locked
res.mustcontain("The account has been locked")
@ -389,7 +389,7 @@ def test_account_locking(
res.mustcontain("Unlock")
res = res.form.submit(name="action", value="unlock")
user = models.User.get(id=user.id)
user = backend.get(models.User, id=user.id)
assert not user.lock_date
assert not user.locked
res.mustcontain("The account has been unlocked")
@ -415,7 +415,7 @@ def test_past_lock_date(
assert res.flashes == [("success", "Profile updated successfully.")]
res = res.follow()
user = models.User.get(id=user.id)
user = backend.get(models.User, id=user.id)
assert user.lock_date == expiration_datetime
assert user.locked
@ -438,7 +438,7 @@ def test_future_lock_date(
assert res.flashes == [("success", "Profile updated successfully.")]
res = res.follow()
user = models.User.get(id=user.id)
user = backend.get(models.User, id=user.id)
assert user.lock_date == expiration_datetime
assert not user.locked
assert res.form["lock_date"].value == expiration_datetime.strftime("%Y-%m-%d %H:%M")
@ -484,7 +484,7 @@ def test_account_limit_values(
assert res.flashes == [("success", "Profile updated successfully.")]
res = res.follow()
user = models.User.get(id=user.id)
user = backend.get(models.User, id=user.id)
assert user.lock_date == expiration_datetime
assert not user.locked

View file

@ -22,7 +22,7 @@ def test_registration_without_email_validation(testclient, backend, foo_group):
res = res.form.submit()
assert ("success", "Your account has been created successfully.") in res.flashes
user = models.User.get(user_name="newuser")
user = backend.get(models.User, user_name="newuser")
assert user
user.delete()
@ -73,7 +73,7 @@ def test_registration_with_email_validation(testclient, backend, smtpd, foo_grou
("success", "Your account has been created successfully."),
]
user = models.User.get(user_name="newuser")
user = backend.get(models.User, user_name="newuser")
assert user
user.delete()

View file

@ -69,8 +69,8 @@ def test_clean_command(testclient, backend, client, user):
)
expired_token.save()
assert models.AuthorizationCode.get(code="my-expired-code")
assert models.Token.get(access_token="my-expired-token")
assert backend.get(models.AuthorizationCode, code="my-expired-code")
assert backend.get(models.Token, access_token="my-expired-token")
assert expired_code.is_expired()
assert expired_token.is_expired()
@ -78,5 +78,5 @@ def test_clean_command(testclient, backend, client, user):
res = runner.invoke(cli, ["clean"])
assert res.exit_code == 0, res.stdout
assert models.AuthorizationCode.get() == valid_code
assert models.Token.get() == valid_token
assert backend.get(models.AuthorizationCode) == valid_code
assert backend.get(models.Token) == valid_token

View file

@ -34,7 +34,7 @@ def test_nominal_case(
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
authcode = backend.get(models.AuthorizationCode, code=code)
assert authcode is not None
assert set(authcode.scope) == {
"openid",
@ -68,7 +68,7 @@ def test_nominal_case(
)
access_token = res.json["access_token"]
token = models.Token.get(access_token=access_token)
token = backend.get(models.Token, access_token=access_token)
assert token.client == client
assert token.subject == logged_user
assert set(token.scope) == {
@ -136,7 +136,7 @@ def test_redirect_uri(
assert res.location.startswith(client.redirect_uris[1])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
authcode = backend.get(models.AuthorizationCode, code=code)
assert authcode is not None
consents = backend.query(models.Consent, client=client, subject=logged_user)
@ -153,7 +153,7 @@ def test_redirect_uri(
)
access_token = res.json["access_token"]
token = models.Token.get(access_token=access_token)
token = backend.get(models.Token, access_token=access_token)
assert token.client == client
assert token.subject == logged_user
@ -183,7 +183,7 @@ def test_preconsented_client(
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
authcode = backend.get(models.AuthorizationCode, code=code)
assert authcode is not None
consents = backend.query(models.Consent, client=client, subject=logged_user)
@ -202,7 +202,7 @@ def test_preconsented_client(
)
access_token = res.json["access_token"]
token = models.Token.get(access_token=access_token)
token = backend.get(models.Token, access_token=access_token)
assert token.client == client
assert token.subject == logged_user
@ -257,7 +257,7 @@ def test_logout_login(testclient, logged_user, client, backend):
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
authcode = backend.get(models.AuthorizationCode, code=code)
assert authcode is not None
consents = backend.query(models.Consent, client=client, subject=logged_user)
@ -276,7 +276,7 @@ def test_logout_login(testclient, logged_user, client, backend):
)
access_token = res.json["access_token"]
token = models.Token.get(access_token=access_token)
token = backend.get(models.Token, access_token=access_token)
assert token.client == client
assert token.subject == logged_user
@ -341,7 +341,7 @@ def test_code_challenge(testclient, logged_user, client, backend):
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
authcode = backend.get(models.AuthorizationCode, code=code)
assert authcode is not None
consents = backend.query(models.Consent, client=client, subject=logged_user)
@ -361,7 +361,7 @@ def test_code_challenge(testclient, logged_user, client, backend):
)
access_token = res.json["access_token"]
token = models.Token.get(access_token=access_token)
token = backend.get(models.Token, access_token=access_token)
assert token.client == client
assert token.subject == logged_user
@ -398,7 +398,7 @@ def test_consent_already_given(testclient, logged_user, client, backend):
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
authcode = backend.get(models.AuthorizationCode, code=code)
assert authcode is not None
consents = backend.query(models.Consent, client=client, subject=logged_user)
@ -456,7 +456,7 @@ def test_when_consent_already_given_but_for_a_smaller_scope(
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
authcode = backend.get(models.AuthorizationCode, code=code)
assert authcode is not None
consents = backend.query(models.Consent, client=client, subject=logged_user)
@ -492,7 +492,7 @@ def test_when_consent_already_given_but_for_a_smaller_scope(
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
authcode = backend.get(models.AuthorizationCode, code=code)
assert authcode is not None
consents = backend.query(models.Consent, client=client, subject=logged_user)
@ -582,7 +582,7 @@ def test_request_scope_too_large(testclient, logged_user, keypair, client, backe
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
authcode = backend.get(models.AuthorizationCode, code=code)
assert set(authcode.scope) == {
"openid",
"profile",
@ -607,7 +607,7 @@ def test_request_scope_too_large(testclient, logged_user, keypair, client, backe
)
access_token = res.json["access_token"]
token = models.Token.get(access_token=access_token)
token = backend.get(models.Token, access_token=access_token)
assert token.client == client
assert token.subject == logged_user
assert set(token.scope) == {
@ -671,7 +671,7 @@ def test_code_expired(testclient, user, client):
}
def test_code_with_invalid_user(testclient, admin, client):
def test_code_with_invalid_user(testclient, admin, client, backend):
user = models.User(
formatted_name="John Doe",
family_name="Doe",
@ -699,7 +699,7 @@ def test_code_with_invalid_user(testclient, admin, client):
res = res.form.submit(name="answer", value="accept", status=302)
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
authcode = backend.get(models.AuthorizationCode, code=code)
user.delete()
@ -721,7 +721,9 @@ def test_code_with_invalid_user(testclient, admin, client):
authcode.delete()
def test_locked_account(testclient, logged_user, client, keypair, trusted_client):
def test_locked_account(
testclient, logged_user, client, keypair, trusted_client, backend
):
"""Users with a locked account should not be able to exchange code against
tokens."""
res = testclient.get(
@ -743,7 +745,7 @@ def test_locked_account(testclient, logged_user, client, keypair, trusted_client
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
authcode = backend.get(models.AuthorizationCode, code=code)
assert authcode is not None
res = testclient.post(

View file

@ -83,7 +83,7 @@ def test_client_list_search(testclient, logged_admin, client, trusted_client):
res.mustcontain(no=client.client_name)
def test_client_add(testclient, logged_admin):
def test_client_add(testclient, logged_admin, backend):
res = testclient.get("/admin/client/add")
data = {
"client_name": "foobar",
@ -112,7 +112,7 @@ def test_client_add(testclient, logged_admin):
res = res.follow(status=200)
client_id = res.forms["readonly"]["client_id"].value
client = models.Client.get(client_id=client_id)
client = backend.get(models.Client, client_id=client_id)
assert client.client_name == "foobar"
assert client.contacts == ["foo@bar.com"]
@ -214,7 +214,7 @@ def test_client_edit_missing_fields(testclient, client, logged_admin, trusted_cl
assert client.client_name
def test_client_delete(testclient, logged_admin):
def test_client_delete(testclient, logged_admin, backend):
client = models.Client(client_id="client_id")
client.save()
token = models.Token(
@ -238,10 +238,10 @@ def test_client_delete(testclient, logged_admin):
res = res.form.submit(name="action", value="delete")
res = res.follow()
assert not models.Client.get()
assert not models.Token.get()
assert not models.AuthorizationCode.get()
assert not models.Consent.get()
assert not backend.get(models.Client)
assert not backend.get(models.Token)
assert not backend.get(models.AuthorizationCode)
assert not backend.get(models.Consent)
def test_client_delete_invalid_client(testclient, logged_admin, client):

View file

@ -134,7 +134,7 @@ def test_oidc_authorization_after_revokation(
)
access_token = res.json["access_token"]
token = models.Token.get(access_token=access_token)
token = backend.get(models.Token, access_token=access_token)
assert token.client == client
assert token.subject == logged_user
@ -151,16 +151,16 @@ def test_preconsented_client_appears_in_consent_list(testclient, client, logged_
res.mustcontain(client.client_name)
def test_revoke_preconsented_client(testclient, client, logged_user, token):
def test_revoke_preconsented_client(testclient, client, logged_user, token, backend):
client.preconsent = True
client.save()
assert not models.Consent.get()
assert not backend.get(models.Consent)
assert not token.revoked
res = testclient.get(f"/consent/revoke-preconsent/{client.client_id}", status=302)
assert ("success", "The access has been revoked") in res.flashes
consent = models.Consent.get()
consent = backend.get(models.Consent)
assert consent.client == client
assert consent.subject == logged_user
assert consent.scope == ["openid", "email", "profile", "groups", "address", "phone"]

View file

@ -33,7 +33,7 @@ def test_client_registration_with_authentication_static_token(
headers = {"Authorization": "Bearer static-token"}
res = testclient.post_json("/oauth/register", payload, headers=headers, status=201)
client = models.Client.get(client_id=res.json["client_id"])
client = backend.get(models.Client, client_id=res.json["client_id"])
assert res.json == {
"client_id": client.client_id,
@ -154,7 +154,7 @@ def test_client_registration_with_software_statement(testclient, backend, keypai
}
res = testclient.post_json("/oauth/register", payload, status=201)
client = models.Client.get(client_id=res.json["client_id"])
client = backend.get(models.Client, client_id=res.json["client_id"])
assert res.json == {
"client_id": client.client_id,
"client_secret": client.client_secret,
@ -205,7 +205,7 @@ def test_client_registration_without_authentication_ok(testclient, backend):
res = testclient.post_json("/oauth/register", payload, status=201)
client = models.Client.get(client_id=res.json["client_id"])
client = backend.get(models.Client, client_id=res.json["client_id"])
assert res.json == {
"client_id": mock.ANY,
"client_secret": mock.ANY,

View file

@ -95,7 +95,7 @@ def test_update(testclient, backend, client, user):
res = testclient.put_json(
f"/oauth/register/{client.client_id}", payload, headers=headers, status=200
)
client = models.Client.get(client_id=res.json["client_id"])
client = backend.get(models.Client, client_id=res.json["client_id"])
assert res.json == {
"client_id": client.client_id,
@ -153,7 +153,7 @@ def test_delete(testclient, backend, user):
testclient.delete(
f"/oauth/register/{client.client_id}", headers=headers, status=204
)
assert not models.Client.get(client_id=client.client_id)
assert not backend.get(models.Client, client_id=client.client_id)
def test_invalid_client(testclient, backend, user):

View file

@ -33,7 +33,7 @@ def test_fieldlist_add(testclient, logged_admin, backend):
res = res.follow(status=200)
client_id = res.forms["readonly"]["client_id"].value
client = models.Client.get(client_id=client_id)
client = backend.get(models.Client, client_id=client_id)
assert client.redirect_uris == [
"https://foo.bar/callback",
@ -68,7 +68,7 @@ def test_fieldlist_delete(testclient, logged_admin, backend):
res = res.follow(status=200)
client_id = res.forms["readonly"]["client_id"].value
client = models.Client.get(client_id=client_id)
client = backend.get(models.Client, client_id=client_id)
assert client.redirect_uris == [
"https://foo.bar/callback1",
@ -128,7 +128,7 @@ def test_fieldlist_duplicate_value(testclient, logged_admin, client):
res.mustcontain("This value is a duplicate")
def test_fieldlist_empty_value(testclient, logged_admin):
def test_fieldlist_empty_value(testclient, logged_admin, backend):
res = testclient.get("/admin/client/add")
data = {
"client_name": "foobar",
@ -145,7 +145,7 @@ def test_fieldlist_empty_value(testclient, logged_admin):
status=200, name="fieldlist_add", value="post_logout_redirect_uris-0"
)
res.form.submit(status=302, name="action", value="edit")
client = models.Client.get()
client = backend.get(models.Client)
client.delete()

View file

@ -32,11 +32,11 @@ def test_oauth_hybrid(testclient, backend, user, client):
params = parse_qs(urlsplit(res.location).fragment)
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
authcode = backend.get(models.AuthorizationCode, code=code)
assert authcode is not None
access_token = params["access_token"][0]
token = models.Token.get(access_token=access_token)
token = backend.get(models.Token, access_token=access_token)
assert token is not None
res = testclient.get(
@ -65,11 +65,11 @@ def test_oidc_hybrid(testclient, backend, logged_user, client, keypair, trusted_
params = parse_qs(urlsplit(res.location).fragment)
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
authcode = backend.get(models.AuthorizationCode, code=code)
assert authcode is not None
access_token = params["access_token"][0]
token = models.Token.get(access_token=access_token)
token = backend.get(models.Token, access_token=access_token)
assert token is not None
id_token = params["id_token"][0]

View file

@ -6,7 +6,7 @@ from authlib.jose import jwt
from canaille.app import models
def test_oauth_implicit(testclient, user, client):
def test_oauth_implicit(testclient, user, client, backend):
client.grant_types = ["token"]
client.token_endpoint_auth_method = "none"
@ -37,7 +37,7 @@ def test_oauth_implicit(testclient, user, client):
params = parse_qs(urlsplit(res.location).fragment)
access_token = params["access_token"][0]
token = models.Token.get(access_token=access_token)
token = backend.get(models.Token, access_token=access_token)
assert token is not None
res = testclient.get(
@ -51,7 +51,7 @@ def test_oauth_implicit(testclient, user, client):
client.save()
def test_oidc_implicit(testclient, keypair, user, client, trusted_client):
def test_oidc_implicit(testclient, keypair, user, client, trusted_client, backend):
client.grant_types = ["token id_token"]
client.token_endpoint_auth_method = "none"
@ -82,7 +82,7 @@ def test_oidc_implicit(testclient, keypair, user, client, trusted_client):
params = parse_qs(urlsplit(res.location).fragment)
access_token = params["access_token"][0]
token = models.Token.get(access_token=access_token)
token = backend.get(models.Token, access_token=access_token)
assert token is not None
id_token = params["id_token"][0]
@ -105,7 +105,7 @@ def test_oidc_implicit(testclient, keypair, user, client, trusted_client):
def test_oidc_implicit_with_group(
testclient, keypair, user, client, foo_group, trusted_client
testclient, keypair, user, client, foo_group, trusted_client, backend
):
client.grant_types = ["token id_token"]
client.token_endpoint_auth_method = "none"
@ -137,7 +137,7 @@ def test_oidc_implicit_with_group(
params = parse_qs(urlsplit(res.location).fragment)
access_token = params["access_token"][0]
token = models.Token.get(access_token=access_token)
token = backend.get(models.Token, access_token=access_token)
assert token is not None
id_token = params["id_token"][0]

View file

@ -3,7 +3,7 @@ from canaille.app import models
from . import client_credentials
def test_password_flow_basic(testclient, user, client):
def test_password_flow_basic(testclient, user, client, backend):
res = testclient.post(
"/oauth/token",
params=dict(
@ -20,7 +20,7 @@ def test_password_flow_basic(testclient, user, client):
assert res.json["token_type"] == "Bearer"
access_token = res.json["access_token"]
token = models.Token.get(access_token=access_token)
token = backend.get(models.Token, access_token=access_token)
assert token is not None
res = testclient.get(
@ -31,7 +31,7 @@ def test_password_flow_basic(testclient, user, client):
assert res.json["name"] == "John (johnny) Doe"
def test_password_flow_post(testclient, user, client):
def test_password_flow_post(testclient, user, client, backend):
client.token_endpoint_auth_method = "client_secret_post"
client.save()
@ -52,7 +52,7 @@ def test_password_flow_post(testclient, user, client):
assert res.json["token_type"] == "Bearer"
access_token = res.json["access_token"]
token = models.Token.get(access_token=access_token)
token = backend.get(models.Token, access_token=access_token)
assert token is not None
res = testclient.get(

View file

@ -24,7 +24,7 @@ def test_refresh_token(testclient, logged_user, client, backend):
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
authcode = backend.get(models.AuthorizationCode, code=code)
assert authcode is not None
consents = backend.query(models.Consent, client=client, subject=logged_user)
@ -42,7 +42,7 @@ def test_refresh_token(testclient, logged_user, client, backend):
status=200,
)
access_token = res.json["access_token"]
old_token = models.Token.get(access_token=access_token)
old_token = backend.get(models.Token, access_token=access_token)
assert old_token is not None
assert not old_token.revokation_date
@ -56,7 +56,7 @@ def test_refresh_token(testclient, logged_user, client, backend):
status=200,
)
access_token = res.json["access_token"]
new_token = models.Token.get(access_token=access_token)
new_token = backend.get(models.Token, access_token=access_token)
assert new_token is not None
assert old_token.access_token != new_token.access_token
@ -74,7 +74,7 @@ def test_refresh_token(testclient, logged_user, client, backend):
consent.delete()
def test_refresh_token_with_invalid_user(testclient, client):
def test_refresh_token_with_invalid_user(testclient, client, backend):
user = models.User(
formatted_name="John Doe",
family_name="Doe",
@ -103,7 +103,7 @@ def test_refresh_token_with_invalid_user(testclient, client):
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
models.AuthorizationCode.get(code=code)
backend.get(models.AuthorizationCode, code=code)
res = testclient.post(
"/oauth/token",
@ -134,7 +134,7 @@ def test_refresh_token_with_invalid_user(testclient, client):
"error": "invalid_request",
"error_description": 'There is no "user" for this token.',
}
models.Token.get(access_token=access_token).delete()
backend.get(models.Token, access_token=access_token).delete()
def test_cannot_refresh_token_for_locked_users(testclient, logged_user, client):

View file

@ -26,7 +26,7 @@ def test_token_default_expiration_date(
res = res.form.submit(name="answer", value="accept", status=302)
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
authcode = backend.get(models.AuthorizationCode, code=code)
assert authcode.lifetime == 84400
res = testclient.post(
@ -44,7 +44,7 @@ def test_token_default_expiration_date(
assert res.json["expires_in"] == 864000
access_token = res.json["access_token"]
token = models.Token.get(access_token=access_token)
token = backend.get(models.Token, access_token=access_token)
assert token.lifetime == 864000
claims = jwt.decode(access_token, keypair[1])
@ -86,7 +86,7 @@ def test_token_custom_expiration_date(
res = res.form.submit(name="answer", value="accept", status=302)
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
authcode = backend.get(models.AuthorizationCode, code=code)
assert authcode.lifetime == 84400
res = testclient.post(
@ -104,7 +104,7 @@ def test_token_custom_expiration_date(
assert res.json["expires_in"] == 1000
access_token = res.json["access_token"]
token = models.Token.get(access_token=access_token)
token = backend.get(models.Token, access_token=access_token)
assert token.lifetime == 1000
claims = jwt.decode(access_token, keypair[1])

View file

@ -58,7 +58,7 @@ def test_token_invalid(testclient, client):
assert {"active": False} == res.json
def test_full_flow(testclient, logged_user, client, user, trusted_client):
def test_full_flow(testclient, logged_user, client, user, trusted_client, backend):
res = testclient.get(
"/oauth/authorize",
params=dict(
@ -75,7 +75,7 @@ def test_full_flow(testclient, logged_user, client, user, trusted_client):
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = models.AuthorizationCode.get(code=code)
authcode = backend.get(models.AuthorizationCode, code=code)
assert authcode is not None
res = testclient.post(
@ -91,7 +91,7 @@ def test_full_flow(testclient, logged_user, client, user, trusted_client):
)
access_token = res.json["access_token"]
token = models.Token.get(access_token=access_token)
token = backend.get(models.Token, access_token=access_token)
assert token.client == client
assert token.subject == logged_user