refactor: move BackendModel.save to Backend.save

This commit is contained in:
Éloi Rivard 2024-04-14 20:31:43 +02:00
parent 44573713ed
commit 09588e0f48
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184
49 changed files with 327 additions and 306 deletions

View file

@ -85,6 +85,10 @@ class BaseBackend:
only one element or :py:data:`None` if no item is matching."""
raise NotImplementedError()
def save(self, instance):
"""Validate the current modifications in the database."""
raise NotImplementedError()
def check_user_password(self, user, password: str) -> bool:
"""Check if the password matches the user password in the database."""
raise NotImplementedError()

View file

@ -9,6 +9,7 @@ from flask import current_app
from ldap.controls import DecodeControlTuples
from ldap.controls.ppolicy import PasswordPolicyControl
from ldap.controls.ppolicy import PasswordPolicyError
from ldap.controls.readentry import PostReadControl
from canaille.app import models
from canaille.app.configuration import ConfigurationException
@ -128,7 +129,7 @@ class Backend(BaseBackend):
emails=f"canaille_{uuid.uuid4()}@mydomain.tld",
password="correct horse battery staple",
)
user.save()
BaseBackend.instance.save(user)
user.delete()
except ldap.INSUFFICIENT_ACCESS as exc:
@ -147,13 +148,13 @@ class Backend(BaseBackend):
emails=f"canaille_{uuid.uuid4()}@mydomain.tld",
password="correct horse battery staple",
)
user.save()
BaseBackend.instance.save(user)
group = models.Group(
display_name=f"canaille_{uuid.uuid4()}",
members=[user],
)
group.save()
BaseBackend.instance.save(group)
group.delete()
except ldap.INSUFFICIENT_ACCESS as exc:
@ -324,6 +325,69 @@ class Backend(BaseBackend):
return None
def save(self, instance):
# run the instance save callback if existing
save_callback = instance.save() if hasattr(instance, "save") else iter([])
next(save_callback, None)
current_object_classes = instance.get_ldap_attribute("objectClass") or []
instance.set_ldap_attribute(
"objectClass",
list(set(instance.ldap_object_class) | set(current_object_classes)),
)
# PostReadControl allows to read the updated object attributes on creation/edition
attributes = ["objectClass"] + [
instance.python_attribute_to_ldap(name) for name in instance.attributes
]
read_post_control = PostReadControl(criticality=True, attrList=attributes)
# Object already exists in the LDAP database
if instance.exists:
deletions = [
name
for name, value in instance.changes.items()
if (
value is None
or value == []
or (isinstance(value, list) and len(value) == 1 and not value[0])
)
and name in instance.state
]
changes = {
name: value
for name, value in instance.changes.items()
if name not in deletions and instance.state.get(name) != value
}
formatted_changes = python_attrs_to_ldap(changes, null_allowed=False)
modlist = [(ldap.MOD_DELETE, name, None) for name in deletions] + [
(ldap.MOD_REPLACE, name, values)
for name, values in formatted_changes.items()
]
_, _, _, [result] = self.connection.modify_ext_s(
instance.dn, modlist, serverctrls=[read_post_control]
)
# Object does not exist yet in the LDAP database
else:
changes = {
name: value
for name, value in {**instance.state, **instance.changes}.items()
if value and value[0]
}
formatted_changes = python_attrs_to_ldap(changes, null_allowed=False)
modlist = [(name, values) for name, values in formatted_changes.items()]
_, _, _, [result] = self.connection.add_ext_s(
instance.dn, modlist, serverctrls=[read_post_control]
)
instance.exists = True
instance.state = {**result.entry, **instance.changes}
instance.changes = {}
# run the instance save callback again if existing
next(save_callback, None)
def setup_ldap_models(config):
from canaille.app import models

View file

@ -2,7 +2,6 @@ import typing
import ldap.dn
import ldap.filter
from ldap.controls.readentry import PostReadControl
from canaille.backends.models import BackendModel
@ -264,64 +263,6 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
self.changes = {}
self.state = result[0][1]
def save(self):
conn = Backend.instance.connection
current_object_classes = self.get_ldap_attribute("objectClass") or []
self.set_ldap_attribute(
"objectClass",
list(set(self.ldap_object_class) | set(current_object_classes)),
)
# PostReadControl allows to read the updated object attributes on creation/edition
attributes = ["objectClass"] + [
self.python_attribute_to_ldap(name) for name in self.attributes
]
read_post_control = PostReadControl(criticality=True, attrList=attributes)
# Object already exists in the LDAP database
if self.exists:
deletions = [
name
for name, value in self.changes.items()
if (
value is None
or value == []
or (isinstance(value, list) and len(value) == 1 and not value[0])
)
and name in self.state
]
changes = {
name: value
for name, value in self.changes.items()
if name not in deletions and self.state.get(name) != value
}
formatted_changes = python_attrs_to_ldap(changes, null_allowed=False)
modlist = [(ldap.MOD_DELETE, name, None) for name in deletions] + [
(ldap.MOD_REPLACE, name, values)
for name, values in formatted_changes.items()
]
_, _, _, [result] = conn.modify_ext_s(
self.dn, modlist, serverctrls=[read_post_control]
)
# Object does not exist yet in the LDAP database
else:
changes = {
name: value
for name, value in {**self.state, **self.changes}.items()
if value and value[0]
}
formatted_changes = python_attrs_to_ldap(changes, null_allowed=False)
modlist = [(name, values) for name, values in formatted_changes.items()]
_, _, _, [result] = conn.add_ext_s(
self.dn, modlist, serverctrls=[read_post_control]
)
self.exists = True
self.state = {**result.entry, **self.changes}
self.changes = {}
def delete(self):
conn = Backend.instance.connection
try:

View file

@ -43,10 +43,10 @@ class User(canaille.core.models.User, LDAPObject):
return super().match_filter(filter)
def save(self, *args, **kwargs):
def save(self):
group_attr = self.python_attribute_to_ldap("groups")
if group_attr not in self.changes:
return super().save(*args, **kwargs)
return
# The LDAP attribute memberOf cannot directly be edited,
# so this is needed to update the Group.member attribute
@ -60,11 +60,11 @@ class User(canaille.core.models.User, LDAPObject):
to_del = set(old_groups) - set(new_groups)
del self.changes[group_attr]
super().save(*args, **kwargs)
yield
for group in to_add:
group.members = group.members + [self]
group.save()
Backend.instance.save(group)
for group in to_del:
# LDAP groups cannot be empty because groupOfNames.member
@ -73,7 +73,7 @@ class User(canaille.core.models.User, LDAPObject):
# TODO: properly manage the situation where one wants to
# remove the last member of a group
group.members = [member for member in group.members if member != self]
group.save()
Backend.instance.save(group)
self.state[group_attr] = new_groups

View file

@ -1,3 +1,6 @@
import datetime
import uuid
from canaille.backends import BaseBackend
@ -39,7 +42,7 @@ class Backend(BaseBackend):
def set_user_password(self, user, password):
user.password = password
user.save()
self.save(user)
def query(self, model, **kwargs):
# if there is no filter, return all models
@ -91,3 +94,17 @@ class Backend(BaseBackend):
results = self.query(model, **kwargs)
return results[0] if results else None
def save(self, instance):
if not instance.id:
instance.id = str(uuid.uuid4())
instance.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0
)
if not instance.created:
instance.created = instance.last_modified
instance.index_delete()
instance.index_save()
instance._cache = {}

View file

@ -1,7 +1,5 @@
import copy
import datetime
import typing
import uuid
import canaille.core.models
import canaille.oidc.models
@ -67,20 +65,6 @@ class MemoryModel(BackendModel):
return value
def save(self):
if not self.id:
self.id = str(uuid.uuid4())
self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0
)
if not self.created:
self.created = self.last_modified
self.index_delete()
self.index_save()
self._cache = {}
def delete(self):
self.index_delete()

View file

@ -88,10 +88,6 @@ class BackendModel:
implemented for every model and for every backend.
"""
def save(self):
"""Validate the current modifications in the database."""
raise NotImplementedError()
def delete(self):
"""Remove the current instance from the database."""
raise NotImplementedError()

