refactor: move BackendModel.save to Backend.save

This commit is contained in:
Éloi Rivard 2024-04-14 20:31:43 +02:00
parent 44573713ed
commit 09588e0f48
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184
49 changed files with 327 additions and 306 deletions

View file

@ -85,6 +85,10 @@ class BaseBackend:
only one element or :py:data:`None` if no item is matching.""" only one element or :py:data:`None` if no item is matching."""
raise NotImplementedError() raise NotImplementedError()
def save(self, instance):
"""Validate the current modifications in the database."""
raise NotImplementedError()
def check_user_password(self, user, password: str) -> bool: def check_user_password(self, user, password: str) -> bool:
"""Check if the password matches the user password in the database.""" """Check if the password matches the user password in the database."""
raise NotImplementedError() raise NotImplementedError()

View file

@ -9,6 +9,7 @@ from flask import current_app
from ldap.controls import DecodeControlTuples from ldap.controls import DecodeControlTuples
from ldap.controls.ppolicy import PasswordPolicyControl from ldap.controls.ppolicy import PasswordPolicyControl
from ldap.controls.ppolicy import PasswordPolicyError from ldap.controls.ppolicy import PasswordPolicyError
from ldap.controls.readentry import PostReadControl
from canaille.app import models from canaille.app import models
from canaille.app.configuration import ConfigurationException from canaille.app.configuration import ConfigurationException
@ -128,7 +129,7 @@ class Backend(BaseBackend):
emails=f"canaille_{uuid.uuid4()}@mydomain.tld", emails=f"canaille_{uuid.uuid4()}@mydomain.tld",
password="correct horse battery staple", password="correct horse battery staple",
) )
user.save() BaseBackend.instance.save(user)
user.delete() user.delete()
except ldap.INSUFFICIENT_ACCESS as exc: except ldap.INSUFFICIENT_ACCESS as exc:
@ -147,13 +148,13 @@ class Backend(BaseBackend):
emails=f"canaille_{uuid.uuid4()}@mydomain.tld", emails=f"canaille_{uuid.uuid4()}@mydomain.tld",
password="correct horse battery staple", password="correct horse battery staple",
) )
user.save() BaseBackend.instance.save(user)
group = models.Group( group = models.Group(
display_name=f"canaille_{uuid.uuid4()}", display_name=f"canaille_{uuid.uuid4()}",
members=[user], members=[user],
) )
group.save() BaseBackend.instance.save(group)
group.delete() group.delete()
except ldap.INSUFFICIENT_ACCESS as exc: except ldap.INSUFFICIENT_ACCESS as exc:
@ -324,6 +325,69 @@ class Backend(BaseBackend):
return None return None
def save(self, instance):
# run the instance save callback if existing
save_callback = instance.save() if hasattr(instance, "save") else iter([])
next(save_callback, None)
current_object_classes = instance.get_ldap_attribute("objectClass") or []
instance.set_ldap_attribute(
"objectClass",
list(set(instance.ldap_object_class) | set(current_object_classes)),
)
# PostReadControl allows to read the updated object attributes on creation/edition
attributes = ["objectClass"] + [
instance.python_attribute_to_ldap(name) for name in instance.attributes
]
read_post_control = PostReadControl(criticality=True, attrList=attributes)
# Object already exists in the LDAP database
if instance.exists:
deletions = [
name
for name, value in instance.changes.items()
if (
value is None
or value == []
or (isinstance(value, list) and len(value) == 1 and not value[0])
)
and name in instance.state
]
changes = {
name: value
for name, value in instance.changes.items()
if name not in deletions and instance.state.get(name) != value
}
formatted_changes = python_attrs_to_ldap(changes, null_allowed=False)
modlist = [(ldap.MOD_DELETE, name, None) for name in deletions] + [
(ldap.MOD_REPLACE, name, values)
for name, values in formatted_changes.items()
]
_, _, _, [result] = self.connection.modify_ext_s(
instance.dn, modlist, serverctrls=[read_post_control]
)
# Object does not exist yet in the LDAP database
else:
changes = {
name: value
for name, value in {**instance.state, **instance.changes}.items()
if value and value[0]
}
formatted_changes = python_attrs_to_ldap(changes, null_allowed=False)
modlist = [(name, values) for name, values in formatted_changes.items()]
_, _, _, [result] = self.connection.add_ext_s(
instance.dn, modlist, serverctrls=[read_post_control]
)
instance.exists = True
instance.state = {**result.entry, **instance.changes}
instance.changes = {}
# run the instance save callback again if existing
next(save_callback, None)
def setup_ldap_models(config): def setup_ldap_models(config):
from canaille.app import models from canaille.app import models

View file

@ -2,7 +2,6 @@ import typing
import ldap.dn import ldap.dn
import ldap.filter import ldap.filter
from ldap.controls.readentry import PostReadControl
from canaille.backends.models import BackendModel from canaille.backends.models import BackendModel
@ -264,64 +263,6 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
self.changes = {} self.changes = {}
self.state = result[0][1] self.state = result[0][1]
def save(self):
conn = Backend.instance.connection
current_object_classes = self.get_ldap_attribute("objectClass") or []
self.set_ldap_attribute(
"objectClass",
list(set(self.ldap_object_class) | set(current_object_classes)),
)
# PostReadControl allows to read the updated object attributes on creation/edition
attributes = ["objectClass"] + [
self.python_attribute_to_ldap(name) for name in self.attributes
]
read_post_control = PostReadControl(criticality=True, attrList=attributes)
# Object already exists in the LDAP database
if self.exists:
deletions = [
name
for name, value in self.changes.items()
if (
value is None
or value == []
or (isinstance(value, list) and len(value) == 1 and not value[0])
)
and name in self.state
]
changes = {
name: value
for name, value in self.changes.items()
if name not in deletions and self.state.get(name) != value
}
formatted_changes = python_attrs_to_ldap(changes, null_allowed=False)
modlist = [(ldap.MOD_DELETE, name, None) for name in deletions] + [
(ldap.MOD_REPLACE, name, values)
for name, values in formatted_changes.items()
]
_, _, _, [result] = conn.modify_ext_s(
self.dn, modlist, serverctrls=[read_post_control]
)
# Object does not exist yet in the LDAP database
else:
changes = {
name: value
for name, value in {**self.state, **self.changes}.items()
if value and value[0]
}
formatted_changes = python_attrs_to_ldap(changes, null_allowed=False)
modlist = [(name, values) for name, values in formatted_changes.items()]
_, _, _, [result] = conn.add_ext_s(
self.dn, modlist, serverctrls=[read_post_control]
)
self.exists = True
self.state = {**result.entry, **self.changes}
self.changes = {}
def delete(self): def delete(self):
conn = Backend.instance.connection conn = Backend.instance.connection
try: try:

View file

@ -43,10 +43,10 @@ class User(canaille.core.models.User, LDAPObject):
return super().match_filter(filter) return super().match_filter(filter)
def save(self, *args, **kwargs): def save(self):
group_attr = self.python_attribute_to_ldap("groups") group_attr = self.python_attribute_to_ldap("groups")
if group_attr not in self.changes: if group_attr not in self.changes:
return super().save(*args, **kwargs) return
# The LDAP attribute memberOf cannot directly be edited, # The LDAP attribute memberOf cannot directly be edited,
# so this is needed to update the Group.member attribute # so this is needed to update the Group.member attribute
@ -60,11 +60,11 @@ class User(canaille.core.models.User, LDAPObject):
to_del = set(old_groups) - set(new_groups) to_del = set(old_groups) - set(new_groups)
del self.changes[group_attr] del self.changes[group_attr]
super().save(*args, **kwargs) yield
for group in to_add: for group in to_add:
group.members = group.members + [self] group.members = group.members + [self]
group.save() Backend.instance.save(group)
for group in to_del: for group in to_del:
# LDAP groups cannot be empty because groupOfNames.member # LDAP groups cannot be empty because groupOfNames.member
@ -73,7 +73,7 @@ class User(canaille.core.models.User, LDAPObject):
# TODO: properly manage the situation where one wants to # TODO: properly manage the situation where one wants to
# remove the last member of a group # remove the last member of a group
group.members = [member for member in group.members if member != self] group.members = [member for member in group.members if member != self]
group.save() Backend.instance.save(group)
self.state[group_attr] = new_groups self.state[group_attr] = new_groups

View file

@ -1,3 +1,6 @@
import datetime
import uuid
from canaille.backends import BaseBackend from canaille.backends import BaseBackend
@ -39,7 +42,7 @@ class Backend(BaseBackend):
def set_user_password(self, user, password): def set_user_password(self, user, password):
user.password = password user.password = password
user.save() self.save(user)
def query(self, model, **kwargs): def query(self, model, **kwargs):
# if there is no filter, return all models # if there is no filter, return all models
@ -91,3 +94,17 @@ class Backend(BaseBackend):
results = self.query(model, **kwargs) results = self.query(model, **kwargs)
return results[0] if results else None return results[0] if results else None
def save(self, instance):
if not instance.id:
instance.id = str(uuid.uuid4())
instance.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0
)
if not instance.created:
instance.created = instance.last_modified
instance.index_delete()
instance.index_save()
instance._cache = {}

View file

@ -1,7 +1,5 @@
import copy import copy
import datetime
import typing import typing
import uuid
import canaille.core.models import canaille.core.models
import canaille.oidc.models import canaille.oidc.models
@ -67,20 +65,6 @@ class MemoryModel(BackendModel):
return value return value
def save(self):
if not self.id:
self.id = str(uuid.uuid4())
self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0
)
if not self.created:
self.created = self.last_modified
self.index_delete()
self.index_save()
self._cache = {}
def delete(self): def delete(self):
self.index_delete() self.index_delete()

View file

@ -88,10 +88,6 @@ class BackendModel:
implemented for every model and for every backend. implemented for every model and for every backend.
""" """
def save(self):
"""Validate the current modifications in the database."""
raise NotImplementedError()
def delete(self): def delete(self):
"""Remove the current instance from the database.""" """Remove the current instance from the database."""
raise NotImplementedError() raise NotImplementedError()

View file

@ -1,3 +1,5 @@
import datetime
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy import or_ from sqlalchemy import or_
from sqlalchemy import select from sqlalchemy import select
@ -66,7 +68,7 @@ class Backend(BaseBackend):
def set_user_password(self, user, password): def set_user_password(self, user, password):
user.password = password user.password = password
user.save() self.save(user)
def query(self, model, **kwargs): def query(self, model, **kwargs):
filter = [ filter = [
@ -106,3 +108,13 @@ class Backend(BaseBackend):
return Backend.instance.db_session.execute( return Backend.instance.db_session.execute(
select(model).filter(*filter) select(model).filter(*filter)
).scalar_one_or_none() ).scalar_one_or_none()
def save(self, instance):
instance.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0
)
if not instance.created:
instance.created = instance.last_modified
Backend.instance.db_session.add(instance)
Backend.instance.db_session.commit()

View file

@ -47,16 +47,6 @@ class SqlAlchemyModel(BackendModel):
return getattr(cls, name) == value return getattr(cls, name) == value
def save(self):
self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0
)
if not self.created:
self.created = self.last_modified
Backend.instance.db_session.add(self)
Backend.instance.db_session.commit()
def delete(self): def delete(self):
Backend.instance.db_session.delete(self) Backend.instance.db_session.delete(self)
Backend.instance.db_session.commit() Backend.instance.db_session.commit()

