canaille-globuzma/canaille/backends/sql/backend.py
2025-01-10 12:32:18 +01:00

190 lines
5.9 KiB
Python

import datetime
from pathlib import Path
from flask import current_app
from flask_alembic import Alembic
from sqlalchemy import create_engine
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import declarative_base
from sqlalchemy_utils import Password
from canaille.backends import Backend
from canaille.backends import ModelEncoder
from canaille.backends import get_lockout_delay_message
Base = declarative_base()
class SQLModelEncoder(ModelEncoder):
def default(self, obj):
if isinstance(obj, Password):
return obj.hash.decode()
return super().default(obj)
class SQLBackend(Backend):
engine = None
db_session = None
json_encoder = SQLModelEncoder
alembic = None
def __init__(self, config):
super().__init__(config)
SQLBackend.engine = create_engine(
self.config["CANAILLE_SQL"]["DATABASE_URI"], echo=False, future=True
)
SQLBackend.alembic = Alembic(metadatas=Base.metadata, engines=SQLBackend.engine)
@classmethod
def install(cls, app): # pragma: no cover
cls.init_alembic(app)
SQLBackend.alembic.upgrade()
@classmethod
def init_alembic(cls, app):
app.config["ALEMBIC"] = {
"script_location": str(Path(__file__).resolve().parent / "migrations"),
}
SQLBackend.alembic.init_app(app)
def init_app(self, app, init_backend=None):
super().init_app(app)
self.init_alembic(app)
init_backend = (
app.config["CANAILLE_SQL"]["AUTO_MIGRATE"]
if init_backend is None
else init_backend
)
if init_backend: # pragma: no cover
with app.app_context():
self.alembic.upgrade()
def setup(self):
if not self.db_session:
self.db_session = Session(SQLBackend.engine)
def teardown(self):
pass
@classmethod
def validate(cls, config):
pass
@classmethod
def login_placeholder(cls):
return ""
def has_account_lockability(self):
return True
def get_user_from_login(self, login):
from .models import User
return self.get(User, user_name=login)
def check_user_password(self, user, password):
if current_app.features.has_intruder_lockout:
if current_lockout_delay := user.get_intruder_lockout_delay():
self.save(user)
return (False, get_lockout_delay_message(current_lockout_delay))
if password != user.password:
if current_app.features.has_intruder_lockout:
self.record_failed_attempt(user)
return (False, None)
if user.locked:
return (False, "Your account has been locked.")
return (True, None)
def set_user_password(self, user, password):
user.password = password
user.password_last_update = datetime.datetime.now(
datetime.timezone.utc
).replace(microsecond=0)
self.save(user)
def query(self, model, **kwargs):
filter = [
model.attribute_filter(attribute_name, expected_value)
for attribute_name, expected_value in kwargs.items()
]
return (
SQLBackend.instance.db_session.execute(select(model).filter(*filter))
.scalars()
.all()
)
def fuzzy(self, model, query, attributes=None, **kwargs):
attributes = attributes or model.attributes
filter = or_(
getattr(model, attribute_name).ilike(f"%{query}%")
for attribute_name in attributes
if "str" in str(model.attributes[attribute_name])
# erk, photo is an URL string according to SCIM, but bytes here
and attribute_name != "photo"
)
return self.db_session.execute(select(model).filter(filter)).scalars().all()
def get(self, model, identifier=None, /, **kwargs):
if identifier:
return (
self.get(model, **{model.identifier_attribute: identifier})
or self.get(model, id=identifier)
or None
)
filter = [
model.attribute_filter(attribute_name, expected_value)
for attribute_name, expected_value in kwargs.items()
]
return SQLBackend.instance.db_session.execute(
select(model).filter(*filter)
).scalar_one_or_none()
def save(self, instance):
# run the instance save callback if existing
if hasattr(instance, "save"):
instance.save()
instance.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0
)
if not instance.created:
instance.created = instance.last_modified
SQLBackend.instance.db_session.add(instance)
SQLBackend.instance.db_session.commit()
def delete(self, instance):
# run the instance delete callback if existing
save_callback = instance.delete() if hasattr(instance, "delete") else iter([])
next(save_callback, None)
SQLBackend.instance.db_session.delete(instance)
SQLBackend.instance.db_session.commit()
# run the instance delete callback again if existing
next(save_callback, None)
def reload(self, instance):
# run the instance reload callback if existing
reload_callback = instance.reload() if hasattr(instance, "reload") else iter([])
next(reload_callback, None)
SQLBackend.instance.db_session.refresh(instance)
# run the instance reload callback again if existing
next(reload_callback, None)
def record_failed_attempt(self, user):
if user.password_failure_timestamps is None:
user.password_failure_timestamps = []
user._password_failure_timestamps.append(
str(datetime.datetime.now(datetime.timezone.utc))
)
self.save(user)