View file

@ -1,3 +1,5 @@
import datetime
from sqlalchemy import create_engine
from sqlalchemy import or_
from sqlalchemy import select
@ -66,7 +68,7 @@ class Backend(BaseBackend):
def set_user_password(self, user, password):
user.password = password
user.save()
self.save(user)
def query(self, model, **kwargs):
filter = [
@ -106,3 +108,13 @@ class Backend(BaseBackend):
return Backend.instance.db_session.execute(
select(model).filter(*filter)
).scalar_one_or_none()
def save(self, instance):
instance.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0
)
if not instance.created:
instance.created = instance.last_modified
Backend.instance.db_session.add(instance)
Backend.instance.db_session.commit()

View file

@ -47,16 +47,6 @@ class SqlAlchemyModel(BackendModel):
return getattr(cls, name) == value
def save(self):
self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0
)
if not self.created:
self.created = self.last_modified
Backend.instance.db_session.add(self)
Backend.instance.db_session.commit()
def delete(self):
Backend.instance.db_session.delete(self)
Backend.instance.db_session.commit()

View file

@ -404,7 +404,7 @@ def email_confirmation(data, hash):
return redirect(url_for("core.account.index"))
user.emails = user.emails + [confirmation_obj.email]
user.save()
BaseBackend.instance.save(user)
flash(_("Your email address have been confirmed."), "success")
return redirect(url_for("core.account.index"))
@ -460,11 +460,11 @@ def profile_create(current_app, form):
given_name = user.given_name if user.given_name else ""
family_name = user.family_name if user.family_name else ""
user.formatted_name = f"{given_name} {family_name}".strip()
user.save()
BaseBackend.instance.save(user)
if form["password1"].data:
BaseBackend.instance.set_user_password(user, form["password1"].data)
user.save()
BaseBackend.instance.save(user)
return user
@ -536,7 +536,7 @@ def profile_edition_main_form_validation(user, edited_user, profile_form):
if profile_form["preferred_language"].data == "auto":
edited_user.preferred_language = None
edited_user.save()
BaseBackend.instance.save(edited_user)
g.user.reload()
@ -574,7 +574,7 @@ def profile_edition_remove_email(user, edited_user, email):
return False
edited_user.emails = [m for m in edited_user.emails if m != email]
edited_user.save()
BaseBackend.instance.save(edited_user)
return True
@ -730,7 +730,7 @@ def profile_settings(user, edited_user):
):
flash(_("The account has been locked"), "success")
edited_user.lock_date = datetime.datetime.now(datetime.timezone.utc)
edited_user.save()
BaseBackend.instance.save(edited_user)
return profile_settings_edit(user, edited_user)
@ -741,7 +741,7 @@ def profile_settings(user, edited_user):
):
flash(_("The account has been unlocked"), "success")
edited_user.lock_date = None
edited_user.save()
BaseBackend.instance.save(edited_user)
return profile_settings_edit(user, edited_user)
@ -791,7 +791,7 @@ def profile_settings_edit(editor, edited_user):
edited_user, form["password1"].data
)
edited_user.save()
BaseBackend.instance.save(edited_user)
flash(_("Profile updated successfully."), "success")
return redirect(
url_for("core.account.profile_settings", edited_user=edited_user)

View file

@ -11,6 +11,7 @@ from canaille.app.flask import render_htmx_template
from canaille.app.forms import TableForm
from canaille.app.i18n import gettext as _
from canaille.app.themes import render_template
from canaille.backends import BaseBackend
from .forms import CreateGroupForm
from .forms import DeleteGroupMemberForm
@ -42,7 +43,7 @@ def create_group(user):
group.members = [user]
group.display_name = form.display_name.data
group.description = form.description.data
group.save()
BaseBackend.instance.save(group)
flash(
_(
"The group %(group)s has been sucessfully created",
@ -102,7 +103,7 @@ def edit_group(group):
):
if form.validate():
group.description = form.description.data
group.save()
BaseBackend.instance.save(group)
flash(
_(
"The group %(group)s has been sucessfully edited.",
@ -151,7 +152,7 @@ def delete_member(group):
group.members = [
member for member in group.members if member != form.member.data
]
group.save()
BaseBackend.instance.save(group)
return edit_group(group)

View file

@ -40,7 +40,7 @@ def fake_users(nb=1):
password=fake.password(),
preferred_language=fake._locales[0],
)
user.save()
BaseBackend.instance.save(user)
users.append(user)
except Exception: # pragma: no cover
pass
@ -59,7 +59,7 @@ def fake_groups(nb=1, nb_users_max=1):
)
nb_users = random.randrange(1, nb_users_max + 1)
group.members = list({random.choice(users) for _ in range(nb_users)})
group.save()
BaseBackend.instance.save(group)
groups.append(group)
except Exception: # pragma: no cover
pass

View file

@ -14,6 +14,7 @@ from canaille.app.flask import render_htmx_template
from canaille.app.forms import TableForm
from canaille.app.i18n import gettext as _
from canaille.app.themes import render_template
from canaille.backends import BaseBackend
from .forms import ClientAddForm
@ -73,9 +74,9 @@ def add(user):
if form["token_endpoint_auth_method"].data == "none"
else gen_salt(48),
)
client.save()
BaseBackend.instance.save(client)
client.audience = [client]
client.save()
BaseBackend.instance.save(client)
flash(
_("The client has been created."),
"success",
@ -137,7 +138,7 @@ def client_edit(client):
audience=form["audience"].data,
preconsent=form["preconsent"].data,
)
client.save()
BaseBackend.instance.save(client)
flash(
_("The client has been edited."),
"success",

View file

@ -95,7 +95,7 @@ def restore(user, consent):
consent.restore()
if not consent.issue_date:
consent.issue_date = datetime.datetime.now(datetime.timezone.utc)
consent.save()
BaseBackend.instance.save(consent)
flash(_("The access has been restored"), "success")
return redirect(url_for("oidc.consents.consents"))
@ -119,7 +119,7 @@ def revoke_preconsent(user, client):
scope=client.scope,
)
consent.revoke()
consent.save()
BaseBackend.instance.save(consent)
flash(_("The access has been revoked"), "success")
return redirect(url_for("oidc.consents.consents"))

View file

@ -177,7 +177,7 @@ def authorize_consent(client, user):
scope=allowed_scopes,
issue_date=datetime.datetime.now(datetime.timezone.utc),
)
consent.save()
BaseBackend.instance.save(consent)
response = authorization.create_authorization_response(grant_user=grant_user)
current_app.logger.debug("authorization endpoint response: %s", response.location)

View file

@ -11,6 +11,7 @@ from canaille.app.flask import render_htmx_template
from canaille.app.forms import TableForm
from canaille.app.i18n import gettext as _
from canaille.app.themes import render_template
from canaille.backends import BaseBackend
from .forms import TokenRevokationForm
@ -40,7 +41,7 @@ def view(user, token):
elif request.form.get("action") == "revoke":
token.revokation_date = datetime.datetime.now(datetime.timezone.utc)
token.save()
BaseBackend.instance.save(token)
flash(_("The token has successfully been revoked."), "success")
else:

View file