View file

@ -404,7 +404,7 @@ def email_confirmation(data, hash):
return redirect(url_for("core.account.index")) return redirect(url_for("core.account.index"))
user.emails = user.emails + [confirmation_obj.email] user.emails = user.emails + [confirmation_obj.email]
user.save() BaseBackend.instance.save(user)
flash(_("Your email address have been confirmed."), "success") flash(_("Your email address have been confirmed."), "success")
return redirect(url_for("core.account.index")) return redirect(url_for("core.account.index"))
@ -460,11 +460,11 @@ def profile_create(current_app, form):
given_name = user.given_name if user.given_name else "" given_name = user.given_name if user.given_name else ""
family_name = user.family_name if user.family_name else "" family_name = user.family_name if user.family_name else ""
user.formatted_name = f"{given_name} {family_name}".strip() user.formatted_name = f"{given_name} {family_name}".strip()
user.save() BaseBackend.instance.save(user)
if form["password1"].data: if form["password1"].data:
BaseBackend.instance.set_user_password(user, form["password1"].data) BaseBackend.instance.set_user_password(user, form["password1"].data)
user.save() BaseBackend.instance.save(user)
return user return user
@ -536,7 +536,7 @@ def profile_edition_main_form_validation(user, edited_user, profile_form):
if profile_form["preferred_language"].data == "auto": if profile_form["preferred_language"].data == "auto":
edited_user.preferred_language = None edited_user.preferred_language = None
edited_user.save() BaseBackend.instance.save(edited_user)
g.user.reload() g.user.reload()
@ -574,7 +574,7 @@ def profile_edition_remove_email(user, edited_user, email):
return False return False
edited_user.emails = [m for m in edited_user.emails if m != email] edited_user.emails = [m for m in edited_user.emails if m != email]
edited_user.save() BaseBackend.instance.save(edited_user)
return True return True
@ -730,7 +730,7 @@ def profile_settings(user, edited_user):
): ):
flash(_("The account has been locked"), "success") flash(_("The account has been locked"), "success")
edited_user.lock_date = datetime.datetime.now(datetime.timezone.utc) edited_user.lock_date = datetime.datetime.now(datetime.timezone.utc)
edited_user.save() BaseBackend.instance.save(edited_user)
return profile_settings_edit(user, edited_user) return profile_settings_edit(user, edited_user)
@ -741,7 +741,7 @@ def profile_settings(user, edited_user):
): ):
flash(_("The account has been unlocked"), "success") flash(_("The account has been unlocked"), "success")
edited_user.lock_date = None edited_user.lock_date = None
edited_user.save() BaseBackend.instance.save(edited_user)
return profile_settings_edit(user, edited_user) return profile_settings_edit(user, edited_user)
@ -791,7 +791,7 @@ def profile_settings_edit(editor, edited_user):
edited_user, form["password1"].data edited_user, form["password1"].data
) )
edited_user.save() BaseBackend.instance.save(edited_user)
flash(_("Profile updated successfully."), "success") flash(_("Profile updated successfully."), "success")
return redirect( return redirect(
url_for("core.account.profile_settings", edited_user=edited_user) url_for("core.account.profile_settings", edited_user=edited_user)

View file

@ -11,6 +11,7 @@ from canaille.app.flask import render_htmx_template
from canaille.app.forms import TableForm from canaille.app.forms import TableForm
from canaille.app.i18n import gettext as _ from canaille.app.i18n import gettext as _
from canaille.app.themes import render_template from canaille.app.themes import render_template
from canaille.backends import BaseBackend
from .forms import CreateGroupForm from .forms import CreateGroupForm
from .forms import DeleteGroupMemberForm from .forms import DeleteGroupMemberForm
@ -42,7 +43,7 @@ def create_group(user):
group.members = [user] group.members = [user]
group.display_name = form.display_name.data group.display_name = form.display_name.data
group.description = form.description.data group.description = form.description.data
group.save() BaseBackend.instance.save(group)
flash( flash(
_( _(
"The group %(group)s has been sucessfully created", "The group %(group)s has been sucessfully created",
@ -102,7 +103,7 @@ def edit_group(group):
): ):
if form.validate(): if form.validate():
group.description = form.description.data group.description = form.description.data
group.save() BaseBackend.instance.save(group)
flash( flash(
_( _(
"The group %(group)s has been sucessfully edited.", "The group %(group)s has been sucessfully edited.",
@ -151,7 +152,7 @@ def delete_member(group):
group.members = [ group.members = [
member for member in group.members if member != form.member.data member for member in group.members if member != form.member.data
] ]
group.save() BaseBackend.instance.save(group)
return edit_group(group) return edit_group(group)

View file

@ -40,7 +40,7 @@ def fake_users(nb=1):
password=fake.password(), password=fake.password(),
preferred_language=fake._locales[0], preferred_language=fake._locales[0],
) )
user.save() BaseBackend.instance.save(user)
users.append(user) users.append(user)
except Exception: # pragma: no cover except Exception: # pragma: no cover
pass pass
@ -59,7 +59,7 @@ def fake_groups(nb=1, nb_users_max=1):
) )
nb_users = random.randrange(1, nb_users_max + 1) nb_users = random.randrange(1, nb_users_max + 1)
group.members = list({random.choice(users) for _ in range(nb_users)}) group.members = list({random.choice(users) for _ in range(nb_users)})
group.save() BaseBackend.instance.save(group)
groups.append(group) groups.append(group)
except Exception: # pragma: no cover except Exception: # pragma: no cover
pass pass

View file

@ -14,6 +14,7 @@ from canaille.app.flask import render_htmx_template
from canaille.app.forms import TableForm from canaille.app.forms import TableForm
from canaille.app.i18n import gettext as _ from canaille.app.i18n import gettext as _
from canaille.app.themes import render_template from canaille.app.themes import render_template
from canaille.backends import BaseBackend
from .forms import ClientAddForm from .forms import ClientAddForm
@ -73,9 +74,9 @@ def add(user):
if form["token_endpoint_auth_method"].data == "none" if form["token_endpoint_auth_method"].data == "none"
else gen_salt(48), else gen_salt(48),
) )
client.save() BaseBackend.instance.save(client)
client.audience = [client] client.audience = [client]
client.save() BaseBackend.instance.save(client)
flash( flash(
_("The client has been created."), _("The client has been created."),
"success", "success",
@ -137,7 +138,7 @@ def client_edit(client):
audience=form["audience"].data, audience=form["audience"].data,
preconsent=form["preconsent"].data, preconsent=form["preconsent"].data,
) )
client.save() BaseBackend.instance.save(client)
flash( flash(
_("The client has been edited."), _("The client has been edited."),
"success", "success",

View file

@ -95,7 +95,7 @@ def restore(user, consent):
consent.restore() consent.restore()
if not consent.issue_date: if not consent.issue_date:
consent.issue_date = datetime.datetime.now(datetime.timezone.utc) consent.issue_date = datetime.datetime.now(datetime.timezone.utc)
consent.save() BaseBackend.instance.save(consent)
flash(_("The access has been restored"), "success") flash(_("The access has been restored"), "success")
return redirect(url_for("oidc.consents.consents")) return redirect(url_for("oidc.consents.consents"))
@ -119,7 +119,7 @@ def revoke_preconsent(user, client):
scope=client.scope, scope=client.scope,
) )
consent.revoke() consent.revoke()
consent.save() BaseBackend.instance.save(consent)
flash(_("The access has been revoked"), "success") flash(_("The access has been revoked"), "success")
return redirect(url_for("oidc.consents.consents")) return redirect(url_for("oidc.consents.consents"))

View file

@ -177,7 +177,7 @@ def authorize_consent(client, user):
scope=allowed_scopes, scope=allowed_scopes,
issue_date=datetime.datetime.now(datetime.timezone.utc), issue_date=datetime.datetime.now(datetime.timezone.utc),
) )
consent.save() BaseBackend.instance.save(consent)
response = authorization.create_authorization_response(grant_user=grant_user) response = authorization.create_authorization_response(grant_user=grant_user)
current_app.logger.debug("authorization endpoint response: %s", response.location) current_app.logger.debug("authorization endpoint response: %s", response.location)

View file

@ -11,6 +11,7 @@ from canaille.app.flask import render_htmx_template
from canaille.app.forms import TableForm from canaille.app.forms import TableForm
from canaille.app.i18n import gettext as _ from canaille.app.i18n import gettext as _
from canaille.app.themes import render_template from canaille.app.themes import render_template
from canaille.backends import BaseBackend
from .forms import TokenRevokationForm from .forms import TokenRevokationForm
@ -40,7 +41,7 @@ def view(user, token):
elif request.form.get("action") == "revoke": elif request.form.get("action") == "revoke":
token.revokation_date = datetime.datetime.now(datetime.timezone.utc) token.revokation_date = datetime.datetime.now(datetime.timezone.utc)
token.save() BaseBackend.instance.save(token)
flash(_("The token has successfully been revoked."), "success") flash(_("The token has successfully been revoked."), "success")
else: else:

View file

