Moved every model import to canaille.models

This commit is contained in:
Éloi Rivard 2023-04-09 11:37:04 +02:00
parent e110c4851b
commit c1d1706007
43 changed files with 421 additions and 428 deletions

View file

@ -3,7 +3,7 @@ import hashlib
import json
import re
from canaille.core.models import User
from canaille.app import models
from flask import current_app
from flask import request
from flask_babel import gettext as _
@ -26,7 +26,7 @@ def profile_hash(*args):
def login_placeholder():
user_filter = current_app.config["BACKENDS"]["LDAP"].get(
"USER_FILTER", User.DEFAULT_FILTER
"USER_FILTER", models.User.DEFAULT_FILTER
)
placeholders = []

View file

@ -3,7 +3,7 @@ from functools import wraps
from urllib.parse import urlsplit
from urllib.parse import urlunsplit
from canaille.core.models import User
from canaille.app import models
from flask import abort
from flask import current_app
from flask import render_template
@ -14,7 +14,7 @@ from flask_babel import gettext as _
def current_user():
for user_id in session.get("user_id", [])[::-1]:
user = User.get(id=user_id)
user = models.User.get(id=user_id)
if user:
return user

7
canaille/app/models.py Normal file
View file

@ -0,0 +1,7 @@
# nopycln: file
from canaille.core.models import Group
from canaille.core.models import User
from canaille.oidc.models import AuthorizationCode
from canaille.oidc.models import Client
from canaille.oidc.models import Consent
from canaille.oidc.models import Token

View file