@ -184,7 +184,7 @@ class Consent(BaseConsent):
def revoke(self):
self.revokation_date = datetime.datetime.now(datetime.timezone.utc)
self.save()
BaseBackend.instance.save(self)
tokens = BaseBackend.instance.query(
models.Token,
@ -194,8 +194,8 @@ class Consent(BaseConsent):
tokens = [token for token in tokens if not token.revoked]
for t in tokens:
t.revokation_date = self.revokation_date
t.save()
BaseBackend.instance.save(t)
def restore(self):
self.revokation_date = None
self.save()
BaseBackend.instance.save(self)

View file

@ -228,7 +228,7 @@ def save_authorization_code(code, request):
challenge=request.data.get("code_challenge"),
challenge_method=request.data.get("code_challenge_method"),
)
code.save()
BaseBackend.instance.save(code)
return code.code
@ -297,7 +297,7 @@ class RefreshTokenGrant(_RefreshTokenGrant):
def revoke_old_credential(self, credential):
credential.revokation_date = datetime.datetime.now(datetime.timezone.utc)
credential.save()
BaseBackend.instance.save(credential)
class OpenIDImplicitGrant(_OpenIDImplicitGrant):
@ -351,7 +351,7 @@ def save_token(token, request):
subject=request.user,
audience=request.client.audience,
)
t.save()
BaseBackend.instance.save(t)
class BearerTokenValidator(_BearerTokenValidator):
@ -382,7 +382,7 @@ class RevocationEndpoint(_RevocationEndpoint):
def revoke_token(self, token, request):
token.revokation_date = datetime.datetime.now(datetime.timezone.utc)
token.save()
BaseBackend.instance.save(token)
class IntrospectionEndpoint(_IntrospectionEndpoint):
@ -463,9 +463,9 @@ class ClientRegistrationEndpoint(ClientManagementMixin, _ClientRegistrationEndpo
post_logout_redirect_uris=request.data.get("post_logout_redirect_uris"),
**self.client_convert_data(**client_info, **client_metadata),
)
client.save()
BaseBackend.instance.save(client)
client.audience = [client]
client.save()
BaseBackend.instance.save(client)
return client
@ -485,7 +485,7 @@ class ClientConfigurationEndpoint(ClientManagementMixin, _ClientConfigurationEnd
def update_client(self, client, client_metadata, request):
client.update(**self.client_convert_data(**client_metadata))
client.save()
BaseBackend.instance.save(client)
return client
def generate_client_registration_info(self, client, request):

View file

@ -206,13 +206,13 @@ def test_fieldlist_add_readonly(testclient, logged_user):
testclient.post("/profile/user", data, status=403)
def test_fieldlist_remove_readonly(testclient, logged_user):
def test_fieldlist_remove_readonly(testclient, logged_user, backend):
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["WRITE"].remove("phone_numbers")
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["READ"].append("phone_numbers")
logged_user.reload()
logged_user.phone_numbers = ["555-555-000", "555-555-111"]
logged_user.save()
backend.save(logged_user)
res = testclient.get("/profile/user")
form = res.forms["baseform"]

View file

@ -1,9 +1,9 @@
from flask_babel import refresh
def test_preferred_language(testclient, logged_user):
def test_preferred_language(testclient, logged_user, backend):
logged_user.preferred_language = None
logged_user.save()
backend.save(logged_user)
res = testclient.get("/profile/user", status=200)
form = res.forms["baseform"]
@ -49,9 +49,9 @@ def test_preferred_language(testclient, logged_user):
res.mustcontain(no="Mon profil")
def test_form_translations(testclient, logged_user):
def test_form_translations(testclient, logged_user, backend):
logged_user.preferred_language = "fr"
logged_user.save()
backend.save(logged_user)
res = testclient.get("/profile/user", status=200)
form = res.forms["baseform"]
@ -62,9 +62,9 @@ def test_form_translations(testclient, logged_user):
res.mustcontain("Nest pas un numéro de téléphone valide")
def test_language_config(testclient, logged_user):
def test_language_config(testclient, logged_user, backend):
logged_user.preferred_language = None
logged_user.save()
backend.save(logged_user)
res = testclient.get("/profile/user", status=200)
assert res.pyquery("html")[0].attrib["lang"] == "en"

View file

@ -7,7 +7,7 @@ def test_model_references_set_unsaved_object(
"""LDAP groups can be inconsistent by containing members which doesn't
exist."""
group = models.Group(members=[user], display_name="foo")
group.save()
backend.save(group)
user.reload()
non_existent_user = models.User(
@ -16,7 +16,7 @@ def test_model_references_set_unsaved_object(
group.members = group.members + [non_existent_user]
assert group.members == [user, non_existent_user]
group.save()
backend.save(group)
assert group.members == [user, non_existent_user]
group.reload()

View file

@ -5,7 +5,7 @@ from canaille.backends.ldap.ldapobject import LDAPObject
def test_guess_object_from_dn(backend, testclient, foo_group):
foo_group.members = [foo_group]
foo_group.save()
backend.save(foo_group)
dn = foo_group.dn
g = backend.get(LDAPObject, dn)
assert isinstance(g, models.Group)
@ -18,7 +18,7 @@ def test_object_class_update(backend, testclient):
setup_ldap_models(testclient.app.config)
user1 = models.User(cn="foo1", sn="bar1", user_name="baz1")
user1.save()
backend.save(user1)
assert set(user1.get_ldap_attribute("objectClass")) == {"inetOrgPerson"}
assert set(
@ -32,7 +32,7 @@ def test_object_class_update(backend, testclient):
setup_ldap_models(testclient.app.config)
user2 = models.User(cn="foo2", sn="bar2", user_name="baz2")
user2.save()
backend.save(user2)
assert set(user2.get_ldap_attribute("objectClass")) == {
"inetOrgPerson",
@ -48,7 +48,7 @@ def test_object_class_update(backend, testclient):
user1 = backend.get(models.User, id=user1.id)
assert user1.get_ldap_attribute("objectClass") == ["inetOrgPerson"]
user1.save()
backend.save(user1)
assert set(user1.get_ldap_attribute("objectClass")) == {
"inetOrgPerson",
"extensibleObject",
@ -72,7 +72,7 @@ def test_keep_old_object_classes(backend, testclient, slapd_server):
attributes.
"""
user = models.User(cn="foo", sn="bar", user_name="baz")
user.save()
backend.save(user)
ldif = f"""dn: {user.dn}
changetype: modify
@ -95,6 +95,6 @@ homeDirectory: /home/foobar
user.reload()
# saving an object should not raise a ldap.OBJECT_CLASS_VIOLATION exception
user.save()
backend.save(user)
user.delete()

View file

@ -24,7 +24,7 @@ def test_object_creation(app, backend):
emails=["john@doe.com"],
)
assert not user.exists
user.save()
backend.save(user)
assert user.exists
user = backend.get(models.User, id=user.id)
@ -45,7 +45,7 @@ def test_dn_when_leading_space_in_id_attribute(testclient, backend):
user_name=" user",
emails=["john@doe.com"],
)
user.save()
backend.save(user)
dn = user.dn
assert dn == "uid=user,ou=users,dc=mydomain,dc=tld"
@ -66,7 +66,7 @@ def test_special_chars_in_rdn(testclient, backend):
user_name="#user", # special char
emails=["john@doe.com"],
)
user.save()
backend.save(user)
dn = user.dn
assert ldap.dn.is_dn(dn)

View file

@ -12,13 +12,13 @@ def test_model_comparison(testclient, backend):
family_name="foo",
formatted_name="foo",
)
foo1.save()
backend.save(foo1)
bar = models.User(
user_name="bar",
family_name="bar",
formatted_name="bar",
)
bar.save()
backend.save(bar)
foo2 = backend.get(models.User, id=foo1.id)
assert foo1 == foo2
@ -41,7 +41,7 @@ def test_model_lifecycle(testclient, backend):
assert not backend.query(models.User, id="invalid")
assert not backend.get(models.User, id=user.id)
user.save()
backend.save(user)
assert backend.query(models.User) == [user]
assert backend.query(models.User, id=user.id) == [user]
@ -72,7 +72,7 @@ def test_model_attribute_edition(testclient, backend):
display_name="display_name",
emails=["email1@user.com", "email2@user.com"],
)
user.save()
backend.save(user)
assert user.user_name == "user_name"
assert user.family_name == "family_name"
@ -85,7 +85,7 @@ def test_model_attribute_edition(testclient, backend):
user.family_name = "new_family_name"
user.emails = ["email1@user.com"]
user.save()
backend.save(user)
assert user.family_name == "new_family_name"
assert user.emails == ["email1@user.com"]
@ -97,7 +97,7 @@ def test_model_attribute_edition(testclient, backend):
user.display_name = ""
assert not user.display_name
user.save()
backend.save(user)
assert not user.display_name
user.delete()
@ -110,7 +110,7 @@ def test_model_indexation(testclient, backend):
formatted_name="formatted_name",
emails=["email1@user.com", "email2@user.com"],
)
user.save()
backend.save(user)
assert backend.get(models.User, family_name="family_name") == user
assert not backend.get(models.User, family_name="new_family_name")
@ -125,7 +125,7 @@ def test_model_indexation(testclient, backend):
assert backend.get(models.User, emails=["email1@user.com"]) != user
assert not backend.get(models.User, emails=["email3@user.com"])
user.save()
backend.save(user)
assert not backend.get(models.User, family_name="family_name")
assert backend.get(models.User, family_name="new_family_name") == user
@ -177,14 +177,14 @@ def test_model_references(testclient, user, foo_group, admin, bar_group, backend
assert user not in bar_group.members
assert bar_group not in user.groups
user.groups = user.groups + [bar_group]
user.save()
backend.save(user)
bar_group.reload()
assert user in bar_group.members
assert bar_group in user.groups
bar_group.members = [admin]
bar_group.save()
backend.save(bar_group)
user.reload()
assert user not in bar_group.members
@ -201,7 +201,7 @@ def test_model_creation_edition_datetime(testclient, backend):
family_name="foo",
formatted_name="foo",
)
user.save()
backend.save(user)
assert user.created == datetime.datetime(
2020, 1, 1, 2, tzinfo=datetime.timezone.utc
)
@ -211,7 +211,7 @@ def test_model_creation_edition_datetime(testclient, backend):
with time_machine.travel("2021-01-01 02:00:00+00:00", tick=False):
user.family_name = "bar"
user.save()
backend.save(user)
assert user.created == datetime.datetime(
2020, 1, 1, 2, tzinfo=datetime.timezone.utc
)

View file

@ -191,7 +191,7 @@ def user(app, backend):
profile_url="https://john.example",
formatted_address="1235, somewhere",
)
u.save()
backend.save(u)
yield u
u.delete()
@ -205,7 +205,7 @@ def admin(app, backend):
emails=["jane@doe.com"],
password="admin",
)
u.save()
backend.save(u)
yield u
u.delete()
@ -219,7 +219,7 @@ def moderator(app, backend):
emails=["jack@doe.com"],
password="moderator",
)
u.save()
backend.save(u)
yield u
u.delete()
@ -251,7 +251,7 @@ def foo_group(app, user, backend):
members=[user],
display_name="foo",
)
group.save()
backend.save(group)
user.reload()
yield group
group.delete()
@ -263,7 +263,7 @@ def bar_group(app, admin, backend):
members=[admin],
display_name="bar",
)
group.save()
backend.save(group)
admin.reload()
yield group
group.delete()