@ -184,7 +184,7 @@ class Consent(BaseConsent):
def revoke(self): def revoke(self):
self.revokation_date = datetime.datetime.now(datetime.timezone.utc) self.revokation_date = datetime.datetime.now(datetime.timezone.utc)
self.save() BaseBackend.instance.save(self)
tokens = BaseBackend.instance.query( tokens = BaseBackend.instance.query(
models.Token, models.Token,
@ -194,8 +194,8 @@ class Consent(BaseConsent):
tokens = [token for token in tokens if not token.revoked] tokens = [token for token in tokens if not token.revoked]
for t in tokens: for t in tokens:
t.revokation_date = self.revokation_date t.revokation_date = self.revokation_date
t.save() BaseBackend.instance.save(t)
def restore(self): def restore(self):
self.revokation_date = None self.revokation_date = None
self.save() BaseBackend.instance.save(self)

View file

@ -228,7 +228,7 @@ def save_authorization_code(code, request):
challenge=request.data.get("code_challenge"), challenge=request.data.get("code_challenge"),
challenge_method=request.data.get("code_challenge_method"), challenge_method=request.data.get("code_challenge_method"),
) )
code.save() BaseBackend.instance.save(code)
return code.code return code.code
@ -297,7 +297,7 @@ class RefreshTokenGrant(_RefreshTokenGrant):
def revoke_old_credential(self, credential): def revoke_old_credential(self, credential):
credential.revokation_date = datetime.datetime.now(datetime.timezone.utc) credential.revokation_date = datetime.datetime.now(datetime.timezone.utc)
credential.save() BaseBackend.instance.save(credential)
class OpenIDImplicitGrant(_OpenIDImplicitGrant): class OpenIDImplicitGrant(_OpenIDImplicitGrant):
@ -351,7 +351,7 @@ def save_token(token, request):
subject=request.user, subject=request.user,
audience=request.client.audience, audience=request.client.audience,
) )
t.save() BaseBackend.instance.save(t)
class BearerTokenValidator(_BearerTokenValidator): class BearerTokenValidator(_BearerTokenValidator):
@ -382,7 +382,7 @@ class RevocationEndpoint(_RevocationEndpoint):
def revoke_token(self, token, request): def revoke_token(self, token, request):
token.revokation_date = datetime.datetime.now(datetime.timezone.utc) token.revokation_date = datetime.datetime.now(datetime.timezone.utc)
token.save() BaseBackend.instance.save(token)
class IntrospectionEndpoint(_IntrospectionEndpoint): class IntrospectionEndpoint(_IntrospectionEndpoint):
@ -463,9 +463,9 @@ class ClientRegistrationEndpoint(ClientManagementMixin, _ClientRegistrationEndpo
post_logout_redirect_uris=request.data.get("post_logout_redirect_uris"), post_logout_redirect_uris=request.data.get("post_logout_redirect_uris"),
**self.client_convert_data(**client_info, **client_metadata), **self.client_convert_data(**client_info, **client_metadata),
) )
client.save() BaseBackend.instance.save(client)
client.audience = [client] client.audience = [client]
client.save() BaseBackend.instance.save(client)
return client return client
@ -485,7 +485,7 @@ class ClientConfigurationEndpoint(ClientManagementMixin, _ClientConfigurationEnd
def update_client(self, client, client_metadata, request): def update_client(self, client, client_metadata, request):
client.update(**self.client_convert_data(**client_metadata)) client.update(**self.client_convert_data(**client_metadata))
client.save() BaseBackend.instance.save(client)
return client return client
def generate_client_registration_info(self, client, request): def generate_client_registration_info(self, client, request):

View file

@ -206,13 +206,13 @@ def test_fieldlist_add_readonly(testclient, logged_user):
testclient.post("/profile/user", data, status=403) testclient.post("/profile/user", data, status=403)
def test_fieldlist_remove_readonly(testclient, logged_user): def test_fieldlist_remove_readonly(testclient, logged_user, backend):
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["WRITE"].remove("phone_numbers") testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["WRITE"].remove("phone_numbers")
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["READ"].append("phone_numbers") testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"]["READ"].append("phone_numbers")
logged_user.reload() logged_user.reload()
logged_user.phone_numbers = ["555-555-000", "555-555-111"] logged_user.phone_numbers = ["555-555-000", "555-555-111"]
logged_user.save() backend.save(logged_user)
res = testclient.get("/profile/user") res = testclient.get("/profile/user")
form = res.forms["baseform"] form = res.forms["baseform"]

View file

@ -1,9 +1,9 @@
from flask_babel import refresh from flask_babel import refresh
def test_preferred_language(testclient, logged_user): def test_preferred_language(testclient, logged_user, backend):
logged_user.preferred_language = None logged_user.preferred_language = None
logged_user.save() backend.save(logged_user)
res = testclient.get("/profile/user", status=200) res = testclient.get("/profile/user", status=200)
form = res.forms["baseform"] form = res.forms["baseform"]
@ -49,9 +49,9 @@ def test_preferred_language(testclient, logged_user):
res.mustcontain(no="Mon profil") res.mustcontain(no="Mon profil")
def test_form_translations(testclient, logged_user): def test_form_translations(testclient, logged_user, backend):
logged_user.preferred_language = "fr" logged_user.preferred_language = "fr"
logged_user.save() backend.save(logged_user)
res = testclient.get("/profile/user", status=200) res = testclient.get("/profile/user", status=200)
form = res.forms["baseform"] form = res.forms["baseform"]
@ -62,9 +62,9 @@ def test_form_translations(testclient, logged_user):
res.mustcontain("Nest pas un numéro de téléphone valide") res.mustcontain("Nest pas un numéro de téléphone valide")
def test_language_config(testclient, logged_user): def test_language_config(testclient, logged_user, backend):
logged_user.preferred_language = None logged_user.preferred_language = None
logged_user.save() backend.save(logged_user)
res = testclient.get("/profile/user", status=200) res = testclient.get("/profile/user", status=200)
assert res.pyquery("html")[0].attrib["lang"] == "en" assert res.pyquery("html")[0].attrib["lang"] == "en"

View file

@ -7,7 +7,7 @@ def test_model_references_set_unsaved_object(
"""LDAP groups can be inconsistent by containing members which doesn't """LDAP groups can be inconsistent by containing members which doesn't
exist.""" exist."""
group = models.Group(members=[user], display_name="foo") group = models.Group(members=[user], display_name="foo")
group.save() backend.save(group)
user.reload() user.reload()
non_existent_user = models.User( non_existent_user = models.User(
@ -16,7 +16,7 @@ def test_model_references_set_unsaved_object(
group.members = group.members + [non_existent_user] group.members = group.members + [non_existent_user]
assert group.members == [user, non_existent_user] assert group.members == [user, non_existent_user]
group.save() backend.save(group)
assert group.members == [user, non_existent_user] assert group.members == [user, non_existent_user]
group.reload() group.reload()

View file

@ -5,7 +5,7 @@ from canaille.backends.ldap.ldapobject import LDAPObject
def test_guess_object_from_dn(backend, testclient, foo_group): def test_guess_object_from_dn(backend, testclient, foo_group):
foo_group.members = [foo_group] foo_group.members = [foo_group]
foo_group.save() backend.save(foo_group)
dn = foo_group.dn dn = foo_group.dn
g = backend.get(LDAPObject, dn) g = backend.get(LDAPObject, dn)
assert isinstance(g, models.Group) assert isinstance(g, models.Group)
@ -18,7 +18,7 @@ def test_object_class_update(backend, testclient):
setup_ldap_models(testclient.app.config) setup_ldap_models(testclient.app.config)
user1 = models.User(cn="foo1", sn="bar1", user_name="baz1") user1 = models.User(cn="foo1", sn="bar1", user_name="baz1")
user1.save() backend.save(user1)
assert set(user1.get_ldap_attribute("objectClass")) == {"inetOrgPerson"} assert set(user1.get_ldap_attribute("objectClass")) == {"inetOrgPerson"}
assert set( assert set(
@ -32,7 +32,7 @@ def test_object_class_update(backend, testclient):
setup_ldap_models(testclient.app.config) setup_ldap_models(testclient.app.config)
user2 = models.User(cn="foo2", sn="bar2", user_name="baz2") user2 = models.User(cn="foo2", sn="bar2", user_name="baz2")
user2.save() backend.save(user2)
assert set(user2.get_ldap_attribute("objectClass")) == { assert set(user2.get_ldap_attribute("objectClass")) == {
"inetOrgPerson", "inetOrgPerson",
@ -48,7 +48,7 @@ def test_object_class_update(backend, testclient):
user1 = backend.get(models.User, id=user1.id) user1 = backend.get(models.User, id=user1.id)
assert user1.get_ldap_attribute("objectClass") == ["inetOrgPerson"] assert user1.get_ldap_attribute("objectClass") == ["inetOrgPerson"]
user1.save() backend.save(user1)
assert set(user1.get_ldap_attribute("objectClass")) == { assert set(user1.get_ldap_attribute("objectClass")) == {
"inetOrgPerson", "inetOrgPerson",
"extensibleObject", "extensibleObject",
@ -72,7 +72,7 @@ def test_keep_old_object_classes(backend, testclient, slapd_server):
attributes. attributes.
""" """
user = models.User(cn="foo", sn="bar", user_name="baz") user = models.User(cn="foo", sn="bar", user_name="baz")
user.save() backend.save(user)
ldif = f"""dn: {user.dn} ldif = f"""dn: {user.dn}
changetype: modify changetype: modify
@ -95,6 +95,6 @@ homeDirectory: /home/foobar
user.reload() user.reload()
# saving an object should not raise a ldap.OBJECT_CLASS_VIOLATION exception # saving an object should not raise a ldap.OBJECT_CLASS_VIOLATION exception
user.save() backend.save(user)
user.delete() user.delete()

View file

@ -24,7 +24,7 @@ def test_object_creation(app, backend):
emails=["john@doe.com"], emails=["john@doe.com"],
) )
assert not user.exists assert not user.exists
user.save() backend.save(user)
assert user.exists assert user.exists
user = backend.get(models.User, id=user.id) user = backend.get(models.User, id=user.id)
@ -45,7 +45,7 @@ def test_dn_when_leading_space_in_id_attribute(testclient, backend):
user_name=" user", user_name=" user",
emails=["john@doe.com"], emails=["john@doe.com"],
) )
user.save() backend.save(user)
dn = user.dn dn = user.dn
assert dn == "uid=user,ou=users,dc=mydomain,dc=tld" assert dn == "uid=user,ou=users,dc=mydomain,dc=tld"
@ -66,7 +66,7 @@ def test_special_chars_in_rdn(testclient, backend):
user_name="#user", # special char user_name="#user", # special char
emails=["john@doe.com"], emails=["john@doe.com"],
) )
user.save() backend.save(user)
dn = user.dn dn = user.dn
assert ldap.dn.is_dn(dn) assert ldap.dn.is_dn(dn)

View file

@ -12,13 +12,13 @@ def test_model_comparison(testclient, backend):
family_name="foo", family_name="foo",
formatted_name="foo", formatted_name="foo",
) )
foo1.save() backend.save(foo1)
bar = models.User( bar = models.User(
user_name="bar", user_name="bar",
family_name="bar", family_name="bar",
formatted_name="bar", formatted_name="bar",
) )
bar.save() backend.save(bar)
foo2 = backend.get(models.User, id=foo1.id) foo2 = backend.get(models.User, id=foo1.id)
assert foo1 == foo2 assert foo1 == foo2
@ -41,7 +41,7 @@ def test_model_lifecycle(testclient, backend):
assert not backend.query(models.User, id="invalid") assert not backend.query(models.User, id="invalid")
assert not backend.get(models.User, id=user.id) assert not backend.get(models.User, id=user.id)
user.save() backend.save(user)
assert backend.query(models.User) == [user] assert backend.query(models.User) == [user]
assert backend.query(models.User, id=user.id) == [user] assert backend.query(models.User, id=user.id) == [user]
@ -72,7 +72,7 @@ def test_model_attribute_edition(testclient, backend):
display_name="display_name", display_name="display_name",
emails=["email1@user.com", "email2@user.com"], emails=["email1@user.com", "email2@user.com"],
) )
user.save() backend.save(user)
assert user.user_name == "user_name" assert user.user_name == "user_name"
assert user.family_name == "family_name" assert user.family_name == "family_name"
@ -85,7 +85,7 @@ def test_model_attribute_edition(testclient, backend):
user.family_name = "new_family_name" user.family_name = "new_family_name"
user.emails = ["email1@user.com"] user.emails = ["email1@user.com"]
user.save() backend.save(user)
assert user.family_name == "new_family_name" assert user.family_name == "new_family_name"
assert user.emails == ["email1@user.com"] assert user.emails == ["email1@user.com"]
@ -97,7 +97,7 @@ def test_model_attribute_edition(testclient, backend):
user.display_name = "" user.display_name = ""
assert not user.display_name assert not user.display_name
user.save() backend.save(user)
assert not user.display_name assert not user.display_name
user.delete() user.delete()
@ -110,7 +110,7 @@ def test_model_indexation(testclient, backend):
formatted_name="formatted_name", formatted_name="formatted_name",
emails=["email1@user.com", "email2@user.com"], emails=["email1@user.com", "email2@user.com"],
) )
user.save() backend.save(user)
assert backend.get(models.User, family_name="family_name") == user assert backend.get(models.User, family_name="family_name") == user
assert not backend.get(models.User, family_name="new_family_name") assert not backend.get(models.User, family_name="new_family_name")
@ -125,7 +125,7 @@ def test_model_indexation(testclient, backend):
assert backend.get(models.User, emails=["email1@user.com"]) != user assert backend.get(models.User, emails=["email1@user.com"]) != user
assert not backend.get(models.User, emails=["email3@user.com"]) assert not backend.get(models.User, emails=["email3@user.com"])
user.save() backend.save(user)
assert not backend.get(models.User, family_name="family_name") assert not backend.get(models.User, family_name="family_name")
assert backend.get(models.User, family_name="new_family_name") == user assert backend.get(models.User, family_name="new_family_name") == user
@ -177,14 +177,14 @@ def test_model_references(testclient, user, foo_group, admin, bar_group, backend
assert user not in bar_group.members assert user not in bar_group.members
assert bar_group not in user.groups assert bar_group not in user.groups
user.groups = user.groups + [bar_group] user.groups = user.groups + [bar_group]
user.save() backend.save(user)
bar_group.reload() bar_group.reload()
assert user in bar_group.members assert user in bar_group.members
assert bar_group in user.groups assert bar_group in user.groups
bar_group.members = [admin] bar_group.members = [admin]
bar_group.save() backend.save(bar_group)
user.reload() user.reload()
assert user not in bar_group.members assert user not in bar_group.members
@ -201,7 +201,7 @@ def test_model_creation_edition_datetime(testclient, backend):
family_name="foo", family_name="foo",
formatted_name="foo", formatted_name="foo",
) )
user.save() backend.save(user)
assert user.created == datetime.datetime( assert user.created == datetime.datetime(
2020, 1, 1, 2, tzinfo=datetime.timezone.utc 2020, 1, 1, 2, tzinfo=datetime.timezone.utc
) )
@ -211,7 +211,7 @@ def test_model_creation_edition_datetime(testclient, backend):
with time_machine.travel("2021-01-01 02:00:00+00:00", tick=False): with time_machine.travel("2021-01-01 02:00:00+00:00", tick=False):
user.family_name = "bar" user.family_name = "bar"
user.save() backend.save(user)
assert user.created == datetime.datetime( assert user.created == datetime.datetime(
2020, 1, 1, 2, tzinfo=datetime.timezone.utc 2020, 1, 1, 2, tzinfo=datetime.timezone.utc
) )

