forked from Github-Mirrors/canaille
Dynamic model registration
This commit is contained in:
parent
cfe5d0c7b8
commit
54abdaea3b
4 changed files with 97 additions and 36 deletions
|
@ -3,7 +3,6 @@ import os
|
|||
from logging.config import dictConfig
|
||||
|
||||
from flask import Flask
|
||||
from flask import g
|
||||
from flask import request
|
||||
from flask import session
|
||||
from flask_themer import FileSystemThemeLoader
|
||||
|
@ -15,21 +14,6 @@ from flask_wtf.csrf import CSRFProtect
|
|||
csrf = CSRFProtect()
|
||||
|
||||
|
||||
def setup_backend(app, backend):
|
||||
from .backends.ldap.backend import Backend
|
||||
|
||||
if not backend:
|
||||
backend = Backend(app.config)
|
||||
backend.init_app(app)
|
||||
|
||||
with app.app_context():
|
||||
g.backend = backend
|
||||
app.backend = backend
|
||||
|
||||
if app.debug: # pragma: no cover
|
||||
backend.install(app.config)
|
||||
|
||||
|
||||
def setup_sentry(app): # pragma: no cover
|
||||
if not app.config.get("SENTRY_DSN"):
|
||||
return None
|
||||
|
@ -159,20 +143,15 @@ def setup_flask_converters(app):
|
|||
from canaille.app.flask import model_converter
|
||||
from canaille.app import models
|
||||
|
||||
app.url_map.converters["user"] = model_converter(models.User)
|
||||
app.url_map.converters["group"] = model_converter(models.Group)
|
||||
app.url_map.converters["client"] = model_converter(models.Client)
|
||||
app.url_map.converters["token"] = model_converter(models.Token)
|
||||
app.url_map.converters["authorizationcode"] = model_converter(
|
||||
models.AuthorizationCode
|
||||
)
|
||||
app.url_map.converters["consent"] = model_converter(models.Consent)
|
||||
for model_name, model_class in models.MODELS.items():
|
||||
app.url_map.converters[model_name.lower()] = model_converter(model_class)
|
||||
|
||||
|
||||
def create_app(config=None, validate=True, backend=None):
|
||||
from .oidc.oauth import setup_oauth
|
||||
from .app.i18n import setup_i18n
|
||||
from .app.configuration import setup_config
|
||||
from .backends import setup_backend
|
||||
|
||||
app = Flask(__name__)
|
||||
with app.app_context():
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
from canaille.backends.ldap.models import AuthorizationCode # noqa: F401
|
||||
from canaille.backends.ldap.models import Client # noqa: F401
|
||||
from canaille.backends.ldap.models import Consent # noqa: F401
|
||||
from canaille.backends.ldap.models import Group # noqa: F401
|
||||
from canaille.backends.ldap.models import Token # noqa: F401
|
||||
from canaille.backends.ldap.models import User # noqa: F401
|
||||
MODELS = {}
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
if name in MODELS:
|
||||
return MODELS[name]
|
||||
|
||||
|
||||
def register(model):
|
||||
MODELS[model.__name__] = model
|
||||
|
|
|
@ -1,11 +1,16 @@
|
|||
import importlib
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
from flask import g
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
instance = None
|
||||
|
||||
def __init__(self, config):
|
||||
BaseBackend.instance = self
|
||||
self.register_models()
|
||||
|
||||
@classmethod
|
||||
def get(cls):
|
||||
|
@ -61,3 +66,47 @@ class BaseBackend:
|
|||
Indicates wether the backend supports locking user accounts.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def register_models(self):
|
||||
from canaille.app import models
|
||||
|
||||
module = ".".join(self.__class__.__module__.split(".")[:-1] + ["models"])
|
||||
try:
|
||||
backend_models = importlib.import_module(module)
|
||||
except ModuleNotFoundError:
|
||||
return
|
||||
|
||||
model_names = [
|
||||
"AuthorizationCode",
|
||||
"Client",
|
||||
"Consent",
|
||||
"Group",
|
||||
"Token",
|
||||
"User",
|
||||
]
|
||||
for model_name in model_names:
|
||||
models.register(getattr(backend_models, model_name))
|
||||
|
||||
|
||||
def setup_backend(app, backend):
|
||||
if not backend:
|
||||
backend_name = list(app.config.get("BACKENDS").keys())[0].lower()
|
||||
module = importlib.import_module(f"canaille.backends.{backend_name}.backend")
|
||||
backend_class = getattr(module, "Backend")
|
||||
backend = backend_class(app.config)
|
||||
backend.init_app(app)
|
||||
|
||||
with app.app_context():
|
||||
g.backend = backend
|
||||
app.backend = backend
|
||||
|
||||
if app.debug: # pragma: no cover
|
||||
backend.install(app.config, True)
|
||||
|
||||
|
||||
def available_backends():
|
||||
return {
|
||||
elt.name
|
||||
for elt in os.scandir(os.path.dirname(__file__))
|
||||
if elt.is_dir() and os.path.exists(os.path.join(elt, "backend.py"))
|
||||
}
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from canaille import create_app
|
||||
from canaille.app import models
|
||||
from canaille.backends import available_backends
|
||||
from flask_webtest import TestApp
|
||||
from jinja2 import StrictUndefined
|
||||
from pytest_lazyfixture import lazy_fixture
|
||||
|
@ -8,10 +11,41 @@ from werkzeug.security import gen_salt
|
|||
|
||||
|
||||
pytest_plugins = [
|
||||
"tests.backends.ldap.fixtures",
|
||||
f"tests.backends.{backend}.fixtures"
|
||||
for backend in available_backends()
|
||||
if os.path.exists(os.path.join("tests", "backends", backend, "fixtures.py"))
|
||||
]
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--backend", action="append", default=[], help="the backends to test"
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
backends = available_backends()
|
||||
if metafunc.config.getoption("backend"): # pragma: no cover
|
||||
backends &= set(metafunc.config.getoption("backend"))
|
||||
|
||||
# tests in tests.backends.BACKENDNAME should only run one backend
|
||||
if metafunc.module.__name__.startswith("tests.backends"):
|
||||
module_name_parts = metafunc.module.__name__.split(".")
|
||||
in_backend_module = len(module_name_parts) > 3
|
||||
if in_backend_module:
|
||||
backend = module_name_parts[2]
|
||||
if backend not in backends: # pragma: no cover
|
||||
pytest.skip()
|
||||
elif "backend" in metafunc.fixturenames:
|
||||
metafunc.parametrize("backend", [lazy_fixture(f"{backend}_backend")])
|
||||
return
|
||||
|
||||
if "backend" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"backend", [lazy_fixture(f"{backend}_backend") for backend in backends]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def configuration(smtpd):
|
||||
smtpd.config.use_starttls = True
|
||||
|
@ -77,11 +111,6 @@ def configuration(smtpd):
|
|||
return conf
|
||||
|
||||
|
||||
@pytest.fixture(params=[lazy_fixture("ldap_backend")])
|
||||
def backend(request):
|
||||
yield request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(configuration, backend):
|
||||
return create_app(configuration, backend=backend)
|
||||
|
|
Loading…
Reference in a new issue