canaille-globuzma/canaille/backends/sql/models.py
Éloi Rivard 27639081f0
feat: implement sqlalchemy backend
Co-authored-by: Loan Robert <loan@yaal.coop>
2023-11-24 13:57:46 +01:00

378 lines
14 KiB
Python

import datetime
import uuid
from typing import List
import canaille.core.models
import canaille.oidc.models
from canaille.app import models
from canaille.backends.models import Model
from flask import current_app
from sqlalchemy import Boolean
from sqlalchemy import Column
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import LargeBinary
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import reconstructor
from sqlalchemy.orm import relationship
from sqlalchemy_json import MutableJson
from .backend import Backend
from .backend import Base
from .utils import TZDateTime
class SqlAlchemyModel(Model):
def __html__(self):
return self.id
def __repr__(self):
return (
f"<{self.__class__.__name__} {self.identifier_attribute}={self.identifier}>"
)
@classmethod
def query(cls, **kwargs):
filter = [
cls.attribute_filter(attribute_name, expected_value)
for attribute_name, expected_value in kwargs.items()
]
return (
Backend.get()
.db_session.execute(select(cls).filter(*filter))
.scalars()
.all()
)
@classmethod
def fuzzy(cls, query, attributes=None, **kwargs):
attributes = attributes or cls.attributes
filter = or_(
getattr(cls, attribute_name).ilike(f"%{query}%")
for attribute_name in attributes
if "str" in str(cls.__annotations__[attribute_name])
)
return (
Backend.get().db_session.execute(select(cls).filter(filter)).scalars().all()
)
@classmethod
def attribute_filter(cls, name, value):
if isinstance(value, list):
return or_(cls.attribute_filter(name, v) for v in value)
multiple = "List" in str(cls.__annotations__[name])
if multiple:
return getattr(cls, name).contains(value)
return getattr(cls, name) == value
@classmethod
def get(cls, identifier=None, **kwargs):
if identifier:
kwargs[cls.identifier_attribute] = identifier
filter = [
cls.attribute_filter(attribute_name, expected_value)
for attribute_name, expected_value in kwargs.items()
]
return (
Backend.get()
.db_session.execute(select(cls).filter(*filter))
.scalar_one_or_none()
)
@property
def identifier(self):
return getattr(self, self.identifier_attribute)
def save(self):
Backend.get().db_session.add(self)
Backend.get().db_session.commit()
def delete(self):
Backend.get().db_session.delete(self)
Backend.get().db_session.commit()
def reload(self):
Backend.get().db_session.refresh(self)
membership_association_table = Table(
"membership_association_table",
Base.metadata,
Column("user_id", ForeignKey("user.id"), primary_key=True),
Column("group_id", ForeignKey("group.id"), primary_key=True),
)
class User(canaille.core.models.User, Base, SqlAlchemyModel):
__tablename__ = "user"
identifier_attribute = "user_name"
id: Mapped[str] = mapped_column(
String, primary_key=True, default=lambda: str(uuid.uuid4())
)
user_name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
password: Mapped[str] = mapped_column(String, nullable=True)
preferred_language: Mapped[str] = mapped_column(String, nullable=True)
family_name: Mapped[str] = mapped_column(String, nullable=True)
given_name: Mapped[str] = mapped_column(String, nullable=True)
formatted_name: Mapped[str] = mapped_column(String, nullable=True)
display_name: Mapped[str] = mapped_column(String, nullable=True)
emails: Mapped[List[str]] = mapped_column(MutableJson, nullable=True)
phone_numbers: Mapped[List[str]] = mapped_column(MutableJson, nullable=True)
formatted_address: Mapped[str] = mapped_column(String, nullable=True)
street: Mapped[str] = mapped_column(String, nullable=True)
postal_code: Mapped[str] = mapped_column(String, nullable=True)
locality: Mapped[str] = mapped_column(String, nullable=True)
region: Mapped[str] = mapped_column(String, nullable=True)
photo: Mapped[bytes] = mapped_column(LargeBinary, nullable=True)
profile_url: Mapped[str] = mapped_column(String, nullable=True)
employee_number: Mapped[str] = mapped_column(String, nullable=True)
department: Mapped[str] = mapped_column(String, nullable=True)
title: Mapped[str] = mapped_column(String, nullable=True)
organization: Mapped[str] = mapped_column(String, nullable=True)
last_modified: Mapped[datetime.datetime] = mapped_column(
TZDateTime(timezone=True), nullable=True
)
groups: Mapped[List["Group"]] = relationship(
secondary=membership_association_table, back_populates="members"
)
lock_date: Mapped[datetime.datetime] = mapped_column(
TZDateTime(timezone=True), nullable=True
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.load_permissions()
def reload(self):
super().reload()
self.load_permissions()
@reconstructor
def load_permissions(self):
self.permissions = set()
self.read = set()
self.write = set()
for access_group_name, details in current_app.config["ACL"].items():
if self.match_filter(details.get("FILTER")):
self.permissions |= set(details.get("PERMISSIONS", []))
self.read |= set(details.get("READ", []))
self.write |= set(details.get("WRITE", []))
def normalize_filter_value(self, attribute, value):
# not super generic, but we can improve this when we have
# type checking and/or pydantic for the models
if attribute == "groups":
if models.Group.get(id=value):
return models.Group.get(id=value)
elif models.Group.get(display_name=value):
return models.Group.get(display_name=value)
return value
def match_filter(self, filter):
if filter is None:
return True
if isinstance(filter, dict):
return all(
self.normalize_filter_value(attribute, value)
in getattr(self, attribute, [])
if "List" in str(self.__annotations__[attribute])
else self.normalize_filter_value(attribute, value)
== getattr(self, attribute, None)
for attribute, value in filter.items()
)
return any(self.match_filter(subfilter) for subfilter in filter)
@classmethod
def get_from_login(cls, login=None, **kwargs):
return User.get(user_name=login)
def has_password(self):
return bool(self.password)
def check_password(self, password):
if password not in self.password:
return (False, None)
if self.locked:
return (False, "Your account has been locked.")
return (True, None)
def set_password(self, password):
self.password = password
self.save()
def save(self):
self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
microsecond=0
)
super().save()
class Group(canaille.core.models.Group, Base, SqlAlchemyModel):
__tablename__ = "group"
identifier_attribute = "display_name"
id: Mapped[str] = mapped_column(
String, primary_key=True, default=lambda: str(uuid.uuid4())
)
display_name: Mapped[str] = mapped_column(String)
description: Mapped[str] = mapped_column(String, nullable=True)
members: Mapped[List["User"]] = relationship(
secondary=membership_association_table, back_populates="groups"
)
client_audience_association_table = Table(
"client_audience_association_table",
Base.metadata,
Column("audience_id", ForeignKey("client.id"), primary_key=True, nullable=True),
Column("client_id", ForeignKey("client.id"), primary_key=True, nullable=True),
)
class Client(canaille.oidc.models.Client, Base, SqlAlchemyModel):
__tablename__ = "client"
id: Mapped[str] = mapped_column(
String, primary_key=True, default=lambda: str(uuid.uuid4())
)
identifier_attribute = "client_id"
description: Mapped[str] = mapped_column(String, nullable=True)
preconsent: Mapped[bool] = mapped_column(Boolean, nullable=True)
post_logout_redirect_uris: Mapped[List[str]] = mapped_column(
MutableJson, nullable=True
)
audience: Mapped[List["Client"]] = relationship(
"Client",
secondary=client_audience_association_table,
primaryjoin=id == client_audience_association_table.c.client_id,
secondaryjoin=id == client_audience_association_table.c.audience_id,
)
client_id: Mapped[str] = mapped_column(String, nullable=True)
client_secret: Mapped[str] = mapped_column(String, nullable=True)
client_id_issued_at: Mapped[datetime.datetime] = mapped_column(
TZDateTime(timezone=True), nullable=True
)
client_secret_expires_at: Mapped[datetime.datetime] = mapped_column(
TZDateTime(timezone=True), nullable=True
)
client_name: Mapped[str] = mapped_column(String, nullable=True)
contacts: Mapped[List[str]] = mapped_column(MutableJson, nullable=True)
client_uri: Mapped[str] = mapped_column(String, nullable=True)
redirect_uris: Mapped[List[str]] = mapped_column(MutableJson, nullable=True)
logo_uri: Mapped[str] = mapped_column(String, nullable=True)
grant_types: Mapped[List[str]] = mapped_column(MutableJson, nullable=True)
response_types: Mapped[List[str]] = mapped_column(MutableJson, nullable=True)
scope: Mapped[List[str]] = mapped_column(MutableJson, nullable=True)
tos_uri: Mapped[str] = mapped_column(String, nullable=True)
policy_uri: Mapped[str] = mapped_column(String, nullable=True)
jwks_uri: Mapped[str] = mapped_column(String, nullable=True)
jwk: Mapped[str] = mapped_column(String, nullable=True)
token_endpoint_auth_method: Mapped[str] = mapped_column(String, nullable=True)
software_id: Mapped[str] = mapped_column(String, nullable=True)
software_version: Mapped[str] = mapped_column(String, nullable=True)
class AuthorizationCode(canaille.oidc.models.AuthorizationCode, Base, SqlAlchemyModel):
__tablename__ = "authorization_code"
identifier_attribute = "authorization_code_id"
id: Mapped[str] = mapped_column(
String, primary_key=True, default=lambda: str(uuid.uuid4())
)
authorization_code_id: Mapped[str] = mapped_column(String, nullable=True)
code: Mapped[str] = mapped_column(String, nullable=True)
client_id: Mapped[str] = mapped_column(ForeignKey("client.id"))
client: Mapped["Client"] = relationship()
subject_id: Mapped[str] = mapped_column(ForeignKey("user.id"))
subject: Mapped["User"] = relationship()
redirect_uri: Mapped[str] = mapped_column(String, nullable=True)
response_type: Mapped[str] = mapped_column(String, nullable=True)
scope: Mapped[List[str]] = mapped_column(MutableJson, nullable=True)
nonce: Mapped[str] = mapped_column(String, nullable=True)
issue_date: Mapped[datetime.datetime] = mapped_column(
TZDateTime(timezone=True), nullable=True
)
lifetime: Mapped[int] = mapped_column(Integer, nullable=True)
challenge: Mapped[str] = mapped_column(String, nullable=True)
challenge_method: Mapped[str] = mapped_column(String, nullable=True)
revokation_date: Mapped[datetime.datetime] = mapped_column(
TZDateTime(timezone=True), nullable=True
)
token_audience_association_table = Table(
"token_audience_association_table",
Base.metadata,
Column("token_id", ForeignKey("token.id"), primary_key=True, nullable=True),
Column("client_id", ForeignKey("client.id"), primary_key=True, nullable=True),
)
class Token(canaille.oidc.models.Token, Base, SqlAlchemyModel):
__tablename__ = "token"
identifier_attribute = "token_id"
id: Mapped[str] = mapped_column(
String, primary_key=True, default=lambda: str(uuid.uuid4())
)
token_id: Mapped[str] = mapped_column(String, nullable=True)
access_token: Mapped[str] = mapped_column(String, nullable=True)
client_id: Mapped[str] = mapped_column(ForeignKey("client.id"))
client: Mapped["Client"] = relationship()
subject_id: Mapped[str] = mapped_column(ForeignKey("user.id"))
subject: Mapped["User"] = relationship()
type: Mapped[str] = mapped_column(String, nullable=True)
refresh_token: Mapped[str] = mapped_column(String, nullable=True)
scope: Mapped[List[str]] = mapped_column(MutableJson, nullable=True)
issue_date: Mapped[datetime.datetime] = mapped_column(
TZDateTime(timezone=True), nullable=True
)
lifetime: Mapped[int] = mapped_column(Integer, nullable=True)
revokation_date: Mapped[datetime.datetime] = mapped_column(
TZDateTime(timezone=True), nullable=True
)
audience: Mapped[List["Client"]] = relationship(
"Client",
secondary=token_audience_association_table,
primaryjoin=id == token_audience_association_table.c.token_id,
secondaryjoin=Client.id == token_audience_association_table.c.client_id,
)
class Consent(canaille.oidc.models.Consent, Base, SqlAlchemyModel):
__tablename__ = "consent"
identifier_attribute = "consent_id"
id: Mapped[str] = mapped_column(
String, primary_key=True, default=lambda: str(uuid.uuid4())
)
consent_id: Mapped[str] = mapped_column(String, nullable=True)
subject_id: Mapped[str] = mapped_column(ForeignKey("user.id"))
subject: Mapped["User"] = relationship()
client_id: Mapped[str] = mapped_column(ForeignKey("client.id"))
client: Mapped["Client"] = relationship()
scope: Mapped[List[str]] = mapped_column(MutableJson, nullable=True)
issue_date: Mapped[datetime.datetime] = mapped_column(
TZDateTime(timezone=True), nullable=True
)
revokation_date: Mapped[datetime.datetime] = mapped_column(
TZDateTime(timezone=True), nullable=True
)