View file

@ -191,7 +191,7 @@ def user(app, backend):
profile_url="https://john.example", profile_url="https://john.example",
formatted_address="1235, somewhere", formatted_address="1235, somewhere",
) )
u.save() backend.save(u)
yield u yield u
u.delete() u.delete()
@ -205,7 +205,7 @@ def admin(app, backend):
emails=["jane@doe.com"], emails=["jane@doe.com"],
password="admin", password="admin",
) )
u.save() backend.save(u)
yield u yield u
u.delete() u.delete()
@ -219,7 +219,7 @@ def moderator(app, backend):
emails=["jack@doe.com"], emails=["jack@doe.com"],
password="moderator", password="moderator",
) )
u.save() backend.save(u)
yield u yield u
u.delete() u.delete()
@ -251,7 +251,7 @@ def foo_group(app, user, backend):
members=[user], members=[user],
display_name="foo", display_name="foo",
) )
group.save() backend.save(group)
user.reload() user.reload()
yield group yield group
group.delete() group.delete()
@ -263,7 +263,7 @@ def bar_group(app, admin, backend):
members=[admin], members=[admin],
display_name="bar", display_name="bar",
) )
group.save() backend.save(group)
admin.reload() admin.reload()
yield group yield group
group.delete() group.delete()

View file

@ -27,7 +27,7 @@ def test_user_deleted_in_session(testclient, backend):
emails=["jake@doe.com"], emails=["jake@doe.com"],
password="correct horse battery staple", password="correct horse battery staple",
) )
u.save() backend.save(u)
testclient.get("/profile/jake", status=403) testclient.get("/profile/jake", status=403)
with testclient.session_transaction() as session: with testclient.session_transaction() as session:
@ -66,7 +66,7 @@ def test_admin_self_deletion(testclient, backend):
emails=["temp@temp.com"], emails=["temp@temp.com"],
password="admin", password="admin",
) )
admin.save() backend.save(admin)
with testclient.session_transaction() as sess: with testclient.session_transaction() as sess:
sess["user_id"] = [admin.id] sess["user_id"] = [admin.id]
@ -92,7 +92,7 @@ def test_user_self_deletion(testclient, backend):
emails=["temp@temp.com"], emails=["temp@temp.com"],
password="correct horse battery staple", password="correct horse battery staple",
) )
user.save() backend.save(user)
with testclient.session_transaction() as sess: with testclient.session_transaction() as sess:
sess["user_id"] = [user.id] sess["user_id"] = [user.id]
@ -134,7 +134,7 @@ def test_account_locking(user, backend):
user.lock_date = datetime.datetime.now(datetime.timezone.utc) user.lock_date = datetime.datetime.now(datetime.timezone.utc)
assert user.locked assert user.locked
user.save() backend.save(user)
assert user.locked assert user.locked
assert backend.get(models.User, id=user.id).locked assert backend.get(models.User, id=user.id).locked
assert backend.check_user_password(user, "correct horse battery staple") == ( assert backend.check_user_password(user, "correct horse battery staple") == (
@ -143,7 +143,7 @@ def test_account_locking(user, backend):
) )
user.lock_date = None user.lock_date = None
user.save() backend.save(user)
assert not user.locked assert not user.locked
assert not backend.get(models.User, id=user.id).locked assert not backend.get(models.User, id=user.id).locked
assert backend.check_user_password(user, "correct horse battery staple") == ( assert backend.check_user_password(user, "correct horse battery staple") == (
@ -163,7 +163,7 @@ def test_account_locking_past_date(user, backend):
user.lock_date = datetime.datetime.now(datetime.timezone.utc).replace( user.lock_date = datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0 microsecond=0
) - datetime.timedelta(days=30) ) - datetime.timedelta(days=30)
user.save() backend.save(user)
assert user.locked assert user.locked
assert backend.get(models.User, id=user.id).locked assert backend.get(models.User, id=user.id).locked
assert backend.check_user_password(user, "correct horse battery staple") == ( assert backend.check_user_password(user, "correct horse battery staple") == (
@ -183,7 +183,7 @@ def test_account_locking_future_date(user, backend):
user.lock_date = datetime.datetime.now(datetime.timezone.utc).replace( user.lock_date = datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0 microsecond=0
) + datetime.timedelta(days=365 * 4) ) + datetime.timedelta(days=365 * 4)
user.save() backend.save(user)
assert not user.locked assert not user.locked
assert not backend.get(models.User, id=user.id).locked assert not backend.get(models.User, id=user.id).locked
assert backend.check_user_password(user, "correct horse battery staple") == ( assert backend.check_user_password(user, "correct horse battery staple") == (
@ -192,7 +192,7 @@ def test_account_locking_future_date(user, backend):
) )
def test_account_locked_during_session(testclient, logged_user): def test_account_locked_during_session(testclient, logged_user, backend):
logged_user.lock_date = datetime.datetime.now(datetime.timezone.utc) logged_user.lock_date = datetime.datetime.now(datetime.timezone.utc)
logged_user.save() backend.save(logged_user)
testclient.get("/profile/user/settings", status=403) testclient.get("/profile/user/settings", status=403)

View file

@ -154,7 +154,7 @@ def test_user_without_password_first_login(testclient, backend, smtpd):
user_name="temp", user_name="temp",
emails=["john@doe.com", "johhny@doe.com"], emails=["john@doe.com", "johhny@doe.com"],
) )
u.save() backend.save(u)
res = testclient.get("/login", status=200) res = testclient.get("/login", status=200)
res.form["login"] = "temp" res.form["login"] = "temp"
@ -189,7 +189,7 @@ def test_first_login_account_initialization_mail_sending_failed(
user_name="temp", user_name="temp",
emails=["john@doe.com"], emails=["john@doe.com"],
) )
u.save() backend.save(u)
res = testclient.get("/firstlogin/temp") res = testclient.get("/firstlogin/temp")
res = res.form.submit(name="action", value="sendmail", expect_errors=True) res = res.form.submit(name="action", value="sendmail", expect_errors=True)
@ -211,7 +211,7 @@ def test_first_login_form_error(testclient, backend, smtpd):
user_name="temp", user_name="temp",
emails=["john@doe.com"], emails=["john@doe.com"],
) )
u.save() backend.save(u)
res = testclient.get("/firstlogin/temp", status=200) res = testclient.get("/firstlogin/temp", status=200)
res.form["csrf_token"] = "invalid" res.form["csrf_token"] = "invalid"
@ -236,7 +236,7 @@ def test_user_password_deleted_during_login(testclient, backend):
emails=["john@doe.com"], emails=["john@doe.com"],
password="correct horse battery staple", password="correct horse battery staple",
) )
u.save() backend.save(u)
res = testclient.get("/login") res = testclient.get("/login")
res.form["login"] = "temp" res.form["login"] = "temp"
@ -244,7 +244,7 @@ def test_user_password_deleted_during_login(testclient, backend):
res.form["password"] = "correct horse battery staple" res.form["password"] = "correct horse battery staple"
u.password = None u.password = None
u.save() backend.save(u)
res = res.form.submit(status=302) res = res.form.submit(status=302)
assert res.location == "/firstlogin/temp" assert res.location == "/firstlogin/temp"
@ -272,12 +272,12 @@ def test_wrong_login(testclient, user):
res.mustcontain("The login 'invalid' does not exist") res.mustcontain("The login 'invalid' does not exist")
def test_signin_locked_account(testclient, user): def test_signin_locked_account(testclient, user, backend):
with testclient.session_transaction() as session: with testclient.session_transaction() as session:
assert not session.get("user_id") assert not session.get("user_id")
user.lock_date = datetime.datetime.now(datetime.timezone.utc) user.lock_date = datetime.datetime.now(datetime.timezone.utc)
user.save() backend.save(user)
res = testclient.get("/login", status=200) res = testclient.get("/login", status=200)
res.form["login"] = "user" res.form["login"] = "user"
@ -289,4 +289,4 @@ def test_signin_locked_account(testclient, user):
res.mustcontain("Your account has been locked.") res.mustcontain("Your account has been locked.")
user.lock_date = None user.lock_date = None
user.save() backend.save(user)