@ -76,8 +76,7 @@ class LDAPBackend(Backend):
@classmethod
def validate(cls, config):
from canaille.app.configuration import ConfigurationException
from canaille.core.models import Group
from canaille.core.models import User
from canaille.app import models
try:
conn = ldap.initialize(config["BACKENDS"]["LDAP"]["URI"])
@ -100,8 +99,8 @@ class LDAPBackend(Backend):
) from exc
try:
User.ldap_object_classes(conn)
user = User(
models.User.ldap_object_classes(conn)
user = models.User(
formatted_name=f"canaille_{uuid.uuid4()}",
family_name=f"canaille_{uuid.uuid4()}",
user_name=f"canaille_{uuid.uuid4()}",
@ -118,9 +117,9 @@ class LDAPBackend(Backend):
) from exc
try:
Group.ldap_object_classes(conn)
models.Group.ldap_object_classes(conn)
user = User(
user = models.User(
cn=f"canaille_{uuid.uuid4()}",
family_name=f"canaille_{uuid.uuid4()}",
user_name=f"canaille_{uuid.uuid4()}",
@ -129,7 +128,7 @@ class LDAPBackend(Backend):
)
user.save(conn)
group = Group(
group = models.Group(
display_name=f"canaille_{uuid.uuid4()}",
members=[user],
)
@ -150,22 +149,21 @@ class LDAPBackend(Backend):
def setup_ldap_models(config):
from .ldapobject import LDAPObject
from canaille.core.models import Group
from canaille.core.models import User
from canaille.app import models
LDAPObject.root_dn = config["BACKENDS"]["LDAP"]["ROOT_DN"]
user_base = config["BACKENDS"]["LDAP"]["USER_BASE"].replace(
f',{config["BACKENDS"]["LDAP"]["ROOT_DN"]}', ""
)
User.base = user_base
User.rdn_attribute = config["BACKENDS"]["LDAP"].get(
"USER_ID_ATTRIBUTE", User.DEFAULT_ID_ATTRIBUTE
models.User.base = user_base
models.User.rdn_attribute = config["BACKENDS"]["LDAP"].get(
"USER_ID_ATTRIBUTE", models.User.DEFAULT_ID_ATTRIBUTE
)
object_class = config["BACKENDS"]["LDAP"].get(
"USER_CLASS", User.DEFAULT_OBJECT_CLASS
"USER_CLASS", models.User.DEFAULT_OBJECT_CLASS
)
User.ldap_object_class = (
models.User.ldap_object_class = (
object_class if isinstance(object_class, list) else [object_class]
)
@ -174,13 +172,13 @@ def setup_ldap_models(config):
.get("GROUP_BASE", "")
.replace(f',{config["BACKENDS"]["LDAP"]["ROOT_DN"]}', "")
)
Group.base = group_base or None
Group.rdn_attribute = config["BACKENDS"]["LDAP"].get(
"GROUP_ID_ATTRIBUTE", Group.DEFAULT_ID_ATTRIBUTE
models.Group.base = group_base or None
models.Group.rdn_attribute = config["BACKENDS"]["LDAP"].get(
"GROUP_ID_ATTRIBUTE", models.Group.DEFAULT_ID_ATTRIBUTE
)
object_class = config["BACKENDS"]["LDAP"].get(
"GROUP_CLASS", Group.DEFAULT_OBJECT_CLASS
"GROUP_CLASS", models.Group.DEFAULT_OBJECT_CLASS
)
Group.ldap_object_class = (
models.Group.ldap_object_class = (
object_class if isinstance(object_class, list) else [object_class]
)

View file

@ -9,6 +9,7 @@ import wtforms
from canaille.app import b64_to_obj
from canaille.app import default_fields
from canaille.app import login_placeholder
from canaille.app import models
from canaille.app import obj_to_b64
from canaille.app import profile_hash
from canaille.app.flask import current_user
@ -43,8 +44,6 @@ from .forms import profile_form
from .mails import send_invitation_mail
from .mails import send_password_initialization_mail
from .mails import send_password_reset_mail
from .models import Group
from .models import User
bp = Blueprint("account", __name__)
@ -88,12 +87,12 @@ def login():
form["login"].render_kw["placeholder"] = login_placeholder()
if request.form:
user = User.get_from_login(form.login.data)
user = models.User.get_from_login(form.login.data)
if user and not user.has_password():
return redirect(url_for("account.firstlogin", user_name=user.user_name[0]))
if not form.validate():
User.logout()
models.User.logout()
flash(_("Login failed, please check your information"), "error")
return render_template("login.html", form=form)
@ -111,7 +110,7 @@ def password():
form = PasswordForm(request.form or None)
if request.form:
user = User.get_from_login(session["attempt_login"])
user = models.User.get_from_login(session["attempt_login"])
if user and not user.has_password():
return redirect(url_for("account.firstlogin", user_name=user.user_name[0]))
@ -120,7 +119,7 @@ def password():
or not user
or not user.check_password(form.password.data)
):
User.logout()
models.User.logout()
flash(_("Login failed, please check your information"), "error")
return render_template(
"password.html", form=form, username=session["attempt_login"]
@ -156,7 +155,7 @@ def logout():
@bp.route("/firstlogin/<user_name>", methods=("GET", "POST"))
def firstlogin(user_name):
user = User.get_from_login(user_name)
user = models.User.get_from_login(user_name)
if not user or user.has_password():
abort(404)
@ -182,7 +181,9 @@ def firstlogin(user_name):
@bp.route("/users", methods=["GET", "POST"])
@permissions_needed("manage_users")
def users(user):
table_form = TableForm(User, fields=user.read | user.write, formdata=request.form)
table_form = TableForm(
models.User, fields=user.read | user.write, formdata=request.form
)
if request.form and not table_form.validate():
abort(404)
@ -278,7 +279,7 @@ def registration(data, hash):
)
return redirect(url_for("account.index"))
if User.get_from_login(invitation.user_name):
if models.User.get_from_login(invitation.user_name):
flash(
_("Your account has already been created."),
"error",
@ -311,7 +312,7 @@ def registration(data, hash):
if "groups" not in form and invitation.groups:
form["groups"] = wtforms.SelectMultipleField(
_("Groups"),
choices=[(group.id, group.display_name) for group in Group.query()],
choices=[(group.id, group.display_name) for group in models.Group.query()],
render_kw={"readonly": "true"},
)
form.process(CombinedMultiDict((request.files, request.form)) or None, data=data)
@ -381,7 +382,7 @@ def profile_creation(user):
def profile_create(current_app, form):
user = User()
user = models.User()
for attribute in form:
if attribute.name in user.attributes:
if isinstance(attribute.data, FileStorage):
@ -420,7 +421,7 @@ def profile_edition(user, username):
menuitem = "profile" if username == editor.user_name[0] else "users"
fields = editor.read | editor.write
if username != editor.user_name[0]:
user = User.get_from_login(username)
user = models.User.get_from_login(username)
else:
user = editor
@ -505,7 +506,7 @@ def profile_settings(user, username):
):
abort(403)
edited_user = User.get_from_login(username)
edited_user = models.User.get_from_login(username)
if not edited_user:
abort(404)
@ -622,7 +623,7 @@ def profile_delete(user, edited_user):
@bp.route("/impersonate/<username>")
@permissions_needed("impersonate_users")
def impersonate(user, username):
puppet = User.get_from_login(username)
puppet = models.User.get_from_login(username)
if not puppet:
abort(404)
@ -649,7 +650,7 @@ def forgotten():
flash(_("Could not send the password reset link."), "error")
return render_template("forgotten-password.html", form=form)
user = User.get_from_login(form.login.data)
user = models.User.get_from_login(form.login.data)
success_message = _(
"A password reset link has been sent at your email address. You should receive it within a few minutes."
)
@ -689,7 +690,7 @@ def reset(user_name, hash):
abort(404)
form = PasswordResetForm(request.form)
user = User.get_from_login(user_name)
user = models.User.get_from_login(user_name)
if not user or hash != profile_hash(
user.user_name[0],
@ -719,7 +720,7 @@ def photo(user_name, field):
if field.lower() != "photo":
abort(404)
user = User.get_from_login(user_name)
user = models.User.get_from_login(user_name)
if not user:
abort(404)

View file

@ -1,4 +1,5 @@
import wtforms.form
from canaille.app import models
from canaille.app.forms import HTMXBaseForm
from canaille.app.forms import HTMXForm
from canaille.app.forms import is_uri
@ -9,12 +10,9 @@ from flask_babel import lazy_gettext as _
from flask_wtf.file import FileAllowed
from flask_wtf.file import FileField
from .models import Group
from .models import User
def unique_login(form, field):
if User.get_from_login(field.data) and (
if models.User.get_from_login(field.data) and (
not getattr(form, "user", None) or form.user.user_name[0] != field.data
):
raise wtforms.ValidationError(
@ -23,7 +21,7 @@ def unique_login(form, field):
def unique_email(form, field):
if User.get(email=field.data) and (
if models.User.get(email=field.data) and (
not getattr(form, "user", None) or form.user.email[0] != field.data
):
raise wtforms.ValidationError(
@ -32,7 +30,7 @@ def unique_email(form, field):
def unique_group(form, field):
if Group.get(display_name=field.data):
if models.Group.get(display_name=field.data):
raise wtforms.ValidationError(
_("The group '{group}' already exists").format(group=field.data)
)
@ -41,7 +39,7 @@ def unique_group(form, field):
def existing_login(form, field):
if not current_app.config.get(
"HIDE_INVALID_LOGINS", True
) and not User.get_from_login(field.data):
) and not models.User.get_from_login(field.data):
raise wtforms.ValidationError(
_("The login '{login}' does not exist").format(login=field.data)
)
@ -257,7 +255,9 @@ PROFILE_FORM_FIELDS = dict(
),
groups=wtforms.SelectMultipleField(
_("Groups"),
choices=lambda: [(group.id, group.display_name) for group in Group.query()],
choices=lambda: [
(group.id, group.display_name) for group in models.Group.query()
],
render_kw={"placeholder": _("users, admins …")},
),
)
@ -276,7 +276,7 @@ def profile_form(write_field_names, readonly_field_names, user=None):
if PROFILE_FORM_FIELDS.get(name)
}
if "groups" in fields and not Group.query():
if "groups" in fields and not models.Group.query():
del fields["groups"]
form = HTMXBaseForm(fields)
@ -338,6 +338,8 @@ class InvitationForm(HTMXForm):
)
groups = wtforms.SelectMultipleField(
_("Groups"),
choices=lambda: [(group.id, group.display_name) for group in Group.query()],
choices=lambda: [
(group.id, group.display_name) for group in models.Group.query()
],
render_kw={},
)

View file

@ -1,3 +1,4 @@
from canaille.app import models
from canaille.app.flask import permissions_needed
from canaille.app.flask import render_htmx_template
from canaille.app.forms import TableForm
@ -12,8 +13,6 @@ from flask_themer import render_template
from .forms import CreateGroupForm
from .forms import EditGroupForm
from .models import Group
from .models import User
bp = Blueprint("groups", __name__, url_prefix="/groups")
@ -21,7 +20,7 @@ bp = Blueprint("groups", __name__, url_prefix="/groups")
@bp.route("/", methods=["GET", "POST"])
@permissions_needed("manage_groups")
def groups(user):
table_form = TableForm(Group, formdata=request.form)
table_form = TableForm(models.Group, formdata=request.form)
if request.form and request.form.get("page") and not table_form.validate():
abort(404)
@ -37,7 +36,7 @@ def create_group(user):
if not form.validate():
flash(_("Group creation failed."), "error")
else:
group = Group()
group = models.Group()
group.members = [user]
group.display_name = [form.display_name.data]
group.description = [form.description.data]
@ -59,7 +58,7 @@ def create_group(user):
@bp.route("/<groupname>", methods=("GET", "POST"))
@permissions_needed("manage_groups")
def group(user, groupname):
group = Group.get(display_name=groupname)
group = models.Group.get(display_name=groupname)
if not group:
abort(404)
@ -78,7 +77,7 @@ def group(user, groupname):
def edit_group(group):
table_form = TableForm(User, filter={"groups": group}, formdata=request.form)
table_form = TableForm(models.User, filter={"groups": group}, formdata=request.form)
if request.form and request.form.get("page") and not table_form.validate():
abort(404)

View file

@ -1,9 +1,8 @@
import random
import faker
from canaille.app import models
from canaille.app.i18n import available_language_codes
from canaille.core.models import Group
from canaille.core.models import User
from faker.config import AVAILABLE_LOCALES
@ -19,7 +18,7 @@ def fake_users(nb=1):
try:
fake = random.choice(fakes)
name = fake.unique.name()
user = User(
user = models.User(
formatted_name=name,
given_name=name.split(" ")[0],
family_name=name.split(" ")[1],
@ -47,11 +46,11 @@ def fake_users(nb=1):
def fake_groups(nb=1, nb_users_max=1):
fake = faker_generator(["en_US"])[0]
users = User.query()
users = models.User.query()
groups = list()
for _ in range(nb):
try:
group = Group(
group = models.Group(
display_name=fake.unique.word(),
description=fake.sentence(),
)

View file

@ -1,7 +1,7 @@
from canaille.app import models
from canaille.app.flask import permissions_needed
from canaille.app.flask import render_htmx_template
from canaille.app.forms import TableForm
from canaille.oidc.models import AuthorizationCode
from flask import abort
from flask import Blueprint
from flask import request
@ -14,7 +14,7 @@ bp = Blueprint("authorizations", __name__, url_prefix="/admin/authorization")
@bp.route("/", methods=["GET", "POST"])
@permissions_needed("manage_oidc")
def index(user):
table_form = TableForm(AuthorizationCode, formdata=request.form)
table_form = TableForm(models.AuthorizationCode, formdata=request.form)
if request.form and request.form.get("page") and not table_form.validate():
abort(404)
@ -28,7 +28,7 @@ def index(user):
@bp.route("/<authorization_id>", methods=["GET", "POST"])
@permissions_needed("manage_oidc")
def view(user, authorization_id):
authorization = AuthorizationCode.get(authorization_code_id=authorization_id)
authorization = models.AuthorizationCode.get(authorization_code_id=authorization_id)
return render_template(
"oidc/admin/authorization_view.html",
authorization=authorization,

View file

@ -1,11 +1,11 @@
import datetime
from canaille.app import models
from canaille.app.flask import permissions_needed
from canaille.app.flask import render_htmx_template
from canaille.app.flask import request_is_htmx
from canaille.app.forms import TableForm
from canaille.oidc.forms import ClientAddForm
from canaille.oidc.models import Client
from flask import abort
from flask import Blueprint
from flask import flash
@ -23,7 +23,7 @@ bp = Blueprint("clients", __name__, url_prefix="/admin/client")
@bp.route("/", methods=["GET", "POST"])
@permissions_needed("manage_oidc")
def index(user):
table_form = TableForm(Client, formdata=request.form)
table_form = TableForm(models.Client, formdata=request.form)
if request.form and request.form.get("page") and not table_form.validate():
abort(404)
@ -53,7 +53,7 @@ def add(user):
client_id = gen_salt(24)
client_id_issued_at = datetime.datetime.now(datetime.timezone.utc)
client = Client(
client = models.Client(
client_id=client_id,
client_id_issued_at=client_id_issued_at,
client_name=form["client_name"].data,
@ -104,7 +104,7 @@ def edit(user, client_id):
def client_edit(client_id):
client = Client.get(client_id=client_id)
client = models.Client.get(client_id=client_id)
if not client:
abort(404)
@ -152,7 +152,7 @@ def client_edit(client_id):
software_version=form["software_version"].data,
jwk=form["jwk"].data,
jwks_uri=form["jwks_uri"].data,
audience=[Client.get(id=id) for id in form["audience"].data],
audience=[models.Client.get(id=id) for id in form["audience"].data],
preconsent=form["preconsent"].data,
)
client.save()
@ -164,7 +164,7 @@ def client_edit(client_id):
def client_delete(client_id):
client = Client.get(client_id=client_id)
client = models.Client.get(client_id=client_id)
if not client:
abort(404)

View file

@ -1,7 +1,6 @@
import click
from canaille.app import models
from canaille.app.commands import with_backendcontext
from canaille.oidc.models import AuthorizationCode
from canaille.oidc.models import Token
from flask.cli import with_appcontext
@ -12,11 +11,11 @@ def clean():
"""
Remove expired tokens and authorization codes.
"""
for t in Token.query():
for t in models.Token.query():
if t.is_expired():
t.delete()
for a in AuthorizationCode.query():
for a in models.AuthorizationCode.query():
if a.is_expired():
a.delete()

View file

@ -1,9 +1,8 @@
import datetime
import uuid
from canaille.app import models
from canaille.app.flask import user_needed
from canaille.oidc.models import Client
from canaille.oidc.models import Consent
from flask import Blueprint
from flask import flash
from flask import redirect
@ -20,12 +19,14 @@ bp = Blueprint("consents", __name__, url_prefix="/consent")
@bp.route("/")
@user_needed()
def consents(user):
consents = Consent.query(subject=user)
consents = models.Consent.query(subject=user)
clients = {t.client for t in consents}
nb_consents = len(consents)
nb_preconsents = sum(
1 for client in Client.query() if client.preconsent and client not in clients
1
for client in models.Client.query()
if client.preconsent and client not in clients
)
return render_template(
@ -42,11 +43,11 @@ def consents(user):
@bp.route("/pre-consents")
@user_needed()
def pre_consents(user):
consents = Consent.query(subject=user)
consents = models.Consent.query(subject=user)
clients = {t.client for t in consents}
preconsented = [
client
for client in Client.query()
for client in models.Client.query()
if client.preconsent and client not in clients
]
@ -67,7 +68,7 @@ def pre_consents(user):
@bp.route("/revoke/<consent_id>")
@user_needed()
def revoke(user, consent_id):
consent = Consent.get(consent_id=consent_id)
consent = models.Consent.get(consent_id=consent_id)
if not consent or consent.subject != user:
flash(_("Could not revoke this access"), "error")
@ -85,7 +86,7 @@ def revoke(user, consent_id):
@bp.route("/restore/<consent_id>")
@user_needed()
def restore(user, consent_id):
consent = Consent.get(consent_id=consent_id)
consent = models.Consent.get(consent_id=consent_id)
if not consent or consent.subject != user:
flash(_("Could not restore this access"), "error")
@ -106,19 +107,19 @@ def restore(user, consent_id):
@bp.route("/revoke-preconsent/<client_id>")
@user_needed()
def revoke_preconsent(user, client_id):
client = Client.get(client_id=client_id)
client = models.Client.get(client_id=client_id)
if not client or not client.preconsent:
flash(_("Could not revoke this access"), "error")
return redirect(url_for("oidc.consents.consents"))
consent = Consent.get(client=client, subject=user)
consent = models.Consent.get(client=client, subject=user)
if consent:
return redirect(
url_for("oidc.consents.revoke", consent_id=consent.consent_id[0])
)
consent = Consent(
consent = models.Consent(
consent_id=str(uuid.uuid4()),
client=client,
subject=user,

View file

@ -6,10 +6,10 @@ from authlib.jose import JsonWebKey
from authlib.jose import jwt
from authlib.oauth2 import OAuth2Error
from canaille import csrf
from canaille.app import models
from canaille.app.flask import current_user
from canaille.app.flask import set_parameter_in_url_query
from canaille.core.forms import FullLoginForm
from canaille.core.models import User
from flask import abort
from flask import Blueprint
from flask import current_app
@ -25,8 +25,6 @@ from werkzeug.datastructures import CombinedMultiDict
from .forms import AuthorizeForm
from .forms import LogoutForm
from .models import Client
from .models import Consent
from .oauth import authorization
from .oauth import ClientConfigurationEndpoint
from .oauth import ClientRegistrationEndpoint
@ -59,7 +57,7 @@ def authorize():
if "client_id" not in request.args:
abort(400)
client = Client.get(request.args["client_id"])
client = models.Client.get(client_id=request.args["client_id"])
if not client:
abort(400)
@ -78,7 +76,7 @@ def authorize():
if request.method == "GET":
return render_template("login.html", form=form, menu=False)
user = User.get_from_login(form.login.data)
user = models.User.get_from_login(form.login.data)
if (
not form.validate()
or not user
@ -96,7 +94,7 @@ def authorize():
# CONSENT
consents = Consent.query(
consents = models.Consent.query(
client=client,
subject=user,
)
@ -153,7 +151,7 @@ def authorize():
list(set(scopes + consents[0].scope))
).split(" ")
else:
consent = Consent(
consent = models.Consent(
consent_id=str(uuid.uuid4()),
client=client,
subject=user,
@ -275,7 +273,7 @@ def end_session():
valid_uris = []
if "client_id" in data:
client = Client.get(data["client_id"])
client = models.Client.get(client_id=data["client_id"])
if client:
valid_uris = client.post_logout_redirect_uris
@ -317,7 +315,7 @@ def end_session():
else [id_token["aud"]]
)
for client_id in client_ids:
client = Client.get(client_id)
client = models.Client.get(client_id=client_id)
if client:
valid_uris.extend(client.post_logout_redirect_uris or [])

View file

@ -1,7 +1,7 @@
import wtforms
from canaille.app import models
from canaille.app.forms import HTMXForm
from canaille.app.forms import is_uri
from canaille.oidc.models import Client
from flask_babel import lazy_gettext as _
@ -14,7 +14,7 @@ class LogoutForm(HTMXForm):
def client_audiences():
return [(client.id, client.client_name) for client in Client.query()]
return [(client.id, client.client_name) for client in models.Client.query()]
class ClientAddForm(HTMXForm):

View file

@ -1,11 +1,8 @@
import os
from canaille.app import models
from canaille.backends.ldap.installation import install_schema
from canaille.backends.ldap.installation import ldap_connection
from canaille.oidc.models import AuthorizationCode
from canaille.oidc.models import Client
from canaille.oidc.models import Consent
from canaille.oidc.models import Token
from cryptography.hazmat.backends import default_backend as crypto_default_backend
from cryptography.hazmat.primitives import serialization as crypto_serialization
from cryptography.hazmat.primitives.asymmetric import rsa
@ -19,10 +16,10 @@ def install(config):
def setup_ldap_tree(config):
with ldap_connection(config) as conn:
Token.install(conn)
AuthorizationCode.install(conn)
Client.install(conn)
Consent.install(conn)
models.Token.install(conn)
models.AuthorizationCode.install(conn)
models.Client.install(conn)
models.Consent.install(conn)
def setup_keypair(config):

View file

@ -4,6 +4,7 @@ from authlib.oauth2.rfc6749 import AuthorizationCodeMixin
from authlib.oauth2.rfc6749 import ClientMixin
from authlib.oauth2.rfc6749 import TokenMixin
from authlib.oauth2.rfc6749 import util
from canaille.app import models
from canaille.backends.ldap.ldapobject import LDAPObject
@ -97,13 +98,13 @@ class Client(LDAPObject, ClientMixin):
return metadata
def delete(self):
for consent in Consent.query(client=self):
for consent in models.Consent.query(client=self):
consent.delete()
for code in AuthorizationCode.query(client=self):
for code in models.AuthorizationCode.query(client=self):
code.delete()
for token in Token.query(client=self):
for token in models.Token.query(client=self):
token.delete()
super().delete()
@ -243,7 +244,7 @@ class Consent(LDAPObject):
self.revokation_date = datetime.datetime.now(datetime.timezone.utc)
self.save()
tokens = Token.query(
tokens = models.Token.query(
client=self.client,
subject=self.subject,
)

View file

@ -26,15 +26,11 @@ from authlib.oidc.core.grants import OpenIDCode as _OpenIDCode
from authlib.oidc.core.grants import OpenIDHybridGrant as _OpenIDHybridGrant
from authlib.oidc.core.grants import OpenIDImplicitGrant as _OpenIDImplicitGrant
from authlib.oidc.core.grants.util import generate_id_token
from canaille.core.models import User
from canaille.app import models
from flask import current_app
from flask import request
from werkzeug.security import gen_salt
from .models import AuthorizationCode
from .models import Client
from .models import Token
DEFAULT_JWT_KTY = "RSA"
DEFAULT_JWT_ALG = "RS256"
DEFAULT_JWT_EXP = 3600
@ -55,8 +51,8 @@ DEFAULT_JWT_MAPPING = {
def exists_nonce(nonce, req):
client = Client.get(id=req.client_id)
exists = AuthorizationCode.query(client=client, nonce=nonce)
client = models.Client.get(id=req.client_id)
exists = models.AuthorizationCode.query(client=client, nonce=nonce)
return bool(exists)
@ -142,7 +138,7 @@ def save_authorization_code(code, request):
nonce = request.data.get("nonce")
now = datetime.datetime.now(datetime.timezone.utc)
scope = request.client.get_allowed_scope(request.scope)
code = AuthorizationCode(
code = models.AuthorizationCode(
authorization_code_id=gen_salt(48),
code=code,
subject=request.user,
@ -166,7 +162,7 @@ class AuthorizationCodeGrant(_AuthorizationCodeGrant):
return save_authorization_code(code, request)
def query_authorization_code(self, code, client):
item = AuthorizationCode.query(code=code, client=client)
item = models.AuthorizationCode.query(code=code, client=client)
if item and not item[0].is_expired():
return item[0]
@ -196,7 +192,7 @@ class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant):
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
def authenticate_user(self, username, password):
user = User.get_from_login(username)
user = models.User.get_from_login(username)
if not user or not user.check_password(password):
return None
@ -206,7 +202,7 @@ class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant):
class RefreshTokenGrant(_RefreshTokenGrant):
def authenticate_refresh_token(self, refresh_token):
token = Token.query(refresh_token=refresh_token)
token = models.Token.query(refresh_token=refresh_token)
if token and token[0].is_refresh_token_active():
return token[0]
@ -252,12 +248,12 @@ class OpenIDHybridGrant(_OpenIDHybridGrant):
def query_client(client_id):
return Client.get(client_id=client_id)
return models.Client.get(client_id=client_id)
def save_token(token, request):
now = datetime.datetime.now(datetime.timezone.utc)
t = Token(
t = models.Token(
token_id=gen_salt(48),
type=token["token_type"],
access_token=token["access_token"],
@ -274,20 +270,20 @@ def save_token(token, request):
class BearerTokenValidator(_BearerTokenValidator):
def authenticate_token(self, token_string):
return Token.get(access_token=token_string)
return models.Token.get(access_token=token_string)
def query_token(token, token_type_hint):
if token_type_hint == "access_token":
return Token.get(access_token=token)
return models.Token.get(access_token=token)
elif token_type_hint == "refresh_token":
return Token.get(refresh_token=token)
return models.Token.get(refresh_token=token)
item = Token.get(access_token=token)
item = models.Token.get(access_token=token)
if item:
return item
item = Token.get(refresh_token=token)
item = models.Token.get(refresh_token=token)
if item:
return item
@ -369,7 +365,7 @@ class ClientRegistrationEndpoint(ClientManagementMixin, _ClientRegistrationEndpo
client_metadata["scope"], list
):
client_metadata["scope"] = client_metadata["scope"].split(" ")
client = Client(**client_info, **client_metadata)
client = models.Client(**client_info, **client_metadata)
client.save()
return client
@ -377,7 +373,7 @@ class ClientRegistrationEndpoint(ClientManagementMixin, _ClientRegistrationEndpo
class ClientConfigurationEndpoint(ClientManagementMixin, _ClientConfigurationEndpoint):
def authenticate_client(self, request):
client_id = request.uri.split("/")[-1]
return Client.get(client_id=client_id)
return models.Client.get(client_id=client_id)
def revoke_access_token(self, request, token):
pass

View file

@ -1,9 +1,9 @@
import datetime
from canaille.app import models
from canaille.app.flask import permissions_needed
from canaille.app.flask import render_htmx_template
from canaille.app.forms import TableForm
from canaille.oidc.models import Token
from flask import abort
from flask import Blueprint
from flask import flash
@ -19,7 +19,7 @@ bp = Blueprint("tokens", __name__, url_prefix="/admin/token")
@bp.route("/", methods=["GET", "POST"])
@permissions_needed("manage_oidc")
def index(user):
table_form = TableForm(Token, formdata=request.form)
table_form = TableForm(models.Token, formdata=request.form)
if request.form and request.form.get("page") and not table_form.validate():
abort(404)
@ -31,7 +31,7 @@ def index(user):
@bp.route("/<token_id>", methods=["GET", "POST"])
@permissions_needed("manage_oidc")
def view(user, token_id):
token = Token.get(token_id=token_id)
token = models.Token.get(token_id=token_id)
if not token:
abort(404)
@ -46,7 +46,7 @@ def view(user, token_id):
@bp.route("/<token_id>/revoke", methods=["GET", "POST"])
@permissions_needed("manage_oidc")
def revoke(user, token_id):
token = Token.get(token_id=token_id)
token = models.Token.get(token_id=token_id)
if not token:
abort(404)

View file

@ -10,15 +10,13 @@ from canaille import create_app as canaille_app
def populate(app):
from canaille.core.models import Group
from canaille.core.models import User
from canaille.app import models
from canaille.core.populate import fake_groups
from canaille.core.populate import fake_users
from canaille.oidc.models import Client
with app.app_context():
with app.backend.session():
jane = User(
jane = models.User(
formatted_name="Jane Doe",
given_name="Jane",
family_name="Doe",
@ -38,7 +36,7 @@ def populate(app):
)
jane.save()
jack = User(
jack = models.User(
formatted_name="Jack Doe",
given_name="Jack",
family_name="Doe",
@ -53,7 +51,7 @@ def populate(app):
)
jack.save()
john = User(
john = models.User(
formatted_name="John Doe",
given_name="John",
family_name="Doe",
@ -68,7 +66,7 @@ def populate(app):
)
john.save()
james = User(
james = models.User(
formatted_name="James Doe",
given_name="James",
family_name="Doe",
@ -77,28 +75,28 @@ def populate(app):
)
james.save()
users = Group(
users = models.Group(
display_name="users",
members=[jane, jack, john, james],
description="The regular users.",
)
users.save()
users = Group(
users = models.Group(
display_name="admins",
members=[jane],
description="The administrators.",
)
users.save()
users = Group(
users = models.Group(
display_name="moderators",
members=[james],
description="People who can manage users.",
)
users.save()
client1 = Client(
client1 = models.Client(
client_id="1JGkkzCbeHpGtlqgI5EENByf",
client_secret="2xYPSReTQRmGG1yppMVZQ0ASXwFejPyirvuPbKhNa6TmKC5x",
client_name="Client1",
@ -115,7 +113,7 @@ def populate(app):
)
client1.save()
client2 = Client(
client2 = models.Client(
client_id="gn4yFN7GDykL7QP8v8gS9YfV",
client_secret="ouFJE5WpICt6hxTyf8icXPeeklMektMY4gV0Rmf3aY60VElA",
client_name="Client2",

View file

@ -3,6 +3,7 @@ from unittest import mock
import ldap.dn
import pytest
from canaille.app import models
from canaille.app.configuration import ConfigurationException
from canaille.app.configuration import validate
from canaille.backends.ldap.backend import setup_ldap_models
@ -11,12 +12,10 @@ from canaille.backends.ldap.ldapobject import python_attrs_to_ldap
from canaille.backends.ldap.utils import ldap_to_python
from canaille.backends.ldap.utils import python_to_ldap
from canaille.backends.ldap.utils import Syntax
from canaille.core.models import Group
from canaille.core.models import User
def test_object_creation(app, backend):
user = User(
user = models.User(
formatted_name="Doe", # leading space
family_name="Doe",
user_name="user",
@ -26,7 +25,7 @@ def test_object_creation(app, backend):
user.save()
assert user.exists
user = User.get(id=user.id)
user = models.User.get(id=user.id)
assert user.exists
user.delete()
@ -38,7 +37,7 @@ def test_repr(backend, foo_group, user):
def test_dn_when_leading_space_in_id_attribute(backend):
user = User(
user = models.User(
formatted_name=" Doe", # leading space
family_name="Doe",
user_name="user",
@ -54,7 +53,7 @@ def test_dn_when_leading_space_in_id_attribute(backend):
def test_dn_when_ldap_special_char_in_id_attribute(backend):
user = User(
user = models.User(
formatted_name="#Doe", # special char
family_name="Doe",
user_name="user",
@ -70,27 +69,32 @@ def test_dn_when_ldap_special_char_in_id_attribute(backend):
def test_filter(backend, foo_group, bar_group):
assert Group.query(display_name="foo") == [foo_group]
assert Group.query(display_name="bar") == [bar_group]
assert models.Group.query(display_name="foo") == [foo_group]
assert models.Group.query(display_name="bar") == [bar_group]
assert Group.query(display_name="foo") != 3
assert models.Group.query(display_name="foo") != 3
assert Group.query(display_name=["foo"]) == [foo_group]
assert Group.query(display_name=["bar"]) == [bar_group]
assert models.Group.query(display_name=["foo"]) == [foo_group]
assert models.Group.query(display_name=["bar"]) == [bar_group]
assert set(Group.query(display_name=["foo", "bar"])) == {foo_group, bar_group}
assert set(models.Group.query(display_name=["foo", "bar"])) == {
foo_group,
bar_group,
}
def test_fuzzy(backend, user, moderator, admin):
assert set(User.query()) == {user, moderator, admin}
assert set(User.fuzzy("Jack")) == {moderator}
assert set(User.fuzzy("Jack", ["formatted_name"])) == {moderator}
assert set(User.fuzzy("Jack", ["user_name"])) == set()
assert set(User.fuzzy("Jack", ["user_name", "formatted_name"])) == {moderator}
assert set(User.fuzzy("moderator")) == {moderator}
assert set(User.fuzzy("oderat")) == {moderator}
assert set(User.fuzzy("oDeRat")) == {moderator}
assert set(User.fuzzy("ack")) == {moderator}
assert set(models.User.query()) == {user, moderator, admin}
assert set(models.User.fuzzy("Jack")) == {moderator}
assert set(models.User.fuzzy("Jack", ["formatted_name"])) == {moderator}
assert set(models.User.fuzzy("Jack", ["user_name"])) == set()
assert set(models.User.fuzzy("Jack", ["user_name", "formatted_name"])) == {
moderator
}
assert set(models.User.fuzzy("moderator")) == {moderator}
assert set(models.User.fuzzy("oderat")) == {moderator}
assert set(models.User.fuzzy("oDeRat")) == {moderator}
assert set(models.User.fuzzy("ack")) == {moderator}
def test_ldap_to_python():
@ -180,11 +184,11 @@ def test_guess_object_from_dn(backend, testclient, foo_group):
foo_group.members = [foo_group]
foo_group.save()
g = LDAPObject.get(id=foo_group.dn)
assert isinstance(g, Group)
assert isinstance(g, models.Group)
assert g == foo_group
assert g.cn == foo_group.cn
ou = LDAPObject.get(id=f"{Group.base},{Group.root_dn}")
ou = LDAPObject.get(id=f"{models.Group.base},{models.Group.root_dn}")
assert isinstance(ou, LDAPObject)
@ -192,11 +196,11 @@ def test_object_class_update(backend, testclient):
testclient.app.config["BACKENDS"]["LDAP"]["USER_CLASS"] = ["inetOrgPerson"]
setup_ldap_models(testclient.app.config)
user1 = User(cn="foo1", sn="bar1")
user1 = models.User(cn="foo1", sn="bar1")
user1.save()
assert user1.objectClass == ["inetOrgPerson"]
assert User.get(id=user1.id).objectClass == ["inetOrgPerson"]
assert models.User.get(id=user1.id).objectClass == ["inetOrgPerson"]
testclient.app.config["BACKENDS"]["LDAP"]["USER_CLASS"] = [
"inetOrgPerson",
@ -204,18 +208,24 @@ def test_object_class_update(backend, testclient):
]
setup_ldap_models(testclient.app.config)
user2 = User(cn="foo2", sn="bar2")
user2 = models.User(cn="foo2", sn="bar2")
user2.save()
assert user2.objectClass == ["inetOrgPerson", "extensibleObject"]
assert User.get(id=user2.id).objectClass == ["inetOrgPerson", "extensibleObject"]
assert models.User.get(id=user2.id).objectClass == [
"inetOrgPerson",
"extensibleObject",
]
user1 = User.get(id=user1.id)
user1 = models.User.get(id=user1.id)
assert user1.objectClass == ["inetOrgPerson"]
user1.save()
assert user1.objectClass == ["inetOrgPerson", "extensibleObject"]
assert User.get(id=user1.id).objectClass == ["inetOrgPerson", "extensibleObject"]
assert models.User.get(id=user1.id).objectClass == [
"inetOrgPerson",
"extensibleObject",
]
user1.delete()
user2.delete()

View file

@ -1,21 +1,20 @@
from canaille.core.models import Group
from canaille.core.models import User
from canaille.app import models
def test_model_comparison(testclient, backend):
foo1 = User(
foo1 = models.User(
user_name="foo",
family_name="foo",
formatted_name="foo",
)
foo1.save()
bar = User(
bar = models.User(
user_name="bar",
family_name="bar",
formatted_name="bar",
)
bar.save()
foo2 = User.get(id=foo1.id)
foo2 = models.User.get(id=foo1.id)
assert foo1 == foo2
assert foo1 != bar
@ -25,23 +24,23 @@ def test_model_comparison(testclient, backend):
def test_model_lifecycle(testclient, backend):
user = User(
user = models.User(
user_name="user_name",
family_name="family_name",
formatted_name="formatted_name",
)
assert not User.query()
assert not User.query(id=user.id)
assert not User.query(id="invalid")
assert not User.get(id=user.id)
assert not models.User.query()
assert not models.User.query(id=user.id)
assert not models.User.query(id="invalid")
assert not models.User.get(id=user.id)
user.save()
assert User.query() == [user]
assert User.query(id=user.id) == [user]
assert not User.query(id="invalid")
assert User.get(id=user.id) == user
assert models.User.query() == [user]
assert models.User.query(id=user.id) == [user]
assert not models.User.query(id="invalid")
assert models.User.get(id=user.id) == user
user.family_name = "new_family_name"
@ -53,12 +52,12 @@ def test_model_lifecycle(testclient, backend):
user.delete()
assert not User.query(id=user.id)
assert not User.get(id=user.id)
assert not models.User.query(id=user.id)
assert not models.User.get(id=user.id)
def test_model_attribute_edition(testclient, backend):
user = User(
user = models.User(
user_name="user_name",
family_name="family_name",
formatted_name="formatted_name",
@ -71,7 +70,7 @@ def test_model_attribute_edition(testclient, backend):
assert user.family_name == ["family_name"]
assert user.email == ["email1@user.com", "email2@user.com"]
user = User.get(id=user.id)
user = models.User.get(id=user.id)
assert user.user_name == ["user_name"]
assert user.family_name == ["family_name"]
assert user.email == ["email1@user.com", "email2@user.com"]
@ -83,7 +82,7 @@ def test_model_attribute_edition(testclient, backend):
assert user.family_name == ["new_family_name"]
assert user.email == ["email1@user.com"]
user = User.get(id=user.id)
user = models.User.get(id=user.id)
assert user.family_name == ["new_family_name"]
assert user.email == ["email1@user.com"]
@ -96,7 +95,7 @@ def test_model_attribute_edition(testclient, backend):
def test_model_indexation(testclient, backend):
user = User(
user = models.User(
user_name="user_name",
family_name="family_name",
formatted_name="formatted_name",
@ -104,56 +103,58 @@ def test_model_indexation(testclient, backend):
)
user.save()
assert User.get(family_name="family_name") == user
assert not User.get(family_name="new_family_name")
assert User.get(email="email1@user.com") == user
assert User.get(email="email2@user.com") == user
assert not User.get(email="email3@user.com")
assert models.User.get(family_name="family_name") == user
assert not models.User.get(family_name="new_family_name")
assert models.User.get(email="email1@user.com") == user
assert models.User.get(email="email2@user.com") == user
assert not models.User.get(email="email3@user.com")
user.family_name = "new_family_name"
user.email = ["email2@user.com"]
assert User.get(family_name="family_name") != user
assert not User.get(family_name="new_family_name")
assert User.get(email="email1@user.com") != user
assert User.get(email="email2@user.com") != user
assert not User.get(email="email3@user.com")
assert models.User.get(family_name="family_name") != user
assert not models.User.get(family_name="new_family_name")
assert models.User.get(email="email1@user.com") != user
assert models.User.get(email="email2@user.com") != user
assert not models.User.get(email="email3@user.com")
user.save()
assert not User.get(family_name="family_name")
assert User.get(family_name="new_family_name") == user
assert not User.get(email="email1@user.com")
assert User.get(email="email2@user.com") == user
assert not User.get(email="email3@user.com")
assert not models.User.get(family_name="family_name")
assert models.User.get(family_name="new_family_name") == user
assert not models.User.get(email="email1@user.com")
assert models.User.get(email="email2@user.com") == user
assert not models.User.get(email="email3@user.com")
user.delete()
assert not User.get(family_name="family_name")
assert not User.get(family_name="new_family_name")
assert not User.get(email="email1@user.com")
assert not User.get(email="email2@user.com")
assert not User.get(email="email3@user.com")
assert not models.User.get(family_name="family_name")
assert not models.User.get(family_name="new_family_name")
assert not models.User.get(email="email1@user.com")
assert not models.User.get(email="email2@user.com")
assert not models.User.get(email="email3@user.com")
def test_fuzzy(user, moderator, admin, backend):
assert set(User.query()) == {user, moderator, admin}
assert set(User.fuzzy("Jack")) == {moderator}
assert set(User.fuzzy("Jack", ["formatted_name"])) == {moderator}
assert set(User.fuzzy("Jack", ["user_name"])) == set()
assert set(User.fuzzy("Jack", ["user_name", "formatted_name"])) == {moderator}
assert set(User.fuzzy("moderator")) == {moderator}
assert set(User.fuzzy("oderat")) == {moderator}
assert set(User.fuzzy("oDeRat")) == {moderator}
assert set(User.fuzzy("ack")) == {moderator}
assert set(models.User.query()) == {user, moderator, admin}
assert set(models.User.fuzzy("Jack")) == {moderator}
assert set(models.User.fuzzy("Jack", ["formatted_name"])) == {moderator}
assert set(models.User.fuzzy("Jack", ["user_name"])) == set()
assert set(models.User.fuzzy("Jack", ["user_name", "formatted_name"])) == {
moderator
}
assert set(models.User.fuzzy("moderator")) == {moderator}
assert set(models.User.fuzzy("oderat")) == {moderator}
assert set(models.User.fuzzy("oDeRat")) == {moderator}
assert set(models.User.fuzzy("ack")) == {moderator}
# def test_model_references(user, admin, foo_group, bar_group):
def test_model_references(testclient, user, foo_group, admin, bar_group, backend):
assert foo_group.members == [user]
assert user.groups == [foo_group]
assert foo_group in Group.query(members=user)
assert user in User.query(groups=foo_group)
assert foo_group in models.Group.query(members=user)
assert user in models.User.query(groups=foo_group)
assert user not in bar_group.members
assert bar_group not in user.groups
@ -175,11 +176,11 @@ def test_model_references(testclient, user, foo_group, admin, bar_group, backend
def test_model_references_set_unsaved_object(
testclient, logged_moderator, user, backend
):
group = Group(members=[user], display_name="foo")
group = models.Group(members=[user], display_name="foo")
group.save()
user.reload() # an LDAP group can be inconsistent by containing members which doesn't exist
non_existent_user = User(formatted_name="foo", family_name="bar")
non_existent_user = models.User(formatted_name="foo", family_name="bar")
group.members = group.members + [non_existent_user]
assert group.members == [user, non_existent_user]

View file

@ -1,9 +1,8 @@
import pytest
import slapd
from canaille import create_app
from canaille.app import models
from canaille.backends.ldap.backend import LDAPBackend
from canaille.core.models import Group
from canaille.core.models import User
from flask_webtest import TestApp
from werkzeug.security import gen_salt
@ -156,7 +155,7 @@ def testclient(app):
@pytest.fixture
def user(app, backend):
u = User(
u = models.User(
formatted_name="John (johnny) Doe",
given_name="John",
family_name="Doe",
@ -176,7 +175,7 @@ def user(app, backend):
@pytest.fixture
def admin(app, backend):
u = User(
u = models.User(
formatted_name="Jane Doe",
family_name="Doe",
user_name="admin",
@ -190,7 +189,7 @@ def admin(app, backend):
@pytest.fixture
def moderator(app, backend):
u = User(
u = models.User(
formatted_name="Jack Doe",
family_name="Doe",
user_name="moderator",
@ -225,7 +224,7 @@ def logged_moderator(moderator, testclient):
@pytest.fixture
def foo_group(app, user, backend):
group = Group(
group = models.Group(
members=[user],
display_name="foo",
)
@ -237,7 +236,7 @@ def foo_group(app, user, backend):
@pytest.fixture
def bar_group(app, admin, backend):
group = Group(
group = models.Group(
members=[admin],
display_name="bar",
)

View file

@ -1,17 +1,16 @@
from canaille.app import models
from canaille.commands import cli
from canaille.core.models import Group
from canaille.core.models import User
from canaille.core.populate import fake_users
def test_populate_users(testclient, backend):
runner = testclient.app.test_cli_runner()
assert len(User.query()) == 0
assert len(models.User.query()) == 0
res = runner.invoke(cli, ["populate", "--nb", "10", "users"])
assert res.exit_code == 0, res.stdout
assert len(User.query()) == 10
for user in User.query():
assert len(models.User.query()) == 10
for user in models.User.query():
user.delete()
@ -19,13 +18,13 @@ def test_populate_groups(testclient, backend):
fake_users(10)
runner = testclient.app.test_cli_runner()
assert len(Group.query()) == 0
assert len(models.Group.query()) == 0
res = runner.invoke(cli, ["populate", "--nb", "10", "groups"])
assert res.exit_code == 0, res.stdout
assert len(Group.query()) == 10
assert len(models.Group.query()) == 10
for group in Group.query():
for group in models.Group.query():
group.delete()
for user in User.query():
for user in models.User.query():
user.delete()

View file

@ -1,6 +1,6 @@
from unittest import mock
from canaille.core.models import User
from canaille.app import models
def test_index(testclient, user):
@ -112,7 +112,7 @@ def test_password_page_without_signin_in_redirects_to_login_page(testclient, use
def test_user_without_password_first_login(testclient, backend, smtpd):
assert len(smtpd.messages) == 0
u = User(
u = models.User(
formatted_name="Temp User",
family_name="Temp",
user_name="temp",
@ -146,7 +146,7 @@ def test_first_login_account_initialization_mail_sending_failed(
SMTP.side_effect = mock.Mock(side_effect=OSError("unit test mail error"))
assert len(smtpd.messages) == 0
u = User(
u = models.User(
formatted_name="Temp User",
family_name="Temp",
user_name="temp",
@ -168,7 +168,7 @@ def test_first_login_account_initialization_mail_sending_failed(
def test_first_login_form_error(testclient, backend, smtpd):
assert len(smtpd.messages) == 0
u = User(
u = models.User(
formatted_name="Temp User",
family_name="Temp",
user_name="temp",
@ -190,7 +190,7 @@ def test_first_login_page_unavailable_for_users_with_password(
def test_user_password_deleted_during_login(testclient, backend):
u = User(
u = models.User(
formatted_name="Temp User",
family_name="Temp",
user_name="temp",
@ -214,7 +214,7 @@ def test_user_password_deleted_during_login(testclient, backend):
def test_user_deleted_in_session(testclient, backend):
u = User(
u = models.User(
formatted_name="Jake Doe",
family_name="Jake",
user_name="jake",
@ -278,7 +278,7 @@ def test_wrong_login(testclient, user):
def test_admin_self_deletion(testclient, backend):
admin = User(
admin = models.User(
formatted_name="Temp admin",
family_name="admin",
user_name="temp",
@ -296,14 +296,14 @@ def test_admin_self_deletion(testclient, backend):
.follow(status=200)
)
assert User.get_from_login("temp") is None
assert models.User.get_from_login("temp") is None
with testclient.session_transaction() as sess:
assert not sess.get("user_id")
def test_user_self_deletion(testclient, backend):
user = User(
user = models.User(
formatted_name="Temp user",
family_name="user",
user_name="temp",
@ -330,7 +330,7 @@ def test_user_self_deletion(testclient, backend):
.follow(status=200)
)
assert User.get_from_login("temp") is None
assert models.User.get_from_login("temp") is None
with testclient.session_transaction() as sess:
assert not sess.get("user_id")

View file

@ -1,11 +1,10 @@
from canaille.core.models import Group
from canaille.core.models import User
from canaille.app import models
from canaille.core.populate import fake_groups
from canaille.core.populate import fake_users
def test_no_group(app, backend):
assert Group.query() == []
assert models.Group.query() == []
def test_group_list_pagination(testclient, logged_admin, foo_group):
@ -49,7 +48,7 @@ def test_group_list_bad_pages(testclient, logged_admin):
def test_group_deletion(testclient, backend):
user = User(
user = models.User(
formatted_name="foobar",
family_name="foobar",
user_name="foobar",
@ -57,7 +56,7 @@ def test_group_deletion(testclient, backend):
)
user.save()
group = Group(
group = models.Group(
members=[user],
display_name="foobar",
)
@ -109,7 +108,7 @@ def test_set_groups(app, user, foo_group, bar_group):
def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group):
user = User(
user = models.User(
formatted_name=" Doe", # leading space in id attribute
family_name="Doe",
user_name="user2",
@ -137,8 +136,8 @@ def test_moderator_can_create_edit_and_delete_group(
):
# The group does not exist
res = testclient.get("/groups", status=200)
assert Group.get(display_name="bar") is None
assert Group.get(display_name="foo") == foo_group
assert models.Group.get(display_name="bar") is None
assert models.Group.get(display_name="foo") == foo_group
res.mustcontain(no="bar")
res.mustcontain("foo")
@ -152,7 +151,7 @@ def test_moderator_can_create_edit_and_delete_group(
res = form.submit(status=302).follow(status=200)
logged_moderator.reload()
bar_group = Group.get(display_name="bar")
bar_group = models.Group.get(display_name="bar")
assert bar_group.display_name == "bar"
assert bar_group.description == ["yolo"]
assert bar_group.members == [
@ -168,17 +167,17 @@ def test_moderator_can_create_edit_and_delete_group(
res = form.submit(name="action", value="edit").follow()
bar_group = Group.get(display_name="bar")
bar_group = models.Group.get(display_name="bar")
assert bar_group.display_name == "bar"
assert bar_group.description == ["yolo2"]
assert Group.get(display_name="bar2") is None
assert models.Group.get(display_name="bar2") is None
members = bar_group.members
for member in members:
res.mustcontain(member.formatted_name[0])
# Group is deleted
res = form.submit(name="action", value="delete", status=302)
assert Group.get(display_name="bar") is None
assert models.Group.get(display_name="bar") is None
assert ("success", "The group bar has been sucessfully deleted") in res.flashes

View file

@ -1,11 +1,11 @@
import datetime
from canaille.app import models
from canaille.core.account import Invitation
from canaille.core.models import User
def test_invitation(testclient, logged_admin, foo_group, smtpd):
assert User.get_from_login("someone") is None
assert models.User.get_from_login("someone") is None
res = testclient.get("/invite", status=200)
@ -39,7 +39,7 @@ def test_invitation(testclient, logged_admin, foo_group, smtpd):
assert ("success", "Your account has been created successfuly.") in res.flashes
res = res.follow(status=200)
user = User.get_from_login("someone")
user = models.User.get_from_login("someone")
foo_group.reload()
assert user.check_password("whatever")
assert user.groups == [foo_group]
@ -53,8 +53,8 @@ def test_invitation(testclient, logged_admin, foo_group, smtpd):
def test_invitation_editable_user_name(testclient, logged_admin, foo_group, smtpd):
assert User.get_from_login("jackyjack") is None
assert User.get_from_login("djorje") is None
assert models.User.get_from_login("jackyjack") is None
assert models.User.get_from_login("djorje") is None
res = testclient.get("/invite", status=200)
@ -89,7 +89,7 @@ def test_invitation_editable_user_name(testclient, logged_admin, foo_group, smtp
assert ("success", "Your account has been created successfuly.") in res.flashes
res = res.follow(status=200)
user = User.get_from_login("djorje")
user = models.User.get_from_login("djorje")
foo_group.reload()
assert user.check_password("whatever")
assert user.groups == [foo_group]
@ -101,7 +101,7 @@ def test_invitation_editable_user_name(testclient, logged_admin, foo_group, smtp
def test_generate_link(testclient, logged_admin, foo_group, smtpd):
assert User.get_from_login("sometwo") is None
assert models.User.get_from_login("sometwo") is None
res = testclient.get("/invite", status=200)
@ -131,7 +131,7 @@ def test_generate_link(testclient, logged_admin, foo_group, smtpd):
res = res.form.submit(status=302)
res = res.follow(status=200)
user = User.get_from_login("sometwo")
user = models.User.get_from_login("sometwo")
foo_group.reload()
assert user.check_password("whatever")
assert user.groups == [foo_group]
@ -228,7 +228,7 @@ def test_registration_no_password(testclient, foo_group):
res = res.form.submit(status=200)
res.mustcontain("This field is required.")
assert not User.get_from_login("someoneelse")
assert not models.User.get_from_login("someoneelse")
with testclient.session_transaction() as sess:
assert "user_id" not in sess
@ -295,7 +295,7 @@ def test_groups_are_saved_even_when_user_does_not_have_read_permission(
res = res.form.submit(status=302)
res = res.follow(status=200)
user = User.get_from_login("someoneelse")
user = models.User.get_from_login("someoneelse")
foo_group.reload()
assert user.groups == [foo_group]
user.delete()

View file

@ -1,13 +1,13 @@
from canaille.core.models import User
from canaille.app import models
def test_user_get_from_login(testclient, user, backend):
assert User.get_from_login(login="invalid") is None
assert User.get_from_login(login="user") == user
assert models.User.get_from_login(login="invalid") is None
assert models.User.get_from_login(login="user") == user
def test_user_has_password(testclient, backend):
u = User(
u = models.User(
formatted_name="Temp User",
family_name="Temp",
user_name="temp",

View file

@ -1,4 +1,4 @@
from canaille.core.models import User
from canaille.app import models
def test_user_creation_edition_and_deletion(
@ -6,7 +6,7 @@ def test_user_creation_edition_and_deletion(
):
# The user does not exist.
res = testclient.get("/users", status=200)
assert User.get_from_login("george") is None
assert models.User.get_from_login("george") is None
res.mustcontain(no="george")
# Fill the profile for a new user.
@ -24,7 +24,7 @@ def test_user_creation_edition_and_deletion(
res = res.form.submit(name="action", value="edit", status=302)
assert ("success", "User account creation succeed.") in res.flashes
res = res.follow(status=200)
george = User.get_from_login("george")
george = models.User.get_from_login("george")
foo_group.reload()
assert "George" == george.given_name[0]
assert george.groups == [foo_group]
@ -45,7 +45,7 @@ def test_user_creation_edition_and_deletion(
res.form["groups"] = [foo_group.id, bar_group.id]
res = res.form.submit(name="action", value="edit").follow()
george = User.get_from_login("george")
george = models.User.get_from_login("george")
assert "Georgio" == george.given_name[0]
assert george.check_password("totoyolo")
@ -60,7 +60,7 @@ def test_user_creation_edition_and_deletion(
# User have been deleted.
res = testclient.get("/profile/george/settings", status=200)
res = res.form.submit(name="action", value="delete", status=302).follow(status=200)
assert User.get_from_login("george") is None
assert models.User.get_from_login("george") is None
res.mustcontain(no="george")
@ -89,7 +89,7 @@ def test_user_creation_without_password(testclient, logged_moderator):
res = res.form.submit(name="action", value="edit", status=302)
assert ("success", "User account creation succeed.") in res.flashes
res = res.follow(status=200)
george = User.get_from_login("george")
george = models.User.get_from_login("george")
assert george.user_name[0] == "george"
assert not george.has_password()
@ -100,13 +100,13 @@ def test_user_creation_form_validation_failed(
testclient, logged_moderator, foo_group, bar_group
):
res = testclient.get("/users", status=200)
assert User.get_from_login("george") is None
assert models.User.get_from_login("george") is None
res.mustcontain(no="george")
res = testclient.get("/profile", status=200)
res = res.form.submit(name="action", value="edit")
assert ("error", "User account creation failed.") in res.flashes
assert User.get_from_login("george") is None
assert models.User.get_from_login("george") is None
def test_username_already_taken(
@ -140,7 +140,7 @@ def test_cn_setting_with_given_name_and_surname(testclient, logged_moderator):
res = res.form.submit(name="action", value="edit", status=302).follow(status=200)
george = User.get_from_login("george")
george = models.User.get_from_login("george")
assert george.formatted_name[0] == "George Abitbol"
george.delete()
@ -153,6 +153,6 @@ def test_cn_setting_with_surname_only(testclient, logged_moderator):
res = res.form.submit(name="action", value="edit", status=302).follow(status=200)
george = User.get_from_login("george")
george = models.User.get_from_login("george")
assert george.formatted_name[0] == "Abitbol"
george.delete()

View file

@ -1,6 +1,6 @@
import datetime
from canaille.core.models import User
from canaille.app import models
from webtest import Upload
@ -101,7 +101,7 @@ def test_photo_on_profile_edition(
def test_photo_on_profile_creation(testclient, jpeg_photo, logged_admin):
res = testclient.get("/users", status=200)
assert User.get_from_login("foobar") is None
assert models.User.get_from_login("foobar") is None
res.mustcontain(no="foobar")
res = testclient.get("/profile", status=200)
@ -111,14 +111,14 @@ def test_photo_on_profile_creation(testclient, jpeg_photo, logged_admin):
res.form["email"] = "george@abitbol.com"
res = res.form.submit(name="action", value="edit", status=302).follow(status=200)
user = User.get_from_login("foobar")
user = models.User.get_from_login("foobar")
assert user.photo == [jpeg_photo]
user.delete()
def test_photo_deleted_on_profile_creation(testclient, jpeg_photo, logged_admin):
res = testclient.get("/users", status=200)
assert User.get_from_login("foobar") is None
assert models.User.get_from_login("foobar") is None
res.mustcontain(no="foobar")
res = testclient.get("/profile", status=200)
@ -129,6 +129,6 @@ def test_photo_deleted_on_profile_creation(testclient, jpeg_photo, logged_admin)
res.form["email"] = "george@abitbol.com"
res = res.form.submit(name="action", value="edit", status=302).follow(status=200)
user = User.get_from_login("foobar")
user = models.User.get_from_login("foobar")
assert user.photo == []
user.delete()

View file

@ -1,6 +1,6 @@
from unittest import mock
from canaille.core.models import User
from canaille.app import models
def test_edition(
@ -126,7 +126,7 @@ def test_password_change_fail(testclient, logged_user):
def test_password_initialization_mail(smtpd, testclient, backend, logged_admin):
u = User(
u = models.User(
formatted_name="Temp User",
family_name="Temp",
user_name="temp",
@ -162,7 +162,7 @@ def test_password_initialization_mail_send_fail(
SMTP, smtpd, testclient, backend, logged_admin
):
SMTP.side_effect = mock.Mock(side_effect=OSError("unit test mail error"))
u = User(
u = models.User(
formatted_name="Temp User",
family_name="Temp",
user_name="temp",
@ -232,7 +232,7 @@ def test_invalid_form_request(testclient, logged_admin):
def test_password_reset_email(smtpd, testclient, backend, logged_admin):
u = User(
u = models.User(
formatted_name="Temp User",
family_name="Temp",
user_name="temp",
@ -260,7 +260,7 @@ def test_password_reset_email(smtpd, testclient, backend, logged_admin):
@mock.patch("smtplib.SMTP")
def test_password_reset_email_failed(SMTP, smtpd, testclient, backend, logged_admin):
SMTP.side_effect = mock.Mock(side_effect=OSError("unit test mail error"))
u = User(
u = models.User(
formatted_name="Temp User",
family_name="Temp",
user_name="temp",

View file

@ -1,13 +1,12 @@
import datetime
from canaille.app import models
from canaille.commands import cli
from canaille.oidc.models import AuthorizationCode
from canaille.oidc.models import Token
from werkzeug.security import gen_salt
def test_clean_command(testclient, backend, client, user):
valid_code = AuthorizationCode(
valid_code = models.AuthorizationCode(
authorization_code_id=gen_salt(48),
code="my-valid-code",
client=client,
@ -23,7 +22,7 @@ def test_clean_command(testclient, backend, client, user):
revokation="",
)
valid_code.save()
expired_code = AuthorizationCode(
expired_code = models.AuthorizationCode(
authorization_code_id=gen_salt(48),
code="my-expired-code",
client=client,
@ -43,7 +42,7 @@ def test_clean_command(testclient, backend, client, user):
)
expired_code.save()
valid_token = Token(
valid_token = models.Token(
token_id=gen_salt(48),
access_token="my-valid-token",
client=client,
@ -57,7 +56,7 @@ def test_clean_command(testclient, backend, client, user):
lifetime=3600,
)
valid_token.save()
expired_token = Token(
expired_token = models.Token(
token_id=gen_salt(48),
access_token="my-expired-token",
client=client,
@ -73,8 +72,8 @@ def test_clean_command(testclient, backend, client, user):
)
expired_token.save()
assert AuthorizationCode.get(code="my-expired-code")
assert Token.get(access_token="my-expired-token")
assert models.AuthorizationCode.get(code="my-expired-code")
assert models.Token.get(access_token="my-expired-token")
assert expired_code.is_expired()
assert expired_token.is_expired()
@ -82,5 +81,5 @@ def test_clean_command(testclient, backend, client, user):
res = runner.invoke(cli, ["clean"])
assert res.exit_code == 0, res.stdout
assert AuthorizationCode.query() == [valid_code]
assert Token.query() == [valid_token]
assert models.AuthorizationCode.query() == [valid_code]
assert models.Token.query() == [valid_token]

View file

@ -4,10 +4,7 @@ import uuid
import pytest
from authlib.oidc.core.grants.util import generate_id_token
from canaille.oidc.models import AuthorizationCode
from canaille.oidc.models import Client
from canaille.oidc.models import Consent
from canaille.oidc.models import Token
from canaille.app import models
from canaille.oidc.oauth import generate_user_info
from canaille.oidc.oauth import get_jwt_config
from cryptography.hazmat.backends import default_backend as crypto_default_backend
@ -71,7 +68,7 @@ def configuration(configuration, keypair_path):
@pytest.fixture
def client(testclient, other_client, backend):
c = Client(
c = models.Client(
client_id=gen_salt(24),
client_name="Some client",
contacts="contact@mydomain.tld",
@ -107,7 +104,7 @@ def client(testclient, other_client, backend):
@pytest.fixture
def other_client(testclient, backend):
c = Client(
c = models.Client(
client_id=gen_salt(24),
client_name="Some other client",
contacts="contact@myotherdomain.tld",
@ -143,7 +140,7 @@ def other_client(testclient, backend):
@pytest.fixture
def authorization(testclient, user, client, backend):
a = AuthorizationCode(
a = models.AuthorizationCode(
authorization_code_id=gen_salt(48),
code="my-code",
client=client,
@ -165,7 +162,7 @@ def authorization(testclient, user, client, backend):
@pytest.fixture
def token(testclient, client, user, backend):
t = Token(
t = models.Token(
token_id=gen_salt(48),
access_token=gen_salt(48),
audience=[client],
@ -194,7 +191,7 @@ def id_token(testclient, client, user, backend):
@pytest.fixture
def consent(testclient, client, user, backend):
t = Consent(
t = models.Consent(
consent_id=str(uuid.uuid4()),
client=client,
subject=user,

View file

@ -5,10 +5,7 @@ from urllib.parse import urlsplit
import freezegun
from authlib.jose import jwt
from authlib.oauth2.rfc7636 import create_s256_code_challenge
from canaille.core.models import User
from canaille.oidc.models import AuthorizationCode
from canaille.oidc.models import Consent
from canaille.oidc.models import Token
from canaille.app import models
from canaille.oidc.oauth import setup_oauth
from werkzeug.security import gen_salt
@ -18,7 +15,7 @@ from . import client_credentials
def test_authorization_code_flow(
testclient, logged_user, client, keypair, other_client
):
assert not Consent.query()
assert not models.Consent.query()
res = testclient.get(
"/oauth/authorize",
@ -36,7 +33,7 @@ def test_authorization_code_flow(
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
assert set(authcode.scope[0].split(" ")) == {
"openid",
@ -47,7 +44,7 @@ def test_authorization_code_flow(
"phone",
}
consents = Consent.query(client=client, subject=logged_user)
consents = models.Consent.query(client=client, subject=logged_user)
assert set(consents[0].scope) == {
"openid",
"profile",
@ -70,7 +67,7 @@ def test_authorization_code_flow(
)
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
token = models.Token.get(access_token=access_token)
assert token.client == client
assert token.subject == logged_user
assert set(token.scope[0].split(" ")) == {
@ -119,7 +116,7 @@ def test_invalid_client(testclient, logged_user, keypair):
def test_authorization_code_flow_with_redirect_uri(
testclient, logged_user, client, keypair, other_client
):
assert not Consent.query()
assert not models.Consent.query()
res = testclient.get(
"/oauth/authorize",
@ -138,9 +135,9 @@ def test_authorization_code_flow_with_redirect_uri(
assert res.location.startswith(client.redirect_uris[1])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
consents = Consent.query(client=client, subject=logged_user)
consents = models.Consent.query(client=client, subject=logged_user)
res = testclient.post(
"/oauth/token",
@ -155,7 +152,7 @@ def test_authorization_code_flow_with_redirect_uri(
)
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
token = models.Token.get(access_token=access_token)
assert token.client == client
assert token.subject == logged_user
@ -166,7 +163,7 @@ def test_authorization_code_flow_with_redirect_uri(
def test_authorization_code_flow_preconsented(
testclient, logged_user, client, keypair, other_client
):
assert not Consent.query()
assert not models.Consent.query()
client.preconsent = True
client.save()
@ -185,10 +182,10 @@ def test_authorization_code_flow_preconsented(
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
consents = Consent.query(client=client, subject=logged_user)
consents = models.Consent.query(client=client, subject=logged_user)
assert not consents
res = testclient.post(
@ -204,7 +201,7 @@ def test_authorization_code_flow_preconsented(
)
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
token = models.Token.get(access_token=access_token)
assert token.client == client
assert token.subject == logged_user
@ -223,7 +220,7 @@ def test_authorization_code_flow_preconsented(
def test_logout_login(testclient, logged_user, client):
assert not Consent.query()
assert not models.Consent.query()
res = testclient.get(
"/oauth/authorize",
@ -254,10 +251,10 @@ def test_logout_login(testclient, logged_user, client):
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
consents = Consent.query(client=client, subject=logged_user)
consents = models.Consent.query(client=client, subject=logged_user)
assert "profile" in consents[0].scope
res = testclient.post(
@ -273,7 +270,7 @@ def test_logout_login(testclient, logged_user, client):
)
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
token = models.Token.get(access_token=access_token)
assert token.client == client
assert token.subject == logged_user
@ -289,7 +286,7 @@ def test_logout_login(testclient, logged_user, client):
def test_deny(testclient, logged_user, client):
assert not Consent.query()
assert not models.Consent.query()
res = testclient.get(
"/oauth/authorize",
@ -308,11 +305,11 @@ def test_deny(testclient, logged_user, client):
error = params["error"][0]
assert error == "access_denied"
assert not Consent.query()
assert not models.Consent.query()
def test_refresh_token(testclient, user, client):
assert not Consent.query()
assert not models.Consent.query()
with freezegun.freeze_time("2020-01-01 01:00:00"):
res = testclient.get(
@ -335,10 +332,10 @@ def test_refresh_token(testclient, user, client):
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
consents = Consent.query(client=client, subject=user)
consents = models.Consent.query(client=client, subject=user)
assert "profile" in consents[0].scope
with freezegun.freeze_time("2020-01-01 00:01:00"):
@ -354,7 +351,7 @@ def test_refresh_token(testclient, user, client):
status=200,
)
access_token = res.json["access_token"]
old_token = Token.get(access_token=access_token)
old_token = models.Token.get(access_token=access_token)
assert old_token is not None
assert not old_token.revokation_date
@ -369,7 +366,7 @@ def test_refresh_token(testclient, user, client):
status=200,
)
access_token = res.json["access_token"]
new_token = Token.get(access_token=access_token)
new_token = models.Token.get(access_token=access_token)
assert new_token is not None
assert old_token.access_token != new_token.access_token
@ -389,7 +386,7 @@ def test_refresh_token(testclient, user, client):
def test_code_challenge(testclient, logged_user, client):
assert not Consent.query()
assert not models.Consent.query()
client.token_endpoint_auth_method = "none"
client.save()
@ -415,10 +412,10 @@ def test_code_challenge(testclient, logged_user, client):
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
consents = Consent.query(client=client, subject=logged_user)
consents = models.Consent.query(client=client, subject=logged_user)
assert "profile" in consents[0].scope
res = testclient.post(
@ -435,7 +432,7 @@ def test_code_challenge(testclient, logged_user, client):
)
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
token = models.Token.get(access_token=access_token)
assert token.client == client
assert token.subject == logged_user
@ -456,7 +453,7 @@ def test_code_challenge(testclient, logged_user, client):
def test_authorization_code_flow_when_consent_already_given(
testclient, logged_user, client
):
assert not Consent.query()
assert not models.Consent.query()
res = testclient.get(
"/oauth/authorize",
@ -474,10 +471,10 @@ def test_authorization_code_flow_when_consent_already_given(
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
consents = Consent.query(client=client, subject=logged_user)
consents = models.Consent.query(client=client, subject=logged_user)
assert "profile" in consents[0].scope
res = testclient.post(
@ -514,7 +511,7 @@ def test_authorization_code_flow_when_consent_already_given(
def test_authorization_code_flow_when_consent_already_given_but_for_a_smaller_scope(
testclient, logged_user, client
):
assert not Consent.query()
assert not models.Consent.query()
res = testclient.get(
"/oauth/authorize",
@ -532,10 +529,10 @@ def test_authorization_code_flow_when_consent_already_given_but_for_a_smaller_sc
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
consents = Consent.query(client=client, subject=logged_user)
consents = models.Consent.query(client=client, subject=logged_user)
assert "profile" in consents[0].scope
assert "groups" not in consents[0].scope
@ -568,10 +565,10 @@ def test_authorization_code_flow_when_consent_already_given_but_for_a_smaller_sc
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
consents = Consent.query(client=client, subject=logged_user)
consents = models.Consent.query(client=client, subject=logged_user)
assert "profile" in consents[0].scope
assert "groups" in consents[0].scope
@ -604,7 +601,7 @@ def test_authorization_code_flow_but_user_cannot_use_oidc(
def test_prompt_none(testclient, logged_user, client):
consent = Consent(
consent = models.Consent(
consent_id=str(uuid.uuid4()),
client=client,
subject=logged_user,
@ -631,7 +628,7 @@ def test_prompt_none(testclient, logged_user, client):
def test_prompt_not_logged(testclient, user, client):
consent = Consent(
consent = models.Consent(
consent_id=str(uuid.uuid4()),
client=client,
subject=user,
@ -685,7 +682,7 @@ def test_nonce_required_in_oidc_requests(testclient, logged_user, client):
def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client):
assert not Consent.query()
assert not models.Consent.query()
testclient.app.config["REQUIRE_NONCE"] = False
res = testclient.get(
@ -701,14 +698,14 @@ def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client):
res = res.form.submit(name="answer", value="accept", status=302)
assert res.location.startswith(client.redirect_uris[0])
for consent in Consent.query():
for consent in models.Consent.query():
consent.delete()
def test_authorization_code_request_scope_too_large(
testclient, logged_user, keypair, other_client
):
assert not Consent.query()
assert not models.Consent.query()
assert "email" not in other_client.scope
res = testclient.get(
@ -726,13 +723,13 @@ def test_authorization_code_request_scope_too_large(
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
authcode = models.AuthorizationCode.get(code=code)
assert set(authcode.scope[0].split(" ")) == {
"openid",
"profile",
}
consents = Consent.query(client=other_client, subject=logged_user)
consents = models.Consent.query(client=other_client, subject=logged_user)
assert set(consents[0].scope) == {
"openid",
"profile",
@ -751,7 +748,7 @@ def test_authorization_code_request_scope_too_large(
)
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
token = models.Token.get(access_token=access_token)
assert token.client == other_client
assert token.subject == logged_user
assert set(token.scope[0].split(" ")) == {
@ -813,7 +810,7 @@ def test_authorization_code_expired(testclient, user, client):
def test_code_with_invalid_user(testclient, admin, client):
user = User(
user = models.User(
formatted_name="John Doe",
family_name="Doe",
user_name="temp",
@ -838,7 +835,7 @@ def test_code_with_invalid_user(testclient, admin, client):
res = res.form.submit(name="answer", value="accept", status=302)
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
authcode = models.AuthorizationCode.get(code=code)
user.delete()
@ -861,7 +858,7 @@ def test_code_with_invalid_user(testclient, admin, client):
def test_refresh_token_with_invalid_user(testclient, client):
user = User(
user = models.User(
formatted_name="John Doe",
family_name="Doe",
user_name="temp",
@ -888,7 +885,7 @@ def test_refresh_token_with_invalid_user(testclient, client):
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
authcode = models.AuthorizationCode.get(code=code)
res = testclient.post(
"/oauth/token",
@ -919,7 +916,7 @@ def test_refresh_token_with_invalid_user(testclient, client):
"error": "invalid_request",
"error_description": 'There is no "user" for this token.',
}
Token.get(access_token=access_token).delete()
models.Token.get(access_token=access_token).delete()
def test_token_default_expiration_date(testclient, logged_user, client, keypair):
@ -937,7 +934,7 @@ def test_token_default_expiration_date(testclient, logged_user, client, keypair)
res = res.form.submit(name="answer", value="accept", status=302)
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
authcode = models.AuthorizationCode.get(code=code)
assert authcode.lifetime == 84400
res = testclient.post(
@ -955,7 +952,7 @@ def test_token_default_expiration_date(testclient, logged_user, client, keypair)
assert res.json["expires_in"] == 864000
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
token = models.Token.get(access_token=access_token)
assert token.lifetime == 864000
claims = jwt.decode(access_token, keypair[1])
@ -965,7 +962,7 @@ def test_token_default_expiration_date(testclient, logged_user, client, keypair)
claims = jwt.decode(id_token, keypair[1])
assert claims["exp"] - claims["iat"] == 3600
consents = Consent.query(client=client, subject=logged_user)
consents = models.Consent.query(client=client, subject=logged_user)
for consent in consents:
consent.delete()
@ -995,7 +992,7 @@ def test_token_custom_expiration_date(testclient, logged_user, client, keypair):
res = res.form.submit(name="answer", value="accept", status=302)
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
authcode = models.AuthorizationCode.get(code=code)
assert authcode.lifetime == 84400
res = testclient.post(
@ -1013,7 +1010,7 @@ def test_token_custom_expiration_date(testclient, logged_user, client, keypair):
assert res.json["expires_in"] == 1000
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
token = models.Token.get(access_token=access_token)
assert token.lifetime == 1000
claims = jwt.decode(access_token, keypair[1])
@ -1023,6 +1020,6 @@ def test_token_custom_expiration_date(testclient, logged_user, client, keypair):
claims = jwt.decode(id_token, keypair[1])
assert claims["exp"] - claims["iat"] == 6000
consents = Consent.query(client=client, subject=logged_user)
consents = models.Consent.query(client=client, subject=logged_user)
for consent in consents:
consent.delete()

View file

@ -1,9 +1,6 @@
import datetime
from canaille.oidc.models import AuthorizationCode
from canaille.oidc.models import Client
from canaille.oidc.models import Consent
from canaille.oidc.models import Token
from canaille.app import models
from werkzeug.security import gen_salt
@ -29,7 +26,7 @@ def test_client_list_pagination(testclient, logged_admin, client, other_client):
res.mustcontain("2 items")
clients = []
for _ in range(25):
client = Client(client_id=gen_salt(48), client_name=gen_salt(48))
client = models.Client(client_id=gen_salt(48), client_name=gen_salt(48))
client.save()
clients.append(client)
@ -115,7 +112,7 @@ def test_client_add(testclient, logged_admin):
res = res.follow(status=200)
client_id = res.forms["readonly"]["client_id"].value
client = Client.get(client_id=client_id)
client = models.Client.get(client_id=client_id)
data["audience"] = [client]
for k, v in data.items():
client_value = getattr(client, k)
@ -198,27 +195,29 @@ def test_client_edit_missing_fields(testclient, client, logged_admin, other_clie
def test_client_delete(testclient, logged_admin):
client = Client(client_id="client_id")
client = models.Client(client_id="client_id")
client.save()
token = Token(
token = models.Token(
token_id="id",
client=client,
issue_datetime=datetime.datetime.now(datetime.timezone.utc),
)
token.save()
consent = Consent(
consent = models.Consent(
consent_id="consent_id", subject=logged_admin, client=client, scope="openid"
)
consent.save()
code = AuthorizationCode(authorization_code_id="id", client=client, subject=client)
code = models.AuthorizationCode(
authorization_code_id="id", client=client, subject=client
)
res = testclient.get("/admin/client/edit/" + client.client_id)
res = res.forms["clientaddform"].submit(name="action", value="delete").follow()
assert not Client.get()
assert not Token.get()
assert not AuthorizationCode.get()
assert not Consent.get()
assert not models.Client.get()
assert not models.Token.get()
assert not models.AuthorizationCode.get()
assert not models.Consent.get()
def test_client_delete_invalid_client(testclient, logged_admin, client):

View file

@ -1,4 +1,4 @@
from canaille.oidc.models import AuthorizationCode
from canaille.app import models
from werkzeug.security import gen_salt
@ -20,7 +20,7 @@ def test_authorization_list_pagination(testclient, logged_admin, client):
res.mustcontain("0 items")
authorizations = []
for _ in range(26):
code = AuthorizationCode(
code = models.AuthorizationCode(
authorization_code_id=gen_salt(48), client=client, subject=logged_admin
)
code.save()
@ -66,13 +66,13 @@ def test_authorization_list_bad_pages(testclient, logged_admin):
def test_authorization_list_search(testclient, logged_admin, client):
id1 = gen_salt(48)
auth1 = AuthorizationCode(
auth1 = models.AuthorizationCode(
authorization_code_id=id1, client=client, subject=logged_admin
)
auth1.save()
id2 = gen_salt(48)
auth2 = AuthorizationCode(
auth2 = models.AuthorizationCode(
authorization_code_id=id2, client=client, subject=logged_admin
)
auth2.save()

View file

@ -1,8 +1,7 @@
from urllib.parse import parse_qs
from urllib.parse import urlsplit
from canaille.oidc.models import Consent
from canaille.oidc.models import Token
from canaille.app import models
from . import client_credentials
@ -118,7 +117,7 @@ def test_oidc_authorization_after_revokation(
res = res.form.submit(name="answer", value="accept", status=302)
consents = Consent.query(client=client, subject=logged_user)
consents = models.Consent.query(client=client, subject=logged_user)
consent.reload()
assert consents[0] == consent
assert not consent.revoked
@ -138,7 +137,7 @@ def test_oidc_authorization_after_revokation(
)
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
token = models.Token.get(access_token=access_token)
assert token.client == client
assert token.subject == logged_user
@ -158,13 +157,13 @@ def test_preconsented_client_appears_in_consent_list(testclient, client, logged_
def test_revoke_preconsented_client(testclient, client, logged_user, token):
client.preconsent = True
client.save()
assert not Consent.get()
assert not models.Consent.get()
assert not token.revoked
res = testclient.get(f"/consent/revoke-preconsent/{client.client_id}", status=302)
assert ("success", "The access has been revoked") in res.flashes
consent = Consent.get()
consent = models.Consent.get()
assert consent.client == client
assert consent.subject == logged_user
assert consent.scope == ["openid", "email", "profile", "groups", "address", "phone"]

View file

@ -1,7 +1,7 @@
from unittest import mock
from authlib.jose import jwt
from canaille.oidc.models import Client
from canaille.app import models
def test_client_registration_with_authentication_static_token(
@ -29,7 +29,7 @@ def test_client_registration_with_authentication_static_token(
headers = {"Authorization": "Bearer static-token"}
res = testclient.post_json("/oauth/register", payload, headers=headers, status=201)
client = Client.get(client_id=res.json["client_id"])
client = models.Client.get(client_id=res.json["client_id"])
assert res.json == {
"client_id": client.client_id,
@ -148,7 +148,7 @@ def test_client_registration_with_software_statement(testclient, backend, keypai
}
res = testclient.post_json("/oauth/register", payload, status=201)
client = Client.get(client_id=res.json["client_id"])
client = models.Client.get(client_id=res.json["client_id"])
assert res.json == {
"client_id": client.client_id,
"client_secret": client.client_secret,
@ -199,7 +199,7 @@ def test_client_registration_without_authentication_ok(testclient, backend):
res = testclient.post_json("/oauth/register", payload, status=201)
client = Client.get(client_id=res.json["client_id"])
client = models.Client.get(client_id=res.json["client_id"])
assert res.json == {
"client_id": mock.ANY,
"client_secret": mock.ANY,

View file

@ -1,7 +1,7 @@
import warnings
from datetime import datetime
from canaille.oidc.models import Client
from canaille.app import models
def test_get(testclient, backend, client, user):
@ -95,7 +95,7 @@ def test_update(testclient, backend, client, user):
res = testclient.put_json(
f"/oauth/register/{client.client_id}", payload, headers=headers, status=200
)
client = Client.get(client_id=res.json["client_id"])
client = models.Client.get(client_id=res.json["client_id"])
assert res.json == {
"client_id": client.client_id,
@ -145,7 +145,7 @@ def test_delete(testclient, backend, user):
"static-token"
]
client = Client(client_id="foobar", client_name="Some client")
client = models.Client(client_id="foobar", client_name="Some client")
client.save()
headers = {"Authorization": "Bearer static-token"}
@ -153,7 +153,7 @@ def test_delete(testclient, backend, user):
res = testclient.delete(
f"/oauth/register/{client.client_id}", headers=headers, status=204
)
assert not Client.get(client_id=client.client_id)
assert not models.Client.get(client_id=client.client_id)
def test_invalid_client(testclient, backend, user):

View file

@ -2,8 +2,7 @@ from urllib.parse import parse_qs
from urllib.parse import urlsplit
from authlib.jose import jwt
from canaille.oidc.models import AuthorizationCode
from canaille.oidc.models import Token
from canaille.app import models
def test_oauth_hybrid(testclient, backend, user, client):
@ -32,11 +31,11 @@ def test_oauth_hybrid(testclient, backend, user, client):
params = parse_qs(urlsplit(res.location).fragment)
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
access_token = params["access_token"][0]
token = Token.get(access_token=access_token)
token = models.Token.get(access_token=access_token)
assert token is not None
res = testclient.get(
@ -65,11 +64,11 @@ def test_oidc_hybrid(testclient, backend, logged_user, client, keypair, other_cl
params = parse_qs(urlsplit(res.location).fragment)
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
access_token = params["access_token"][0]
token = Token.get(access_token=access_token)
token = models.Token.get(access_token=access_token)
assert token is not None
id_token = params["id_token"][0]

View file

@ -2,7 +2,7 @@ from urllib.parse import parse_qs
from urllib.parse import urlsplit
from authlib.jose import jwt
from canaille.oidc.models import Token
from canaille.app import models
def test_oauth_implicit(testclient, user, client):
@ -35,7 +35,7 @@ def test_oauth_implicit(testclient, user, client):
params = parse_qs(urlsplit(res.location).fragment)
access_token = params["access_token"][0]
token = Token.get(access_token=access_token)
token = models.Token.get(access_token=access_token)
assert token is not None
res = testclient.get(
@ -79,7 +79,7 @@ def test_oidc_implicit(testclient, keypair, user, client, other_client):
params = parse_qs(urlsplit(res.location).fragment)
access_token = params["access_token"][0]
token = Token.get(access_token=access_token)
token = models.Token.get(access_token=access_token)
assert token is not None
id_token = params["id_token"][0]
@ -133,7 +133,7 @@ def test_oidc_implicit_with_group(
params = parse_qs(urlsplit(res.location).fragment)
access_token = params["access_token"][0]
token = Token.get(access_token=access_token)
token = models.Token.get(access_token=access_token)
assert token is not None
id_token = params["id_token"][0]

View file

@ -1,4 +1,4 @@
from canaille.oidc.models import Token
from canaille.app import models
from . import client_credentials
@ -20,7 +20,7 @@ def test_password_flow_basic(testclient, user, client):
assert res.json["token_type"] == "Bearer"
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
token = models.Token.get(access_token=access_token)
assert token is not None
res = testclient.get(
@ -52,7 +52,7 @@ def test_password_flow_post(testclient, user, client):
assert res.json["token_type"] == "Bearer"
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
token = models.Token.get(access_token=access_token)
assert token is not None
res = testclient.get(

View file

@ -1,6 +1,6 @@
import datetime
from canaille.oidc.models import Token
from canaille.app import models
from werkzeug.security import gen_salt
@ -22,7 +22,7 @@ def test_token_list_pagination(testclient, logged_admin, client):
res.mustcontain("0 items")
tokens = []
for _ in range(26):
token = Token(
token = models.Token(
token_id=gen_salt(48),
access_token="my-valid-token",
client=client,
@ -75,7 +75,7 @@ def test_token_list_bad_pages(testclient, logged_admin):
def test_token_list_search(testclient, logged_admin, client):
token1 = Token(
token1 = models.Token(
token_id=gen_salt(48),
access_token="this-token-is-ok",
client=client,
@ -89,7 +89,7 @@ def test_token_list_search(testclient, logged_admin, client):
lifetime=3600,
)
token1.save()
token2 = Token(
token2 = models.Token(
token_id=gen_salt(48),
access_token="this-token-is-valid",
client=client,

View file

@ -1,8 +1,7 @@
from urllib.parse import parse_qs
from urllib.parse import urlsplit
from canaille.oidc.models import AuthorizationCode
from canaille.oidc.models import Token
from canaille.app import models
from . import client_credentials
@ -76,7 +75,7 @@ def test_full_flow(testclient, logged_user, client, user, other_client):
assert res.location.startswith(client.redirect_uris[0])
params = parse_qs(urlsplit(res.location).query)
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
authcode = models.AuthorizationCode.get(code=code)
assert authcode is not None
res = testclient.post(
@ -92,7 +91,7 @@ def test_full_flow(testclient, logged_user, client, user, other_client):
)
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
token = models.Token.get(access_token=access_token)
assert token.client == client
assert token.subject == logged_user