View file

@ -27,7 +27,7 @@ def test_user_deleted_in_session(testclient, backend):
emails=["jake@doe.com"],
password="correct horse battery staple",
)
u.save()
backend.save(u)
testclient.get("/profile/jake", status=403)
with testclient.session_transaction() as session:
@ -66,7 +66,7 @@ def test_admin_self_deletion(testclient, backend):
emails=["temp@temp.com"],
password="admin",
)
admin.save()
backend.save(admin)
with testclient.session_transaction() as sess:
sess["user_id"] = [admin.id]
@ -92,7 +92,7 @@ def test_user_self_deletion(testclient, backend):
emails=["temp@temp.com"],
password="correct horse battery staple",
)
user.save()
backend.save(user)
with testclient.session_transaction() as sess:
sess["user_id"] = [user.id]
@ -134,7 +134,7 @@ def test_account_locking(user, backend):
user.lock_date = datetime.datetime.now(datetime.timezone.utc)
assert user.locked
user.save()
backend.save(user)
assert user.locked
assert backend.get(models.User, id=user.id).locked
assert backend.check_user_password(user, "correct horse battery staple") == (
@ -143,7 +143,7 @@ def test_account_locking(user, backend):
)
user.lock_date = None
user.save()
backend.save(user)
assert not user.locked
assert not backend.get(models.User, id=user.id).locked
assert backend.check_user_password(user, "correct horse battery staple") == (
@ -163,7 +163,7 @@ def test_account_locking_past_date(user, backend):
user.lock_date = datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0
) - datetime.timedelta(days=30)
user.save()
backend.save(user)
assert user.locked
assert backend.get(models.User, id=user.id).locked
assert backend.check_user_password(user, "correct horse battery staple") == (
@ -183,7 +183,7 @@ def test_account_locking_future_date(user, backend):
user.lock_date = datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0
) + datetime.timedelta(days=365 * 4)
user.save()
backend.save(user)
assert not user.locked
assert not backend.get(models.User, id=user.id).locked
assert backend.check_user_password(user, "correct horse battery staple") == (
@ -192,7 +192,7 @@ def test_account_locking_future_date(user, backend):
)
def test_account_locked_during_session(testclient, logged_user):
def test_account_locked_during_session(testclient, logged_user, backend):
logged_user.lock_date = datetime.datetime.now(datetime.timezone.utc)
logged_user.save()
backend.save(logged_user)
testclient.get("/profile/user/settings", status=403)

View file

@ -154,7 +154,7 @@ def test_user_without_password_first_login(testclient, backend, smtpd):
user_name="temp",
emails=["john@doe.com", "johhny@doe.com"],
)
u.save()
backend.save(u)
res = testclient.get("/login", status=200)
res.form["login"] = "temp"
@ -189,7 +189,7 @@ def test_first_login_account_initialization_mail_sending_failed(
user_name="temp",
emails=["john@doe.com"],
)
u.save()
backend.save(u)
res = testclient.get("/firstlogin/temp")
res = res.form.submit(name="action", value="sendmail", expect_errors=True)
@ -211,7 +211,7 @@ def test_first_login_form_error(testclient, backend, smtpd):
user_name="temp",
emails=["john@doe.com"],
)
u.save()
backend.save(u)
res = testclient.get("/firstlogin/temp", status=200)
res.form["csrf_token"] = "invalid"
@ -236,7 +236,7 @@ def test_user_password_deleted_during_login(testclient, backend):
emails=["john@doe.com"],
password="correct horse battery staple",
)
u.save()
backend.save(u)
res = testclient.get("/login")
res.form["login"] = "temp"
@ -244,7 +244,7 @@ def test_user_password_deleted_during_login(testclient, backend):
res.form["password"] = "correct horse battery staple"
u.password = None
u.save()
backend.save(u)
res = res.form.submit(status=302)
assert res.location == "/firstlogin/temp"
@ -272,12 +272,12 @@ def test_wrong_login(testclient, user):
res.mustcontain("The login 'invalid' does not exist")
def test_signin_locked_account(testclient, user):
def test_signin_locked_account(testclient, user, backend):
with testclient.session_transaction() as session:
assert not session.get("user_id")
user.lock_date = datetime.datetime.now(datetime.timezone.utc)
user.save()
backend.save(user)
res = testclient.get("/login", status=200)
res.form["login"] = "user"
@ -289,4 +289,4 @@ def test_signin_locked_account(testclient, user):
res.mustcontain("Your account has been locked.")
user.lock_date = None
user.save()
backend.save(user)