View file

@ -371,14 +371,14 @@ def test_confirmation_email_already_used_link(testclient, backend, user, admin):
assert "new_email@mydomain.tld" not in user.emails assert "new_email@mydomain.tld" not in user.emails
def test_delete_email(testclient, logged_user): def test_delete_email(testclient, logged_user, backend):
"""Tests that user can deletes its emails unless they have only one """Tests that user can deletes its emails unless they have only one
left.""" left."""
res = testclient.get("/profile/user") res = testclient.get("/profile/user")
assert "email_remove" not in res.forms["emailconfirmationform"].fields assert "email_remove" not in res.forms["emailconfirmationform"].fields
logged_user.emails = logged_user.emails + ["new@email.com"] logged_user.emails = logged_user.emails + ["new@email.com"]
logged_user.save() backend.save(logged_user)
res = testclient.get("/profile/user") res = testclient.get("/profile/user")
assert "email_remove" in res.forms["emailconfirmationform"].fields assert "email_remove" in res.forms["emailconfirmationform"].fields
@ -391,10 +391,10 @@ def test_delete_email(testclient, logged_user):
assert logged_user.emails == ["john@doe.com"] assert logged_user.emails == ["john@doe.com"]
def test_delete_wrong_email(testclient, logged_user): def test_delete_wrong_email(testclient, logged_user, backend):
"""Tests that removing an already removed email do not produce anything.""" """Tests that removing an already removed email do not produce anything."""
logged_user.emails = logged_user.emails + ["new@email.com"] logged_user.emails = logged_user.emails + ["new@email.com"]
logged_user.save() backend.save(logged_user)
res = testclient.get("/profile/user") res = testclient.get("/profile/user")
@ -412,10 +412,10 @@ def test_delete_wrong_email(testclient, logged_user):
assert logged_user.emails == ["john@doe.com"] assert logged_user.emails == ["john@doe.com"]
def test_delete_last_email(testclient, logged_user): def test_delete_last_email(testclient, logged_user, backend):
"""Tests that users cannot remove their last email address.""" """Tests that users cannot remove their last email address."""
logged_user.emails = logged_user.emails + ["new@email.com"] logged_user.emails = logged_user.emails + ["new@email.com"]
logged_user.save() backend.save(logged_user)
res = testclient.get("/profile/user") res = testclient.get("/profile/user")

View file

@ -26,9 +26,9 @@ def test_password_forgotten(smtpd, testclient, user):
assert len(smtpd.messages) == 1 assert len(smtpd.messages) == 1
def test_password_forgotten_multiple_mails(smtpd, testclient, user): def test_password_forgotten_multiple_mails(smtpd, testclient, user, backend):
user.emails = ["foo@bar.com", "foo@baz.com", "foo@foo.com"] user.emails = ["foo@bar.com", "foo@baz.com", "foo@foo.com"]
user.save() backend.save(user)
res = testclient.get("/reset", status=200) res = testclient.get("/reset", status=200)

View file

