forked from Github-Mirrors/canaille
Merge branch '214-dump-passwords' into 'main'
CLI commands can dump hashed passwords Closes #214 See merge request yaal/canaille!212
This commit is contained in:
commit
a5377f4544
7 changed files with 95 additions and 55 deletions
|
@ -10,6 +10,7 @@ Added
|
|||
Changed
|
||||
^^^^^^^
|
||||
- fixed a bug on updating user's settings :issue:`206`
|
||||
- CLI commands dump hashed passwords :issue:`214`
|
||||
|
||||
Changed
|
||||
^^^^^^^
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import datetime
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import typing
|
||||
from contextlib import contextmanager
|
||||
from math import ceil
|
||||
|
||||
|
@ -8,8 +11,48 @@ from flask import g
|
|||
from canaille.app import classproperty
|
||||
|
||||
|
||||
class ModelEncoder(json.JSONEncoder):
|
||||
"""JSON serializer that can handle Canaille models."""
|
||||
|
||||
@staticmethod
|
||||
def serialize_model(instance):
|
||||
def serialize_attribute(attribute_name, value):
|
||||
"""Replace model instances by their id."""
|
||||
|
||||
multiple = typing.get_origin(instance.attributes[attribute_name]) is list
|
||||
if multiple and isinstance(value, list):
|
||||
return [serialize_attribute(attribute_name, v) for v in value]
|
||||
|
||||
model, _ = instance.get_model_annotations(attribute_name)
|
||||
if model:
|
||||
return value.id
|
||||
|
||||
return value
|
||||
|
||||
result = {}
|
||||
for attribute in instance.attributes:
|
||||
if serialized := serialize_attribute(
|
||||
attribute, getattr(instance, attribute)
|
||||
):
|
||||
result[attribute] = serialized
|
||||
|
||||
return result
|
||||
|
||||
def default(self, obj):
|
||||
from canaille.backends.models import Model
|
||||
|
||||
if isinstance(obj, datetime.datetime):
|
||||
return obj.isoformat()
|
||||
|
||||
if isinstance(obj, Model):
|
||||
return self.serialize_model(obj)
|
||||
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
class Backend:
|
||||
_instance = None
|
||||
json_encoder = ModelEncoder
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
|
|
@ -76,38 +76,6 @@ def register(cli):
|
|||
cli.add_command(reset_otp)
|
||||
|
||||
|
||||
def serialize(instance):
|
||||
"""Quick and dirty serialization method.
|
||||
|
||||
This can probably be made simpler when we will use pydantic models.
|
||||
"""
|
||||
|
||||
def serialize_attribute(attribute_name, value):
|
||||
multiple = is_multiple(instance.attributes[attribute_name])
|
||||
if multiple and isinstance(value, list):
|
||||
return [serialize_attribute(attribute_name, v) for v in value]
|
||||
|
||||
model, _ = instance.get_model_annotations(attribute_name)
|
||||
if model:
|
||||
return value.id
|
||||
|
||||
anonymized = ("password",)
|
||||
if attribute_name in anonymized and value:
|
||||
return "***"
|
||||
|
||||
if isinstance(value, datetime.datetime):
|
||||
return value.isoformat()
|
||||
|
||||
return value
|
||||
|
||||
result = {}
|
||||
for attribute in instance.attributes:
|
||||
if serialized := serialize_attribute(attribute, getattr(instance, attribute)):
|
||||
result[attribute] = serialized
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_factory(model):
|
||||
command_help = f"""Search for {model.__name__.lower()}s and display the
|
||||
matching models as JSON."""
|
||||
|
@ -120,7 +88,7 @@ def get_factory(model):
|
|||
attribute: value for attribute, value in kwargs.items() if value is not None
|
||||
}
|
||||
items = Backend.instance.query(model, **filter)
|
||||
output = json.dumps([serialize(item) for item in items])
|
||||
output = json.dumps(list(items), cls=Backend.instance.json_encoder)
|
||||
click.echo(output)
|
||||
|
||||
for attribute, attribute_type in model.attributes.items():
|
||||
|
@ -141,6 +109,8 @@ def get_factory(model):
|
|||
help="Dump all the model instances",
|
||||
)
|
||||
@click.pass_context
|
||||
@with_appcontext
|
||||
@with_backendcontext
|
||||
def get_command(ctx, all: bool):
|
||||
"""Read information about models.
|
||||
|
||||
|
@ -158,11 +128,9 @@ def get_command(ctx, all: bool):
|
|||
if all:
|
||||
objects = {}
|
||||
for model_name, model in MODELS.items():
|
||||
objects[model_name] = [
|
||||
serialize(instance) for instance in Backend.instance.query(model)
|
||||
]
|
||||
objects[model_name] = list(Backend.instance.query(model))
|
||||
|
||||
output = json.dumps(objects)
|
||||
output = json.dumps(objects, cls=Backend.instance.json_encoder)
|
||||
click.echo(output)
|
||||
|
||||
|
||||
|
@ -200,7 +168,7 @@ def set_factory(model):
|
|||
except Exception as exc: # pragma: no cover
|
||||
raise click.ClickException(exc) from exc
|
||||
|
||||
output = json.dumps(serialize(instance))
|
||||
output = json.dumps(instance, cls=Backend.instance.json_encoder)
|
||||
click.echo(output)
|
||||
|
||||
attributes = dict(model.attributes)
|
||||
|
@ -253,7 +221,7 @@ def create_factory(model):
|
|||
except Exception as exc: # pragma: no cover
|
||||
raise click.ClickException(exc) from exc
|
||||
|
||||
output = json.dumps(serialize(instance))
|
||||
output = json.dumps(instance, cls=Backend.instance.json_encoder)
|
||||
click.echo(output)
|
||||
|
||||
attributes = dict(model.attributes)
|
||||
|
@ -342,5 +310,5 @@ def reset_otp(identifier):
|
|||
except Exception as exc: # pragma: no cover
|
||||
raise click.ClickException(exc) from exc
|
||||
|
||||
output = json.dumps(serialize(user))
|
||||
output = json.dumps(user, cls=Backend.instance.json_encoder)
|
||||
click.echo(output)
|
||||
|
|
|
@ -6,8 +6,10 @@ from sqlalchemy import or_
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy_utils import Password
|
||||
|
||||
from canaille.backends import Backend
|
||||
from canaille.backends import ModelEncoder
|
||||
from canaille.backends import get_lockout_delay_message
|
||||
|
||||
Base = declarative_base()
|
||||
|
@ -21,8 +23,16 @@ def db_session(db_uri=None, init=False):
|
|||
return session
|
||||
|
||||
|
||||
class SQLModelEncoder(ModelEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, Password):
|
||||
return obj.hash.decode()
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
class SQLBackend(Backend):
|
||||
db_session = None
|
||||
json_encoder = SQLModelEncoder
|
||||
|
||||
@classmethod
|
||||
def install(cls, config): # pragma: no cover
|
||||
|
|
|
@ -1,14 +1,26 @@
|
|||
import datetime
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
from canaille.backends import ModelEncoder
|
||||
from canaille.commands import cli
|
||||
|
||||
|
||||
def test_serialize(user):
|
||||
"""Test ModelSerializer with basic types."""
|
||||
assert json.dumps({"foo": "bar"}, cls=ModelEncoder) == '{"foo": "bar"}'
|
||||
|
||||
assert (
|
||||
json.dumps({"foo": datetime.datetime(1970, 1, 1)}, cls=ModelEncoder)
|
||||
== '{"foo": "1970-01-01T00:00:00"}'
|
||||
)
|
||||
|
||||
|
||||
def test_get_list_models(testclient, backend, user):
|
||||
"""Nominal case test for model get command."""
|
||||
|
||||
runner = testclient.app.test_cli_runner()
|
||||
res = runner.invoke(cli, ["get"])
|
||||
res = runner.invoke(cli, ["get"], catch_exceptions=False)
|
||||
assert res.exit_code == 0, res.stdout
|
||||
models = ("user", "group")
|
||||
for model in models:
|
||||
|
@ -19,7 +31,7 @@ def test_get(testclient, backend, user):
|
|||
"""Nominal case test for model get command."""
|
||||
|
||||
runner = testclient.app.test_cli_runner()
|
||||
res = runner.invoke(cli, ["get", "user"])
|
||||
res = runner.invoke(cli, ["get", "user"], catch_exceptions=False)
|
||||
assert res.exit_code == 0, res.stdout
|
||||
assert json.loads(res.stdout) == [
|
||||
{
|
||||
|
@ -34,7 +46,7 @@ def test_get(testclient, backend, user):
|
|||
"given_name": "John",
|
||||
"id": user.id,
|
||||
"last_modified": mock.ANY,
|
||||
"password": "***",
|
||||
"password": mock.ANY,
|
||||
"phone_numbers": [
|
||||
"555-000-000",
|
||||
],
|
||||
|
@ -49,7 +61,9 @@ def test_get_model_filter(testclient, backend, user, admin, foo_group):
|
|||
"""Test model get filter."""
|
||||
|
||||
runner = testclient.app.test_cli_runner()
|
||||
res = runner.invoke(cli, ["get", "user", "--groups", foo_group.id])
|
||||
res = runner.invoke(
|
||||
cli, ["get", "user", "--groups", foo_group.id], catch_exceptions=False
|
||||
)
|
||||
assert res.exit_code == 0, res.stdout
|
||||
assert json.loads(res.stdout) == [
|
||||
{
|
||||
|
@ -64,7 +78,7 @@ def test_get_model_filter(testclient, backend, user, admin, foo_group):
|
|||
"given_name": "John",
|
||||
"id": user.id,
|
||||
"last_modified": mock.ANY,
|
||||
"password": "***",
|
||||
"password": mock.ANY,
|
||||
"phone_numbers": [
|
||||
"555-000-000",
|
||||
],
|
||||
|
@ -80,7 +94,11 @@ def test_get_datetime_filter(testclient, backend, user):
|
|||
"""Test model get filter."""
|
||||
|
||||
runner = testclient.app.test_cli_runner()
|
||||
res = runner.invoke(cli, ["get", "user", "--created", user.created.isoformat()])
|
||||
res = runner.invoke(
|
||||
cli,
|
||||
["get", "user", "--created", user.created.isoformat()],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
assert res.exit_code == 0, res.stdout
|
||||
assert json.loads(res.stdout) == [
|
||||
{
|
||||
|
@ -95,7 +113,7 @@ def test_get_datetime_filter(testclient, backend, user):
|
|||
"given_name": "John",
|
||||
"id": user.id,
|
||||
"last_modified": mock.ANY,
|
||||
"password": "***",
|
||||
"password": mock.ANY,
|
||||
"phone_numbers": [
|
||||
"555-000-000",
|
||||
],
|
||||
|
@ -110,7 +128,7 @@ def test_get_all(testclient, backend, user, foo_group):
|
|||
"""Test the full database dump command."""
|
||||
|
||||
runner = testclient.app.test_cli_runner()
|
||||
res = runner.invoke(cli, ["get", "--all"])
|
||||
res = runner.invoke(cli, ["get", "--all"], catch_exceptions=False)
|
||||
assert res.exit_code == 0, res.stdout
|
||||
assert json.loads(res.stdout) == {
|
||||
"authorizationcode": [],
|
||||
|
@ -142,7 +160,7 @@ def test_get_all(testclient, backend, user, foo_group):
|
|||
"groups": [foo_group.id],
|
||||
"id": user.id,
|
||||
"last_modified": mock.ANY,
|
||||
"password": "***",
|
||||
"password": mock.ANY,
|
||||
"phone_numbers": [
|
||||
"555-000-000",
|
||||
],
|
||||
|
|
|
@ -38,7 +38,7 @@ def test_reset_otp_by_id(testclient, backend, caplog, user_otp, otp_method):
|
|||
"given_name": "John",
|
||||
"id": user_otp.id,
|
||||
"last_modified": mock.ANY,
|
||||
"password": "***",
|
||||
"password": mock.ANY,
|
||||
"phone_numbers": [
|
||||
"555-000-000",
|
||||
],
|
||||
|
|
|
@ -22,7 +22,7 @@ def test_set_string_by_id(testclient, backend, user):
|
|||
"given_name": "foobar",
|
||||
"id": user.id,
|
||||
"last_modified": mock.ANY,
|
||||
"password": "***",
|
||||
"password": mock.ANY,
|
||||
"phone_numbers": [
|
||||
"555-000-000",
|
||||
],
|
||||
|
@ -52,7 +52,7 @@ def test_set_string_by_identifier(testclient, backend, user):
|
|||
"given_name": "foobar",
|
||||
"id": user.id,
|
||||
"last_modified": mock.ANY,
|
||||
"password": "***",
|
||||
"password": mock.ANY,
|
||||
"phone_numbers": [
|
||||
"555-000-000",
|
||||
],
|
||||
|
@ -94,7 +94,7 @@ def test_set_multiple(testclient, backend, user):
|
|||
"given_name": "John",
|
||||
"id": user.id,
|
||||
"last_modified": mock.ANY,
|
||||
"password": "***",
|
||||
"password": mock.ANY,
|
||||
"phone_numbers": [
|
||||
"555-000-000",
|
||||
],
|
||||
|
@ -136,7 +136,7 @@ def test_set_remove_simple_attribute(testclient, backend, user, admin):
|
|||
"given_name": "John",
|
||||
"id": user.id,
|
||||
"last_modified": mock.ANY,
|
||||
"password": "***",
|
||||
"password": mock.ANY,
|
||||
"phone_numbers": [
|
||||
"555-000-000",
|
||||
],
|
||||
|
@ -169,7 +169,7 @@ def test_set_remove_multiple_attribute(testclient, backend, user, admin, foo_gro
|
|||
"given_name": "John",
|
||||
"id": user.id,
|
||||
"last_modified": mock.ANY,
|
||||
"password": "***",
|
||||
"password": mock.ANY,
|
||||
"phone_numbers": [
|
||||
"555-000-000",
|
||||
],
|
||||
|
|
Loading…
Reference in a new issue