View file

@ -371,14 +371,14 @@ def test_confirmation_email_already_used_link(testclient, backend, user, admin):
assert "new_email@mydomain.tld" not in user.emails
def test_delete_email(testclient, logged_user):
def test_delete_email(testclient, logged_user, backend):
"""Tests that user can deletes its emails unless they have only one
left."""
res = testclient.get("/profile/user")
assert "email_remove" not in res.forms["emailconfirmationform"].fields
logged_user.emails = logged_user.emails + ["new@email.com"]
logged_user.save()
backend.save(logged_user)
res = testclient.get("/profile/user")
assert "email_remove" in res.forms["emailconfirmationform"].fields
@ -391,10 +391,10 @@ def test_delete_email(testclient, logged_user):
assert logged_user.emails == ["john@doe.com"]
def test_delete_wrong_email(testclient, logged_user):
def test_delete_wrong_email(testclient, logged_user, backend):
"""Tests that removing an already removed email do not produce anything."""
logged_user.emails = logged_user.emails + ["new@email.com"]
logged_user.save()
backend.save(logged_user)
res = testclient.get("/profile/user")
@ -412,10 +412,10 @@ def test_delete_wrong_email(testclient, logged_user):
assert logged_user.emails == ["john@doe.com"]
def test_delete_last_email(testclient, logged_user):
def test_delete_last_email(testclient, logged_user, backend):
"""Tests that users cannot remove their last email address."""
logged_user.emails = logged_user.emails + ["new@email.com"]
logged_user.save()
backend.save(logged_user)
res = testclient.get("/profile/user")

View file

@ -26,9 +26,9 @@ def test_password_forgotten(smtpd, testclient, user):
assert len(smtpd.messages) == 1
def test_password_forgotten_multiple_mails(smtpd, testclient, user):
def test_password_forgotten_multiple_mails(smtpd, testclient, user, backend):
user.emails = ["foo@bar.com", "foo@baz.com", "foo@foo.com"]
user.save()
backend.save(user)
res = testclient.get("/reset", status=200)

View file