@ -53,13 +53,13 @@ def test_group_deletion(testclient, backend):
user_name="foobar", user_name="foobar",
emails=["foo@bar.com"], emails=["foo@bar.com"],
) )
user.save() backend.save(user)
group = models.Group( group = models.Group(
members=[user], members=[user],
display_name="foobar", display_name="foobar",
) )
group.save() backend.save(group)
user.reload() user.reload()
assert user.groups == [group] assert user.groups == [group]
@ -86,19 +86,19 @@ def test_group_list_search(testclient, logged_admin, foo_group, bar_group):
res.mustcontain(no=bar_group.display_name) res.mustcontain(no=bar_group.display_name)
def test_set_groups(app, user, foo_group, bar_group): def test_set_groups(app, user, foo_group, bar_group, backend):
assert user in foo_group.members assert user in foo_group.members
assert user.groups == [foo_group] assert user.groups == [foo_group]
user.groups = [foo_group, bar_group] user.groups = [foo_group, bar_group]
user.save() backend.save(user)
bar_group.reload() bar_group.reload()
assert user in bar_group.members assert user in bar_group.members
assert bar_group in user.groups assert bar_group in user.groups
user.groups = [foo_group] user.groups = [foo_group]
user.save() backend.save(user)
foo_group.reload() foo_group.reload()
bar_group.reload() bar_group.reload()
@ -106,23 +106,23 @@ def test_set_groups(app, user, foo_group, bar_group):
assert user not in bar_group.members assert user not in bar_group.members
def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group): def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group, backend):
user = models.User( user = models.User(
formatted_name=" Doe", # leading space in id attribute formatted_name=" Doe", # leading space in id attribute
family_name="Doe", family_name="Doe",
user_name="user2", user_name="user2",
emails=["john@doe.com"], emails=["john@doe.com"],
) )
user.save() backend.save(user)
user.groups = [foo_group] user.groups = [foo_group]
user.save() backend.save(user)
foo_group.reload() foo_group.reload()
assert user in foo_group.members assert user in foo_group.members
user.groups = [] user.groups = []
user.save() backend.save(user)
foo_group.reload() foo_group.reload()
assert user.id not in foo_group.members assert user.id not in foo_group.members
@ -231,14 +231,14 @@ def test_edition_failed(testclient, logged_moderator, foo_group):
assert foo_group.display_name == "foo" assert foo_group.display_name == "foo"
def test_user_list_pagination(testclient, logged_admin, foo_group): def test_user_list_pagination(testclient, logged_admin, foo_group, backend):
res = testclient.get("/groups/foo") res = testclient.get("/groups/foo")
res.mustcontain("1 item") res.mustcontain("1 item")
users = fake_users(25) users = fake_users(25)
for user in users: for user in users:
foo_group.members = foo_group.members + [user] foo_group.members = foo_group.members + [user]
foo_group.save() backend.save(foo_group)
assert len(foo_group.members) == 26 assert len(foo_group.members) == 26
res = testclient.get("/groups/foo") res = testclient.get("/groups/foo")
@ -274,9 +274,11 @@ def test_user_list_bad_pages(testclient, logged_admin, foo_group):
) )
def test_user_list_search(testclient, logged_admin, foo_group, user, moderator): def test_user_list_search(
testclient, logged_admin, foo_group, user, moderator, backend
):
foo_group.members = foo_group.members + [logged_admin, moderator] foo_group.members = foo_group.members + [logged_admin, moderator]
foo_group.save() backend.save(foo_group)
res = testclient.get("/groups/foo") res = testclient.get("/groups/foo")
res.mustcontain("3 items") res.mustcontain("3 items")
@ -294,9 +296,9 @@ def test_user_list_search(testclient, logged_admin, foo_group, user, moderator):
res.mustcontain(no=moderator.formatted_name) res.mustcontain(no=moderator.formatted_name)
def test_remove_member(testclient, logged_admin, foo_group, user, moderator): def test_remove_member(testclient, logged_admin, foo_group, user, moderator, backend):
foo_group.members = [user, moderator] foo_group.members = [user, moderator]
foo_group.save() backend.save(foo_group)
res = testclient.get("/groups/foo") res = testclient.get("/groups/foo")
form = res.forms[f"deletegroupmemberform-{user.id}"] form = res.forms[f"deletegroupmemberform-{user.id}"]
@ -313,15 +315,15 @@ def test_remove_member(testclient, logged_admin, foo_group, user, moderator):
def test_remove_member_already_remove_from_group( def test_remove_member_already_remove_from_group(
testclient, logged_admin, foo_group, user, moderator testclient, logged_admin, foo_group, user, moderator, backend
): ):
foo_group.members = [user, moderator] foo_group.members = [user, moderator]
foo_group.save() backend.save(foo_group)
res = testclient.get("/groups/foo") res = testclient.get("/groups/foo")
form = res.forms[f"deletegroupmemberform-{user.id}"] form = res.forms[f"deletegroupmemberform-{user.id}"]
foo_group.members = [moderator] foo_group.members = [moderator]
foo_group.save() backend.save(foo_group)
res = form.submit(name="action", value="confirm-remove-member") res = form.submit(name="action", value="confirm-remove-member")
assert ( assert (
@ -331,17 +333,17 @@ def test_remove_member_already_remove_from_group(
def test_confirm_remove_member_already_removed_from_group( def test_confirm_remove_member_already_removed_from_group(
testclient, logged_admin, foo_group, user, moderator testclient, logged_admin, foo_group, user, moderator, backend
): ):
foo_group.members = [user, moderator] foo_group.members = [user, moderator]
foo_group.save() backend.save(foo_group)
res = testclient.get("/groups/foo") res = testclient.get("/groups/foo")
form = res.forms[f"deletegroupmemberform-{user.id}"] form = res.forms[f"deletegroupmemberform-{user.id}"]
res = form.submit(name="action", value="confirm-remove-member") res = form.submit(name="action", value="confirm-remove-member")
foo_group.members = [moderator] foo_group.members = [moderator]
foo_group.save() backend.save(foo_group)
res = res.form.submit(name="action", value="remove-member") res = res.form.submit(name="action", value="remove-member")
assert ( assert (
@ -351,10 +353,10 @@ def test_confirm_remove_member_already_removed_from_group(
def test_remove_member_already_deleted( def test_remove_member_already_deleted(
testclient, logged_admin, foo_group, user, moderator testclient, logged_admin, foo_group, user, moderator, backend
): ):
foo_group.members = [user, moderator] foo_group.members = [user, moderator]
foo_group.save() backend.save(foo_group)
res = testclient.get("/groups/foo") res = testclient.get("/groups/foo")
form = res.forms[f"deletegroupmemberform-{user.id}"] form = res.forms[f"deletegroupmemberform-{user.id}"]
@ -368,10 +370,10 @@ def test_remove_member_already_deleted(
def test_confirm_remove_member_already_deleted( def test_confirm_remove_member_already_deleted(
testclient, logged_admin, foo_group, user, moderator testclient, logged_admin, foo_group, user, moderator, backend
): ):
foo_group.members = [user, moderator] foo_group.members = [user, moderator]
foo_group.save() backend.save(foo_group)
res = testclient.get("/groups/foo") res = testclient.get("/groups/foo")
form = res.forms[f"deletegroupmemberform-{user.id}"] form = res.forms[f"deletegroupmemberform-{user.id}"]

View file

@ -13,13 +13,13 @@ def test_user_has_password(testclient, backend):
user_name="temp", user_name="temp",
emails=["john@doe.com"], emails=["john@doe.com"],
) )
user.save() backend.save(user)
assert user.password is None assert user.password is None
assert not user.has_password() assert not user.has_password()
user.password = "foobar" user.password = "foobar"
user.save() backend.save(user)
assert user.password is not None assert user.password is not None
assert user.has_password() assert user.has_password()

View file

@ -25,7 +25,7 @@ def test_password_reset(testclient, user, backend):
def test_password_reset_multiple_emails(testclient, user, backend): def test_password_reset_multiple_emails(testclient, user, backend):
user.emails = ["foo@bar.com", "foo@baz.com"] user.emails = ["foo@bar.com", "foo@baz.com"]
user.save() backend.save(user)
assert not backend.check_user_password(user, "foobarbaz")[0] assert not backend.check_user_password(user, "foobarbaz")[0]
hash = build_hash("user", "foo@baz.com", user.password) hash = build_hash("user", "foo@baz.com", user.password)

View file

@ -118,6 +118,7 @@ def test_edition(
logged_user, logged_user,
admin, admin,
jpeg_photo, jpeg_photo,
backend,
): ):
res = testclient.get("/profile/user", status=200) res = testclient.get("/profile/user", status=200)
form = res.forms["baseform"] form = res.forms["baseform"]
@ -168,13 +169,14 @@ def test_edition(
logged_user.emails = ["john@doe.com"] logged_user.emails = ["john@doe.com"]
logged_user.given_name = None logged_user.given_name = None
logged_user.photo = None logged_user.photo = None
logged_user.save() backend.save(logged_user)
def test_edition_remove_fields( def test_edition_remove_fields(
testclient, testclient,
logged_user, logged_user,
admin, admin,
backend,
): ):
res = testclient.get("/profile/user", status=200) res = testclient.get("/profile/user", status=200)
form = res.forms["baseform"] form = res.forms["baseform"]
@ -195,13 +197,13 @@ def test_edition_remove_fields(
logged_user.emails = ["john@doe.com"] logged_user.emails = ["john@doe.com"]
logged_user.given_name = None logged_user.given_name = None
logged_user.photo = None logged_user.photo = None
logged_user.save() backend.save(logged_user)
def test_field_permissions_none(testclient, logged_user): def test_field_permissions_none(testclient, logged_user, backend):
testclient.get("/profile/user", status=200) testclient.get("/profile/user", status=200)
logged_user.phone_numbers = ["555-666-777"] logged_user.phone_numbers = ["555-666-777"]
logged_user.save() backend.save(logged_user)
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"] = { testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"] = {
"READ": ["user_name"], "READ": ["user_name"],
@ -227,10 +229,10 @@ def test_field_permissions_none(testclient, logged_user):
assert logged_user.phone_numbers == ["555-666-777"] assert logged_user.phone_numbers == ["555-666-777"]
def test_field_permissions_read(testclient, logged_user): def test_field_permissions_read(testclient, logged_user, backend):
testclient.get("/profile/user", status=200) testclient.get("/profile/user", status=200)
logged_user.phone_numbers = ["555-666-777"] logged_user.phone_numbers = ["555-666-777"]
logged_user.save() backend.save(logged_user)
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"] = { testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"] = {
"READ": ["user_name", "phone_numbers"], "READ": ["user_name", "phone_numbers"],
@ -255,10 +257,10 @@ def test_field_permissions_read(testclient, logged_user):
assert logged_user.phone_numbers == ["555-666-777"] assert logged_user.phone_numbers == ["555-666-777"]
def test_field_permissions_write(testclient, logged_user): def test_field_permissions_write(testclient, logged_user, backend):
testclient.get("/profile/user", status=200) testclient.get("/profile/user", status=200)
logged_user.phone_numbers = ["555-666-777"] logged_user.phone_numbers = ["555-666-777"]
logged_user.save() backend.save(logged_user)
testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"] = { testclient.app.config["CANAILLE"]["ACL"]["DEFAULT"] = {
"READ": ["user_name"], "READ": ["user_name"],

View file

@ -5,9 +5,9 @@ from webtest import Upload
from canaille.app import models from canaille.app import models
def test_photo(testclient, user, jpeg_photo): def test_photo(testclient, user, jpeg_photo, backend):
user.photo = jpeg_photo user.photo = jpeg_photo
user.save() backend.save(user)
user.reload() user.reload()
res = testclient.get("/profile/user/photo") res = testclient.get("/profile/user/photo")

View file

@ -34,13 +34,13 @@ def test_edition(testclient, logged_user, admin, foo_group, bar_group, backend):
assert backend.check_user_password(logged_user, "correct horse battery staple")[0] assert backend.check_user_password(logged_user, "correct horse battery staple")[0]
logged_user.user_name = "user" logged_user.user_name = "user"
logged_user.save() backend.save(logged_user)
def test_group_removal(testclient, logged_admin, user, foo_group, backend): def test_group_removal(testclient, logged_admin, user, foo_group, backend):
"""Tests that one can remove a group from a user.""" """Tests that one can remove a group from a user."""
foo_group.members = [user, logged_admin] foo_group.members = [user, logged_admin]
foo_group.save() backend.save(foo_group)
user.reload() user.reload()
assert foo_group in user.groups assert foo_group in user.groups
@ -115,7 +115,7 @@ def test_edition_without_groups(
assert backend.check_user_password(logged_user, "correct horse battery staple")[0] assert backend.check_user_password(logged_user, "correct horse battery staple")[0]
logged_user.user_name = "user" logged_user.user_name = "user"
logged_user.save() backend.save(logged_user)
def test_password_change(testclient, logged_user, backend): def test_password_change(testclient, logged_user, backend):
@ -171,7 +171,7 @@ def test_password_initialization_mail(smtpd, testclient, backend, logged_admin):
user_name="temp", user_name="temp",
emails=["john@doe.com"], emails=["john@doe.com"],
) )
u.save() backend.save(u)
res = testclient.get("/profile/temp/settings", status=200) res = testclient.get("/profile/temp/settings", status=200)
res.mustcontain("This user does not have a password yet") res.mustcontain("This user does not have a password yet")
@ -188,7 +188,7 @@ def test_password_initialization_mail(smtpd, testclient, backend, logged_admin):
u.reload() u.reload()
u.password = "correct horse battery staple" u.password = "correct horse battery staple"
u.save() backend.save(u)
res = testclient.get("/profile/temp/settings", status=200) res = testclient.get("/profile/temp/settings", status=200)
res.mustcontain(no="This user does not have a password yet") res.mustcontain(no="This user does not have a password yet")
@ -207,7 +207,7 @@ def test_password_initialization_mail_send_fail(
user_name="temp", user_name="temp",
emails=["john@doe.com"], emails=["john@doe.com"],
) )
u.save() backend.save(u)
res = testclient.get("/profile/temp/settings", status=200) res = testclient.get("/profile/temp/settings", status=200)
res.mustcontain("This user does not have a password yet") res.mustcontain("This user does not have a password yet")
@ -272,7 +272,7 @@ def test_impersonate_locked_user(testclient, backend, logged_admin, user):
user.lock_date = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta( user.lock_date = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
days=1 days=1
) )
user.save() backend.save(user)
assert user.locked assert user.locked
res = testclient.get("/profile/user/settings") res = testclient.get("/profile/user/settings")
@ -295,7 +295,7 @@ def test_password_reset_email(smtpd, testclient, backend, logged_admin):
emails=["john@doe.com"], emails=["john@doe.com"],
password="correct horse battery staple", password="correct horse battery staple",
) )
u.save() backend.save(u)
res = testclient.get("/profile/temp/settings", status=200) res = testclient.get("/profile/temp/settings", status=200)
res.mustcontain("If the user has forgotten his password") res.mustcontain("If the user has forgotten his password")
@ -323,7 +323,7 @@ def test_password_reset_email_failed(SMTP, smtpd, testclient, backend, logged_ad
emails=["john@doe.com"], emails=["john@doe.com"],
password="correct horse battery staple", password="correct horse battery staple",
) )
u.save() backend.save(u)
res = testclient.get("/profile/temp/settings", status=200) res = testclient.get("/profile/temp/settings", status=200)
res.mustcontain("If the user has forgotten his password") res.mustcontain("If the user has forgotten his password")
@ -454,7 +454,7 @@ def test_empty_lock_date(
second=0, microsecond=0 second=0, microsecond=0
) + datetime.timedelta(days=30) ) + datetime.timedelta(days=30)
user.lock_date = expiration_datetime user.lock_date = expiration_datetime
user.save() backend.save(user)
res = testclient.get("/profile/user/settings", status=200) res = testclient.get("/profile/user/settings", status=200)
res.form["lock_date"] = "" res.form["lock_date"] = ""

View file

@ -21,7 +21,7 @@ def test_clean_command(testclient, backend, client, user):
challenge="challenge", challenge="challenge",
challenge_method="method", challenge_method="method",
) )
valid_code.save() backend.save(valid_code)
expired_code = models.AuthorizationCode( expired_code = models.AuthorizationCode(
authorization_code_id=gen_salt(48), authorization_code_id=gen_salt(48),
code="my-expired-code", code="my-expired-code",
@ -39,7 +39,7 @@ def test_clean_command(testclient, backend, client, user):
challenge="challenge", challenge="challenge",
challenge_method="method", challenge_method="method",
) )
expired_code.save() backend.save(expired_code)
valid_token = models.Token( valid_token = models.Token(
token_id=gen_salt(48), token_id=gen_salt(48),
@ -53,7 +53,7 @@ def test_clean_command(testclient, backend, client, user):
), ),
lifetime=3600, lifetime=3600,
) )
valid_token.save() backend.save(valid_token)
expired_token = models.Token( expired_token = models.Token(
token_id=gen_salt(48), token_id=gen_salt(48),
access_token="my-expired-token", access_token="my-expired-token",
@ -67,7 +67,7 @@ def test_clean_command(testclient, backend, client, user):
), ),
lifetime=3600, lifetime=3600,
) )
expired_token.save() backend.save(expired_token)
assert backend.get(models.AuthorizationCode, code="my-expired-code") assert backend.get(models.AuthorizationCode, code="my-expired-code")
assert backend.get(models.Token, access_token="my-expired-token") assert backend.get(models.Token, access_token="my-expired-token")

View file

