refactor: split the base model class in two

This commit is contained in:
Éloi Rivard 2024-04-01 18:25:38 +02:00
parent c1b901261f
commit 18e3f8cde5
No known key found for this signature in database
GPG key ID: 7EDA204EA57DD184
5 changed files with 18 additions and 37 deletions

View file

@ -5,7 +5,7 @@ from collections.abc import Iterable
import ldap.dn import ldap.dn
import ldap.filter import ldap.filter
from canaille.backends.models import Model from canaille.backends.models import BackendModel
from .backend import Backend from .backend import Backend
from .utils import cardinalize_attribute from .utils import cardinalize_attribute
@ -104,7 +104,7 @@ class LDAPObjectQuery:
return klass return klass
class LDAPObject(Model, metaclass=LDAPObjectMetaclass): class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
_object_class_by_name = None _object_class_by_name = None
_attribute_type_by_name = None _attribute_type_by_name = None
_may = None _may = None
@ -166,8 +166,8 @@ class LDAPObject(Model, metaclass=LDAPObjectMetaclass):
def __hash__(self): def __hash__(self):
return hash(self.id) return hash(self.id)
def __getattr__(self, name): def __getattribute__(self, name):
if name not in self.attributes: if name == "attributes" or name not in self.attributes:
return super().__getattribute__(name) return super().__getattribute__(name)
ldap_name = self.python_attribute_to_ldap(name) ldap_name = self.python_attribute_to_ldap(name)

View file

@ -8,17 +8,17 @@ from flask import current_app
import canaille.core.models import canaille.core.models
import canaille.oidc.models import canaille.oidc.models
from canaille.app import models from canaille.app import models
from canaille.backends.models import Model from canaille.backends.models import BackendModel
class MemoryModel(Model): class MemoryModel(BackendModel):
indexes = {} indexes = {}
"""Associates ids and states.""" """Associates ids and states."""
attribute_indexes = {} attribute_indexes = {}
"""Associates attribute values and ids.""" """Associates attribute values and ids."""
def __init__(self, **kwargs): def __init__(self, *args, **kwargs):
kwargs.setdefault("id", str(uuid.uuid4())) kwargs.setdefault("id", str(uuid.uuid4()))
self._state = {} self._state = {}
self._cache = {} self._cache = {}

View file

@ -8,8 +8,7 @@ from canaille.app import classproperty
class Model: class Model:
"""The model abstract class. """The model abstract class.
It details all the methods and attributes that are expected to be It details all the common attributes shared by every models.
implemented for every model and for every backend.
""" """
created: Optional[datetime.datetime] created: Optional[datetime.datetime]
@ -24,6 +23,14 @@ class Model:
the value MUST be the same as the value of :attr:`~canaille.backends.models.Model.created`. the value MUST be the same as the value of :attr:`~canaille.backends.models.Model.created`.
""" """
class BackendModel:
"""The backend model abstract class.
It details all the methods and attributes that are expected to be
implemented for every model and for every backend.
"""
@classproperty @classproperty
def attributes(cls): def attributes(cls):
return ChainMap( return ChainMap(

View file

@ -24,7 +24,7 @@ from sqlalchemy_utils import force_auto_coercion
import canaille.core.models import canaille.core.models
import canaille.oidc.models import canaille.oidc.models
from canaille.app import models from canaille.app import models
from canaille.backends.models import Model from canaille.backends.models import BackendModel
from .backend import Backend from .backend import Backend
from .backend import Base from .backend import Base
@ -33,7 +33,7 @@ from .utils import TZDateTime
force_auto_coercion() force_auto_coercion()
class SqlAlchemyModel(Model): class SqlAlchemyModel(BackendModel):
def __html__(self): def __html__(self):
return self.id return self.id

View file

@ -4,32 +4,6 @@ import freezegun
import pytest import pytest
from canaille.app import models from canaille.app import models
from canaille.backends.models import Model
def test_required_methods(testclient):
with pytest.raises(NotImplementedError):
Model.query()
with pytest.raises(NotImplementedError):
Model.fuzzy("foobar")
with pytest.raises(NotImplementedError):
Model.get()
obj = Model()
with pytest.raises(NotImplementedError):
obj.identifier
with pytest.raises(NotImplementedError):
obj.save()
with pytest.raises(NotImplementedError):
obj.delete()
with pytest.raises(NotImplementedError):
obj.reload()
def test_model_comparison(testclient, backend): def test_model_comparison(testclient, backend):