@ -53,13 +53,13 @@ def test_group_deletion(testclient, backend):
user_name="foobar",
emails=["foo@bar.com"],
)
user.save()
backend.save(user)
group = models.Group(
members=[user],
display_name="foobar",
)
group.save()
backend.save(group)
user.reload()
assert user.groups == [group]
@ -86,19 +86,19 @@ def test_group_list_search(testclient, logged_admin, foo_group, bar_group):
res.mustcontain(no=bar_group.display_name)
def test_set_groups(app, user, foo_group, bar_group):
def test_set_groups(app, user, foo_group, bar_group, backend):
assert user in foo_group.members
assert user.groups == [foo_group]
user.groups = [foo_group, bar_group]
user.save()
backend.save(user)
bar_group.reload()
assert user in bar_group.members
assert bar_group in user.groups
user.groups = [foo_group]
user.save()
backend.save(user)
foo_group.reload()
bar_group.reload()
@ -106,23 +106,23 @@ def test_set_groups(app, user, foo_group, bar_group):
assert user not in bar_group.members
def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group):
def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group, backend):
user = models.User(
formatted_name=" Doe", # leading space in id attribute
family_name="Doe",
user_name="user2",
emails=["john@doe.com"],
)
user.save()
backend.save(user)
user.groups = [foo_group]
user.save()
backend.save(user)
foo_group.reload()
assert user in foo_group.members
user.groups = []
user.save()
backend.save(user)
foo_group.reload()
assert user.id not in foo_group.members
@ -231,14 +231,14 @@ def test_edition_failed(testclient, logged_moderator, foo_group):
assert foo_group.display_name == "foo"
def test_user_list_pagination(testclient, logged_admin, foo_group):
def test_user_list_pagination(testclient, logged_admin, foo_group, backend):
res = testclient.get("/groups/foo")
res.mustcontain("1 item")
users = fake_users(25)
for user in users:
foo_group.members = foo_group.members + [user]
foo_group.save()
backend.save(foo_group)
assert len(foo_group.members) == 26
res = testclient.get("/groups/foo")
@ -274,9 +274,11 @@ def test_user_list_bad_pages(testclient, logged_admin, foo_group):
)
def test_user_list_search(testclient, logged_admin, foo_group, user, moderator):
def test_user_list_search(
testclient, logged_admin, foo_group, user, moderator, backend
):
foo_group.members = foo_group.members + [logged_admin, moderator]
foo_group.save()
backend.save(foo_group)
res = testclient.get("/groups/foo")
res.mustcontain("3 items")
@ -294,9 +296,9 @@ def test_user_list_search(testclient, logged_admin, foo_group, user, moderator):
res.mustcontain(no=moderator.formatted_name)
def test_remove_member(testclient, logged_admin, foo_group, user, moderator):
def test_remove_member(testclient, logged_admin, foo_group, user, moderator, backend):
foo_group.members = [user, moderator]
foo_group.save()
backend.save(foo_group)
res = testclient.get("/groups/foo")
form = res.forms[f"deletegroupmemberform-{user.id}"]
@ -313,15 +315,15 @@ def test_remove_member(testclient, logged_admin, foo_group, user, moderator):
def test_remove_member_already_remove_from_group(
testclient, logged_admin, foo_group, user, moderator
testclient, logged_admin, foo_group, user, moderator, backend
):
foo_group.members = [user, moderator]
foo_group.save()
backend.save(foo_group)
res = testclient.get("/groups/foo")
form = res.forms[f"deletegroupmemberform-{user.id}"]
foo_group.members = [moderator]
foo_group.save()
backend.save(foo_group)
res = form.submit(name="action", value="confirm-remove-member")
assert (
@ -331,17 +333,17 @@ def test_remove_member_already_remove_from_group(
def test_confirm_remove_member_already_removed_from_group(
testclient, logged_admin, foo_group, user, moderator
testclient, logged_admin, foo_group, user, moderator, backend
):
foo_group.members = [user, moderator]
foo_group.save()
backend.save(foo_group)
res = testclient.get("/groups/foo")
form = res.forms[f"deletegroupmemberform-{user.id}"]
res = form.submit(name="action", value="confirm-remove-member")
foo_group.members = [moderator]
foo_group.save()
backend.save(foo_group)
res = res.form.submit(name="action", value="remove-member")
assert (
@ -351,10 +353,10 @@ def test_confirm_remove_member_already_removed_from_group(
def test_remove_member_already_deleted(
testclient, logged_admin, foo_group, user, moderator
testclient, logged_admin, foo_group, user, moderator, backend
):
foo_group.members = [user, moderator]
foo_group.save()
backend.save(foo_group)
res = testclient.get("/groups/foo")
form = res.forms[f"deletegroupmemberform-{user.id}"]
@ -368,10 +370,10 @@ def test_remove_member_already_deleted(
def test_confirm_remove_member_already_deleted(
testclient, logged_admin, foo_group, user, moderator
testclient, logged_admin, foo_group, user, moderator, backend
):
foo_group.members = [user, moderator]
foo_group.save()
backend.save(foo_group)
res = testclient.get("/groups/foo")
form = res.forms[f"deletegroupmemberform-{user.id}"]

View file

@ -13,13 +13,13 @@ def test_user_has_password(testclient, backend):
user_name="temp",
emails=["john@doe.com"],
)
user.save()
backend.save(user)
assert user.password is None
assert not user.has_password()
user.password = "foobar"
user.save()
backend.save(user)
assert user.password is not None
assert user.has_password()

View file

@ -25,7 +25,7 @@ def test_password_reset(testclient, user, backend):
def test_password_reset_multiple_emails(testclient, user, backend):
user.emails = ["foo@bar.com", "foo@baz.com"]
user.save()
backend.save(user)
assert not backend.check_user_password(user, "foobarbaz")[0]
hash = build_hash("user", "foo@baz.com", user.password)

View file

@ -118,6 +118,7 @@ def test_edition(
logged_user,
admin,
jpeg_photo,
backend,
):
res = testclient.get("/profile/user", status=200)
form = res.forms["baseform"]
@ -168,13 +169,14 @@ def test_edition(
logged_user.emails = ["john@doe.com"]
logged_user.given_name = None
logged_user.photo = None
logged_user.save()
backend.save(logged_user)
def test_edition_remove_fields(
testclient,
logged_user,
admin,
backend,
):
res = testclient.get("/profile/user", status=200)
form = res.forms["baseform"]
@ -195,13 +197,13 @@ def test_edition_remove_fields(
logged_user.emails = ["john@doe.com"]
logged_user.given_name = None
logged_user.photo = None
logged_user.save()
backend.save(logged_user)
def test_field_permissions_none(testclient, logged_user):
def test_field_permissions_none(testclient, logged_user, backend):
testclient.get("/profile/user", status=200)
logged_user.phone_numbers = ["555-666-777"]
logged_user.save()
backend.save(logged_user)
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"] = {
"READ": ["user_name"],
@ -227,10 +229,10 @@ def test_field_permissions_none(testclient, logged_user):
assert logged_user.phone_numbers == ["555-666-777"]
def test_field_permissions_read(testclient, logged_user):
def test_field_permissions_read(testclient, logged_user, backend):
testclient.get("/profile/user", status=200)
logged_user.phone_numbers = ["555-666-777"]
logged_user.save()
backend.save(logged_user)
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"] = {
"READ": ["user_name", "phone_numbers"],
@ -255,10 +257,10 @@ def test_field_permissions_read(testclient, logged_user):
assert logged_user.phone_numbers == ["555-666-777"]
def test_field_permissions_write(testclient, logged_user):
def test_field_permissions_write(testclient, logged_user, backend):
testclient.get("/profile/user", status=200)
logged_user.phone_numbers = ["555-666-777"]
logged_user.save()
backend.save(logged_user)
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"] = {
"READ": ["user_name"],

View file

@ -5,9 +5,9 @@ from webtest import Upload
from canaille.app import models
def test_photo(testclient, user, jpeg_photo):
def test_photo(testclient, user, jpeg_photo, backend):
user.photo = jpeg_photo
user.save()
backend.save(user)
user.reload()
res = testclient.get("/profile/user/photo")

View file

@ -34,13 +34,13 @@ def test_edition(testclient, logged_user, admin, foo_group, bar_group, backend):
assert backend.check_user_password(logged_user, "correct horse battery staple")[0]
logged_user.user_name = "user"
logged_user.save()
backend.save(logged_user)
def test_group_removal(testclient, logged_admin, user, foo_group, backend):
"""Tests that one can remove a group from a user."""
foo_group.members = [user, logged_admin]
foo_group.save()
backend.save(foo_group)
user.reload()
assert foo_group in user.groups
@ -115,7 +115,7 @@ def test_edition_without_groups(
assert backend.check_user_password(logged_user, "correct horse battery staple")[0]
logged_user.user_name = "user"
logged_user.save()
backend.save(logged_user)
def test_password_change(testclient, logged_user, backend):
@ -171,7 +171,7 @@ def test_password_initialization_mail(smtpd, testclient, backend, logged_admin):
user_name="temp",
emails=["john@doe.com"],
)
u.save()
backend.save(u)
res = testclient.get("/profile/temp/settings", status=200)
res.mustcontain("This user does not have a password yet")
@ -188,7 +188,7 @@ def test_password_initialization_mail(smtpd, testclient, backend, logged_admin):
u.reload()
u.password = "correct horse battery staple"
u.save()
backend.save(u)
res = testclient.get("/profile/temp/settings", status=200)
res.mustcontain(no="This user does not have a password yet")
@ -207,7 +207,7 @@ def test_password_initialization_mail_send_fail(
user_name="temp",
emails=["john@doe.com"],
)
u.save()
backend.save(u)
res = testclient.get("/profile/temp/settings", status=200)
res.mustcontain("This user does not have a password yet")
@ -272,7 +272,7 @@ def test_impersonate_locked_user(testclient, backend, logged_admin, user):
user.lock_date = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
days=1
)
user.save()
backend.save(user)
assert user.locked
res = testclient.get("/profile/user/settings")
@ -295,7 +295,7 @@ def test_password_reset_email(smtpd, testclient, backend, logged_admin):
emails=["john@doe.com"],
password="correct horse battery staple",
)
u.save()
backend.save(u)
res = testclient.get("/profile/temp/settings", status=200)
res.mustcontain("If the user has forgotten his password")
@ -323,7 +323,7 @@ def test_password_reset_email_failed(SMTP, smtpd, testclient, backend, logged_ad
emails=["john@doe.com"],
password="correct horse battery staple",
)
u.save()
backend.save(u)
res = testclient.get("/profile/temp/settings", status=200)
res.mustcontain("If the user has forgotten his password")
@ -454,7 +454,7 @@ def test_empty_lock_date(
second=0, microsecond=0
) + datetime.timedelta(days=30)
user.lock_date = expiration_datetime
user.save()
backend.save(user)
res = testclient.get("/profile/user/settings", status=200)
res.form["lock_date"] = ""

View file

@ -21,7 +21,7 @@ def test_clean_command(testclient, backend, client, user):
challenge="challenge",
challenge_method="method",
)
valid_code.save()
backend.save(valid_code)
expired_code = models.AuthorizationCode(
authorization_code_id=gen_salt(48),
code="my-expired-code",
@ -39,7 +39,7 @@ def test_clean_command(testclient, backend, client, user):
challenge="challenge",
challenge_method="method",
)
expired_code.save()
backend.save(expired_code)
valid_token = models.Token(
token_id=gen_salt(48),
@ -53,7 +53,7 @@ def test_clean_command(testclient, backend, client, user):
),
lifetime=3600,
)
valid_token.save()
backend.save(valid_token)
expired_token = models.Token(
token_id=gen_salt(48),
access_token="my-expired-token",
@ -67,7 +67,7 @@ def test_clean_command(testclient, backend, client, user):
),
lifetime=3600,
)
expired_token.save()
backend.save(expired_token)
assert backend.get(models.AuthorizationCode, code="my-expired-code")
assert backend.get(models.Token, access_token="my-expired-token")

View file

@ -67,9 +67,9 @@ def client(testclient, trusted_client, backend):
token_endpoint_auth_method="client_secret_basic",
post_logout_redirect_uris=["https://mydomain.tld/disconnected"],
)
c.save()
backend.save(c)
c.audience = [c, trusted_client]
c.save()
backend.save(c)
yield c
c.delete()
@ -105,9 +105,9 @@ def trusted_client(testclient, backend):
post_logout_redirect_uris=["https://myotherdomain.tld/disconnected"],
preconsent=True,
)
c.save()
backend.save(c)
c.audience = [c]
c.save()
backend.save(c)
yield c
c.delete()
@ -129,7 +129,7 @@ def authorization(testclient, user, client, backend):
challenge="challenge",
challenge_method="method",
)
a.save()
backend.save(a)
yield a
a.delete()
@ -147,7 +147,7 @@ def token(testclient, client, user, backend):
issue_date=datetime.datetime.now(datetime.timezone.utc),
lifetime=3600,
)
t.save()
backend.save(t)
yield t
t.delete()
@ -171,6 +171,6 @@ def consent(testclient, client, user, backend):
scope=["openid", "profile"],
issue_date=datetime.datetime.now(datetime.timezone.utc),
)
t.save()
backend.save(t)
yield t
t.delete()

