Merge branch '214-dump-passwords' into 'main'

CLI commands can dump hashed passwords

Closes #214

See merge request yaal/canaille!212
This commit is contained in:
Éloi Rivard 2025-01-09 08:41:20 +00:00
commit a5377f4544
7 changed files with 95 additions and 55 deletions

View file

@ -10,6 +10,7 @@ Added
Changed Changed
^^^^^^^ ^^^^^^^
- fixed a bug on updating user's settings :issue:`206` - fixed a bug on updating user's settings :issue:`206`
- CLI commands dump hashed passwords :issue:`214`
Changed Changed
^^^^^^^ ^^^^^^^

View file

@ -1,5 +1,8 @@
import datetime
import importlib import importlib
import json
import os import os
import typing
from contextlib import contextmanager from contextlib import contextmanager
from math import ceil from math import ceil
@ -8,8 +11,48 @@ from flask import g
from canaille.app import classproperty from canaille.app import classproperty
class ModelEncoder(json.JSONEncoder):
"""JSON serializer that can handle Canaille models."""
@staticmethod
def serialize_model(instance):
def serialize_attribute(attribute_name, value):
"""Replace model instances by their id."""
multiple = typing.get_origin(instance.attributes[attribute_name]) is list
if multiple and isinstance(value, list):
return [serialize_attribute(attribute_name, v) for v in value]
model, _ = instance.get_model_annotations(attribute_name)
if model:
return value.id
return value
result = {}
for attribute in instance.attributes:
if serialized := serialize_attribute(
attribute, getattr(instance, attribute)
):
result[attribute] = serialized
return result
def default(self, obj):
from canaille.backends.models import Model
if isinstance(obj, datetime.datetime):
return obj.isoformat()
if isinstance(obj, Model):
return self.serialize_model(obj)
return super().default(obj)
class Backend: class Backend:
_instance = None _instance = None
json_encoder = ModelEncoder
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config

View file

@ -76,38 +76,6 @@ def register(cli):
cli.add_command(reset_otp) cli.add_command(reset_otp)
def serialize(instance):
"""Quick and dirty serialization method.
This can probably be made simpler when we will use pydantic models.
"""
def serialize_attribute(attribute_name, value):
multiple = is_multiple(instance.attributes[attribute_name])
if multiple and isinstance(value, list):
return [serialize_attribute(attribute_name, v) for v in value]
model, _ = instance.get_model_annotations(attribute_name)
if model:
return value.id
anonymized = ("password",)
if attribute_name in anonymized and value:
return "***"
if isinstance(value, datetime.datetime):
return value.isoformat()
return value
result = {}
for attribute in instance.attributes:
if serialized := serialize_attribute(attribute, getattr(instance, attribute)):
result[attribute] = serialized
return result
def get_factory(model): def get_factory(model):
command_help = f"""Search for {model.__name__.lower()}s and display the command_help = f"""Search for {model.__name__.lower()}s and display the
matching models as JSON.""" matching models as JSON."""
@ -120,7 +88,7 @@ def get_factory(model):
attribute: value for attribute, value in kwargs.items() if value is not None attribute: value for attribute, value in kwargs.items() if value is not None
} }
items = Backend.instance.query(model, **filter) items = Backend.instance.query(model, **filter)
output = json.dumps([serialize(item) for item in items]) output = json.dumps(list(items), cls=Backend.instance.json_encoder)
click.echo(output) click.echo(output)
for attribute, attribute_type in model.attributes.items(): for attribute, attribute_type in model.attributes.items():
@ -141,6 +109,8 @@ def get_factory(model):
help="Dump all the model instances", help="Dump all the model instances",
) )
@click.pass_context @click.pass_context
@with_appcontext
@with_backendcontext
def get_command(ctx, all: bool): def get_command(ctx, all: bool):
"""Read information about models. """Read information about models.
@ -158,11 +128,9 @@ def get_command(ctx, all: bool):
if all: if all:
objects = {} objects = {}
for model_name, model in MODELS.items(): for model_name, model in MODELS.items():
objects[model_name] = [ objects[model_name] = list(Backend.instance.query(model))
serialize(instance) for instance in Backend.instance.query(model)
]
output = json.dumps(objects) output = json.dumps(objects, cls=Backend.instance.json_encoder)
click.echo(output) click.echo(output)
@ -200,7 +168,7 @@ def set_factory(model):
except Exception as exc: # pragma: no cover except Exception as exc: # pragma: no cover
raise click.ClickException(exc) from exc raise click.ClickException(exc) from exc
output = json.dumps(serialize(instance)) output = json.dumps(instance, cls=Backend.instance.json_encoder)
click.echo(output) click.echo(output)
attributes = dict(model.attributes) attributes = dict(model.attributes)
@ -253,7 +221,7 @@ def create_factory(model):
except Exception as exc: # pragma: no cover except Exception as exc: # pragma: no cover
raise click.ClickException(exc) from exc raise click.ClickException(exc) from exc
output = json.dumps(serialize(instance)) output = json.dumps(instance, cls=Backend.instance.json_encoder)
click.echo(output) click.echo(output)
attributes = dict(model.attributes) attributes = dict(model.attributes)
@ -342,5 +310,5 @@ def reset_otp(identifier):
except Exception as exc: # pragma: no cover except Exception as exc: # pragma: no cover
raise click.ClickException(exc) from exc raise click.ClickException(exc) from exc
output = json.dumps(serialize(user)) output = json.dumps(user, cls=Backend.instance.json_encoder)
click.echo(output) click.echo(output)