@ -67,9 +67,9 @@ def client(testclient, trusted_client, backend):
token_endpoint_auth_method="client_secret_basic", token_endpoint_auth_method="client_secret_basic",
post_logout_redirect_uris=["https://mydomain.tld/disconnected"], post_logout_redirect_uris=["https://mydomain.tld/disconnected"],
) )
c.save() backend.save(c)
c.audience = [c, trusted_client] c.audience = [c, trusted_client]
c.save() backend.save(c)
yield c yield c
c.delete() c.delete()
@ -105,9 +105,9 @@ def trusted_client(testclient, backend):
post_logout_redirect_uris=["https://myotherdomain.tld/disconnected"], post_logout_redirect_uris=["https://myotherdomain.tld/disconnected"],
preconsent=True, preconsent=True,
) )
c.save() backend.save(c)
c.audience = [c] c.audience = [c]
c.save() backend.save(c)
yield c yield c
c.delete() c.delete()
@ -129,7 +129,7 @@ def authorization(testclient, user, client, backend):
challenge="challenge", challenge="challenge",
challenge_method="method", challenge_method="method",
) )
a.save() backend.save(a)
yield a yield a
a.delete() a.delete()
@ -147,7 +147,7 @@ def token(testclient, client, user, backend):
issue_date=datetime.datetime.now(datetime.timezone.utc), issue_date=datetime.datetime.now(datetime.timezone.utc),
lifetime=3600, lifetime=3600,
) )
t.save() backend.save(t)
yield t yield t
t.delete() t.delete()
@ -171,6 +171,6 @@ def consent(testclient, client, user, backend):
scope=["openid", "profile"], scope=["openid", "profile"],
issue_date=datetime.datetime.now(datetime.timezone.utc), issue_date=datetime.datetime.now(datetime.timezone.utc),
) )
t.save() backend.save(t)
yield t yield t
t.delete() t.delete()

View file

@ -167,7 +167,7 @@ def test_preconsented_client(
assert not backend.query(models.Consent) assert not backend.query(models.Consent)
client.preconsent = True client.preconsent = True
client.save() backend.save(client)
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
@ -318,7 +318,7 @@ def test_code_challenge(testclient, logged_user, client, backend):
assert not backend.query(models.Consent) assert not backend.query(models.Consent)
client.token_endpoint_auth_method = "none" client.token_endpoint_auth_method = "none"
client.save() backend.save(client)
code_verifier = gen_salt(48) code_verifier = gen_salt(48)
code_challenge = create_s256_code_challenge(code_verifier) code_challenge = create_s256_code_challenge(code_verifier)
@ -373,7 +373,7 @@ def test_code_challenge(testclient, logged_user, client, backend):
assert res.json["name"] == "John (johnny) Doe" assert res.json["name"] == "John (johnny) Doe"
client.token_endpoint_auth_method = "client_secret_basic" client.token_endpoint_auth_method = "client_secret_basic"
client.save() backend.save(client)
for consent in consents: for consent in consents:
consent.delete() consent.delete()
@ -565,7 +565,7 @@ def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client, b
def test_request_scope_too_large(testclient, logged_user, keypair, client, backend): def test_request_scope_too_large(testclient, logged_user, keypair, client, backend):
assert not backend.query(models.Consent) assert not backend.query(models.Consent)
client.scope = ["openid", "profile", "groups"] client.scope = ["openid", "profile", "groups"]
client.save() backend.save(client)
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
@ -679,7 +679,7 @@ def test_code_with_invalid_user(testclient, admin, client, backend):
emails=["temp@temp.com"], emails=["temp@temp.com"],
password="correct horse battery staple", password="correct horse battery staple",
) )
user.save() backend.save(user)
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
@ -738,7 +738,7 @@ def test_locked_account(
) )
logged_user.lock_date = datetime.datetime.now(datetime.timezone.utc) logged_user.lock_date = datetime.datetime.now(datetime.timezone.utc)
logged_user.save() backend.save(logged_user)
res = res.form.submit(name="answer", value="accept", status=302) res = res.form.submit(name="answer", value="accept", status=302)

View file

@ -14,7 +14,7 @@ from canaille.app import models
from canaille.core.endpoints.account import RegistrationPayload from canaille.core.endpoints.account import RegistrationPayload
def test_prompt_none(testclient, logged_user, client): def test_prompt_none(testclient, logged_user, client, backend):
"""Nominal case with prompt=none.""" """Nominal case with prompt=none."""
consent = models.Consent( consent = models.Consent(
consent_id=str(uuid.uuid4()), consent_id=str(uuid.uuid4()),
@ -22,7 +22,7 @@ def test_prompt_none(testclient, logged_user, client):
subject=logged_user, subject=logged_user,
scope=["openid", "profile"], scope=["openid", "profile"],
) )
consent.save() backend.save(consent)
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
@ -42,7 +42,7 @@ def test_prompt_none(testclient, logged_user, client):
consent.delete() consent.delete()
def test_prompt_not_logged(testclient, user, client): def test_prompt_not_logged(testclient, user, client, backend):
"""Prompt=none should return a login_required error when no user is logged """Prompt=none should return a login_required error when no user is logged
in. in.
@ -58,7 +58,7 @@ def test_prompt_not_logged(testclient, user, client):
subject=user, subject=user,
scope=["openid", "profile"], scope=["openid", "profile"],
) )
consent.save() backend.save(consent)
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
@ -100,7 +100,7 @@ def test_prompt_no_consent(testclient, logged_user, client):
assert "consent_required" == res.json.get("error") assert "consent_required" == res.json.get("error")
def test_prompt_create_logged(testclient, logged_user, client): def test_prompt_create_logged(testclient, logged_user, client, backend):
"""If prompt=create and user is already logged in, then go straight to the """If prompt=create and user is already logged in, then go straight to the
consent page.""" consent page."""
testclient.app.config["CANAILLE"]["ENABLE_REGISTRATION"] = True testclient.app.config["CANAILLE"]["ENABLE_REGISTRATION"] = True
@ -111,7 +111,7 @@ def test_prompt_create_logged(testclient, logged_user, client):
subject=logged_user, subject=logged_user,
scope=["openid", "profile"], scope=["openid", "profile"],
) )
consent.save() backend.save(consent)
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",

View file

@ -22,13 +22,15 @@ def test_client_list(testclient, client, logged_admin):
res.mustcontain(client.client_name) res.mustcontain(client.client_name)
def test_client_list_pagination(testclient, logged_admin, client, trusted_client): def test_client_list_pagination(
testclient, logged_admin, client, trusted_client, backend
):
res = testclient.get("/admin/client") res = testclient.get("/admin/client")
res.mustcontain("2 items") res.mustcontain("2 items")
clients = [] clients = []
for _ in range(25): for _ in range(25):
client = models.Client(client_id=gen_salt(48), client_name=gen_salt(48)) client = models.Client(client_id=gen_salt(48), client_name=gen_salt(48))
client.save() backend.save(client)
clients.append(client) clients.append(client)
res = testclient.get("/admin/client") res = testclient.get("/admin/client")
@ -216,22 +218,22 @@ def test_client_edit_missing_fields(testclient, client, logged_admin, trusted_cl
def test_client_delete(testclient, logged_admin, backend): def test_client_delete(testclient, logged_admin, backend):
client = models.Client(client_id="client_id") client = models.Client(client_id="client_id")
client.save() backend.save(client)
token = models.Token( token = models.Token(
token_id="id", token_id="id",
client=client, client=client,
subject=logged_admin, subject=logged_admin,
issue_date=datetime.datetime.now(datetime.timezone.utc), issue_date=datetime.datetime.now(datetime.timezone.utc),
) )
token.save() backend.save(token)
consent = models.Consent( consent = models.Consent(
consent_id="consent_id", subject=logged_admin, client=client, scope=["openid"] consent_id="consent_id", subject=logged_admin, client=client, scope=["openid"]
) )
consent.save() backend.save(consent)
authorization_code = models.AuthorizationCode( authorization_code = models.AuthorizationCode(
authorization_code_id="id", client=client, subject=logged_admin authorization_code_id="id", client=client, subject=logged_admin
) )
authorization_code.save() backend.save(authorization_code)
res = testclient.get("/admin/client/edit/" + client.client_id) res = testclient.get("/admin/client/edit/" + client.client_id)
res = res.forms["clientaddform"].submit(name="action", value="confirm-delete") res = res.forms["clientaddform"].submit(name="action", value="confirm-delete")

View file

@ -16,7 +16,7 @@ def test_authorizaton_list(testclient, authorization, logged_admin):
res.mustcontain(authorization.authorization_code_id) res.mustcontain(authorization.authorization_code_id)
def test_authorization_list_pagination(testclient, logged_admin, client): def test_authorization_list_pagination(testclient, logged_admin, client, backend):
res = testclient.get("/admin/authorization") res = testclient.get("/admin/authorization")
res.mustcontain("0 items") res.mustcontain("0 items")
authorizations = [] authorizations = []
@ -24,7 +24,7 @@ def test_authorization_list_pagination(testclient, logged_admin, client):
code = models.AuthorizationCode( code = models.AuthorizationCode(
authorization_code_id=gen_salt(48), client=client, subject=logged_admin authorization_code_id=gen_salt(48), client=client, subject=logged_admin
) )
code.save() backend.save(code)
authorizations.append(code) authorizations.append(code)
res = testclient.get("/admin/authorization") res = testclient.get("/admin/authorization")
@ -64,18 +64,18 @@ def test_authorization_list_bad_pages(testclient, logged_admin):
) )
def test_authorization_list_search(testclient, logged_admin, client): def test_authorization_list_search(testclient, logged_admin, client, backend):
id1 = gen_salt(48) id1 = gen_salt(48)
auth1 = models.AuthorizationCode( auth1 = models.AuthorizationCode(
authorization_code_id=id1, client=client, subject=logged_admin authorization_code_id=id1, client=client, subject=logged_admin
) )
auth1.save() backend.save(auth1)
id2 = gen_salt(48) id2 = gen_salt(48)
auth2 = models.AuthorizationCode( auth2 = models.AuthorizationCode(
authorization_code_id=id2, client=client, subject=logged_admin authorization_code_id=id2, client=client, subject=logged_admin
) )
auth2.save() backend.save(auth2)
res = testclient.get("/admin/authorization") res = testclient.get("/admin/authorization")
res.mustcontain("2 items") res.mustcontain("2 items")

View file