View file

@ -167,7 +167,7 @@ def test_preconsented_client(
assert not backend.query(models.Consent)
client.preconsent = True
client.save()
backend.save(client)
res = testclient.get(
"/oauth/authorize",
@ -318,7 +318,7 @@ def test_code_challenge(testclient, logged_user, client, backend):
assert not backend.query(models.Consent)
client.token_endpoint_auth_method = "none"
client.save()
backend.save(client)
code_verifier = gen_salt(48)
code_challenge = create_s256_code_challenge(code_verifier)
@ -373,7 +373,7 @@ def test_code_challenge(testclient, logged_user, client, backend):
assert res.json["name"] == "John (johnny) Doe"
client.token_endpoint_auth_method = "client_secret_basic"
client.save()
backend.save(client)
for consent in consents:
consent.delete()
@ -565,7 +565,7 @@ def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client, b
def test_request_scope_too_large(testclient, logged_user, keypair, client, backend):
assert not backend.query(models.Consent)
client.scope = ["openid", "profile", "groups"]
client.save()
backend.save(client)
res = testclient.get(
"/oauth/authorize",
@ -679,7 +679,7 @@ def test_code_with_invalid_user(testclient, admin, client, backend):
emails=["temp@temp.com"],
password="correct horse battery staple",
)
user.save()
backend.save(user)
res = testclient.get(
"/oauth/authorize",
@ -738,7 +738,7 @@ def test_locked_account(
)
logged_user.lock_date = datetime.datetime.now(datetime.timezone.utc)
logged_user.save()
backend.save(logged_user)
res = res.form.submit(name="answer", value="accept", status=302)

View file

@ -14,7 +14,7 @@ from canaille.app import models
from canaille.core.endpoints.account import RegistrationPayload
def test_prompt_none(testclient, logged_user, client):
def test_prompt_none(testclient, logged_user, client, backend):
"""Nominal case with prompt=none."""
consent = models.Consent(
consent_id=str(uuid.uuid4()),
@ -22,7 +22,7 @@ def test_prompt_none(testclient, logged_user, client):
subject=logged_user,
scope=["openid", "profile"],
)
consent.save()
backend.save(consent)
res = testclient.get(
"/oauth/authorize",
@ -42,7 +42,7 @@ def test_prompt_none(testclient, logged_user, client):
consent.delete()
def test_prompt_not_logged(testclient, user, client):
def test_prompt_not_logged(testclient, user, client, backend):
"""Prompt=none should return a login_required error when no user is logged
in.
@ -58,7 +58,7 @@ def test_prompt_not_logged(testclient, user, client):
subject=user,
scope=["openid", "profile"],
)
consent.save()
backend.save(consent)
res = testclient.get(
"/oauth/authorize",
@ -100,7 +100,7 @@ def test_prompt_no_consent(testclient, logged_user, client):
assert "consent_required" == res.json.get("error")
def test_prompt_create_logged(testclient, logged_user, client):
def test_prompt_create_logged(testclient, logged_user, client, backend):
"""If prompt=create and user is already logged in, then go straight to the
consent page."""
testclient.app.config["CANAILLE"]["ENABLE_REGISTRATION"] = True
@ -111,7 +111,7 @@ def test_prompt_create_logged(testclient, logged_user, client):
subject=logged_user,
scope=["openid", "profile"],
)
consent.save()
backend.save(consent)
res = testclient.get(
"/oauth/authorize",

View file

@ -22,13 +22,15 @@ def test_client_list(testclient, client, logged_admin):
res.mustcontain(client.client_name)
def test_client_list_pagination(testclient, logged_admin, client, trusted_client):
def test_client_list_pagination(
testclient, logged_admin, client, trusted_client, backend
):
res = testclient.get("/admin/client")
res.mustcontain("2 items")
clients = []
for _ in range(25):
client = models.Client(client_id=gen_salt(48), client_name=gen_salt(48))
client.save()
backend.save(client)
clients.append(client)
res = testclient.get("/admin/client")
@ -216,22 +218,22 @@ def test_client_edit_missing_fields(testclient, client, logged_admin, trusted_cl
def test_client_delete(testclient, logged_admin, backend):
client = models.Client(client_id="client_id")
client.save()
backend.save(client)
token = models.Token(
token_id="id",
client=client,
subject=logged_admin,
issue_date=datetime.datetime.now(datetime.timezone.utc),
)
token.save()
backend.save(token)
consent = models.Consent(
consent_id="consent_id", subject=logged_admin, client=client, scope=["openid"]
)
consent.save()
backend.save(consent)
authorization_code = models.AuthorizationCode(
authorization_code_id="id", client=client, subject=logged_admin
)
authorization_code.save()
backend.save(authorization_code)
res = testclient.get("/admin/client/edit/" + client.client_id)
res = res.forms["clientaddform"].submit(name="action", value="confirm-delete")

View file

@ -16,7 +16,7 @@ def test_authorizaton_list(testclient, authorization, logged_admin):
res.mustcontain(authorization.authorization_code_id)
def test_authorization_list_pagination(testclient, logged_admin, client):
def test_authorization_list_pagination(testclient, logged_admin, client, backend):
res = testclient.get("/admin/authorization")
res.mustcontain("0 items")
authorizations = []
@ -24,7 +24,7 @@ def test_authorization_list_pagination(testclient, logged_admin, client):
code = models.AuthorizationCode(
authorization_code_id=gen_salt(48), client=client, subject=logged_admin
)
code.save()
backend.save(code)
authorizations.append(code)
res = testclient.get("/admin/authorization")
@ -64,18 +64,18 @@ def test_authorization_list_bad_pages(testclient, logged_admin):
)
def test_authorization_list_search(testclient, logged_admin, client):
def test_authorization_list_search(testclient, logged_admin, client, backend):
id1 = gen_salt(48)
auth1 = models.AuthorizationCode(
authorization_code_id=id1, client=client, subject=logged_admin
)
auth1.save()
backend.save(auth1)
id2 = gen_salt(48)
auth2 = models.AuthorizationCode(
authorization_code_id=id2, client=client, subject=logged_admin
)
auth2.save()
backend.save(auth2)
res = testclient.get("/admin/authorization")
res.mustcontain("2 items")

View file

@ -139,13 +139,15 @@ def test_oidc_authorization_after_revokation(
assert token.subject == logged_user
def test_preconsented_client_appears_in_consent_list(testclient, client, logged_user):
def test_preconsented_client_appears_in_consent_list(
testclient, client, logged_user, backend
):
assert not client.preconsent
res = testclient.get("/consent/pre-consents")
res.mustcontain(no=client.client_name)
client.preconsent = True
client.save()
backend.save(client)
res = testclient.get("/consent/pre-consents")
res.mustcontain(client.client_name)
@ -153,7 +155,7 @@ def test_preconsented_client_appears_in_consent_list(testclient, client, logged_
def test_revoke_preconsented_client(testclient, client, logged_user, token, backend):
client.preconsent = True
client.save()
backend.save(client)
assert not backend.get(models.Consent)
assert not token.revoked
@ -190,22 +192,22 @@ def test_revoke_invalid_preconsented_client(testclient, logged_user):
def test_revoke_preconsented_client_with_manual_consent(
testclient, logged_user, client, consent
testclient, logged_user, client, consent, backend
):
client.preconsent = True
client.save()
backend.save(client)
res = testclient.get(f"/consent/revoke-preconsent/{client.client_id}", status=302)
res = res.follow()
assert ("success", "The access has been revoked") in res.flashes
def test_revoke_preconsented_client_with_manual_revokation(
testclient, logged_user, client, consent
testclient, logged_user, client, consent, backend
):
client.preconsent = True
client.save()
backend.save(client)
consent.revoke()
consent.save()
backend.save(consent)
res = testclient.get(f"/consent/revoke-preconsent/{client.client_id}", status=302)
res = res.follow()

View file

@ -146,7 +146,7 @@ def test_delete(testclient, backend, user):
]
client = models.Client(client_id="foobar", client_name="Some client")
client.save()
backend.save(client)
headers = {"Authorization": "Bearer static-token"}
with warnings.catch_warnings(record=True):

View file

@ -10,7 +10,7 @@ def test_oauth_implicit(testclient, user, client, backend):
client.grant_types = ["token"]
client.token_endpoint_auth_method = "none"
client.save()
backend.save(client)
res = testclient.get(
"/oauth/authorize",
@ -48,14 +48,14 @@ def test_oauth_implicit(testclient, user, client, backend):
client.grant_types = ["code"]
client.token_endpoint_auth_method = "client_secret_basic"
client.save()
backend.save(client)
def test_oidc_implicit(testclient, keypair, user, client, trusted_client, backend):
client.grant_types = ["token id_token"]
client.token_endpoint_auth_method = "none"
client.save()
backend.save(client)
res = testclient.get(
"/oauth/authorize",
@ -101,7 +101,7 @@ def test_oidc_implicit(testclient, keypair, user, client, trusted_client, backen
client.grant_types = ["code"]
client.token_endpoint_auth_method = "client_secret_basic"
client.save()
backend.save(client)
def test_oidc_implicit_with_group(
@ -110,7 +110,7 @@ def test_oidc_implicit_with_group(
client.grant_types = ["token id_token"]
client.token_endpoint_auth_method = "none"
client.save()
backend.save(client)
res = testclient.get(
"/oauth/authorize",
@ -157,4 +157,4 @@ def test_oidc_implicit_with_group(
client.grant_types = ["code"]
client.token_endpoint_auth_method = "client_secret_basic"
client.save()
backend.save(client)

View file

@ -33,7 +33,7 @@ def test_password_flow_basic(testclient, user, client, backend):
def test_password_flow_post(testclient, user, client, backend):
client.token_endpoint_auth_method = "client_secret_post"
client.save()
backend.save(client)
res = testclient.post(
"/oauth/token",

View file

@ -82,7 +82,7 @@ def test_refresh_token_with_invalid_user(testclient, client, backend):
emails=["temp@temp.com"],
password="correct horse battery staple",
)
user.save()
backend.save(user)
res = testclient.get(
"/oauth/authorize",
@ -137,7 +137,9 @@ def test_refresh_token_with_invalid_user(testclient, client, backend):
backend.get(models.Token, access_token=access_token).delete()
def test_cannot_refresh_token_for_locked_users(testclient, logged_user, client):
def test_cannot_refresh_token_for_locked_users(
testclient, logged_user, client, backend
):
"""Canaille should not issue new tokens for locked users."""
res = testclient.get(
"/oauth/authorize",
@ -167,7 +169,7 @@ def test_cannot_refresh_token_for_locked_users(testclient, logged_user, client):
)
logged_user.lock_date = datetime.datetime.now(datetime.timezone.utc)
logged_user.save()
backend.save(logged_user)
res = testclient.post(
"/oauth/token",

View file

@ -18,7 +18,7 @@ def test_token_list(testclient, token, logged_admin):
res.mustcontain(token.token_id)
def test_token_list_pagination(testclient, logged_admin, client):
def test_token_list_pagination(testclient, logged_admin, client, backend):
res = testclient.get("/admin/token")
res.mustcontain("0 items")
tokens = []
@ -36,7 +36,7 @@ def test_token_list_pagination(testclient, logged_admin, client):
),
lifetime=3600,
)
token.save()
backend.save(token)
tokens.append(token)
res = testclient.get("/admin/token")
@ -74,7 +74,7 @@ def test_token_list_bad_pages(testclient, logged_admin):
)
def test_token_list_search(testclient, logged_admin, client):
def test_token_list_search(testclient, logged_admin, client, backend):
token1 = models.Token(
token_id=gen_salt(48),
access_token="this-token-is-ok",
@ -88,7 +88,7 @@ def test_token_list_search(testclient, logged_admin, client):
),
lifetime=3600,
)
token1.save()
backend.save(token1)
token2 = models.Token(
token_id=gen_salt(48),
access_token="this-token-is-valid",
@ -102,7 +102,7 @@ def test_token_list_search(testclient, logged_admin, client):
),
lifetime=3600,
)
token2.save()
backend.save(token2)
res = testclient.get("/admin/token")
res.mustcontain("2 items")

View file

@ -63,11 +63,11 @@ def test_revoke_refresh_token_with_hint(testclient, user, client, token):
assert token.revokation_date
def test_cannot_refresh_after_revocation(testclient, user, client, token):
def test_cannot_refresh_after_revocation(testclient, user, client, token, backend):
token.revokation_date = datetime.datetime.now(
datetime.timezone.utc
) - datetime.timedelta(days=7)
token.save()
backend.save(token)
res = testclient.post(
"/oauth/token",

View file

@ -146,9 +146,9 @@ def test_generate_user_claims(user, foo_group):
}
def test_userinfo(testclient, token, user, foo_group):
def test_userinfo(testclient, token, user, foo_group, backend):
token.scope = ["openid"]
token.save()
backend.save(token)
testclient.get(
"/oauth/userinfo",
headers={"Authorization": f"Bearer {token.access_token}"},
@ -156,7 +156,7 @@ def test_userinfo(testclient, token, user, foo_group):
)
token.scope = ["openid", "profile"]
token.save()
backend.save(token)
res = testclient.get(
"/oauth/userinfo",
headers={"Authorization": f"Bearer {token.access_token}"},
@ -172,7 +172,7 @@ def test_userinfo(testclient, token, user, foo_group):
}
token.scope = ["openid", "profile", "email"]
token.save()
backend.save(token)
res = testclient.get(
"/oauth/userinfo",
headers={"Authorization": f"Bearer {token.access_token}"},
@ -189,7 +189,7 @@ def test_userinfo(testclient, token, user, foo_group):
}
token.scope = ["openid", "profile", "address"]
token.save()
backend.save(token)
res = testclient.get(
"/oauth/userinfo",
headers={"Authorization": f"Bearer {token.access_token}"},
@ -206,7 +206,7 @@ def test_userinfo(testclient, token, user, foo_group):
}
token.scope = ["openid", "profile", "phone"]
token.save()
backend.save(token)
res = testclient.get(
"/oauth/userinfo",
headers={"Authorization": f"Bearer {token.access_token}"},
@ -223,7 +223,7 @@ def test_userinfo(testclient, token, user, foo_group):
}
token.scope = ["openid", "profile", "groups"]
token.save()
backend.save(token)
res = testclient.get(
"/oauth/userinfo",
headers={"Authorization": f"Bearer {token.access_token}"},
@ -296,7 +296,7 @@ def test_claim_is_omitted_if_empty(testclient, backend, user):
# According to https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
# it's better to not insert a null or empty string value
user.emails = []
user.save()
backend.save(user)
default_jwt_mapping = JWTSettings().model_dump()
data = generate_user_claims(user, STANDARD_CLAIMS, default_jwt_mapping)