forked from Github-Mirrors/canaille
Dynamic model registration
This commit is contained in:
parent
cfe5d0c7b8
commit
54abdaea3b
4 changed files with 97 additions and 36 deletions
|
@ -3,7 +3,6 @@ import os
|
||||||
from logging.config import dictConfig
|
from logging.config import dictConfig
|
||||||
|
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
from flask import g
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask import session
|
from flask import session
|
||||||
from flask_themer import FileSystemThemeLoader
|
from flask_themer import FileSystemThemeLoader
|
||||||
|
@ -15,21 +14,6 @@ from flask_wtf.csrf import CSRFProtect
|
||||||
csrf = CSRFProtect()
|
csrf = CSRFProtect()
|
||||||
|
|
||||||
|
|
||||||
def setup_backend(app, backend):
|
|
||||||
from .backends.ldap.backend import Backend
|
|
||||||
|
|
||||||
if not backend:
|
|
||||||
backend = Backend(app.config)
|
|
||||||
backend.init_app(app)
|
|
||||||
|
|
||||||
with app.app_context():
|
|
||||||
g.backend = backend
|
|
||||||
app.backend = backend
|
|
||||||
|
|
||||||
if app.debug: # pragma: no cover
|
|
||||||
backend.install(app.config)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_sentry(app): # pragma: no cover
|
def setup_sentry(app): # pragma: no cover
|
||||||
if not app.config.get("SENTRY_DSN"):
|
if not app.config.get("SENTRY_DSN"):
|
||||||
return None
|
return None
|
||||||
|
@ -159,20 +143,15 @@ def setup_flask_converters(app):
|
||||||
from canaille.app.flask import model_converter
|
from canaille.app.flask import model_converter
|
||||||
from canaille.app import models
|
from canaille.app import models
|
||||||
|
|
||||||
app.url_map.converters["user"] = model_converter(models.User)
|
for model_name, model_class in models.MODELS.items():
|
||||||
app.url_map.converters["group"] = model_converter(models.Group)
|
app.url_map.converters[model_name.lower()] = model_converter(model_class)
|
||||||
app.url_map.converters["client"] = model_converter(models.Client)
|
|
||||||
app.url_map.converters["token"] = model_converter(models.Token)
|
|
||||||
app.url_map.converters["authorizationcode"] = model_converter(
|
|
||||||
models.AuthorizationCode
|
|
||||||
)
|
|
||||||
app.url_map.converters["consent"] = model_converter(models.Consent)
|
|
||||||
|
|
||||||
|
|
||||||
def create_app(config=None, validate=True, backend=None):
|
def create_app(config=None, validate=True, backend=None):
|
||||||
from .oidc.oauth import setup_oauth
|
from .oidc.oauth import setup_oauth
|
||||||
from .app.i18n import setup_i18n
|
from .app.i18n import setup_i18n
|
||||||
from .app.configuration import setup_config
|
from .app.configuration import setup_config
|
||||||
|
from .backends import setup_backend
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
from canaille.backends.ldap.models import AuthorizationCode # noqa: F401
|
MODELS = {}
|
||||||
from canaille.backends.ldap.models import Client # noqa: F401
|
|
||||||
from canaille.backends.ldap.models import Consent # noqa: F401
|
|
||||||
from canaille.backends.ldap.models import Group # noqa: F401
|
def __getattr__(name):
|
||||||
from canaille.backends.ldap.models import Token # noqa: F401
|
if name in MODELS:
|
||||||
from canaille.backends.ldap.models import User # noqa: F401
|
return MODELS[name]
|
||||||
|
|
||||||
|
|
||||||
|
def register(model):
|
||||||
|
MODELS[model.__name__] = model
|
||||||
|
|
|
@ -1,11 +1,16 @@
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from flask import g
|
||||||
|
|
||||||
|
|
||||||
class BaseBackend:
|
class BaseBackend:
|
||||||
instance = None
|
instance = None
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
BaseBackend.instance = self
|
BaseBackend.instance = self
|
||||||
|
self.register_models()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls):
|
def get(cls):
|
||||||
|
@ -61,3 +66,47 @@ class BaseBackend:
|
||||||
Indicates wether the backend supports locking user accounts.
|
Indicates wether the backend supports locking user accounts.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def register_models(self):
|
||||||
|
from canaille.app import models
|
||||||
|
|
||||||
|
module = ".".join(self.__class__.__module__.split(".")[:-1] + ["models"])
|
||||||
|
try:
|
||||||
|
backend_models = importlib.import_module(module)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
return
|
||||||
|
|
||||||
|
model_names = [
|
||||||
|
"AuthorizationCode",
|
||||||
|
"Client",
|
||||||
|
"Consent",
|
||||||
|
"Group",
|
||||||
|
"Token",
|
||||||
|
"User",
|
||||||
|
]
|
||||||
|
for model_name in model_names:
|
||||||
|
models.register(getattr(backend_models, model_name))
|
||||||
|
|
||||||
|
|
||||||
|
def setup_backend(app, backend):
|
||||||
|
if not backend:
|
||||||
|
backend_name = list(app.config.get("BACKENDS").keys())[0].lower()
|
||||||
|
module = importlib.import_module(f"canaille.backends.{backend_name}.backend")
|
||||||
|
backend_class = getattr(module, "Backend")
|
||||||
|
backend = backend_class(app.config)
|
||||||
|
backend.init_app(app)
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
g.backend = backend
|
||||||
|
app.backend = backend
|
||||||
|
|
||||||
|
if app.debug: # pragma: no cover
|
||||||
|
backend.install(app.config, True)
|
||||||
|
|
||||||
|
|
||||||
|
def available_backends():
|
||||||
|
return {
|
||||||
|
elt.name
|
||||||
|
for elt in os.scandir(os.path.dirname(__file__))
|
||||||
|
if elt.is_dir() and os.path.exists(os.path.join(elt, "backend.py"))
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from canaille import create_app
|
from canaille import create_app
|
||||||
from canaille.app import models
|
from canaille.app import models
|
||||||
|
from canaille.backends import available_backends
|
||||||
from flask_webtest import TestApp
|
from flask_webtest import TestApp
|
||||||
from jinja2 import StrictUndefined
|
from jinja2 import StrictUndefined
|
||||||
from pytest_lazyfixture import lazy_fixture
|
from pytest_lazyfixture import lazy_fixture
|
||||||
|
@ -8,10 +11,41 @@ from werkzeug.security import gen_salt
|
||||||
|
|
||||||
|
|
||||||
pytest_plugins = [
|
pytest_plugins = [
|
||||||
"tests.backends.ldap.fixtures",
|
f"tests.backends.{backend}.fixtures"
|
||||||
|
for backend in available_backends()
|
||||||
|
if os.path.exists(os.path.join("tests", "backends", backend, "fixtures.py"))
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption(
|
||||||
|
"--backend", action="append", default=[], help="the backends to test"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
backends = available_backends()
|
||||||
|
if metafunc.config.getoption("backend"): # pragma: no cover
|
||||||
|
backends &= set(metafunc.config.getoption("backend"))
|
||||||
|
|
||||||
|
# tests in tests.backends.BACKENDNAME should only run one backend
|
||||||
|
if metafunc.module.__name__.startswith("tests.backends"):
|
||||||
|
module_name_parts = metafunc.module.__name__.split(".")
|
||||||
|
in_backend_module = len(module_name_parts) > 3
|
||||||
|
if in_backend_module:
|
||||||
|
backend = module_name_parts[2]
|
||||||
|
if backend not in backends: # pragma: no cover
|
||||||
|
pytest.skip()
|
||||||
|
elif "backend" in metafunc.fixturenames:
|
||||||
|
metafunc.parametrize("backend", [lazy_fixture(f"{backend}_backend")])
|
||||||
|
return
|
||||||
|
|
||||||
|
if "backend" in metafunc.fixturenames:
|
||||||
|
metafunc.parametrize(
|
||||||
|
"backend", [lazy_fixture(f"{backend}_backend") for backend in backends]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def configuration(smtpd):
|
def configuration(smtpd):
|
||||||
smtpd.config.use_starttls = True
|
smtpd.config.use_starttls = True
|
||||||
|
@ -77,11 +111,6 @@ def configuration(smtpd):
|
||||||
return conf
|
return conf
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=[lazy_fixture("ldap_backend")])
|
|
||||||
def backend(request):
|
|
||||||
yield request.param
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def app(configuration, backend):
|
def app(configuration, backend):
|
||||||
return create_app(configuration, backend=backend)
|
return create_app(configuration, backend=backend)
|
||||||
|
|
Loading…
Reference in a new issue