View file

@ -6,8 +6,10 @@ from sqlalchemy import or_
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_base
from sqlalchemy_utils import Password
from canaille.backends import Backend from canaille.backends import Backend
from canaille.backends import ModelEncoder
from canaille.backends import get_lockout_delay_message from canaille.backends import get_lockout_delay_message
Base = declarative_base() Base = declarative_base()
@ -21,8 +23,16 @@ def db_session(db_uri=None, init=False):
return session return session
class SQLModelEncoder(ModelEncoder):
def default(self, obj):
if isinstance(obj, Password):
return obj.hash.decode()
return super().default(obj)
class SQLBackend(Backend): class SQLBackend(Backend):
db_session = None db_session = None
json_encoder = SQLModelEncoder
@classmethod @classmethod
def install(cls, config): # pragma: no cover def install(cls, config): # pragma: no cover

View file

@ -1,14 +1,26 @@
import datetime
import json import json
from unittest import mock from unittest import mock
from canaille.backends import ModelEncoder
from canaille.commands import cli from canaille.commands import cli
def test_serialize(user):
"""Test ModelSerializer with basic types."""
assert json.dumps({"foo": "bar"}, cls=ModelEncoder) == '{"foo": "bar"}'
assert (
json.dumps({"foo": datetime.datetime(1970, 1, 1)}, cls=ModelEncoder)
== '{"foo": "1970-01-01T00:00:00"}'
)
def test_get_list_models(testclient, backend, user): def test_get_list_models(testclient, backend, user):
"""Nominal case test for model get command.""" """Nominal case test for model get command."""
runner = testclient.app.test_cli_runner() runner = testclient.app.test_cli_runner()
res = runner.invoke(cli, ["get"]) res = runner.invoke(cli, ["get"], catch_exceptions=False)
assert res.exit_code == 0, res.stdout assert res.exit_code == 0, res.stdout
models = ("user", "group") models = ("user", "group")
for model in models: for model in models:
@ -19,7 +31,7 @@ def test_get(testclient, backend, user):
"""Nominal case test for model get command.""" """Nominal case test for model get command."""
runner = testclient.app.test_cli_runner() runner = testclient.app.test_cli_runner()
res = runner.invoke(cli, ["get", "user"]) res = runner.invoke(cli, ["get", "user"], catch_exceptions=False)
assert res.exit_code == 0, res.stdout assert res.exit_code == 0, res.stdout
assert json.loads(res.stdout) == [ assert json.loads(res.stdout) == [
{ {
@ -34,7 +46,7 @@ def test_get(testclient, backend, user):
"given_name": "John", "given_name": "John",
"id": user.id, "id": user.id,
"last_modified": mock.ANY, "last_modified": mock.ANY,
"password": "***", "password": mock.ANY,
"phone_numbers": [ "phone_numbers": [
"555-000-000", "555-000-000",
], ],
@ -49,7 +61,9 @@ def test_get_model_filter(testclient, backend, user, admin, foo_group):
"""Test model get filter.""" """Test model get filter."""
runner = testclient.app.test_cli_runner() runner = testclient.app.test_cli_runner()
res = runner.invoke(cli, ["get", "user", "--groups", foo_group.id]) res = runner.invoke(
cli, ["get", "user", "--groups", foo_group.id], catch_exceptions=False
)
assert res.exit_code == 0, res.stdout assert res.exit_code == 0, res.stdout
assert json.loads(res.stdout) == [ assert json.loads(res.stdout) == [
{ {
@ -64,7 +78,7 @@ def test_get_model_filter(testclient, backend, user, admin, foo_group):
"given_name": "John", "given_name": "John",
"id": user.id, "id": user.id,
"last_modified": mock.ANY, "last_modified": mock.ANY,
"password": "***", "password": mock.ANY,
"phone_numbers": [ "phone_numbers": [
"555-000-000", "555-000-000",
], ],
@ -80,7 +94,11 @@ def test_get_datetime_filter(testclient, backend, user):
"""Test model get filter.""" """Test model get filter."""
runner = testclient.app.test_cli_runner() runner = testclient.app.test_cli_runner()
res = runner.invoke(cli, ["get", "user", "--created", user.created.isoformat()]) res = runner.invoke(
cli,
["get", "user", "--created", user.created.isoformat()],
catch_exceptions=False,
)
assert res.exit_code == 0, res.stdout assert res.exit_code == 0, res.stdout
assert json.loads(res.stdout) == [ assert json.loads(res.stdout) == [
{ {
@ -95,7 +113,7 @@ def test_get_datetime_filter(testclient, backend, user):
"given_name": "John", "given_name": "John",
"id": user.id, "id": user.id,
"last_modified": mock.ANY, "last_modified": mock.ANY,
"password": "***", "password": mock.ANY,
"phone_numbers": [ "phone_numbers": [
"555-000-000", "555-000-000",
], ],
@ -110,7 +128,7 @@ def test_get_all(testclient, backend, user, foo_group):
"""Test the full database dump command.""" """Test the full database dump command."""
runner = testclient.app.test_cli_runner() runner = testclient.app.test_cli_runner()
res = runner.invoke(cli, ["get", "--all"]) res = runner.invoke(cli, ["get", "--all"], catch_exceptions=False)
assert res.exit_code == 0, res.stdout assert res.exit_code == 0, res.stdout
assert json.loads(res.stdout) == { assert json.loads(res.stdout) == {
"authorizationcode": [], "authorizationcode": [],
@ -142,7 +160,7 @@ def test_get_all(testclient, backend, user, foo_group):
"groups": [foo_group.id], "groups": [foo_group.id],
"id": user.id, "id": user.id,
"last_modified": mock.ANY, "last_modified": mock.ANY,
"password": "***", "password": mock.ANY,
"phone_numbers": [ "phone_numbers": [
"555-000-000", "555-000-000",
], ],

View file

@ -38,7 +38,7 @@ def test_reset_otp_by_id(testclient, backend, caplog, user_otp, otp_method):
"given_name": "John", "given_name": "John",
"id": user_otp.id, "id": user_otp.id,
"last_modified": mock.ANY, "last_modified": mock.ANY,
"password": "***", "password": mock.ANY,
"phone_numbers": [ "phone_numbers": [
"555-000-000", "555-000-000",
], ],

View file

@ -22,7 +22,7 @@ def test_set_string_by_id(testclient, backend, user):
"given_name": "foobar", "given_name": "foobar",
"id": user.id, "id": user.id,
"last_modified": mock.ANY, "last_modified": mock.ANY,
"password": "***", "password": mock.ANY,
"phone_numbers": [ "phone_numbers": [
"555-000-000", "555-000-000",
], ],
@ -52,7 +52,7 @@ def test_set_string_by_identifier(testclient, backend, user):
"given_name": "foobar", "given_name": "foobar",
"id": user.id, "id": user.id,
"last_modified": mock.ANY, "last_modified": mock.ANY,
"password": "***", "password": mock.ANY,
"phone_numbers": [ "phone_numbers": [
"555-000-000", "555-000-000",
], ],
@ -94,7 +94,7 @@ def test_set_multiple(testclient, backend, user):
"given_name": "John", "given_name": "John",
"id": user.id, "id": user.id,
"last_modified": mock.ANY, "last_modified": mock.ANY,
"password": "***", "password": mock.ANY,
"phone_numbers": [ "phone_numbers": [
"555-000-000", "555-000-000",
], ],
@ -136,7 +136,7 @@ def test_set_remove_simple_attribute(testclient, backend, user, admin):
"given_name": "John", "given_name": "John",
"id": user.id, "id": user.id,
"last_modified": mock.ANY, "last_modified": mock.ANY,
"password": "***", "password": mock.ANY,
"phone_numbers": [ "phone_numbers": [
"555-000-000", "555-000-000",
], ],
@ -169,7 +169,7 @@ def test_set_remove_multiple_attribute(testclient, backend, user, admin, foo_gro
"given_name": "John", "given_name": "John",
"id": user.id, "id": user.id,
"last_modified": mock.ANY, "last_modified": mock.ANY,
"password": "***", "password": mock.ANY,
"phone_numbers": [ "phone_numbers": [
"555-000-000", "555-000-000",
], ],