Refactored the unit test backend fixtures

This commit is contained in:
Éloi Rivard 2023-05-20 17:17:46 +02:00
parent 475a6d153c
commit 6f637b8129
26 changed files with 302 additions and 330 deletions

View file

@ -45,6 +45,18 @@ def setup_config(app, config=None, validate=True):
canaille.app.configuration.validate(app.config)
def setup_backend(app, backend):
from .backends.ldap.backend import LDAPBackend
if not backend:
backend = LDAPBackend(app.config)
backend.init_app(app)
with app.app_context():
g.backend = backend
app.backend = backend
def setup_sentry(app): # pragma: no cover
if not app.config.get("SENTRY_DSN"):
return None
@ -167,18 +179,17 @@ def setup_flask(app):
return render_template("error.html", error=500), 500
def create_app(config=None, validate=True):
def create_app(config=None, validate=True, backend=None):
app = Flask(__name__)
setup_config(app, config, validate)
sentry_sdk = setup_sentry(app)
try:
from .oidc.oauth import setup_oauth
from .backends.ldap.backend import LDAPBackend
from .app.i18n import setup_i18n
setup_logging(app)
LDAPBackend(app)
setup_backend(app, backend)
setup_oauth(app)
setup_blueprints(app)
setup_jinja(app)

View file

