diff --git a/CHANGES.rst b/CHANGES.rst index 2e686528..3c8be58b 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -10,6 +10,7 @@ Added Changed ^^^^^^^ - fixed a bug on updating user's settings :issue:`206` +- CLI commands dump hashed passwords :issue:`214` Changed ^^^^^^^ diff --git a/canaille/backends/__init__.py b/canaille/backends/__init__.py index 940b28ac..fe4b2207 100644 --- a/canaille/backends/__init__.py +++ b/canaille/backends/__init__.py @@ -1,5 +1,8 @@ +import datetime import importlib +import json import os +import typing from contextlib import contextmanager from math import ceil @@ -8,8 +11,48 @@ from flask import g from canaille.app import classproperty +class ModelEncoder(json.JSONEncoder): + """JSON serializer that can handle Canaille models.""" + + @staticmethod + def serialize_model(instance): + def serialize_attribute(attribute_name, value): + """Replace model instances by their id.""" + + multiple = typing.get_origin(instance.attributes[attribute_name]) is list + if multiple and isinstance(value, list): + return [serialize_attribute(attribute_name, v) for v in value] + + model, _ = instance.get_model_annotations(attribute_name) + if model: + return value.id + + return value + + result = {} + for attribute in instance.attributes: + if serialized := serialize_attribute( + attribute, getattr(instance, attribute) + ): + result[attribute] = serialized + + return result + + def default(self, obj): + from canaille.backends.models import Model + + if isinstance(obj, datetime.datetime): + return obj.isoformat() + + if isinstance(obj, Model): + return self.serialize_model(obj) + + return super().default(obj) + + class Backend: _instance = None + json_encoder = ModelEncoder def __init__(self, config): self.config = config diff --git a/canaille/backends/commands.py b/canaille/backends/commands.py index 51753d8e..d376f3f3 100644 --- a/canaille/backends/commands.py +++ b/canaille/backends/commands.py @@ -76,38 +76,6 @@ def register(cli): cli.add_command(reset_otp) -def serialize(instance): - """Quick and dirty serialization method. - - This can probably be made simpler when we will use pydantic models. - """ - - def serialize_attribute(attribute_name, value): - multiple = is_multiple(instance.attributes[attribute_name]) - if multiple and isinstance(value, list): - return [serialize_attribute(attribute_name, v) for v in value] - - model, _ = instance.get_model_annotations(attribute_name) - if model: - return value.id - - anonymized = ("password",) - if attribute_name in anonymized and value: - return "***" - - if isinstance(value, datetime.datetime): - return value.isoformat() - - return value - - result = {} - for attribute in instance.attributes: - if serialized := serialize_attribute(attribute, getattr(instance, attribute)): - result[attribute] = serialized - - return result - - def get_factory(model): command_help = f"""Search for {model.__name__.lower()}s and display the matching models as JSON.""" @@ -120,7 +88,7 @@ def get_factory(model): attribute: value for attribute, value in kwargs.items() if value is not None } items = Backend.instance.query(model, **filter) - output = json.dumps([serialize(item) for item in items]) + output = json.dumps(list(items), cls=Backend.instance.json_encoder) click.echo(output) for attribute, attribute_type in model.attributes.items(): @@ -141,6 +109,8 @@ def get_factory(model): help="Dump all the model instances", ) @click.pass_context +@with_appcontext +@with_backendcontext def get_command(ctx, all: bool): """Read information about models. @@ -158,11 +128,9 @@ def get_command(ctx, all: bool): if all: objects = {} for model_name, model in MODELS.items(): - objects[model_name] = [ - serialize(instance) for instance in Backend.instance.query(model) - ] + objects[model_name] = list(Backend.instance.query(model)) - output = json.dumps(objects) + output = json.dumps(objects, cls=Backend.instance.json_encoder) click.echo(output) @@ -200,7 +168,7 @@ def set_factory(model): except Exception as exc: # pragma: no cover raise click.ClickException(exc) from exc - output = json.dumps(serialize(instance)) + output = json.dumps(instance, cls=Backend.instance.json_encoder) click.echo(output) attributes = dict(model.attributes) @@ -253,7 +221,7 @@ def create_factory(model): except Exception as exc: # pragma: no cover raise click.ClickException(exc) from exc - output = json.dumps(serialize(instance)) + output = json.dumps(instance, cls=Backend.instance.json_encoder) click.echo(output) attributes = dict(model.attributes) @@ -342,5 +310,5 @@ def reset_otp(identifier): except Exception as exc: # pragma: no cover raise click.ClickException(exc) from exc - output = json.dumps(serialize(user)) + output = json.dumps(user, cls=Backend.instance.json_encoder) click.echo(output) diff --git a/canaille/backends/sql/backend.py b/canaille/backends/sql/backend.py index 9bf23d62..429d644f 100644 --- a/canaille/backends/sql/backend.py +++ b/canaille/backends/sql/backend.py @@ -6,8 +6,10 @@ from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy.orm import Session from sqlalchemy.orm import declarative_base +from sqlalchemy_utils import Password from canaille.backends import Backend +from canaille.backends import ModelEncoder from canaille.backends import get_lockout_delay_message Base = declarative_base() @@ -21,8 +23,16 @@ def db_session(db_uri=None, init=False): return session +class SQLModelEncoder(ModelEncoder): + def default(self, obj): + if isinstance(obj, Password): + return obj.hash.decode() + return super().default(obj) + + class SQLBackend(Backend): db_session = None + json_encoder = SQLModelEncoder @classmethod def install(cls, config): # pragma: no cover diff --git a/tests/app/commands/test_get.py b/tests/app/commands/test_get.py index c2b68cd8..85fb7037 100644 --- a/tests/app/commands/test_get.py +++ b/tests/app/commands/test_get.py @@ -1,14 +1,26 @@ +import datetime import json from unittest import mock +from canaille.backends import ModelEncoder from canaille.commands import cli +def test_serialize(user): + """Test ModelSerializer with basic types.""" + assert json.dumps({"foo": "bar"}, cls=ModelEncoder) == '{"foo": "bar"}' + + assert ( + json.dumps({"foo": datetime.datetime(1970, 1, 1)}, cls=ModelEncoder) + == '{"foo": "1970-01-01T00:00:00"}' + ) + + def test_get_list_models(testclient, backend, user): """Nominal case test for model get command.""" runner = testclient.app.test_cli_runner() - res = runner.invoke(cli, ["get"]) + res = runner.invoke(cli, ["get"], catch_exceptions=False) assert res.exit_code == 0, res.stdout models = ("user", "group") for model in models: @@ -19,7 +31,7 @@ def test_get(testclient, backend, user): """Nominal case test for model get command.""" runner = testclient.app.test_cli_runner() - res = runner.invoke(cli, ["get", "user"]) + res = runner.invoke(cli, ["get", "user"], catch_exceptions=False) assert res.exit_code == 0, res.stdout assert json.loads(res.stdout) == [ { @@ -34,7 +46,7 @@ def test_get(testclient, backend, user): "given_name": "John", "id": user.id, "last_modified": mock.ANY, - "password": "***", + "password": mock.ANY, "phone_numbers": [ "555-000-000", ], @@ -49,7 +61,9 @@ def test_get_model_filter(testclient, backend, user, admin, foo_group): """Test model get filter.""" runner = testclient.app.test_cli_runner() - res = runner.invoke(cli, ["get", "user", "--groups", foo_group.id]) + res = runner.invoke( + cli, ["get", "user", "--groups", foo_group.id], catch_exceptions=False + ) assert res.exit_code == 0, res.stdout assert json.loads(res.stdout) == [ { @@ -64,7 +78,7 @@ def test_get_model_filter(testclient, backend, user, admin, foo_group): "given_name": "John", "id": user.id, "last_modified": mock.ANY, - "password": "***", + "password": mock.ANY, "phone_numbers": [ "555-000-000", ], @@ -80,7 +94,11 @@ def test_get_datetime_filter(testclient, backend, user): """Test model get filter.""" runner = testclient.app.test_cli_runner() - res = runner.invoke(cli, ["get", "user", "--created", user.created.isoformat()]) + res = runner.invoke( + cli, + ["get", "user", "--created", user.created.isoformat()], + catch_exceptions=False, + ) assert res.exit_code == 0, res.stdout assert json.loads(res.stdout) == [ { @@ -95,7 +113,7 @@ def test_get_datetime_filter(testclient, backend, user): "given_name": "John", "id": user.id, "last_modified": mock.ANY, - "password": "***", + "password": mock.ANY, "phone_numbers": [ "555-000-000", ], @@ -110,7 +128,7 @@ def test_get_all(testclient, backend, user, foo_group): """Test the full database dump command.""" runner = testclient.app.test_cli_runner() - res = runner.invoke(cli, ["get", "--all"]) + res = runner.invoke(cli, ["get", "--all"], catch_exceptions=False) assert res.exit_code == 0, res.stdout assert json.loads(res.stdout) == { "authorizationcode": [], @@ -142,7 +160,7 @@ def test_get_all(testclient, backend, user, foo_group): "groups": [foo_group.id], "id": user.id, "last_modified": mock.ANY, - "password": "***", + "password": mock.ANY, "phone_numbers": [ "555-000-000", ], diff --git a/tests/app/commands/test_reset_otp.py b/tests/app/commands/test_reset_otp.py index cbf979c7..138a8a16 100644 --- a/tests/app/commands/test_reset_otp.py +++ b/tests/app/commands/test_reset_otp.py @@ -38,7 +38,7 @@ def test_reset_otp_by_id(testclient, backend, caplog, user_otp, otp_method): "given_name": "John", "id": user_otp.id, "last_modified": mock.ANY, - "password": "***", + "password": mock.ANY, "phone_numbers": [ "555-000-000", ], diff --git a/tests/app/commands/test_set.py b/tests/app/commands/test_set.py index 41966f03..0f8f0114 100644 --- a/tests/app/commands/test_set.py +++ b/tests/app/commands/test_set.py @@ -22,7 +22,7 @@ def test_set_string_by_id(testclient, backend, user): "given_name": "foobar", "id": user.id, "last_modified": mock.ANY, - "password": "***", + "password": mock.ANY, "phone_numbers": [ "555-000-000", ], @@ -52,7 +52,7 @@ def test_set_string_by_identifier(testclient, backend, user): "given_name": "foobar", "id": user.id, "last_modified": mock.ANY, - "password": "***", + "password": mock.ANY, "phone_numbers": [ "555-000-000", ], @@ -94,7 +94,7 @@ def test_set_multiple(testclient, backend, user): "given_name": "John", "id": user.id, "last_modified": mock.ANY, - "password": "***", + "password": mock.ANY, "phone_numbers": [ "555-000-000", ], @@ -136,7 +136,7 @@ def test_set_remove_simple_attribute(testclient, backend, user, admin): "given_name": "John", "id": user.id, "last_modified": mock.ANY, - "password": "***", + "password": mock.ANY, "phone_numbers": [ "555-000-000", ], @@ -169,7 +169,7 @@ def test_set_remove_multiple_attribute(testclient, backend, user, admin, foo_gro "given_name": "John", "id": user.id, "last_modified": mock.ANY, - "password": "***", + "password": mock.ANY, "phone_numbers": [ "555-000-000", ],