@ -139,13 +139,15 @@ def test_oidc_authorization_after_revokation(
assert token.subject == logged_user assert token.subject == logged_user
def test_preconsented_client_appears_in_consent_list(testclient, client, logged_user): def test_preconsented_client_appears_in_consent_list(
testclient, client, logged_user, backend
):
assert not client.preconsent assert not client.preconsent
res = testclient.get("/consent/pre-consents") res = testclient.get("/consent/pre-consents")
res.mustcontain(no=client.client_name) res.mustcontain(no=client.client_name)
client.preconsent = True client.preconsent = True
client.save() backend.save(client)
res = testclient.get("/consent/pre-consents") res = testclient.get("/consent/pre-consents")
res.mustcontain(client.client_name) res.mustcontain(client.client_name)
@ -153,7 +155,7 @@ def test_preconsented_client_appears_in_consent_list(testclient, client, logged_
def test_revoke_preconsented_client(testclient, client, logged_user, token, backend): def test_revoke_preconsented_client(testclient, client, logged_user, token, backend):
client.preconsent = True client.preconsent = True
client.save() backend.save(client)
assert not backend.get(models.Consent) assert not backend.get(models.Consent)
assert not token.revoked assert not token.revoked
@ -190,22 +192,22 @@ def test_revoke_invalid_preconsented_client(testclient, logged_user):
def test_revoke_preconsented_client_with_manual_consent( def test_revoke_preconsented_client_with_manual_consent(
testclient, logged_user, client, consent testclient, logged_user, client, consent, backend
): ):
client.preconsent = True client.preconsent = True
client.save() backend.save(client)
res = testclient.get(f"/consent/revoke-preconsent/{client.client_id}", status=302) res = testclient.get(f"/consent/revoke-preconsent/{client.client_id}", status=302)
res = res.follow() res = res.follow()
assert ("success", "The access has been revoked") in res.flashes assert ("success", "The access has been revoked") in res.flashes
def test_revoke_preconsented_client_with_manual_revokation( def test_revoke_preconsented_client_with_manual_revokation(
testclient, logged_user, client, consent testclient, logged_user, client, consent, backend
): ):
client.preconsent = True client.preconsent = True
client.save() backend.save(client)
consent.revoke() consent.revoke()
consent.save() backend.save(consent)
res = testclient.get(f"/consent/revoke-preconsent/{client.client_id}", status=302) res = testclient.get(f"/consent/revoke-preconsent/{client.client_id}", status=302)
res = res.follow() res = res.follow()

View file

@ -146,7 +146,7 @@ def test_delete(testclient, backend, user):
] ]
client = models.Client(client_id="foobar", client_name="Some client") client = models.Client(client_id="foobar", client_name="Some client")
client.save() backend.save(client)
headers = {"Authorization": "Bearer static-token"} headers = {"Authorization": "Bearer static-token"}
with warnings.catch_warnings(record=True): with warnings.catch_warnings(record=True):

View file

@ -10,7 +10,7 @@ def test_oauth_implicit(testclient, user, client, backend):
client.grant_types = ["token"] client.grant_types = ["token"]
client.token_endpoint_auth_method = "none" client.token_endpoint_auth_method = "none"
client.save() backend.save(client)
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
@ -48,14 +48,14 @@ def test_oauth_implicit(testclient, user, client, backend):
client.grant_types = ["code"] client.grant_types = ["code"]
client.token_endpoint_auth_method = "client_secret_basic" client.token_endpoint_auth_method = "client_secret_basic"
client.save() backend.save(client)
def test_oidc_implicit(testclient, keypair, user, client, trusted_client, backend): def test_oidc_implicit(testclient, keypair, user, client, trusted_client, backend):
client.grant_types = ["token id_token"] client.grant_types = ["token id_token"]
client.token_endpoint_auth_method = "none" client.token_endpoint_auth_method = "none"
client.save() backend.save(client)
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
@ -101,7 +101,7 @@ def test_oidc_implicit(testclient, keypair, user, client, trusted_client, backen
client.grant_types = ["code"] client.grant_types = ["code"]
client.token_endpoint_auth_method = "client_secret_basic" client.token_endpoint_auth_method = "client_secret_basic"
client.save() backend.save(client)
def test_oidc_implicit_with_group( def test_oidc_implicit_with_group(
@ -110,7 +110,7 @@ def test_oidc_implicit_with_group(
client.grant_types = ["token id_token"] client.grant_types = ["token id_token"]
client.token_endpoint_auth_method = "none" client.token_endpoint_auth_method = "none"
client.save() backend.save(client)
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
@ -157,4 +157,4 @@ def test_oidc_implicit_with_group(
client.grant_types = ["code"] client.grant_types = ["code"]
client.token_endpoint_auth_method = "client_secret_basic" client.token_endpoint_auth_method = "client_secret_basic"
client.save() backend.save(client)

View file

@ -33,7 +33,7 @@ def test_password_flow_basic(testclient, user, client, backend):
def test_password_flow_post(testclient, user, client, backend): def test_password_flow_post(testclient, user, client, backend):
client.token_endpoint_auth_method = "client_secret_post" client.token_endpoint_auth_method = "client_secret_post"
client.save() backend.save(client)
res = testclient.post( res = testclient.post(
"/oauth/token", "/oauth/token",

View file

@ -82,7 +82,7 @@ def test_refresh_token_with_invalid_user(testclient, client, backend):
emails=["temp@temp.com"], emails=["temp@temp.com"],
password="correct horse battery staple", password="correct horse battery staple",
) )
user.save() backend.save(user)
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
@ -137,7 +137,9 @@ def test_refresh_token_with_invalid_user(testclient, client, backend):
backend.get(models.Token, access_token=access_token).delete() backend.get(models.Token, access_token=access_token).delete()
def test_cannot_refresh_token_for_locked_users(testclient, logged_user, client): def test_cannot_refresh_token_for_locked_users(
testclient, logged_user, client, backend
):
"""Canaille should not issue new tokens for locked users.""" """Canaille should not issue new tokens for locked users."""
res = testclient.get( res = testclient.get(
"/oauth/authorize", "/oauth/authorize",
@ -167,7 +169,7 @@ def test_cannot_refresh_token_for_locked_users(testclient, logged_user, client):
) )
logged_user.lock_date = datetime.datetime.now(datetime.timezone.utc) logged_user.lock_date = datetime.datetime.now(datetime.timezone.utc)
logged_user.save() backend.save(logged_user)
res = testclient.post( res = testclient.post(
"/oauth/token", "/oauth/token",

View file

@ -18,7 +18,7 @@ def test_token_list(testclient, token, logged_admin):
res.mustcontain(token.token_id) res.mustcontain(token.token_id)
def test_token_list_pagination(testclient, logged_admin, client): def test_token_list_pagination(testclient, logged_admin, client, backend):
res = testclient.get("/admin/token") res = testclient.get("/admin/token")
res.mustcontain("0 items") res.mustcontain("0 items")
tokens = [] tokens = []
@ -36,7 +36,7 @@ def test_token_list_pagination(testclient, logged_admin, client):
), ),
lifetime=3600, lifetime=3600,
) )
token.save() backend.save(token)
tokens.append(token) tokens.append(token)
res = testclient.get("/admin/token") res = testclient.get("/admin/token")
@ -74,7 +74,7 @@ def test_token_list_bad_pages(testclient, logged_admin):
) )
def test_token_list_search(testclient, logged_admin, client): def test_token_list_search(testclient, logged_admin, client, backend):
token1 = models.Token( token1 = models.Token(
token_id=gen_salt(48), token_id=gen_salt(48),
access_token="this-token-is-ok", access_token="this-token-is-ok",
@ -88,7 +88,7 @@ def test_token_list_search(testclient, logged_admin, client):
), ),
lifetime=3600, lifetime=3600,
) )
token1.save() backend.save(token1)
token2 = models.Token( token2 = models.Token(
token_id=gen_salt(48), token_id=gen_salt(48),
access_token="this-token-is-valid", access_token="this-token-is-valid",
@ -102,7 +102,7 @@ def test_token_list_search(testclient, logged_admin, client):
), ),
lifetime=3600, lifetime=3600,
) )
token2.save() backend.save(token2)
res = testclient.get("/admin/token") res = testclient.get("/admin/token")
res.mustcontain("2 items") res.mustcontain("2 items")

View file

@ -63,11 +63,11 @@ def test_revoke_refresh_token_with_hint(testclient, user, client, token):
assert token.revokation_date assert token.revokation_date
def test_cannot_refresh_after_revocation(testclient, user, client, token): def test_cannot_refresh_after_revocation(testclient, user, client, token, backend):
token.revokation_date = datetime.datetime.now( token.revokation_date = datetime.datetime.now(
datetime.timezone.utc datetime.timezone.utc
) - datetime.timedelta(days=7) ) - datetime.timedelta(days=7)
token.save() backend.save(token)
res = testclient.post( res = testclient.post(
"/oauth/token", "/oauth/token",

View file

@ -146,9 +146,9 @@ def test_generate_user_claims(user, foo_group):
} }
def test_userinfo(testclient, token, user, foo_group): def test_userinfo(testclient, token, user, foo_group, backend):
token.scope = ["openid"] token.scope = ["openid"]
token.save() backend.save(token)
testclient.get( testclient.get(
"/oauth/userinfo", "/oauth/userinfo",
headers={"Authorization": f"Bearer {token.access_token}"}, headers={"Authorization": f"Bearer {token.access_token}"},
@ -156,7 +156,7 @@ def test_userinfo(testclient, token, user, foo_group):
) )
token.scope = ["openid", "profile"] token.scope = ["openid", "profile"]
token.save() backend.save(token)
res = testclient.get( res = testclient.get(
"/oauth/userinfo", "/oauth/userinfo",
headers={"Authorization": f"Bearer {token.access_token}"}, headers={"Authorization": f"Bearer {token.access_token}"},
@ -172,7 +172,7 @@ def test_userinfo(testclient, token, user, foo_group):
} }
token.scope = ["openid", "profile", "email"] token.scope = ["openid", "profile", "email"]
token.save() backend.save(token)
res = testclient.get( res = testclient.get(
"/oauth/userinfo", "/oauth/userinfo",
headers={"Authorization": f"Bearer {token.access_token}"}, headers={"Authorization": f"Bearer {token.access_token}"},
@ -189,7 +189,7 @@ def test_userinfo(testclient, token, user, foo_group):
} }
token.scope = ["openid", "profile", "address"] token.scope = ["openid", "profile", "address"]
token.save() backend.save(token)
res = testclient.get( res = testclient.get(
"/oauth/userinfo", "/oauth/userinfo",
headers={"Authorization": f"Bearer {token.access_token}"}, headers={"Authorization": f"Bearer {token.access_token}"},
@ -206,7 +206,7 @@ def test_userinfo(testclient, token, user, foo_group):
} }
token.scope = ["openid", "profile", "phone"] token.scope = ["openid", "profile", "phone"]
token.save() backend.save(token)
res = testclient.get( res = testclient.get(
"/oauth/userinfo", "/oauth/userinfo",
headers={"Authorization": f"Bearer {token.access_token}"}, headers={"Authorization": f"Bearer {token.access_token}"},
@ -223,7 +223,7 @@ def test_userinfo(testclient, token, user, foo_group):
} }
token.scope = ["openid", "profile", "groups"] token.scope = ["openid", "profile", "groups"]
token.save() backend.save(token)
res = testclient.get( res = testclient.get(
"/oauth/userinfo", "/oauth/userinfo",
headers={"Authorization": f"Bearer {token.access_token}"}, headers={"Authorization": f"Bearer {token.access_token}"},
@ -296,7 +296,7 @@ def test_claim_is_omitted_if_empty(testclient, backend, user):
# According to https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse # According to https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
# it's better to not insert a null or empty string value # it's better to not insert a null or empty string value
user.emails = [] user.emails = []
user.save() backend.save(user)
default_jwt_mapping = JWTSettings().model_dump() default_jwt_mapping = JWTSettings().model_dump()
data = generate_user_claims(user, STANDARD_CLAIMS, default_jwt_mapping) data = generate_user_claims(user, STANDARD_CLAIMS, default_jwt_mapping)