@ -1,18 +1,25 @@
class Backend:
def __init__(self, app):
self.app = app
from contextlib import contextmanager
@self.app.before_request
class Backend:
def __init__(self, config):
pass
def init_app(self, app):
@app.before_request
def before_request():
if not app.config["TESTING"]:
return self.setup()
@self.app.after_request
@app.after_request
def after_request(response):
if not app.config["TESTING"]:
self.teardown()
return response
@contextmanager
def session(self):
yield self.setup()
self.teardown()
def setup(self):
"""
This method will be called before each http request,

View file

@ -2,18 +2,28 @@ import logging
import uuid
import ldap
from canaille.app.configuration import ConfigurationException
from canaille.backends import Backend
from flask import g
from flask import render_template
from flask import request
from flask_babel import gettext as _
class LDAPBackend(Backend):
def __init__(self, app):
setup_ldap_models(app.config)
super().__init__(app)
instance = None
def __init__(self, config):
from canaille.oidc.installation import setup_ldap_tree
LDAPBackend.instance = self
self.config = config
self.connection = None
setup_ldap_models(config)
setup_ldap_tree(config)
super().__init__(config)
@classmethod
def get(cls):
return cls.instance
def setup(self):
try: # pragma: no cover
@ -23,21 +33,19 @@ class LDAPBackend(Backend):
pass
try:
g.ldap_connection = ldap.initialize(
self.app.config["BACKENDS"]["LDAP"]["URI"]
)
g.ldap_connection.set_option(
self.connection = ldap.initialize(self.config["BACKENDS"]["LDAP"]["URI"])
self.connection.set_option(
ldap.OPT_NETWORK_TIMEOUT,
self.app.config["BACKENDS"]["LDAP"].get("TIMEOUT"),
self.config["BACKENDS"]["LDAP"].get("TIMEOUT"),
)
g.ldap_connection.simple_bind_s(
self.app.config["BACKENDS"]["LDAP"]["BIND_DN"],
self.app.config["BACKENDS"]["LDAP"]["BIND_PW"],
self.connection.simple_bind_s(
self.config["BACKENDS"]["LDAP"]["BIND_DN"],
self.config["BACKENDS"]["LDAP"]["BIND_PW"],
)
except ldap.SERVER_DOWN:
message = _("Could not connect to the LDAP server '{uri}'").format(
uri=self.app.config["BACKENDS"]["LDAP"]["URI"]
uri=self.config["BACKENDS"]["LDAP"]["URI"]
)
logging.error(message)
return (
@ -45,7 +53,7 @@ class LDAPBackend(Backend):
"error.html",
error=500,
icon="database",
debug=self.app.config.get("DEBUG", False),
debug=self.config.get("DEBUG", False),
description=message,
),
500,
@ -53,7 +61,7 @@ class LDAPBackend(Backend):
except ldap.INVALID_CREDENTIALS:
message = _("LDAP authentication failed with user '{user}'").format(
user=self.app.config["BACKENDS"]["LDAP"]["BIND_DN"]
user=self.config["BACKENDS"]["LDAP"]["BIND_DN"]
)
logging.error(message)
return (
@ -61,19 +69,20 @@ class LDAPBackend(Backend):
"error.html",
error=500,
icon="key",
debug=self.app.config.get("DEBUG", False),
debug=self.config.get("DEBUG", False),
description=message,
),
500,
)
def teardown(self):
if g.get("ldap_connection"): # pragma: no branch
g.ldap_connection.unbind_s()
g.ldap_connection = None
if self.connection: # pragma: no branch
self.connection.unbind_s()
self.connection = None
@classmethod
def validate(cls, config):
from canaille.app.configuration import ConfigurationException
from canaille.core.models import Group
from canaille.core.models import User

View file

@ -3,8 +3,8 @@ from collections.abc import Iterable
import ldap.dn
import ldap.filter
from flask import g
from .backend import LDAPBackend
from .utils import ldap_to_python
from .utils import python_to_ldap
@ -214,13 +214,9 @@ class LDAPObject(metaclass=LDAPObjectMetaclass):
def must(cls):
return cls._must
@classmethod
def ldap_connection(cls):
return g.ldap_connection
@classmethod
def install(cls, conn=None):
conn = conn or cls.ldap_connection()
conn = conn or LDAPBackend.get().connection
cls.ldap_object_classes(conn)
cls.ldap_object_attributes(conn)
@ -245,7 +241,7 @@ class LDAPObject(metaclass=LDAPObjectMetaclass):
if cls._object_class_by_name and not force:
return cls._object_class_by_name
conn = conn or cls.ldap_connection()
conn = conn or LDAPBackend.get().connection
res = conn.search_s(
"cn=subschema", ldap.SCOPE_BASE, "(objectclass=*)", ["*", "+"]
@ -267,7 +263,7 @@ class LDAPObject(metaclass=LDAPObjectMetaclass):
if cls._attribute_type_by_name and not force:
return cls._attribute_type_by_name
conn = conn or cls.ldap_connection()
conn = conn or LDAPBackend.get().connection
res = conn.search_s(
"cn=subschema", ldap.SCOPE_BASE, "(objectclass=*)", ["*", "+"]
@ -293,7 +289,7 @@ class LDAPObject(metaclass=LDAPObjectMetaclass):
@classmethod
def query(cls, id=None, filter=None, conn=None, **kwargs):
conn = conn or cls.ldap_connection()
conn = conn or LDAPBackend.get().connection
base = id or kwargs.get("id")
if base is None:
@ -375,13 +371,13 @@ class LDAPObject(metaclass=LDAPObjectMetaclass):
cls._must = list(set(cls._must))
def reload(self, conn=None):
conn = conn or self.ldap_connection()
conn = conn or LDAPBackend.get().connection
result = conn.search_s(self.id, ldap.SCOPE_SUBTREE, None, ["+", "*"])
self.changes = {}
self.state = result[0][1]
def save(self, conn=None):
conn = conn or self.ldap_connection()
conn = conn or LDAPBackend.get().connection
setattr(self, "objectClass", self.ldap_object_class)
@ -436,5 +432,5 @@ class LDAPObject(metaclass=LDAPObjectMetaclass):
self.__setattr__(k, v)
def delete(self, conn=None):
conn = conn or self.ldap_connection()
conn = conn or LDAPBackend.get().connection
conn.delete_s(self.id)

View file

@ -1,4 +1,5 @@
import ldap.filter
from canaille.backends.ldap.backend import LDAPBackend
from canaille.backends.ldap.ldapobject import LDAPObject
from flask import current_app
from flask import session
@ -124,7 +125,7 @@ class User(LDAPObject):
conn.unbind_s()
def set_password(self, password):
conn = self.ldap_connection()
conn = LDAPBackend.get().connection
conn.passwd_s(
self.id,
None,
@ -162,7 +163,7 @@ class User(LDAPObject):
self.state[group_attr] = new_groups
def load_permissions(self):
conn = self.ldap_connection()
conn = LDAPBackend.get().connection
for access_group_name, details in current_app.config["ACL"].items():
filter_ = self.acl_filter_to_ldap_filter(details.get("FILTER"))

View file

@ -14,16 +14,13 @@ def create_app():
@app.before_first_request
def populate():
from canaille.backends.ldap.backend import LDAPBackend
from canaille.core.models import Group
from canaille.core.models import User
from canaille.core.populate import fake_groups
from canaille.core.populate import fake_users
from canaille.oidc.models import Client
backend = LDAPBackend(app)
backend.setup()
with app.backend.session():
jane = User(
formatted_name="Jane Doe",
given_name="Jane",
@ -142,6 +139,4 @@ def create_app():
fake_users(50)
fake_groups(10, nb_users_max=10)
backend.teardown()
return app

View file

@ -7,9 +7,7 @@ from canaille.app.configuration import validate
from flask_webtest import TestApp
def test_smtp_connection_remote_smtp_unreachable(
testclient, slapd_connection, configuration
):
def test_smtp_connection_remote_smtp_unreachable(testclient, backend, configuration):
configuration["SMTP"]["HOST"] = "smtp://invalid-smtp.com"
with pytest.raises(
ConfigurationException,
@ -19,7 +17,7 @@ def test_smtp_connection_remote_smtp_unreachable(
def test_smtp_connection_remote_smtp_wrong_credentials(
testclient, slapd_connection, configuration
testclient, backend, configuration
):
configuration["SMTP"]["PASSWORD"] = "invalid-password"
with pytest.raises(
@ -29,15 +27,13 @@ def test_smtp_connection_remote_smtp_wrong_credentials(
validate(configuration, validate_remote=True)
def test_smtp_connection_remote_smtp_no_credentials(
testclient, slapd_connection, configuration
):
def test_smtp_connection_remote_smtp_no_credentials(testclient, backend, configuration):
del configuration["SMTP"]["LOGIN"]
del configuration["SMTP"]["PASSWORD"]
validate(configuration, validate_remote=True)
def test_smtp_bad_tls(testclient, slapd_connection, smtpd, configuration):
def test_smtp_bad_tls(testclient, backend, smtpd, configuration):
configuration["SMTP"]["TLS"] = False
with pytest.raises(
ConfigurationException,
@ -50,7 +46,7 @@ def test_smtp_bad_tls(testclient, slapd_connection, smtpd, configuration):
def themed_testclient(
app,
configuration,
slapd_connection,
backend,
):
configuration["TESTING"] = True
@ -63,7 +59,7 @@ def themed_testclient(
return TestApp(app)
def test_theme(testclient, themed_testclient, slapd_connection):
def test_theme(testclient, themed_testclient, backend):
res = testclient.get("/login")
res.mustcontain(no="TEST_THEME")
@ -71,7 +67,7 @@ def test_theme(testclient, themed_testclient, slapd_connection):
res.mustcontain("TEST_THEME")
def test_invalid_theme(configuration, slapd_connection):
def test_invalid_theme(configuration, backend):
validate(configuration, validate_remote=False)
with pytest.raises(

View file

@ -44,14 +44,14 @@ def test_no_configuration():
assert "No configuration file found." in str(exc)
def test_logging_to_file(configuration, tmp_path, smtpd, admin, slapd_server):
def test_logging_to_file(configuration, backend, tmp_path, smtpd, admin, slapd_server):
assert len(smtpd.messages) == 0
log_path = os.path.join(tmp_path, "canaille.log")
logging_configuration = {
**configuration,
"LOGGING": {"LEVEL": "DEBUG", "PATH": log_path},
}
app = create_app(logging_configuration)
app = create_app(logging_configuration, backend=backend)
testclient = TestApp(app)
with testclient.session_transaction() as sess:

View file

@ -185,8 +185,8 @@ def test_default_from_addr(testclient, user, smtpd):
assert smtpd.messages[0]["From"] == '"Canaille" <admin@localhost>'
def test_default_from_flask_server_name(configuration, user, smtpd, slapd_server):
app = create_app(configuration)
def test_default_from_flask_server_name(configuration, user, smtpd, backend):
app = create_app(configuration, backend=backend)
del app.config["SMTP"]["FROM_ADDR"]
app.config["SERVER_NAME"] = "foobar.tld"

View file

@ -1,26 +1,32 @@
def test_ldap_connection_remote_ldap_unreachable(testclient):
testclient.app.config["TESTING"] = False
from canaille import create_app
from flask_webtest import TestApp
testclient.app.config["BACKENDS"]["LDAP"]["URI"] = "ldap://invalid-ldap.com"
testclient.app.config["DEBUG"] = True
def test_ldap_connection_remote_ldap_unreachable(configuration):
app = create_app(configuration)
testclient = TestApp(app)
app.config["BACKENDS"]["LDAP"]["URI"] = "ldap://invalid-ldap.com"
app.config["DEBUG"] = True
res = testclient.get("/", status=500, expect_errors=True)
res.mustcontain("Could not connect to the LDAP server")
testclient.app.config["DEBUG"] = False
app.config["DEBUG"] = False
res = testclient.get("/", status=500, expect_errors=True)
res.mustcontain(no="Could not connect to the LDAP server")
def test_ldap_connection_remote_ldap_wrong_credentials(testclient):
testclient.app.config["TESTING"] = False
def test_ldap_connection_remote_ldap_wrong_credentials(configuration):
app = create_app(configuration)
testclient = TestApp(app)
testclient.app.config["BACKENDS"]["LDAP"]["BIND_PW"] = "invalid-password"
app.config["BACKENDS"]["LDAP"]["BIND_PW"] = "invalid-password"
testclient.app.config["DEBUG"] = True
app.config["DEBUG"] = True
res = testclient.get("/", status=500, expect_errors=True)
res.mustcontain("LDAP authentication failed with user")
testclient.app.config["DEBUG"] = False
app.config["DEBUG"] = False
res = testclient.get("/", status=500, expect_errors=True)
res.mustcontain(no="LDAP authentication failed with user")

View file

@ -15,8 +15,7 @@ from canaille.core.models import Group
from canaille.core.models import User
def test_object_creation(slapd_connection):
User.install(slapd_connection)
def test_object_creation(app, backend):
user = User(
formatted_name="Doe", # leading space
family_name="Doe",
@ -33,19 +32,12 @@ def test_object_creation(slapd_connection):
user.delete()
def test_repr(slapd_connection, foo_group, user):
def test_repr(backend, foo_group, user):
assert repr(foo_group) == "<Group display_name=foo>"
assert repr(user) == "<User formatted_name=John (johnny) Doe>"
def test_equality(slapd_connection, foo_group, bar_group):
assert foo_group != bar_group
foo_group2 = Group.get(id=foo_group.id)
assert foo_group == foo_group2
def test_dn_when_leading_space_in_id_attribute(slapd_connection):
User.install(slapd_connection)
def test_dn_when_leading_space_in_id_attribute(backend):
user = User(
formatted_name=" Doe", # leading space
family_name="Doe",
@ -61,8 +53,7 @@ def test_dn_when_leading_space_in_id_attribute(slapd_connection):
user.delete()
def test_dn_when_ldap_special_char_in_id_attribute(slapd_connection):
User.install(slapd_connection)
def test_dn_when_ldap_special_char_in_id_attribute(backend):
user = User(
formatted_name="#Doe", # special char
family_name="Doe",
@ -78,7 +69,7 @@ def test_dn_when_ldap_special_char_in_id_attribute(slapd_connection):
user.delete()
def test_filter(slapd_connection, foo_group, bar_group):
def test_filter(backend, foo_group, bar_group):
assert Group.query(display_name="foo") == [foo_group]
assert Group.query(display_name="bar") == [bar_group]
@ -90,7 +81,7 @@ def test_filter(slapd_connection, foo_group, bar_group):
assert set(Group.query(display_name=["foo", "bar"])) == {foo_group, bar_group}
def test_fuzzy(slapd_connection, user, moderator, admin):
def test_fuzzy(backend, user, moderator, admin):
assert set(User.query()) == {user, moderator, admin}
assert set(User.fuzzy("Jack")) == {moderator}
assert set(User.fuzzy("Jack", ["formatted_name"])) == {moderator}
@ -170,9 +161,9 @@ def test_python_to_ldap():
assert ldap_to_python(b"foobar", Syntax.JPEG) == b"foobar"
def test_operational_attribute_conversion(slapd_connection):
assert "oauthClientName" in LDAPObject.ldap_object_attributes(slapd_connection)
assert "invalidAttribute" not in LDAPObject.ldap_object_attributes(slapd_connection)
def test_operational_attribute_conversion(backend):
assert "oauthClientName" in LDAPObject.ldap_object_attributes(backend)
assert "invalidAttribute" not in LDAPObject.ldap_object_attributes(backend)
assert python_attrs_to_ldap(
{
@ -185,7 +176,7 @@ def test_operational_attribute_conversion(slapd_connection):
}
def test_guess_object_from_dn(slapd_connection, testclient, foo_group):
def test_guess_object_from_dn(backend, testclient, foo_group):
foo_group.members = [foo_group]
foo_group.save()
g = LDAPObject.get(id=foo_group.dn)
@ -197,7 +188,7 @@ def test_guess_object_from_dn(slapd_connection, testclient, foo_group):
assert isinstance(ou, LDAPObject)
def test_object_class_update(slapd_connection, testclient):
def test_object_class_update(backend, testclient):
testclient.app.config["BACKENDS"]["LDAP"]["USER_CLASS"] = ["inetOrgPerson"]
setup_ldap_models(testclient.app.config)
@ -234,7 +225,7 @@ def test_ldap_connection_no_remote(testclient, configuration):
validate(configuration)
def test_ldap_connection_remote(testclient, configuration, slapd_connection):
def test_ldap_connection_remote(testclient, configuration, backend):
validate(configuration, validate_remote=True)
@ -256,7 +247,7 @@ def test_ldap_connection_remote_ldap_wrong_credentials(testclient, configuration
validate(configuration, validate_remote=True)
def test_ldap_cannot_create_users(testclient, configuration, slapd_connection):
def test_ldap_cannot_create_users(testclient, configuration, backend):
from canaille.core.models import User
def fake_init(*args, **kwarg):
@ -270,7 +261,7 @@ def test_ldap_cannot_create_users(testclient, configuration, slapd_connection):
validate(configuration, validate_remote=True)
def test_ldap_cannot_create_groups(testclient, configuration, slapd_connection):
def test_ldap_cannot_create_groups(testclient, configuration, backend):
from canaille.core.models import Group
def fake_init(*args, **kwarg):

View file

@ -6,7 +6,7 @@ def test_required_methods(testclient):
with pytest.raises(NotImplementedError):
Backend.validate({})
backend = Backend(testclient.app)
backend = Backend(testclient.app.config)
with pytest.raises(NotImplementedError):
backend.setup()

View file

@ -2,7 +2,7 @@ from canaille.core.models import Group
from canaille.core.models import User
def test_model_comparison(testclient, slapd_connection):
def test_model_comparison(testclient, backend):
foo1 = User(
user_name="foo",
family_name="foo",
@ -24,7 +24,7 @@ def test_model_comparison(testclient, slapd_connection):
bar.delete()
def test_model_lifecycle(testclient, slapd_connection, slapd_server):
def test_model_lifecycle(testclient, backend, slapd_server):
user = User(
user_name="user_name",
family_name="family_name",
@ -57,7 +57,7 @@ def test_model_lifecycle(testclient, slapd_connection, slapd_server):
assert not User.get(id=user.id)
def test_model_attribute_edition(testclient, slapd_connection):
def test_model_attribute_edition(testclient, backend):
user = User(
user_name="user_name",
family_name="family_name",
@ -95,7 +95,7 @@ def test_model_attribute_edition(testclient, slapd_connection):
user.delete()
def test_model_indexation(testclient, slapd_connection):
def test_model_indexation(testclient, backend):
user = User(
user_name="user_name",
family_name="family_name",
@ -136,7 +136,7 @@ def test_model_indexation(testclient, slapd_connection):
assert not User.get(email="email3@user.com")
def test_fuzzy(user, moderator, admin, slapd_connection):
def test_fuzzy(user, moderator, admin, backend):
assert set(User.query()) == {user, moderator, admin}
assert set(User.fuzzy("Jack")) == {moderator}
assert set(User.fuzzy("Jack", ["formatted_name"])) == {moderator}
@ -149,9 +149,7 @@ def test_fuzzy(user, moderator, admin, slapd_connection):
# def test_model_references(user, admin, foo_group, bar_group):
def test_model_references(
testclient, user, foo_group, admin, bar_group, slapd_connection
):
def test_model_references(testclient, user, foo_group, admin, bar_group, backend):
assert foo_group.members == [user]
assert user.groups == [foo_group]
assert foo_group in Group.query(members=user)
@ -175,7 +173,7 @@ def test_model_references(
def test_model_references_set_unsaved_object(
testclient, logged_moderator, user, slapd_connection
testclient, logged_moderator, user, backend
):
group = Group(members=[user], display_name="foo")
group.save()

View file

@ -1,12 +1,9 @@
import ldap.ldapobject
import pytest
import slapd
from canaille import create_app
from canaille.backends.ldap.backend import setup_ldap_models
from canaille.backends.ldap.backend import LDAPBackend
from canaille.core.models import Group
from canaille.core.models import User
from canaille.oidc.installation import setup_ldap_tree
from flask import g
from flask_webtest import TestApp
from werkzeug.security import gen_salt
@ -63,15 +60,6 @@ def slapd_server():
slapd.stop()
@pytest.fixture
def slapd_connection(slapd_server, testclient):
g.ldap_connection = ldap.ldapobject.SimpleLDAPObject(slapd_server.ldap_uri)
g.ldap_connection.protocol_version = 3
g.ldap_connection.simple_bind_s(slapd_server.root_dn, slapd_server.root_pw)
yield g.ldap_connection
g.ldap_connection.unbind_s()
@pytest.fixture
def configuration(slapd_server, smtpd):
smtpd.config.use_starttls = True
@ -148,11 +136,15 @@ def configuration(slapd_server, smtpd):
@pytest.fixture
def app(configuration):
setup_ldap_models(configuration)
setup_ldap_tree(configuration)
app = create_app(configuration)
return app
def backend(slapd_server, configuration):
backend = LDAPBackend(configuration)
with backend.session():
yield backend
@pytest.fixture
def app(configuration, backend):
return create_app(configuration, backend=backend)
@pytest.fixture
@ -163,7 +155,7 @@ def testclient(app):
@pytest.fixture
def user(app, slapd_connection):
def user(app, backend):
u = User(
formatted_name="John (johnny) Doe",
given_name="John",
@ -183,7 +175,7 @@ def user(app, slapd_connection):
@pytest.fixture
def admin(app, slapd_connection):
def admin(app, backend):
u = User(
formatted_name="Jane Doe",
family_name="Doe",
@ -197,7 +189,7 @@ def admin(app, slapd_connection):
@pytest.fixture
def moderator(app, slapd_connection):
def moderator(app, backend):
u = User(
formatted_name="Jack Doe",
family_name="Doe",
@ -232,7 +224,7 @@ def logged_moderator(moderator, testclient):
@pytest.fixture
def foo_group(app, user, slapd_connection):
def foo_group(app, user, backend):
group = Group(
members=[user],
display_name="foo",
@ -244,7 +236,7 @@ def foo_group(app, user, slapd_connection):
@pytest.fixture
def bar_group(app, admin, slapd_connection):
def bar_group(app, admin, backend):
group = Group(
members=[admin],
display_name="bar",

View file

@ -4,7 +4,7 @@ from canaille.core.models import User
from canaille.core.populate import fake_users
def test_populate_users(testclient, slapd_connection):
def test_populate_users(testclient, backend):
runner = testclient.app.test_cli_runner()
assert len(User.query()) == 0
@ -15,7 +15,7 @@ def test_populate_users(testclient, slapd_connection):
user.delete()
def test_populate_groups(testclient, slapd_connection):
def test_populate_groups(testclient, backend):
fake_users(10)
runner = testclient.app.test_cli_runner()

View file

@ -110,7 +110,7 @@ def test_password_page_without_signin_in_redirects_to_login_page(testclient, use
assert res.location == "/login"
def test_user_without_password_first_login(testclient, slapd_connection, smtpd):
def test_user_without_password_first_login(testclient, backend, smtpd):
assert len(smtpd.messages) == 0
u = User(
formatted_name="Temp User",
@ -141,7 +141,7 @@ def test_user_without_password_first_login(testclient, slapd_connection, smtpd):
@mock.patch("smtplib.SMTP")
def test_first_login_account_initialization_mail_sending_failed(
SMTP, testclient, slapd_connection, smtpd
SMTP, testclient, backend, smtpd
):
SMTP.side_effect = mock.Mock(side_effect=OSError("unit test mail error"))
assert len(smtpd.messages) == 0
@ -166,7 +166,7 @@ def test_first_login_account_initialization_mail_sending_failed(
u.delete()
def test_first_login_form_error(testclient, slapd_connection, smtpd):
def test_first_login_form_error(testclient, backend, smtpd):
assert len(smtpd.messages) == 0
u = User(
formatted_name="Temp User",
@ -184,12 +184,12 @@ def test_first_login_form_error(testclient, slapd_connection, smtpd):
def test_first_login_page_unavailable_for_users_with_password(
testclient, slapd_connection, user
testclient, backend, user
):
testclient.get("/firstlogin/user", status=404)
def test_user_password_deleted_during_login(testclient, slapd_connection):
def test_user_password_deleted_during_login(testclient, backend):
u = User(
formatted_name="Temp User",
family_name="Temp",
@ -213,7 +213,7 @@ def test_user_password_deleted_during_login(testclient, slapd_connection):
u.delete()
def test_user_deleted_in_session(testclient, slapd_connection):
def test_user_deleted_in_session(testclient, backend):
u = User(
formatted_name="Jake Doe",
family_name="Jake",
@ -277,7 +277,7 @@ def test_wrong_login(testclient, user):
res.mustcontain("The login &#39;invalid&#39; does not exist")
def test_admin_self_deletion(testclient, slapd_connection):
def test_admin_self_deletion(testclient, backend):
admin = User(
formatted_name="Temp admin",
family_name="admin",
@ -302,7 +302,7 @@ def test_admin_self_deletion(testclient, slapd_connection):
assert not sess.get("user_id")
def test_user_self_deletion(testclient, slapd_connection):
def test_user_self_deletion(testclient, backend):
user = User(
formatted_name="Temp user",
family_name="user",

View file

@ -4,7 +4,7 @@ from canaille.core.populate import fake_groups
from canaille.core.populate import fake_users
def test_no_group(app, slapd_connection):
def test_no_group(app, backend):
assert Group.query() == []
@ -48,7 +48,7 @@ def test_group_list_bad_pages(testclient, logged_admin):
)
def test_group_deletion(testclient, slapd_server, slapd_connection):
def test_group_deletion(testclient, slapd_server, backend):
user = User(
formatted_name="foobar",
family_name="foobar",

View file

@ -1,12 +1,12 @@
from canaille.core.models import User
def test_user_get_from_login(testclient, user, slapd_connection):
def test_user_get_from_login(testclient, user, backend):
assert User.get_from_login(login="invalid") is None
assert User.get_from_login(login="user") == user
def test_user_has_password(testclient, slapd_connection):
def test_user_has_password(testclient, backend):
u = User(
formatted_name="Temp User",
family_name="Temp",
@ -25,7 +25,7 @@ def test_user_has_password(testclient, slapd_connection):
u.delete()
def test_user_set_and_check_password(testclient, user, slapd_connection):
def test_user_set_and_check_password(testclient, user, backend):
assert not user.check_password("something else")
assert user.check_password("correct horse battery staple")

View file

@ -127,9 +127,7 @@ def test_password_change_fail(testclient, logged_user):
assert logged_user.check_password("correct horse battery staple")
def test_password_initialization_mail(
smtpd, testclient, slapd_connection, logged_admin
):
def test_password_initialization_mail(smtpd, testclient, backend, logged_admin):
u = User(
formatted_name="Temp User",
family_name="Temp",
@ -163,7 +161,7 @@ def test_password_initialization_mail(
@mock.patch("smtplib.SMTP")
def test_password_initialization_mail_send_fail(
SMTP, smtpd, testclient, slapd_connection, logged_admin
SMTP, smtpd, testclient, backend, logged_admin
):
SMTP.side_effect = mock.Mock(side_effect=OSError("unit test mail error"))
u = User(
@ -192,9 +190,7 @@ def test_password_initialization_mail_send_fail(
u.delete()
def test_password_initialization_invalid_user(
smtpd, testclient, slapd_connection, logged_admin
):
def test_password_initialization_invalid_user(smtpd, testclient, backend, logged_admin):
assert len(smtpd.messages) == 0
res = testclient.get("/profile/admin/settings")
testclient.post(
@ -208,7 +204,7 @@ def test_password_initialization_invalid_user(
assert len(smtpd.messages) == 0
def test_password_reset_invalid_user(smtpd, testclient, slapd_connection, logged_admin):
def test_password_reset_invalid_user(smtpd, testclient, backend, logged_admin):
assert len(smtpd.messages) == 0
res = testclient.get("/profile/admin/settings")
testclient.post(
@ -219,7 +215,7 @@ def test_password_reset_invalid_user(smtpd, testclient, slapd_connection, logged
assert len(smtpd.messages) == 0
def test_delete_invalid_user(testclient, slapd_connection, logged_admin):
def test_delete_invalid_user(testclient, backend, logged_admin):
res = testclient.get("/profile/admin/settings")
testclient.post(
"/profile/invalid/settings",
@ -228,7 +224,7 @@ def test_delete_invalid_user(testclient, slapd_connection, logged_admin):
)
def test_impersonate_invalid_user(testclient, slapd_connection, logged_admin):
def test_impersonate_invalid_user(testclient, backend, logged_admin):
testclient.get("/impersonate/invalid", status=404)
@ -237,7 +233,7 @@ def test_invalid_form_request(testclient, logged_admin):
res = res.form.submit(name="action", value="invalid-action", status=400)
def test_password_reset_email(smtpd, testclient, slapd_connection, logged_admin):
def test_password_reset_email(smtpd, testclient, backend, logged_admin):
u = User(
formatted_name="Temp User",
family_name="Temp",
@ -264,9 +260,7 @@ def test_password_reset_email(smtpd, testclient, slapd_connection, logged_admin)
@mock.patch("smtplib.SMTP")
def test_password_reset_email_failed(
SMTP, smtpd, testclient, slapd_connection, logged_admin
):
def test_password_reset_email_failed(SMTP, smtpd, testclient, backend, logged_admin):
SMTP.side_effect = mock.Mock(side_effect=OSError("unit test mail error"))
u = User(
formatted_name="Temp User",

View file

@ -6,7 +6,7 @@ from canaille.oidc.models import Token
from werkzeug.security import gen_salt
def test_clean_command(testclient, slapd_connection, client, user):
def test_clean_command(testclient, backend, client, user):
valid_code = AuthorizationCode(
authorization_code_id=gen_salt(48),
code="my-valid-code",

View file

@ -70,7 +70,7 @@ def configuration(configuration, keypair_path):
@pytest.fixture
def client(testclient, other_client, slapd_connection):
def client(testclient, other_client, backend):
c = Client(
client_id=gen_salt(24),
client_name="Some client",
@ -106,7 +106,7 @@ def client(testclient, other_client, slapd_connection):
@pytest.fixture
def other_client(testclient, slapd_connection):
def other_client(testclient, backend):
c = Client(
client_id=gen_salt(24),
client_name="Some other client",
@ -142,7 +142,7 @@ def other_client(testclient, slapd_connection):
@pytest.fixture
def authorization(testclient, user, client, slapd_connection):
def authorization(testclient, user, client, backend):
a = AuthorizationCode(
authorization_code_id=gen_salt(48),
code="my-code",
@ -164,7 +164,7 @@ def authorization(testclient, user, client, slapd_connection):
@pytest.fixture
def token(testclient, client, user, slapd_connection):
def token(testclient, client, user, backend):
t = Token(
token_id=gen_salt(48),
access_token=gen_salt(48),
@ -183,7 +183,7 @@ def token(testclient, client, user, slapd_connection):
@pytest.fixture
def id_token(testclient, client, user, slapd_connection):
def id_token(testclient, client, user, backend):
return generate_id_token(
{},
generate_user_info(user, client.scope),
@ -193,7 +193,7 @@ def id_token(testclient, client, user, slapd_connection):
@pytest.fixture
def consent(testclient, client, user, slapd_connection):
def consent(testclient, client, user, backend):
t = Consent(
consent_id=str(uuid.uuid4()),
client=client,

View file

@ -5,7 +5,7 @@ from canaille.oidc.models import Client
def test_client_registration_with_authentication_static_token(
testclient, slapd_connection, client, user
testclient, backend, client, user
):
assert not testclient.app.config.get("OIDC", {}).get(
"DYNAMIC_CLIENT_REGISTRATION_OPEN"
@ -60,7 +60,7 @@ def test_client_registration_with_authentication_static_token(
def test_client_registration_with_authentication_no_token(
testclient, slapd_connection, client, user
testclient, backend, client, user
):
assert not testclient.app.config.get("OIDC", {}).get(
"DYNAMIC_CLIENT_REGISTRATION_OPEN"
@ -94,7 +94,7 @@ def test_client_registration_with_authentication_no_token(
def test_client_registration_with_authentication_invalid_token(
testclient, slapd_connection, client, user
testclient, backend, client, user
):
assert not testclient.app.config.get("OIDC", {}).get(
"DYNAMIC_CLIENT_REGISTRATION_OPEN"
@ -120,9 +120,7 @@ def test_client_registration_with_authentication_invalid_token(
}
def test_client_registration_with_software_statement(
testclient, slapd_connection, keypair_path
):
def test_client_registration_with_software_statement(testclient, backend, keypair_path):
private_key_path, _ = keypair_path
testclient.app.config["OIDC"]["DYNAMIC_CLIENT_REGISTRATION_OPEN"] = True
@ -176,7 +174,7 @@ def test_client_registration_with_software_statement(
client.delete()
def test_client_registration_without_authentication_ok(testclient, slapd_connection):
def test_client_registration_without_authentication_ok(testclient, backend):
testclient.app.config["OIDC"]["DYNAMIC_CLIENT_REGISTRATION_OPEN"] = True
payload = {

View file

@ -4,7 +4,7 @@ from datetime import datetime
from canaille.oidc.models import Client
def test_get(testclient, slapd_connection, client, user):
def test_get(testclient, backend, client, user):
assert not testclient.app.config.get("OIDC", {}).get(
"DYNAMIC_CLIENT_REGISTRATION_OPEN"
)
@ -50,7 +50,7 @@ def test_get(testclient, slapd_connection, client, user):
}
def test_update(testclient, slapd_connection, client, user):
def test_update(testclient, backend, client, user):
assert not testclient.app.config.get("OIDC", {}).get(
"DYNAMIC_CLIENT_REGISTRATION_OPEN"
)
@ -137,7 +137,7 @@ def test_update(testclient, slapd_connection, client, user):
assert client.software_version == "3.14"
def test_delete(testclient, slapd_connection, user):
def test_delete(testclient, backend, user):
assert not testclient.app.config.get("OIDC", {}).get(
"DYNAMIC_CLIENT_REGISTRATION_OPEN"
)
@ -156,7 +156,7 @@ def test_delete(testclient, slapd_connection, user):
assert not Client.get(client_id=client.client_id)
def test_invalid_client(testclient, slapd_connection, user):
def test_invalid_client(testclient, backend, user):
assert not testclient.app.config.get("OIDC", {}).get(
"DYNAMIC_CLIENT_REGISTRATION_OPEN"
)

View file

@ -3,7 +3,7 @@ from canaille.oidc.oauth import generate_user_info
from canaille.oidc.oauth import get_jwt_config
def test_end_session(testclient, slapd_connection, logged_user, client, id_token):
def test_end_session(testclient, backend, logged_user, client, id_token):
testclient.get(f"/profile/{logged_user.user_name[0]}", status=200)
post_logout_redirect_url = "https://mydomain.tld/disconnected"
@ -27,9 +27,7 @@ def test_end_session(testclient, slapd_connection, logged_user, client, id_token
testclient.get(f"/profile/{logged_user.user_name[0]}", status=403)
def test_end_session_no_client_id(
testclient, slapd_connection, logged_user, client, id_token
):
def test_end_session_no_client_id(testclient, backend, logged_user, client, id_token):
testclient.get(f"/profile/{logged_user.user_name[0]}", status=200)
post_logout_redirect_url = "https://mydomain.tld/disconnected"
@ -53,7 +51,7 @@ def test_end_session_no_client_id(
def test_no_redirect_uri_no_redirect(
testclient, slapd_connection, logged_user, client, id_token
testclient, backend, logged_user, client, id_token
):
testclient.get(f"/profile/{logged_user.user_name[0]}", status=200)
@ -77,7 +75,7 @@ def test_no_redirect_uri_no_redirect(
def test_bad_redirect_uri_no_redirect(
testclient, slapd_connection, logged_user, client, id_token
testclient, backend, logged_user, client, id_token
):
testclient.get(f"/profile/{logged_user.user_name[0]}", status=200)
@ -102,9 +100,7 @@ def test_bad_redirect_uri_no_redirect(
testclient.get(f"/profile/{logged_user.user_name[0]}", status=403)
def test_no_client_hint_no_redirect(
testclient, slapd_connection, logged_user, client, id_token
):
def test_no_client_hint_no_redirect(testclient, backend, logged_user, client, id_token):
testclient.get(f"/profile/{logged_user.user_name[0]}", status=200)
post_logout_redirect_url = "https://mydomain.tld/disconnected"
@ -128,9 +124,7 @@ def test_no_client_hint_no_redirect(
testclient.get(f"/profile/{logged_user.user_name[0]}", status=403)
def test_end_session_invalid_client_id(
testclient, slapd_connection, logged_user, client
):
def test_end_session_invalid_client_id(testclient, backend, logged_user, client):
testclient.get(f"/profile/{logged_user.user_name[0]}", status=200)
post_logout_redirect_url = "https://mydomain.tld/disconnected"
@ -156,7 +150,7 @@ def test_end_session_invalid_client_id(
testclient.get(f"/profile/{logged_user.user_name[0]}", status=403)
def test_client_hint_invalid(testclient, slapd_connection, logged_user, client):
def test_client_hint_invalid(testclient, backend, logged_user, client):
id_token = generate_id_token(
{},
generate_user_info(logged_user, client.scope),
@ -186,7 +180,7 @@ def test_client_hint_invalid(testclient, slapd_connection, logged_user, client):
testclient.get(f"/profile/{logged_user.user_name[0]}", status=403)
def test_no_jwt_logout(testclient, slapd_connection, logged_user, client):
def test_no_jwt_logout(testclient, backend, logged_user, client):
testclient.get(f"/profile/{logged_user.user_name[0]}", status=200)
post_logout_redirect_url = "https://mydomain.tld/disconnected"
@ -212,7 +206,7 @@ def test_no_jwt_logout(testclient, slapd_connection, logged_user, client):
testclient.get(f"/profile/{logged_user.user_name[0]}", status=403)
def test_no_jwt_no_logout(testclient, slapd_connection, logged_user, client):
def test_no_jwt_no_logout(testclient, backend, logged_user, client):
testclient.get(f"/profile/{logged_user.user_name[0]}", status=200)
post_logout_redirect_url = "https://mydomain.tld/disconnected"
@ -237,9 +231,7 @@ def test_no_jwt_no_logout(testclient, slapd_connection, logged_user, client):
testclient.get(f"/profile/{logged_user.user_name[0]}", status=200)
def test_jwt_not_issued_here(
testclient, slapd_connection, logged_user, client, id_token
):
def test_jwt_not_issued_here(testclient, backend, logged_user, client, id_token):
testclient.app.config["OIDC"]["JWT"]["ISS"] = "https://foo.bar"
testclient.get(f"/profile/{logged_user.user_name[0]}", status=200)
@ -263,7 +255,7 @@ def test_jwt_not_issued_here(
}
def test_client_hint_mismatch(testclient, slapd_connection, logged_user, client):
def test_client_hint_mismatch(testclient, backend, logged_user, client):
id_token = generate_id_token(
{},
generate_user_info(logged_user, client.scope),
@ -292,9 +284,7 @@ def test_client_hint_mismatch(testclient, slapd_connection, logged_user, client)
}
def test_bad_user_id_token_mismatch(
testclient, slapd_connection, logged_user, client, admin
):
def test_bad_user_id_token_mismatch(testclient, backend, logged_user, client, admin):
testclient.get(f"/profile/{logged_user.user_name[0]}", status=200)
id_token = generate_id_token(
@ -328,9 +318,7 @@ def test_bad_user_id_token_mismatch(
testclient.get(f"/profile/{logged_user.user_name[0]}", status=403)
def test_bad_user_hint(
testclient, slapd_connection, logged_user, client, id_token, admin
):
def test_bad_user_hint(testclient, backend, logged_user, client, id_token, admin):
testclient.get(f"/profile/{logged_user.user_name[0]}", status=200)
post_logout_redirect_url = "https://mydomain.tld/disconnected"
@ -357,7 +345,7 @@ def test_bad_user_hint(
testclient.get(f"/profile/{logged_user.user_name[0]}", status=403)
def test_no_jwt_bad_csrf(testclient, slapd_connection, logged_user, client):
def test_no_jwt_bad_csrf(testclient, backend, logged_user, client):
testclient.get(f"/profile/{logged_user.user_name[0]}", status=200)
post_logout_redirect_url = "https://mydomain.tld/disconnected"
@ -377,9 +365,7 @@ def test_no_jwt_bad_csrf(testclient, slapd_connection, logged_user, client):
res = form.submit(name="answer", value="logout", status=400)
def test_end_session_already_disconnected(
testclient, slapd_connection, user, client, id_token
):
def test_end_session_already_disconnected(testclient, backend, user, client, id_token):
post_logout_redirect_url = "https://mydomain.tld/disconnected"
res = testclient.get(
"/oauth/end_session",
@ -396,9 +382,7 @@ def test_end_session_already_disconnected(
assert res.location == "/"
def test_end_session_no_state(
testclient, slapd_connection, logged_user, client, id_token
):
def test_end_session_no_state(testclient, backend, logged_user, client, id_token):
testclient.get(f"/profile/{logged_user.user_name[0]}", status=200)
post_logout_redirect_url = "https://mydomain.tld/disconnected"

View file

@ -6,7 +6,7 @@ from canaille.oidc.models import AuthorizationCode
from canaille.oidc.models import Token
def test_oauth_hybrid(testclient, slapd_connection, user, client):
def test_oauth_hybrid(testclient, backend, user, client):
res = testclient.get(
"/oauth/authorize",
params=dict(
@ -47,9 +47,7 @@ def test_oauth_hybrid(testclient, slapd_connection, user, client):
assert res.json["name"] == "John (johnny) Doe"
def test_oidc_hybrid(
testclient, slapd_connection, logged_user, client, keypair, other_client
):
def test_oidc_hybrid(testclient, backend, logged_user, client, keypair, other_client):
res = testclient.get(
"/oauth/authorize",
params=dict(

View file

@ -264,9 +264,7 @@ STANDARD_CLAIMS = [
]
def test_generate_user_standard_claims_with_default_config(
testclient, slapd_connection, user
):
def test_generate_user_standard_claims_with_default_config(testclient, backend, user):
user.preferred_language = ["fr"]
data = generate_user_claims(user, STANDARD_CLAIMS, DEFAULT_JWT_MAPPING)
@ -284,9 +282,7 @@ def test_generate_user_standard_claims_with_default_config(
}
def test_custom_config_format_claim_is_well_formated(
testclient, slapd_connection, user
):
def test_custom_config_format_claim_is_well_formated(testclient, backend, user):
jwt_mapping_config = DEFAULT_JWT_MAPPING.copy()
jwt_mapping_config["EMAIL"] = "{{ user.user_name[0] }}@mydomain.tld"
@ -295,7 +291,7 @@ def test_custom_config_format_claim_is_well_formated(
assert data["email"] == "user@mydomain.tld"
def test_claim_is_omitted_if_empty(testclient, slapd_connection, user):
def test_claim_is_omitted_if_empty(testclient, backend, user):
# According to https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
# it's better to not insert a null or empty string value
user.email = ""