forked from Github-Mirrors/canaille
refactor: move BackendModel.get to Backend.get
This commit is contained in:
parent
ccde88b1bf
commit
44573713ed
40 changed files with 255 additions and 241 deletions
|
@ -20,7 +20,7 @@ def current_user():
|
|||
return g.user
|
||||
|
||||
for user_id in session.get("user_id", [])[::-1]:
|
||||
user = models.User.get(user_id)
|
||||
user = current_app.backend.instance.get(models.User, user_id)
|
||||
if user and (
|
||||
not current_app.backend.has_account_lockability() or not user.locked
|
||||
):
|
||||
|
@ -147,7 +147,7 @@ def model_converter(model):
|
|||
|
||||
def to_python(self, identifier):
|
||||
current_app.backend.setup()
|
||||
instance = model.get(identifier)
|
||||
instance = current_app.backend.get(model, identifier)
|
||||
if self.required and not instance:
|
||||
abort(404)
|
||||
|
||||
|
|
|
@ -260,7 +260,9 @@ class IDToModel:
|
|||
|
||||
def __call__(self, data):
|
||||
model = getattr(models, self.model_name)
|
||||
instance = data if isinstance(data, model) else model.get(data)
|
||||
instance = (
|
||||
data if isinstance(data, model) else BaseBackend.instance.get(model, data)
|
||||
)
|
||||
if instance:
|
||||
return instance
|
||||
|
||||
|
|
|
@ -80,6 +80,11 @@ class BaseBackend:
|
|||
attribute values loosely be matched."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get(self, model, identifier=None, **kwargs):
|
||||
"""Works like :meth:`~canaille.backends.BaseBackend.query` but return
|
||||
only one element or :py:data:`None` if no item is matching."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def check_user_password(self, user, password: str) -> bool:
|
||||
"""Check if the password matches the user password in the database."""
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -202,7 +202,7 @@ class Backend(BaseBackend):
|
|||
if login
|
||||
else None
|
||||
)
|
||||
return User.get(filter=filter)
|
||||
return self.get(User, filter=filter)
|
||||
|
||||
def check_user_password(self, user, password):
|
||||
conn = ldap.initialize(current_app.config["CANAILLE_LDAP"]["URI"])
|
||||
|
@ -311,6 +311,19 @@ class Backend(BaseBackend):
|
|||
)
|
||||
return self.query(model, filter=filter, **kwargs)
|
||||
|
||||
def get(self, model, identifier=None, /, **kwargs):
|
||||
try:
|
||||
return self.query(model, identifier, **kwargs)[0]
|
||||
except (IndexError, ldap.NO_SUCH_OBJECT):
|
||||
if identifier and model.base:
|
||||
return (
|
||||
self.get(model, **{model.identifier_attribute: identifier})
|
||||
or self.get(model, id=identifier)
|
||||
or None
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def setup_ldap_models(config):
|
||||
from canaille.app import models
|
||||
|
|
|
@ -4,7 +4,6 @@ import ldap.dn
|
|||
import ldap.filter
|
||||
from ldap.controls.readentry import PostReadControl
|
||||
|
||||
from canaille.backends import BaseBackend
|
||||
from canaille.backends.models import BackendModel
|
||||
|
||||
from .backend import Backend
|
||||
|
@ -226,20 +225,6 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
|||
|
||||
return cls._attribute_type_by_name
|
||||
|
||||
@classmethod
|
||||
def get(cls, identifier=None, /, **kwargs):
|
||||
try:
|
||||
return BaseBackend.instance.query(cls, identifier, **kwargs)[0]
|
||||
except (IndexError, ldap.NO_SUCH_OBJECT):
|
||||
if identifier and cls.base:
|
||||
return (
|
||||
cls.get(**{cls.identifier_attribute: identifier})
|
||||
or cls.get(id=identifier)
|
||||
or None
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def update_ldap_attributes(cls):
|
||||
all_object_classes = cls.ldap_object_classes()
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import datetime
|
||||
from enum import Enum
|
||||
|
||||
from canaille.backends import BaseBackend
|
||||
|
||||
LDAP_NULL_DATE = "000001010000Z"
|
||||
|
||||
|
||||
|
@ -50,7 +52,7 @@ def ldap_to_python(value, syntax):
|
|||
return value.decode("utf-8").upper() == "TRUE"
|
||||
|
||||
if syntax == Syntax.DISTINGUISHED_NAME:
|
||||
return LDAPObject.get(value.decode("utf-8"))
|
||||
return BaseBackend.instance.get(LDAPObject, value.decode("utf-8"))
|
||||
|
||||
return value.decode("utf-8")
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ class Backend(BaseBackend):
|
|||
def get_user_from_login(self, login):
|
||||
from .models import User
|
||||
|
||||
return User.get(user_name=login)
|
||||
return self.get(User, user_name=login)
|
||||
|
||||
def check_user_password(self, user, password):
|
||||
if password != user.password:
|
||||
|
@ -80,3 +80,14 @@ class Backend(BaseBackend):
|
|||
if isinstance(value, str)
|
||||
)
|
||||
]
|
||||
|
||||
def get(self, model, identifier=None, /, **kwargs):
|
||||
if identifier:
|
||||
return (
|
||||
self.get(model, **{model.identifier_attribute: identifier})
|
||||
or self.get(model, id=identifier)
|
||||
or None
|
||||
)
|
||||
|
||||
results = self.query(model, **kwargs)
|
||||
return results[0] if results else None
|
||||
|
|
|
@ -36,18 +36,6 @@ class MemoryModel(BackendModel):
|
|||
class_name or cls.__name__, {}
|
||||
).setdefault(attribute, {})
|
||||
|
||||
@classmethod
|
||||
def get(cls, identifier=None, /, **kwargs):
|
||||
if identifier:
|
||||
return (
|
||||
cls.get(**{cls.identifier_attribute: identifier})
|
||||
or cls.get(id=identifier)
|
||||
or None
|
||||
)
|
||||
|
||||
results = BaseBackend.instance.query(cls, **kwargs)
|
||||
return results[0] if results else None
|
||||
|
||||
@classmethod
|
||||
def listify(cls, value):
|
||||
return value if isinstance(value, list) else [value]
|
||||
|
@ -75,7 +63,7 @@ class MemoryModel(BackendModel):
|
|||
model, _ = cls.get_model_annotations(attribute_name)
|
||||
if model and not isinstance(value, model):
|
||||
backend_model = getattr(models, model.__name__)
|
||||
return backend_model.get(id=value)
|
||||
return BaseBackend.instance.get(backend_model, id=value)
|
||||
|
||||
return value
|
||||
|
||||
|
@ -166,7 +154,7 @@ class MemoryModel(BackendModel):
|
|||
del self.index()[self.id]
|
||||
|
||||
def reload(self):
|
||||
self._state = self.__class__.get(id=self.id)._state
|
||||
self._state = BaseBackend.instance.get(self.__class__, id=self.id)._state
|
||||
self._cache = {}
|
||||
|
||||
def __eq__(self, other):
|
||||
|
@ -174,7 +162,7 @@ class MemoryModel(BackendModel):
|
|||
return False
|
||||
|
||||
if not isinstance(other, MemoryModel):
|
||||
return self == self.__class__.get(id=other)
|
||||
return self == BaseBackend.instance.get(self.__class__, id=other)
|
||||
|
||||
return self._state == other._state
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ from typing import get_type_hints
|
|||
|
||||
from canaille.app import classproperty
|
||||
from canaille.app import models
|
||||
from canaille.backends import BaseBackend
|
||||
|
||||
|
||||
class Model:
|
||||
|
@ -87,12 +88,6 @@ class BackendModel:
|
|||
implemented for every model and for every backend.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get(cls, identifier=None, **kwargs):
|
||||
"""Works like :meth:`~canaille.backends.BaseBackend.query` but return
|
||||
only one element or :py:data:`None` if no item is matching."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def save(self):
|
||||
"""Validate the current modifications in the database."""
|
||||
raise NotImplementedError()
|
||||
|
@ -175,7 +170,7 @@ class BackendModel:
|
|||
|
||||
backend_model = getattr(models, model.__name__)
|
||||
|
||||
if instance := backend_model.get(value):
|
||||
if instance := BaseBackend.instance.get(backend_model, value):
|
||||
filter[attribute] = instance
|
||||
|
||||
return all(
|
||||
|
|
|
@ -53,7 +53,7 @@ class Backend(BaseBackend):
|
|||
def get_user_from_login(self, login):
|
||||
from .models import User
|
||||
|
||||
return User.get(user_name=login)
|
||||
return self.get(User, user_name=login)
|
||||
|
||||
def check_user_password(self, user, password):
|
||||
if password != user.password:
|
||||
|
@ -90,3 +90,19 @@ class Backend(BaseBackend):
|
|||
)
|
||||
|
||||
return self.db_session.execute(select(model).filter(filter)).scalars().all()
|
||||
|
||||
def get(self, model, identifier=None, /, **kwargs):
|
||||
if identifier:
|
||||
return (
|
||||
self.get(model, **{model.identifier_attribute: identifier})
|
||||
or self.get(model, id=identifier)
|
||||
or None
|
||||
)
|
||||
|
||||
filter = [
|
||||
model.attribute_filter(attribute_name, expected_value)
|
||||
for attribute_name, expected_value in kwargs.items()
|
||||
]
|
||||
return Backend.instance.db_session.execute(
|
||||
select(model).filter(*filter)
|
||||
).scalar_one_or_none()
|
||||
|
|
|
@ -11,7 +11,6 @@ from sqlalchemy import LargeBinary
|
|||
from sqlalchemy import String
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
|
@ -48,23 +47,6 @@ class SqlAlchemyModel(BackendModel):
|
|||
|
||||
return getattr(cls, name) == value
|
||||
|
||||
@classmethod
|
||||
def get(cls, identifier=None, /, **kwargs):
|
||||
if identifier:
|
||||
return (
|
||||
cls.get(**{cls.identifier_attribute: identifier})
|
||||
or cls.get(id=identifier)
|
||||
or None
|
||||
)
|
||||
|
||||
filter = [
|
||||
cls.attribute_filter(attribute_name, expected_value)
|
||||
for attribute_name, expected_value in kwargs.items()
|
||||
]
|
||||
return Backend.instance.db_session.execute(
|
||||
select(cls).filter(*filter)
|
||||
).scalar_one_or_none()
|
||||
|
||||
def save(self):
|
||||
self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
|
||||
microsecond=0
|
||||
|
|
|
@ -257,7 +257,9 @@ def registration(data=None, hash=None):
|
|||
)
|
||||
return redirect(url_for("core.account.index"))
|
||||
|
||||
if payload.user_name and models.User.get(user_name=payload.user_name):
|
||||
if payload.user_name and BaseBackend.instance.get(
|
||||
models.User, user_name=payload.user_name
|
||||
):
|
||||
flash(
|
||||
_("Your account has already been created."),
|
||||
"error",
|
||||
|
@ -282,7 +284,10 @@ def registration(data=None, hash=None):
|
|||
data = {
|
||||
"user_name": payload.user_name,
|
||||
"emails": [payload.email],
|
||||
"groups": [models.Group.get(id=group_id) for group_id in payload.groups],
|
||||
"groups": [
|
||||
BaseBackend.instance.get(models.Group, id=group_id)
|
||||
for group_id in payload.groups
|
||||
],
|
||||
}
|
||||
|
||||
has_smtp = "SMTP" in current_app.config["CANAILLE"]
|
||||
|
@ -376,7 +381,7 @@ def email_confirmation(data, hash):
|
|||
)
|
||||
return redirect(url_for("core.account.index"))
|
||||
|
||||
user = models.User.get(confirmation_obj.identifier)
|
||||
user = BaseBackend.instance.get(models.User, confirmation_obj.identifier)
|
||||
if not user:
|
||||
flash(
|
||||
_("The email confirmation link that brought you here is invalid."),
|
||||
|
|
|
@ -24,7 +24,7 @@ MINIMUM_PASSWORD_LENGTH = 8
|
|||
|
||||
|
||||
def unique_user_name(form, field):
|
||||
if models.User.get(user_name=field.data) and (
|
||||
if BaseBackend.instance.get(models.User, user_name=field.data) and (
|
||||
not getattr(form, "user", None) or form.user.user_name != field.data
|
||||
):
|
||||
raise wtforms.ValidationError(
|
||||
|
@ -33,7 +33,7 @@ def unique_user_name(form, field):
|
|||
|
||||
|
||||
def unique_email(form, field):
|
||||
if models.User.get(emails=field.data) and (
|
||||
if BaseBackend.instance.get(models.User, emails=field.data) and (
|
||||
not getattr(form, "user", None) or field.data not in form.user.emails
|
||||
):
|
||||
raise wtforms.ValidationError(
|
||||
|
@ -42,7 +42,7 @@ def unique_email(form, field):
|
|||
|
||||
|
||||
def unique_group(form, field):
|
||||
if models.Group.get(display_name=field.data):
|
||||
if BaseBackend.instance.get(models.Group, display_name=field.data):
|
||||
raise wtforms.ValidationError(
|
||||
_("The group '{group}' already exists").format(group=field.data)
|
||||
)
|
||||
|
|
|
@ -258,12 +258,12 @@ class User(Model):
|
|||
def preferred_email(self):
|
||||
return self.emails[0] if self.emails else None
|
||||
|
||||
def __getattr__(self, name):
|
||||
def __getattribute__(self, name):
|
||||
prefix = "can_"
|
||||
if name.startswith(prefix) and name != "can_read":
|
||||
return self.can(name[len(prefix) :])
|
||||
|
||||
return super().__getattr__(name)
|
||||
return super().__getattribute__(name)
|
||||
|
||||
def can(self, *permissions: Permission):
|
||||
"""Wether or not the user has the
|
||||
|
|
|
@ -108,7 +108,7 @@ def revoke_preconsent(user, client):
|
|||
flash(_("Could not revoke this access"), "error")
|
||||
return redirect(url_for("oidc.consents.consents"))
|
||||
|
||||
consent = models.Consent.get(client=client, subject=user)
|
||||
consent = BaseBackend.instance.get(models.Consent, client=client, subject=user)
|
||||
if consent:
|
||||
return redirect(url_for("oidc.consents.revoke", consent=consent))
|
||||
|
||||
|
|
|
@ -50,7 +50,9 @@ def authorize():
|
|||
request.form.to_dict(flat=False),
|
||||
)
|
||||
|
||||
client = models.Client.get(client_id=request.args["client_id"])
|
||||
client = BaseBackend.instance.get(
|
||||
models.Client, client_id=request.args["client_id"]
|
||||
)
|
||||
user = current_user()
|
||||
|
||||
if response := authorize_guards(client):
|
||||
|
@ -276,7 +278,7 @@ def end_session():
|
|||
valid_uris = []
|
||||
|
||||
if "client_id" in data:
|
||||
client = models.Client.get(client_id=data["client_id"])
|
||||
client = BaseBackend.instance.get(models.Client, client_id=data["client_id"])
|
||||
if client:
|
||||
valid_uris = client.post_logout_redirect_uris
|
||||
|
||||
|
@ -328,7 +330,7 @@ def end_session():
|
|||
else [id_token["aud"]]
|
||||
)
|
||||
for client_id in client_ids:
|
||||
client = models.Client.get(client_id=client_id)
|
||||
client = BaseBackend.instance.get(models.Client, client_id=client_id)
|
||||
if client:
|
||||
valid_uris.extend(client.post_logout_redirect_uris or [])
|
||||
|
||||
|
|
|
@ -111,7 +111,7 @@ def openid_configuration():
|
|||
|
||||
|
||||
def exists_nonce(nonce, req):
|
||||
client = models.Client.get(id=req.client_id)
|
||||
client = BaseBackend.instance.get(models.Client, id=req.client_id)
|
||||
exists = BaseBackend.instance.query(
|
||||
models.AuthorizationCode, client=client, nonce=nonce
|
||||
)
|
||||
|
@ -334,7 +334,7 @@ class OpenIDHybridGrant(_OpenIDHybridGrant):
|
|||
|
||||
|
||||
def query_client(client_id):
|
||||
return models.Client.get(client_id=client_id)
|
||||
return BaseBackend.instance.get(models.Client, client_id=client_id)
|
||||
|
||||
|
||||
def save_token(token, request):
|
||||
|
@ -356,20 +356,20 @@ def save_token(token, request):
|
|||
|
||||
class BearerTokenValidator(_BearerTokenValidator):
|
||||
def authenticate_token(self, token_string):
|
||||
return models.Token.get(access_token=token_string)
|
||||
return BaseBackend.instance.get(models.Token, access_token=token_string)
|
||||
|
||||
|
||||
def query_token(token, token_type_hint):
|
||||
if token_type_hint == "access_token":
|
||||
return models.Token.get(access_token=token)
|
||||
return BaseBackend.instance.get(models.Token, access_token=token)
|
||||
elif token_type_hint == "refresh_token":
|
||||
return models.Token.get(refresh_token=token)
|
||||
return BaseBackend.instance.get(models.Token, refresh_token=token)
|
||||
|
||||
item = models.Token.get(access_token=token)
|
||||
item = BaseBackend.instance.get(models.Token, access_token=token)
|
||||
if item:
|
||||
return item
|
||||
|
||||
item = models.Token.get(refresh_token=token)
|
||||
item = BaseBackend.instance.get(models.Token, refresh_token=token)
|
||||
if item:
|
||||
return item
|
||||
|
||||
|
@ -472,7 +472,7 @@ class ClientRegistrationEndpoint(ClientManagementMixin, _ClientRegistrationEndpo
|
|||
class ClientConfigurationEndpoint(ClientManagementMixin, _ClientConfigurationEndpoint):
|
||||
def authenticate_client(self, request):
|
||||
client_id = request.uri.split("/")[-1]
|
||||
return models.Client.get(client_id=client_id)
|
||||
return BaseBackend.instance.get(models.Client, client_id=client_id)
|
||||
|
||||
def revoke_access_token(self, request, token):
|
||||
pass
|
||||
|
|
|
@ -7,7 +7,7 @@ def test_guess_object_from_dn(backend, testclient, foo_group):
|
|||
foo_group.members = [foo_group]
|
||||
foo_group.save()
|
||||
dn = foo_group.dn
|
||||
g = LDAPObject.get(dn)
|
||||
g = backend.get(LDAPObject, dn)
|
||||
assert isinstance(g, models.Group)
|
||||
assert g == foo_group
|
||||
assert g.display_name == foo_group.display_name
|
||||
|
@ -21,9 +21,9 @@ def test_object_class_update(backend, testclient):
|
|||
user1.save()
|
||||
|
||||
assert set(user1.get_ldap_attribute("objectClass")) == {"inetOrgPerson"}
|
||||
assert set(models.User.get(id=user1.id).get_ldap_attribute("objectClass")) == {
|
||||
"inetOrgPerson"
|
||||
}
|
||||
assert set(
|
||||
backend.get(models.User, id=user1.id).get_ldap_attribute("objectClass")
|
||||
) == {"inetOrgPerson"}
|
||||
|
||||
testclient.app.config["CANAILLE_LDAP"]["USER_CLASS"] = [
|
||||
"inetOrgPerson",
|
||||
|
@ -38,12 +38,14 @@ def test_object_class_update(backend, testclient):
|
|||
"inetOrgPerson",
|
||||
"extensibleObject",
|
||||
}
|
||||
assert set(models.User.get(id=user2.id).get_ldap_attribute("objectClass")) == {
|
||||
assert set(
|
||||
backend.get(models.User, id=user2.id).get_ldap_attribute("objectClass")
|
||||
) == {
|
||||
"inetOrgPerson",
|
||||
"extensibleObject",
|
||||
}
|
||||
|
||||
user1 = models.User.get(id=user1.id)
|
||||
user1 = backend.get(models.User, id=user1.id)
|
||||
assert user1.get_ldap_attribute("objectClass") == ["inetOrgPerson"]
|
||||
|
||||
user1.save()
|
||||
|
@ -51,7 +53,9 @@ def test_object_class_update(backend, testclient):
|
|||
"inetOrgPerson",
|
||||
"extensibleObject",
|
||||
}
|
||||
assert set(models.User.get(id=user1.id).get_ldap_attribute("objectClass")) == {
|
||||
assert set(
|
||||
backend.get(models.User, id=user1.id).get_ldap_attribute("objectClass")
|
||||
) == {
|
||||
"inetOrgPerson",
|
||||
"extensibleObject",
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ def test_object_creation(app, backend):
|
|||
user.save()
|
||||
assert user.exists
|
||||
|
||||
user = models.User.get(id=user.id)
|
||||
user = backend.get(models.User, id=user.id)
|
||||
assert user.exists
|
||||
|
||||
user.delete()
|
||||
|
@ -52,9 +52,9 @@ def test_dn_when_leading_space_in_id_attribute(testclient, backend):
|
|||
assert ldap.dn.is_dn(dn)
|
||||
assert ldap.dn.dn2str(ldap.dn.str2dn(dn)) == dn
|
||||
|
||||
assert user == models.User.get(user.user_name)
|
||||
assert user == models.User.get(user_name=user.user_name)
|
||||
assert user == models.User.get(dn)
|
||||
assert user == backend.get(models.User, user.user_name)
|
||||
assert user == backend.get(models.User, user_name=user.user_name)
|
||||
assert user == backend.get(models.User, dn)
|
||||
|
||||
user.delete()
|
||||
|
||||
|
@ -73,9 +73,9 @@ def test_special_chars_in_rdn(testclient, backend):
|
|||
assert ldap.dn.dn2str(ldap.dn.str2dn(dn)) == dn
|
||||
assert dn == "uid=\\#user,ou=users,dc=mydomain,dc=tld"
|
||||
|
||||
assert user == models.User.get(user.user_name)
|
||||
assert user == models.User.get(user_name=user.user_name)
|
||||
assert user == models.User.get(dn)
|
||||
assert user == backend.get(models.User, user.user_name)
|
||||
assert user == backend.get(models.User, user_name=user.user_name)
|
||||
assert user == backend.get(models.User, dn)
|
||||
|
||||
user.delete()
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ def test_model_comparison(testclient, backend):
|
|||
formatted_name="bar",
|
||||
)
|
||||
bar.save()
|
||||
foo2 = models.User.get(id=foo1.id)
|
||||
foo2 = backend.get(models.User, id=foo1.id)
|
||||
|
||||
assert foo1 == foo2
|
||||
assert foo1 != bar
|
||||
|
@ -39,14 +39,14 @@ def test_model_lifecycle(testclient, backend):
|
|||
assert not backend.query(models.User)
|
||||
assert not backend.query(models.User, id=user.id)
|
||||
assert not backend.query(models.User, id="invalid")
|
||||
assert not models.User.get(id=user.id)
|
||||
assert not backend.get(models.User, id=user.id)
|
||||
|
||||
user.save()
|
||||
|
||||
assert backend.query(models.User) == [user]
|
||||
assert backend.query(models.User, id=user.id) == [user]
|
||||
assert not backend.query(models.User, id="invalid")
|
||||
assert models.User.get(id=user.id) == user
|
||||
assert backend.get(models.User, id=user.id) == user
|
||||
|
||||
user.family_name = "new_family_name"
|
||||
|
||||
|
@ -59,7 +59,7 @@ def test_model_lifecycle(testclient, backend):
|
|||
user.delete()
|
||||
|
||||
assert not backend.query(models.User, id=user.id)
|
||||
assert not models.User.get(id=user.id)
|
||||
assert not backend.get(models.User, id=user.id)
|
||||
|
||||
user.delete()
|
||||
|
||||
|
@ -78,7 +78,7 @@ def test_model_attribute_edition(testclient, backend):
|
|||
assert user.family_name == "family_name"
|
||||
assert user.emails == ["email1@user.com", "email2@user.com"]
|
||||
|
||||
user = models.User.get(id=user.id)
|
||||
user = backend.get(models.User, id=user.id)
|
||||
assert user.user_name == "user_name"
|
||||
assert user.family_name == "family_name"
|
||||
assert user.emails == ["email1@user.com", "email2@user.com"]
|
||||
|
@ -90,7 +90,7 @@ def test_model_attribute_edition(testclient, backend):
|
|||
assert user.family_name == "new_family_name"
|
||||
assert user.emails == ["email1@user.com"]
|
||||
|
||||
user = models.User.get(id=user.id)
|
||||
user = backend.get(models.User, id=user.id)
|
||||
assert user.family_name == "new_family_name"
|
||||
assert user.emails == ["email1@user.com"]
|
||||
|
||||
|
@ -112,34 +112,34 @@ def test_model_indexation(testclient, backend):
|
|||
)
|
||||
user.save()
|
||||
|
||||
assert models.User.get(family_name="family_name") == user
|
||||
assert not models.User.get(family_name="new_family_name")
|
||||
assert models.User.get(emails=["email1@user.com"]) == user
|
||||
assert models.User.get(emails=["email2@user.com"]) == user
|
||||
assert not models.User.get(emails=["email3@user.com"])
|
||||
assert backend.get(models.User, family_name="family_name") == user
|
||||
assert not backend.get(models.User, family_name="new_family_name")
|
||||
assert backend.get(models.User, emails=["email1@user.com"]) == user
|
||||
assert backend.get(models.User, emails=["email2@user.com"]) == user
|
||||
assert not backend.get(models.User, emails=["email3@user.com"])
|
||||
|
||||
user.family_name = "new_family_name"
|
||||
user.emails = ["email2@user.com"]
|
||||
|
||||
assert models.User.get(family_name="family_name") != user
|
||||
assert models.User.get(emails=["email1@user.com"]) != user
|
||||
assert not models.User.get(emails=["email3@user.com"])
|
||||
assert backend.get(models.User, family_name="family_name") != user
|
||||
assert backend.get(models.User, emails=["email1@user.com"]) != user
|
||||
assert not backend.get(models.User, emails=["email3@user.com"])
|
||||
|
||||
user.save()
|
||||
|
||||
assert not models.User.get(family_name="family_name")
|
||||
assert models.User.get(family_name="new_family_name") == user
|
||||
assert not models.User.get(emails=["email1@user.com"])
|
||||
assert models.User.get(emails=["email2@user.com"]) == user
|
||||
assert not models.User.get(emails=["email3@user.com"])
|
||||
assert not backend.get(models.User, family_name="family_name")
|
||||
assert backend.get(models.User, family_name="new_family_name") == user
|
||||
assert not backend.get(models.User, emails=["email1@user.com"])
|
||||
assert backend.get(models.User, emails=["email2@user.com"]) == user
|
||||
assert not backend.get(models.User, emails=["email3@user.com"])
|
||||
|
||||
user.delete()
|
||||
|
||||
assert not models.User.get(family_name="family_name")
|
||||
assert not models.User.get(family_name="new_family_name")
|
||||
assert not models.User.get(emails=["email1@user.com"])
|
||||
assert not models.User.get(emails=["email2@user.com"])
|
||||
assert not models.User.get(emails=["email3@user.com"])
|
||||
assert not backend.get(models.User, family_name="family_name")
|
||||
assert not backend.get(models.User, family_name="new_family_name")
|
||||
assert not backend.get(models.User, emails=["email1@user.com"])
|
||||
assert not backend.get(models.User, emails=["email2@user.com"])
|
||||
assert not backend.get(models.User, emails=["email3@user.com"])
|
||||
|
||||
|
||||
def test_fuzzy_unique_attribute(user, moderator, admin, backend):
|
||||
|
|
|
@ -78,7 +78,7 @@ def test_admin_self_deletion(testclient, backend):
|
|||
.follow(status=200)
|
||||
)
|
||||
|
||||
assert models.User.get(user_name="temp") is None
|
||||
assert backend.get(models.User, user_name="temp") is None
|
||||
|
||||
with testclient.session_transaction() as sess:
|
||||
assert not sess.get("user_id")
|
||||
|
@ -116,7 +116,7 @@ def test_user_self_deletion(testclient, backend):
|
|||
.follow(status=200)
|
||||
)
|
||||
|
||||
assert models.User.get(user_name="temp") is None
|
||||
assert backend.get(models.User, user_name="temp") is None
|
||||
|
||||
with testclient.session_transaction() as sess:
|
||||
assert not sess.get("user_id")
|
||||
|
@ -136,7 +136,7 @@ def test_account_locking(user, backend):
|
|||
assert user.locked
|
||||
user.save()
|
||||
assert user.locked
|
||||
assert models.User.get(id=user.id).locked
|
||||
assert backend.get(models.User, id=user.id).locked
|
||||
assert backend.check_user_password(user, "correct horse battery staple") == (
|
||||
False,
|
||||
"Your account has been locked.",
|
||||
|
@ -145,7 +145,7 @@ def test_account_locking(user, backend):
|
|||
user.lock_date = None
|
||||
user.save()
|
||||
assert not user.locked
|
||||
assert not models.User.get(id=user.id).locked
|
||||
assert not backend.get(models.User, id=user.id).locked
|
||||
assert backend.check_user_password(user, "correct horse battery staple") == (
|
||||
True,
|
||||
None,
|
||||
|
@ -165,7 +165,7 @@ def test_account_locking_past_date(user, backend):
|
|||
) - datetime.timedelta(days=30)
|
||||
user.save()
|
||||
assert user.locked
|
||||
assert models.User.get(id=user.id).locked
|
||||
assert backend.get(models.User, id=user.id).locked
|
||||
assert backend.check_user_password(user, "correct horse battery staple") == (
|
||||
False,
|
||||
"Your account has been locked.",
|
||||
|
@ -185,7 +185,7 @@ def test_account_locking_future_date(user, backend):
|
|||
) + datetime.timedelta(days=365 * 4)
|
||||
user.save()
|
||||
assert not user.locked
|
||||
assert not models.User.get(id=user.id).locked
|
||||
assert not backend.get(models.User, id=user.id).locked
|
||||
assert backend.check_user_password(user, "correct horse battery staple") == (
|
||||
True,
|
||||
None,
|
||||
|
|
|
@ -131,12 +131,12 @@ def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group):
|
|||
|
||||
|
||||
def test_moderator_can_create_edit_and_delete_group(
|
||||
testclient, logged_moderator, foo_group
|
||||
testclient, logged_moderator, foo_group, backend
|
||||
):
|
||||
# The group does not exist
|
||||
res = testclient.get("/groups", status=200)
|
||||
assert models.Group.get(display_name="bar") is None
|
||||
assert models.Group.get(display_name="foo") == foo_group
|
||||
assert backend.get(models.Group, display_name="bar") is None
|
||||
assert backend.get(models.Group, display_name="foo") == foo_group
|
||||
res.mustcontain(no="bar")
|
||||
res.mustcontain("foo")
|
||||
|
||||
|
@ -150,7 +150,7 @@ def test_moderator_can_create_edit_and_delete_group(
|
|||
res = form.submit(status=302).follow(status=200)
|
||||
|
||||
logged_moderator.reload()
|
||||
bar_group = models.Group.get(display_name="bar")
|
||||
bar_group = backend.get(models.Group, display_name="bar")
|
||||
assert bar_group.display_name == "bar"
|
||||
assert bar_group.description == "yolo"
|
||||
assert bar_group.members == [
|
||||
|
@ -168,10 +168,10 @@ def test_moderator_can_create_edit_and_delete_group(
|
|||
assert res.flashes == [("error", "Group edition failed.")]
|
||||
res.mustcontain("This field cannot be edited")
|
||||
|
||||
bar_group = models.Group.get(display_name="bar")
|
||||
bar_group = backend.get(models.Group, display_name="bar")
|
||||
assert bar_group.display_name == "bar"
|
||||
assert bar_group.description == "yolo"
|
||||
assert models.Group.get(display_name="bar2") is None
|
||||
assert backend.get(models.Group, display_name="bar2") is None
|
||||
|
||||
# Group description can be edited
|
||||
res = testclient.get("/groups/bar", status=200)
|
||||
|
@ -182,14 +182,14 @@ def test_moderator_can_create_edit_and_delete_group(
|
|||
assert res.flashes == [("success", "The group bar has been sucessfully edited.")]
|
||||
res = res.follow()
|
||||
|
||||
bar_group = models.Group.get(display_name="bar")
|
||||
bar_group = backend.get(models.Group, display_name="bar")
|
||||
assert bar_group.display_name == "bar"
|
||||
assert bar_group.description == "yolo2"
|
||||
|
||||
# Group is deleted
|
||||
res = res.forms["editgroupform"].submit(name="action", value="confirm-delete")
|
||||
res = res.form.submit(name="action", value="delete", status=302)
|
||||
assert models.Group.get(display_name="bar") is None
|
||||
assert backend.get(models.Group, display_name="bar") is None
|
||||
assert ("success", "The group bar has been sucessfully deleted") in res.flashes
|
||||
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from canaille.core.endpoints.account import RegistrationPayload
|
|||
|
||||
|
||||
def test_invitation(testclient, logged_admin, foo_group, smtpd, backend):
|
||||
assert models.User.get(user_name="someone") is None
|
||||
assert backend.get(models.User, user_name="someone") is None
|
||||
|
||||
res = testclient.get("/invite", status=200)
|
||||
|
||||
|
@ -46,7 +46,7 @@ def test_invitation(testclient, logged_admin, foo_group, smtpd, backend):
|
|||
assert ("success", "Your account has been created successfully.") in res.flashes
|
||||
res = res.follow(status=200)
|
||||
|
||||
user = models.User.get(user_name="someone")
|
||||
user = backend.get(models.User, user_name="someone")
|
||||
foo_group.reload()
|
||||
assert backend.check_user_password(user, "whatever")[0]
|
||||
assert user.groups == [foo_group]
|
||||
|
@ -62,8 +62,8 @@ def test_invitation(testclient, logged_admin, foo_group, smtpd, backend):
|
|||
def test_invitation_editable_user_name(
|
||||
testclient, logged_admin, foo_group, smtpd, backend
|
||||
):
|
||||
assert models.User.get(user_name="jackyjack") is None
|
||||
assert models.User.get(user_name="djorje") is None
|
||||
assert backend.get(models.User, user_name="jackyjack") is None
|
||||
assert backend.get(models.User, user_name="djorje") is None
|
||||
|
||||
res = testclient.get("/invite", status=200)
|
||||
|
||||
|
@ -102,7 +102,7 @@ def test_invitation_editable_user_name(
|
|||
assert ("success", "Your account has been created successfully.") in res.flashes
|
||||
res = res.follow(status=200)
|
||||
|
||||
user = models.User.get(user_name="djorje")
|
||||
user = backend.get(models.User, user_name="djorje")
|
||||
foo_group.reload()
|
||||
assert backend.check_user_password(user, "whatever")[0]
|
||||
assert user.groups == [foo_group]
|
||||
|
@ -114,7 +114,7 @@ def test_invitation_editable_user_name(
|
|||
|
||||
|
||||
def test_generate_link(testclient, logged_admin, foo_group, smtpd, backend):
|
||||
assert models.User.get(user_name="sometwo") is None
|
||||
assert backend.get(models.User, user_name="sometwo") is None
|
||||
|
||||
res = testclient.get("/invite", status=200)
|
||||
|
||||
|
@ -149,7 +149,7 @@ def test_generate_link(testclient, logged_admin, foo_group, smtpd, backend):
|
|||
res = res.form.submit(status=302)
|
||||
res = res.follow(status=200)
|
||||
|
||||
user = models.User.get(user_name="sometwo")
|
||||
user = backend.get(models.User, user_name="sometwo")
|
||||
foo_group.reload()
|
||||
assert backend.check_user_password(user, "whatever")[0]
|
||||
assert user.groups == [foo_group]
|
||||
|
@ -245,7 +245,7 @@ def test_registration_more_than_48_hours_after_invitation(testclient, foo_group)
|
|||
testclient.get(f"/register/{b64}/{hash}", status=302)
|
||||
|
||||
|
||||
def test_registration_no_password(testclient, foo_group):
|
||||
def test_registration_no_password(testclient, foo_group, backend):
|
||||
payload = RegistrationPayload(
|
||||
datetime.datetime.now(datetime.timezone.utc).isoformat(),
|
||||
"someoneelse",
|
||||
|
@ -264,7 +264,7 @@ def test_registration_no_password(testclient, foo_group):
|
|||
res = res.form.submit(status=200)
|
||||
res.mustcontain("This field is required.")
|
||||
|
||||
assert not models.User.get(user_name="someoneelse")
|
||||
assert not backend.get(models.User, user_name="someoneelse")
|
||||
|
||||
with testclient.session_transaction() as sess:
|
||||
assert "user_id" not in sess
|
||||
|
@ -302,7 +302,7 @@ def test_unavailable_if_no_smtp(testclient, logged_admin):
|
|||
|
||||
|
||||
def test_groups_are_saved_even_when_user_does_not_have_read_permission(
|
||||
testclient, foo_group
|
||||
testclient, foo_group, backend
|
||||
):
|
||||
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["READ"] = [
|
||||
"user_name"
|
||||
|
@ -331,7 +331,7 @@ def test_groups_are_saved_even_when_user_does_not_have_read_permission(
|
|||
res = res.form.submit(status=302)
|
||||
res = res.follow(status=200)
|
||||
|
||||
user = models.User.get(user_name="someoneelse")
|
||||
user = backend.get(models.User, user_name="someoneelse")
|
||||
foo_group.reload()
|
||||
assert user.groups == [foo_group]
|
||||
user.delete()
|
||||
|
|
|
@ -6,7 +6,7 @@ def test_user_creation_edition_and_deletion(
|
|||
):
|
||||
# The user does not exist.
|
||||
res = testclient.get("/users", status=200)
|
||||
assert models.User.get(user_name="george") is None
|
||||
assert backend.get(models.User, user_name="george") is None
|
||||
res.mustcontain(no="george")
|
||||
|
||||
# Fill the profile for a new user.
|
||||
|
@ -24,7 +24,7 @@ def test_user_creation_edition_and_deletion(
|
|||
res = res.form.submit(name="action", value="create-profile", status=302)
|
||||
assert ("success", "User account creation succeed.") in res.flashes
|
||||
res = res.follow(status=200)
|
||||
george = models.User.get(user_name="george")
|
||||
george = backend.get(models.User, user_name="george")
|
||||
foo_group.reload()
|
||||
assert "George" == george.given_name
|
||||
assert george.groups == [foo_group]
|
||||
|
@ -45,7 +45,7 @@ def test_user_creation_edition_and_deletion(
|
|||
res.form["groups"] = [foo_group.id, bar_group.id]
|
||||
res = res.form.submit(name="action", value="edit-settings").follow()
|
||||
|
||||
george = models.User.get(user_name="george")
|
||||
george = backend.get(models.User, user_name="george")
|
||||
assert "Georgio" == george.given_name
|
||||
assert backend.check_user_password(george, "totoyolo")[0]
|
||||
|
||||
|
@ -62,7 +62,7 @@ def test_user_creation_edition_and_deletion(
|
|||
res = res.form.submit(name="action", value="confirm-delete", status=200)
|
||||
res = res.form.submit(name="action", value="delete", status=302)
|
||||
res = res.follow(status=200)
|
||||
assert models.User.get(user_name="george") is None
|
||||
assert backend.get(models.User, user_name="george") is None
|
||||
res.mustcontain(no="george")
|
||||
|
||||
|
||||
|
@ -82,7 +82,7 @@ def test_profile_creation_dynamic_validation(testclient, logged_admin, user):
|
|||
res.mustcontain("The email 'john@doe.com' is already used")
|
||||
|
||||
|
||||
def test_user_creation_without_password(testclient, logged_moderator):
|
||||
def test_user_creation_without_password(testclient, logged_moderator, backend):
|
||||
res = testclient.get("/profile", status=200)
|
||||
res.form["user_name"] = "george"
|
||||
res.form["family_name"] = "Abitbol"
|
||||
|
@ -91,7 +91,7 @@ def test_user_creation_without_password(testclient, logged_moderator):
|
|||
res = res.form.submit(name="action", value="create-profile", status=302)
|
||||
assert ("success", "User account creation succeed.") in res.flashes
|
||||
res = res.follow(status=200)
|
||||
george = models.User.get(user_name="george")
|
||||
george = backend.get(models.User, user_name="george")
|
||||
assert george.user_name == "george"
|
||||
assert not george.has_password()
|
||||
|
||||
|
@ -99,16 +99,16 @@ def test_user_creation_without_password(testclient, logged_moderator):
|
|||
|
||||
|
||||
def test_user_creation_form_validation_failed(
|
||||
testclient, logged_moderator, foo_group, bar_group
|
||||
testclient, logged_moderator, foo_group, bar_group, backend
|
||||
):
|
||||
res = testclient.get("/users", status=200)
|
||||
assert models.User.get(user_name="george") is None
|
||||
assert backend.get(models.User, user_name="george") is None
|
||||
res.mustcontain(no="george")
|
||||
|
||||
res = testclient.get("/profile", status=200)
|
||||
res = res.form.submit(name="action", value="create-profile")
|
||||
assert ("error", "User account creation failed.") in res.flashes
|
||||
assert models.User.get(user_name="george") is None
|
||||
assert backend.get(models.User, user_name="george") is None
|
||||
|
||||
|
||||
def test_username_already_taken(
|
||||
|
@ -133,7 +133,7 @@ def test_email_already_taken(testclient, logged_moderator, user, foo_group, bar_
|
|||
res.mustcontain("The email 'john@doe.com' is already used")
|
||||
|
||||
|
||||
def test_cn_setting_with_given_name_and_surname(testclient, logged_moderator):
|
||||
def test_cn_setting_with_given_name_and_surname(testclient, logged_moderator, backend):
|
||||
res = testclient.get("/profile", status=200)
|
||||
res.form["user_name"] = "george"
|
||||
res.form["given_name"] = "George"
|
||||
|
@ -144,12 +144,12 @@ def test_cn_setting_with_given_name_and_surname(testclient, logged_moderator):
|
|||
status=200
|
||||
)
|
||||
|
||||
george = models.User.get(user_name="george")
|
||||
george = backend.get(models.User, user_name="george")
|
||||
assert george.formatted_name == "George Abitbol"
|
||||
george.delete()
|
||||
|
||||
|
||||
def test_cn_setting_with_surname_only(testclient, logged_moderator):
|
||||
def test_cn_setting_with_surname_only(testclient, logged_moderator, backend):
|
||||
res = testclient.get("/profile", status=200)
|
||||
res.form["user_name"] = "george"
|
||||
res.form["family_name"] = "Abitbol"
|
||||
|
@ -159,7 +159,7 @@ def test_cn_setting_with_surname_only(testclient, logged_moderator):
|
|||
status=200
|
||||
)
|
||||
|
||||
george = models.User.get(user_name="george")
|
||||
george = backend.get(models.User, user_name="george")
|
||||
assert george.formatted_name == "Abitbol"
|
||||
george.delete()
|
||||
|
||||
|
|
|
@ -104,9 +104,9 @@ def test_photo_on_profile_edition(
|
|||
assert logged_user.photo is None
|
||||
|
||||
|
||||
def test_photo_on_profile_creation(testclient, jpeg_photo, logged_admin):
|
||||
def test_photo_on_profile_creation(testclient, jpeg_photo, logged_admin, backend):
|
||||
res = testclient.get("/users", status=200)
|
||||
assert models.User.get(user_name="foobar") is None
|
||||
assert backend.get(models.User, user_name="foobar") is None
|
||||
res.mustcontain(no="foobar")
|
||||
|
||||
res = testclient.get("/profile", status=200)
|
||||
|
@ -119,14 +119,16 @@ def test_photo_on_profile_creation(testclient, jpeg_photo, logged_admin):
|
|||
status=200
|
||||
)
|
||||
|
||||
user = models.User.get(user_name="foobar")
|
||||
user = backend.get(models.User, user_name="foobar")
|
||||
assert user.photo == jpeg_photo
|
||||
user.delete()
|
||||
|
||||
|
||||
def test_photo_deleted_on_profile_creation(testclient, jpeg_photo, logged_admin):
|
||||
def test_photo_deleted_on_profile_creation(
|
||||
testclient, jpeg_photo, logged_admin, backend
|
||||
):
|
||||
res = testclient.get("/users", status=200)
|
||||
assert models.User.get(user_name="foobar") is None
|
||||
assert backend.get(models.User, user_name="foobar") is None
|
||||
res.mustcontain(no="foobar")
|
||||
|
||||
res = testclient.get("/profile", status=200)
|
||||
|
@ -140,6 +142,6 @@ def test_photo_deleted_on_profile_creation(testclient, jpeg_photo, logged_admin)
|
|||
status=200
|
||||
)
|
||||
|
||||
user = models.User.get(user_name="foobar")
|
||||
user = backend.get(models.User, user_name="foobar")
|
||||
assert user.photo is None
|
||||
user.delete()
|
||||
|
|
|
@ -381,7 +381,7 @@ def test_account_locking(
|
|||
|
||||
res = res.form.submit(name="action", value="confirm-lock")
|
||||
res = res.form.submit(name="action", value="lock")
|
||||
user = models.User.get(id=user.id)
|
||||
user = backend.get(models.User, id=user.id)
|
||||
assert user.lock_date <= datetime.datetime.now(datetime.timezone.utc)
|
||||
assert user.locked
|
||||
res.mustcontain("The account has been locked")
|
||||
|
@ -389,7 +389,7 @@ def test_account_locking(
|
|||
res.mustcontain("Unlock")
|
||||
|
||||
res = res.form.submit(name="action", value="unlock")
|
||||
user = models.User.get(id=user.id)
|
||||
user = backend.get(models.User, id=user.id)
|
||||
assert not user.lock_date
|
||||
assert not user.locked
|
||||
res.mustcontain("The account has been unlocked")
|
||||
|
@ -415,7 +415,7 @@ def test_past_lock_date(
|
|||
assert res.flashes == [("success", "Profile updated successfully.")]
|
||||
|
||||
res = res.follow()
|
||||
user = models.User.get(id=user.id)
|
||||
user = backend.get(models.User, id=user.id)
|
||||
assert user.lock_date == expiration_datetime
|
||||
assert user.locked
|
||||
|
||||
|
@ -438,7 +438,7 @@ def test_future_lock_date(
|
|||
assert res.flashes == [("success", "Profile updated successfully.")]
|
||||
|
||||
res = res.follow()
|
||||
user = models.User.get(id=user.id)
|
||||
user = backend.get(models.User, id=user.id)
|
||||
assert user.lock_date == expiration_datetime
|
||||
assert not user.locked
|
||||
assert res.form["lock_date"].value == expiration_datetime.strftime("%Y-%m-%d %H:%M")
|
||||
|
@ -484,7 +484,7 @@ def test_account_limit_values(
|
|||
assert res.flashes == [("success", "Profile updated successfully.")]
|
||||
|
||||
res = res.follow()
|
||||
user = models.User.get(id=user.id)
|
||||
user = backend.get(models.User, id=user.id)
|
||||
assert user.lock_date == expiration_datetime
|
||||
assert not user.locked
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ def test_registration_without_email_validation(testclient, backend, foo_group):
|
|||
res = res.form.submit()
|
||||
assert ("success", "Your account has been created successfully.") in res.flashes
|
||||
|
||||
user = models.User.get(user_name="newuser")
|
||||
user = backend.get(models.User, user_name="newuser")
|
||||
assert user
|
||||
user.delete()
|
||||
|
||||
|
@ -73,7 +73,7 @@ def test_registration_with_email_validation(testclient, backend, smtpd, foo_grou
|
|||
("success", "Your account has been created successfully."),
|
||||
]
|
||||
|
||||
user = models.User.get(user_name="newuser")
|
||||
user = backend.get(models.User, user_name="newuser")
|
||||
assert user
|
||||
user.delete()
|
||||
|
||||
|
|
|
@ -69,8 +69,8 @@ def test_clean_command(testclient, backend, client, user):
|
|||
)
|
||||
expired_token.save()
|
||||
|
||||
assert models.AuthorizationCode.get(code="my-expired-code")
|
||||
assert models.Token.get(access_token="my-expired-token")
|
||||
assert backend.get(models.AuthorizationCode, code="my-expired-code")
|
||||
assert backend.get(models.Token, access_token="my-expired-token")
|
||||
assert expired_code.is_expired()
|
||||
assert expired_token.is_expired()
|
||||
|
||||
|
@ -78,5 +78,5 @@ def test_clean_command(testclient, backend, client, user):
|
|||
res = runner.invoke(cli, ["clean"])
|
||||
assert res.exit_code == 0, res.stdout
|
||||
|
||||
assert models.AuthorizationCode.get() == valid_code
|
||||
assert models.Token.get() == valid_token
|
||||
assert backend.get(models.AuthorizationCode) == valid_code
|
||||
assert backend.get(models.Token) == valid_token
|
||||
|
|
|
@ -34,7 +34,7 @@ def test_nominal_case(
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
authcode = backend.get(models.AuthorizationCode, code=code)
|
||||
assert authcode is not None
|
||||
assert set(authcode.scope) == {
|
||||
"openid",
|
||||
|
@ -68,7 +68,7 @@ def test_nominal_case(
|
|||
)
|
||||
|
||||
access_token = res.json["access_token"]
|
||||
token = models.Token.get(access_token=access_token)
|
||||
token = backend.get(models.Token, access_token=access_token)
|
||||
assert token.client == client
|
||||
assert token.subject == logged_user
|
||||
assert set(token.scope) == {
|
||||
|
@ -136,7 +136,7 @@ def test_redirect_uri(
|
|||
assert res.location.startswith(client.redirect_uris[1])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
authcode = backend.get(models.AuthorizationCode, code=code)
|
||||
assert authcode is not None
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
|
||||
|
@ -153,7 +153,7 @@ def test_redirect_uri(
|
|||
)
|
||||
|
||||
access_token = res.json["access_token"]
|
||||
token = models.Token.get(access_token=access_token)
|
||||
token = backend.get(models.Token, access_token=access_token)
|
||||
assert token.client == client
|
||||
assert token.subject == logged_user
|
||||
|
||||
|
@ -183,7 +183,7 @@ def test_preconsented_client(
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
authcode = backend.get(models.AuthorizationCode, code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
|
@ -202,7 +202,7 @@ def test_preconsented_client(
|
|||
)
|
||||
|
||||
access_token = res.json["access_token"]
|
||||
token = models.Token.get(access_token=access_token)
|
||||
token = backend.get(models.Token, access_token=access_token)
|
||||
assert token.client == client
|
||||
assert token.subject == logged_user
|
||||
|
||||
|
@ -257,7 +257,7 @@ def test_logout_login(testclient, logged_user, client, backend):
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
authcode = backend.get(models.AuthorizationCode, code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
|
@ -276,7 +276,7 @@ def test_logout_login(testclient, logged_user, client, backend):
|
|||
)
|
||||
|
||||
access_token = res.json["access_token"]
|
||||
token = models.Token.get(access_token=access_token)
|
||||
token = backend.get(models.Token, access_token=access_token)
|
||||
assert token.client == client
|
||||
assert token.subject == logged_user
|
||||
|
||||
|
@ -341,7 +341,7 @@ def test_code_challenge(testclient, logged_user, client, backend):
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
authcode = backend.get(models.AuthorizationCode, code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
|
@ -361,7 +361,7 @@ def test_code_challenge(testclient, logged_user, client, backend):
|
|||
)
|
||||
access_token = res.json["access_token"]
|
||||
|
||||
token = models.Token.get(access_token=access_token)
|
||||
token = backend.get(models.Token, access_token=access_token)
|
||||
assert token.client == client
|
||||
assert token.subject == logged_user
|
||||
|
||||
|
@ -398,7 +398,7 @@ def test_consent_already_given(testclient, logged_user, client, backend):
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
authcode = backend.get(models.AuthorizationCode, code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
|
@ -456,7 +456,7 @@ def test_when_consent_already_given_but_for_a_smaller_scope(
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
authcode = backend.get(models.AuthorizationCode, code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
|
@ -492,7 +492,7 @@ def test_when_consent_already_given_but_for_a_smaller_scope(
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
authcode = backend.get(models.AuthorizationCode, code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
|
@ -582,7 +582,7 @@ def test_request_scope_too_large(testclient, logged_user, keypair, client, backe
|
|||
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
authcode = backend.get(models.AuthorizationCode, code=code)
|
||||
assert set(authcode.scope) == {
|
||||
"openid",
|
||||
"profile",
|
||||
|
@ -607,7 +607,7 @@ def test_request_scope_too_large(testclient, logged_user, keypair, client, backe
|
|||
)
|
||||
|
||||
access_token = res.json["access_token"]
|
||||
token = models.Token.get(access_token=access_token)
|
||||
token = backend.get(models.Token, access_token=access_token)
|
||||
assert token.client == client
|
||||
assert token.subject == logged_user
|
||||
assert set(token.scope) == {
|
||||
|
@ -671,7 +671,7 @@ def test_code_expired(testclient, user, client):
|
|||
}
|
||||
|
||||
|
||||
def test_code_with_invalid_user(testclient, admin, client):
|
||||
def test_code_with_invalid_user(testclient, admin, client, backend):
|
||||
user = models.User(
|
||||
formatted_name="John Doe",
|
||||
family_name="Doe",
|
||||
|
@ -699,7 +699,7 @@ def test_code_with_invalid_user(testclient, admin, client):
|
|||
res = res.form.submit(name="answer", value="accept", status=302)
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
authcode = backend.get(models.AuthorizationCode, code=code)
|
||||
|
||||
user.delete()
|
||||
|
||||
|
@ -721,7 +721,9 @@ def test_code_with_invalid_user(testclient, admin, client):
|
|||
authcode.delete()
|
||||
|
||||
|
||||
def test_locked_account(testclient, logged_user, client, keypair, trusted_client):
|
||||
def test_locked_account(
|
||||
testclient, logged_user, client, keypair, trusted_client, backend
|
||||
):
|
||||
"""Users with a locked account should not be able to exchange code against
|
||||
tokens."""
|
||||
res = testclient.get(
|
||||
|
@ -743,7 +745,7 @@ def test_locked_account(testclient, logged_user, client, keypair, trusted_client
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
authcode = backend.get(models.AuthorizationCode, code=code)
|
||||
assert authcode is not None
|
||||
|
||||
res = testclient.post(
|
||||
|
|
|
@ -83,7 +83,7 @@ def test_client_list_search(testclient, logged_admin, client, trusted_client):
|
|||
res.mustcontain(no=client.client_name)
|
||||
|
||||
|
||||
def test_client_add(testclient, logged_admin):
|
||||
def test_client_add(testclient, logged_admin, backend):
|
||||
res = testclient.get("/admin/client/add")
|
||||
data = {
|
||||
"client_name": "foobar",
|
||||
|
@ -112,7 +112,7 @@ def test_client_add(testclient, logged_admin):
|
|||
res = res.follow(status=200)
|
||||
|
||||
client_id = res.forms["readonly"]["client_id"].value
|
||||
client = models.Client.get(client_id=client_id)
|
||||
client = backend.get(models.Client, client_id=client_id)
|
||||
|
||||
assert client.client_name == "foobar"
|
||||
assert client.contacts == ["foo@bar.com"]
|
||||
|
@ -214,7 +214,7 @@ def test_client_edit_missing_fields(testclient, client, logged_admin, trusted_cl
|
|||
assert client.client_name
|
||||
|
||||
|
||||
def test_client_delete(testclient, logged_admin):
|
||||
def test_client_delete(testclient, logged_admin, backend):
|
||||
client = models.Client(client_id="client_id")
|
||||
client.save()
|
||||
token = models.Token(
|
||||
|
@ -238,10 +238,10 @@ def test_client_delete(testclient, logged_admin):
|
|||
res = res.form.submit(name="action", value="delete")
|
||||
res = res.follow()
|
||||
|
||||
assert not models.Client.get()
|
||||
assert not models.Token.get()
|
||||
assert not models.AuthorizationCode.get()
|
||||
assert not models.Consent.get()
|
||||
assert not backend.get(models.Client)
|
||||
assert not backend.get(models.Token)
|
||||
assert not backend.get(models.AuthorizationCode)
|
||||
assert not backend.get(models.Consent)
|
||||
|
||||
|
||||
def test_client_delete_invalid_client(testclient, logged_admin, client):
|
||||
|
|
|
@ -134,7 +134,7 @@ def test_oidc_authorization_after_revokation(
|
|||
)
|
||||
|
||||
access_token = res.json["access_token"]
|
||||
token = models.Token.get(access_token=access_token)
|
||||
token = backend.get(models.Token, access_token=access_token)
|
||||
assert token.client == client
|
||||
assert token.subject == logged_user
|
||||
|
||||
|
@ -151,16 +151,16 @@ def test_preconsented_client_appears_in_consent_list(testclient, client, logged_
|
|||
res.mustcontain(client.client_name)
|
||||
|
||||
|
||||
def test_revoke_preconsented_client(testclient, client, logged_user, token):
|
||||
def test_revoke_preconsented_client(testclient, client, logged_user, token, backend):
|
||||
client.preconsent = True
|
||||
client.save()
|
||||
assert not models.Consent.get()
|
||||
assert not backend.get(models.Consent)
|
||||
assert not token.revoked
|
||||
|
||||
res = testclient.get(f"/consent/revoke-preconsent/{client.client_id}", status=302)
|
||||
assert ("success", "The access has been revoked") in res.flashes
|
||||
|
||||
consent = models.Consent.get()
|
||||
consent = backend.get(models.Consent)
|
||||
assert consent.client == client
|
||||
assert consent.subject == logged_user
|
||||
assert consent.scope == ["openid", "email", "profile", "groups", "address", "phone"]
|
||||
|
|
|
@ -33,7 +33,7 @@ def test_client_registration_with_authentication_static_token(
|
|||
headers = {"Authorization": "Bearer static-token"}
|
||||
|
||||
res = testclient.post_json("/oauth/register", payload, headers=headers, status=201)
|
||||
client = models.Client.get(client_id=res.json["client_id"])
|
||||
client = backend.get(models.Client, client_id=res.json["client_id"])
|
||||
|
||||
assert res.json == {
|
||||
"client_id": client.client_id,
|
||||
|
@ -154,7 +154,7 @@ def test_client_registration_with_software_statement(testclient, backend, keypai
|
|||
}
|
||||
res = testclient.post_json("/oauth/register", payload, status=201)
|
||||
|
||||
client = models.Client.get(client_id=res.json["client_id"])
|
||||
client = backend.get(models.Client, client_id=res.json["client_id"])
|
||||
assert res.json == {
|
||||
"client_id": client.client_id,
|
||||
"client_secret": client.client_secret,
|
||||
|
@ -205,7 +205,7 @@ def test_client_registration_without_authentication_ok(testclient, backend):
|
|||
|
||||
res = testclient.post_json("/oauth/register", payload, status=201)
|
||||
|
||||
client = models.Client.get(client_id=res.json["client_id"])
|
||||
client = backend.get(models.Client, client_id=res.json["client_id"])
|
||||
assert res.json == {
|
||||
"client_id": mock.ANY,
|
||||
"client_secret": mock.ANY,
|
||||
|
|
|
@ -95,7 +95,7 @@ def test_update(testclient, backend, client, user):
|
|||
res = testclient.put_json(
|
||||
f"/oauth/register/{client.client_id}", payload, headers=headers, status=200
|
||||
)
|
||||
client = models.Client.get(client_id=res.json["client_id"])
|
||||
client = backend.get(models.Client, client_id=res.json["client_id"])
|
||||
|
||||
assert res.json == {
|
||||
"client_id": client.client_id,
|
||||
|
@ -153,7 +153,7 @@ def test_delete(testclient, backend, user):
|
|||
testclient.delete(
|
||||
f"/oauth/register/{client.client_id}", headers=headers, status=204
|
||||
)
|
||||
assert not models.Client.get(client_id=client.client_id)
|
||||
assert not backend.get(models.Client, client_id=client.client_id)
|
||||
|
||||
|
||||
def test_invalid_client(testclient, backend, user):
|
||||
|
|
|
@ -33,7 +33,7 @@ def test_fieldlist_add(testclient, logged_admin, backend):
|
|||
res = res.follow(status=200)
|
||||
|
||||
client_id = res.forms["readonly"]["client_id"].value
|
||||
client = models.Client.get(client_id=client_id)
|
||||
client = backend.get(models.Client, client_id=client_id)
|
||||
|
||||
assert client.redirect_uris == [
|
||||
"https://foo.bar/callback",
|
||||
|
@ -68,7 +68,7 @@ def test_fieldlist_delete(testclient, logged_admin, backend):
|
|||
res = res.follow(status=200)
|
||||
|
||||
client_id = res.forms["readonly"]["client_id"].value
|
||||
client = models.Client.get(client_id=client_id)
|
||||
client = backend.get(models.Client, client_id=client_id)
|
||||
|
||||
assert client.redirect_uris == [
|
||||
"https://foo.bar/callback1",
|
||||
|
@ -128,7 +128,7 @@ def test_fieldlist_duplicate_value(testclient, logged_admin, client):
|
|||
res.mustcontain("This value is a duplicate")
|
||||
|
||||
|
||||
def test_fieldlist_empty_value(testclient, logged_admin):
|
||||
def test_fieldlist_empty_value(testclient, logged_admin, backend):
|
||||
res = testclient.get("/admin/client/add")
|
||||
data = {
|
||||
"client_name": "foobar",
|
||||
|
@ -145,7 +145,7 @@ def test_fieldlist_empty_value(testclient, logged_admin):
|
|||
status=200, name="fieldlist_add", value="post_logout_redirect_uris-0"
|
||||
)
|
||||
res.form.submit(status=302, name="action", value="edit")
|
||||
client = models.Client.get()
|
||||
client = backend.get(models.Client)
|
||||
client.delete()
|
||||
|
||||
|
||||
|
|
|
@ -32,11 +32,11 @@ def test_oauth_hybrid(testclient, backend, user, client):
|
|||
params = parse_qs(urlsplit(res.location).fragment)
|
||||
|
||||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
authcode = backend.get(models.AuthorizationCode, code=code)
|
||||
assert authcode is not None
|
||||
|
||||
access_token = params["access_token"][0]
|
||||
token = models.Token.get(access_token=access_token)
|
||||
token = backend.get(models.Token, access_token=access_token)
|
||||
assert token is not None
|
||||
|
||||
res = testclient.get(
|
||||
|
@ -65,11 +65,11 @@ def test_oidc_hybrid(testclient, backend, logged_user, client, keypair, trusted_
|
|||
params = parse_qs(urlsplit(res.location).fragment)
|
||||
|
||||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
authcode = backend.get(models.AuthorizationCode, code=code)
|
||||
assert authcode is not None
|
||||
|
||||
access_token = params["access_token"][0]
|
||||
token = models.Token.get(access_token=access_token)
|
||||
token = backend.get(models.Token, access_token=access_token)
|
||||
assert token is not None
|
||||
|
||||
id_token = params["id_token"][0]
|
||||
|
|
|
@ -6,7 +6,7 @@ from authlib.jose import jwt
|
|||
from canaille.app import models
|
||||
|
||||
|
||||
def test_oauth_implicit(testclient, user, client):
|
||||
def test_oauth_implicit(testclient, user, client, backend):
|
||||
client.grant_types = ["token"]
|
||||
client.token_endpoint_auth_method = "none"
|
||||
|
||||
|
@ -37,7 +37,7 @@ def test_oauth_implicit(testclient, user, client):
|
|||
params = parse_qs(urlsplit(res.location).fragment)
|
||||
|
||||
access_token = params["access_token"][0]
|
||||
token = models.Token.get(access_token=access_token)
|
||||
token = backend.get(models.Token, access_token=access_token)
|
||||
assert token is not None
|
||||
|
||||
res = testclient.get(
|
||||
|
@ -51,7 +51,7 @@ def test_oauth_implicit(testclient, user, client):
|
|||
client.save()
|
||||
|
||||
|
||||
def test_oidc_implicit(testclient, keypair, user, client, trusted_client):
|
||||
def test_oidc_implicit(testclient, keypair, user, client, trusted_client, backend):
|
||||
client.grant_types = ["token id_token"]
|
||||
client.token_endpoint_auth_method = "none"
|
||||
|
||||
|
@ -82,7 +82,7 @@ def test_oidc_implicit(testclient, keypair, user, client, trusted_client):
|
|||
params = parse_qs(urlsplit(res.location).fragment)
|
||||
|
||||
access_token = params["access_token"][0]
|
||||
token = models.Token.get(access_token=access_token)
|
||||
token = backend.get(models.Token, access_token=access_token)
|
||||
assert token is not None
|
||||
|
||||
id_token = params["id_token"][0]
|
||||
|
@ -105,7 +105,7 @@ def test_oidc_implicit(testclient, keypair, user, client, trusted_client):
|
|||
|
||||
|
||||
def test_oidc_implicit_with_group(
|
||||
testclient, keypair, user, client, foo_group, trusted_client
|
||||
testclient, keypair, user, client, foo_group, trusted_client, backend
|
||||
):
|
||||
client.grant_types = ["token id_token"]
|
||||
client.token_endpoint_auth_method = "none"
|
||||
|
@ -137,7 +137,7 @@ def test_oidc_implicit_with_group(
|
|||
params = parse_qs(urlsplit(res.location).fragment)
|
||||
|
||||
access_token = params["access_token"][0]
|
||||
token = models.Token.get(access_token=access_token)
|
||||
token = backend.get(models.Token, access_token=access_token)
|
||||
assert token is not None
|
||||
|
||||
id_token = params["id_token"][0]
|
||||
|
|
|
@ -3,7 +3,7 @@ from canaille.app import models
|
|||
from . import client_credentials
|
||||
|
||||
|
||||
def test_password_flow_basic(testclient, user, client):
|
||||
def test_password_flow_basic(testclient, user, client, backend):
|
||||
res = testclient.post(
|
||||
"/oauth/token",
|
||||
params=dict(
|
||||
|
@ -20,7 +20,7 @@ def test_password_flow_basic(testclient, user, client):
|
|||
assert res.json["token_type"] == "Bearer"
|
||||
access_token = res.json["access_token"]
|
||||
|
||||
token = models.Token.get(access_token=access_token)
|
||||
token = backend.get(models.Token, access_token=access_token)
|
||||
assert token is not None
|
||||
|
||||
res = testclient.get(
|
||||
|
@ -31,7 +31,7 @@ def test_password_flow_basic(testclient, user, client):
|
|||
assert res.json["name"] == "John (johnny) Doe"
|
||||
|
||||
|
||||
def test_password_flow_post(testclient, user, client):
|
||||
def test_password_flow_post(testclient, user, client, backend):
|
||||
client.token_endpoint_auth_method = "client_secret_post"
|
||||
client.save()
|
||||
|
||||
|
@ -52,7 +52,7 @@ def test_password_flow_post(testclient, user, client):
|
|||
assert res.json["token_type"] == "Bearer"
|
||||
access_token = res.json["access_token"]
|
||||
|
||||
token = models.Token.get(access_token=access_token)
|
||||
token = backend.get(models.Token, access_token=access_token)
|
||||
assert token is not None
|
||||
|
||||
res = testclient.get(
|
||||
|
|
|
@ -24,7 +24,7 @@ def test_refresh_token(testclient, logged_user, client, backend):
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
authcode = backend.get(models.AuthorizationCode, code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = backend.query(models.Consent, client=client, subject=logged_user)
|
||||
|
@ -42,7 +42,7 @@ def test_refresh_token(testclient, logged_user, client, backend):
|
|||
status=200,
|
||||
)
|
||||
access_token = res.json["access_token"]
|
||||
old_token = models.Token.get(access_token=access_token)
|
||||
old_token = backend.get(models.Token, access_token=access_token)
|
||||
assert old_token is not None
|
||||
assert not old_token.revokation_date
|
||||
|
||||
|
@ -56,7 +56,7 @@ def test_refresh_token(testclient, logged_user, client, backend):
|
|||
status=200,
|
||||
)
|
||||
access_token = res.json["access_token"]
|
||||
new_token = models.Token.get(access_token=access_token)
|
||||
new_token = backend.get(models.Token, access_token=access_token)
|
||||
assert new_token is not None
|
||||
assert old_token.access_token != new_token.access_token
|
||||
|
||||
|
@ -74,7 +74,7 @@ def test_refresh_token(testclient, logged_user, client, backend):
|
|||
consent.delete()
|
||||
|
||||
|
||||
def test_refresh_token_with_invalid_user(testclient, client):
|
||||
def test_refresh_token_with_invalid_user(testclient, client, backend):
|
||||
user = models.User(
|
||||
formatted_name="John Doe",
|
||||
family_name="Doe",
|
||||
|
@ -103,7 +103,7 @@ def test_refresh_token_with_invalid_user(testclient, client):
|
|||
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
models.AuthorizationCode.get(code=code)
|
||||
backend.get(models.AuthorizationCode, code=code)
|
||||
|
||||
res = testclient.post(
|
||||
"/oauth/token",
|
||||
|
@ -134,7 +134,7 @@ def test_refresh_token_with_invalid_user(testclient, client):
|
|||
"error": "invalid_request",
|
||||
"error_description": 'There is no "user" for this token.',
|
||||
}
|
||||
models.Token.get(access_token=access_token).delete()
|
||||
backend.get(models.Token, access_token=access_token).delete()
|
||||
|
||||
|
||||
def test_cannot_refresh_token_for_locked_users(testclient, logged_user, client):
|
||||
|
|
|
@ -26,7 +26,7 @@ def test_token_default_expiration_date(
|
|||
res = res.form.submit(name="answer", value="accept", status=302)
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
authcode = backend.get(models.AuthorizationCode, code=code)
|
||||
assert authcode.lifetime == 84400
|
||||
|
||||
res = testclient.post(
|
||||
|
@ -44,7 +44,7 @@ def test_token_default_expiration_date(
|
|||
assert res.json["expires_in"] == 864000
|
||||
|
||||
access_token = res.json["access_token"]
|
||||
token = models.Token.get(access_token=access_token)
|
||||
token = backend.get(models.Token, access_token=access_token)
|
||||
assert token.lifetime == 864000
|
||||
|
||||
claims = jwt.decode(access_token, keypair[1])
|
||||
|
@ -86,7 +86,7 @@ def test_token_custom_expiration_date(
|
|||
res = res.form.submit(name="answer", value="accept", status=302)
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
authcode = backend.get(models.AuthorizationCode, code=code)
|
||||
assert authcode.lifetime == 84400
|
||||
|
||||
res = testclient.post(
|
||||
|
@ -104,7 +104,7 @@ def test_token_custom_expiration_date(
|
|||
assert res.json["expires_in"] == 1000
|
||||
|
||||
access_token = res.json["access_token"]
|
||||
token = models.Token.get(access_token=access_token)
|
||||
token = backend.get(models.Token, access_token=access_token)
|
||||
assert token.lifetime == 1000
|
||||
|
||||
claims = jwt.decode(access_token, keypair[1])
|
||||
|
|
|
@ -58,7 +58,7 @@ def test_token_invalid(testclient, client):
|
|||
assert {"active": False} == res.json
|
||||
|
||||
|
||||
def test_full_flow(testclient, logged_user, client, user, trusted_client):
|
||||
def test_full_flow(testclient, logged_user, client, user, trusted_client, backend):
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
params=dict(
|
||||
|
@ -75,7 +75,7 @@ def test_full_flow(testclient, logged_user, client, user, trusted_client):
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
authcode = backend.get(models.AuthorizationCode, code=code)
|
||||
assert authcode is not None
|
||||
|
||||
res = testclient.post(
|
||||
|
@ -91,7 +91,7 @@ def test_full_flow(testclient, logged_user, client, user, trusted_client):
|
|||
)
|
||||
access_token = res.json["access_token"]
|
||||
|
||||
token = models.Token.get(access_token=access_token)
|
||||
token = backend.get(models.Token, access_token=access_token)
|
||||
assert token.client == client
|
||||
assert token.subject == logged_user
|
||||
|
||||
|
|
Loading…
Reference in a new issue