forked from Github-Mirrors/canaille
Moved every model import to canaille.models
This commit is contained in:
parent
e110c4851b
commit
c1d1706007
43 changed files with 421 additions and 428 deletions
|
@ -3,7 +3,7 @@ import hashlib
|
|||
import json
|
||||
import re
|
||||
|
||||
from canaille.core.models import User
|
||||
from canaille.app import models
|
||||
from flask import current_app
|
||||
from flask import request
|
||||
from flask_babel import gettext as _
|
||||
|
@ -26,7 +26,7 @@ def profile_hash(*args):
|
|||
|
||||
def login_placeholder():
|
||||
user_filter = current_app.config["BACKENDS"]["LDAP"].get(
|
||||
"USER_FILTER", User.DEFAULT_FILTER
|
||||
"USER_FILTER", models.User.DEFAULT_FILTER
|
||||
)
|
||||
placeholders = []
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from functools import wraps
|
|||
from urllib.parse import urlsplit
|
||||
from urllib.parse import urlunsplit
|
||||
|
||||
from canaille.core.models import User
|
||||
from canaille.app import models
|
||||
from flask import abort
|
||||
from flask import current_app
|
||||
from flask import render_template
|
||||
|
@ -14,7 +14,7 @@ from flask_babel import gettext as _
|
|||
|
||||
def current_user():
|
||||
for user_id in session.get("user_id", [])[::-1]:
|
||||
user = User.get(id=user_id)
|
||||
user = models.User.get(id=user_id)
|
||||
if user:
|
||||
return user
|
||||
|
||||
|
|
7
canaille/app/models.py
Normal file
7
canaille/app/models.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# nopycln: file
|
||||
from canaille.core.models import Group
|
||||
from canaille.core.models import User
|
||||
from canaille.oidc.models import AuthorizationCode
|
||||
from canaille.oidc.models import Client
|
||||
from canaille.oidc.models import Consent
|
||||
from canaille.oidc.models import Token
|
|
@ -76,8 +76,7 @@ class LDAPBackend(Backend):
|
|||
@classmethod
|
||||
def validate(cls, config):
|
||||
from canaille.app.configuration import ConfigurationException
|
||||
from canaille.core.models import Group
|
||||
from canaille.core.models import User
|
||||
from canaille.app import models
|
||||
|
||||
try:
|
||||
conn = ldap.initialize(config["BACKENDS"]["LDAP"]["URI"])
|
||||
|
@ -100,8 +99,8 @@ class LDAPBackend(Backend):
|
|||
) from exc
|
||||
|
||||
try:
|
||||
User.ldap_object_classes(conn)
|
||||
user = User(
|
||||
models.User.ldap_object_classes(conn)
|
||||
user = models.User(
|
||||
formatted_name=f"canaille_{uuid.uuid4()}",
|
||||
family_name=f"canaille_{uuid.uuid4()}",
|
||||
user_name=f"canaille_{uuid.uuid4()}",
|
||||
|
@ -118,9 +117,9 @@ class LDAPBackend(Backend):
|
|||
) from exc
|
||||
|
||||
try:
|
||||
Group.ldap_object_classes(conn)
|
||||
models.Group.ldap_object_classes(conn)
|
||||
|
||||
user = User(
|
||||
user = models.User(
|
||||
cn=f"canaille_{uuid.uuid4()}",
|
||||
family_name=f"canaille_{uuid.uuid4()}",
|
||||
user_name=f"canaille_{uuid.uuid4()}",
|
||||
|
@ -129,7 +128,7 @@ class LDAPBackend(Backend):
|
|||
)
|
||||
user.save(conn)
|
||||
|
||||
group = Group(
|
||||
group = models.Group(
|
||||
display_name=f"canaille_{uuid.uuid4()}",
|
||||
members=[user],
|
||||
)
|
||||
|
@ -150,22 +149,21 @@ class LDAPBackend(Backend):
|
|||
|
||||
def setup_ldap_models(config):
|
||||
from .ldapobject import LDAPObject
|
||||
from canaille.core.models import Group
|
||||
from canaille.core.models import User
|
||||
from canaille.app import models
|
||||
|
||||
LDAPObject.root_dn = config["BACKENDS"]["LDAP"]["ROOT_DN"]
|
||||
|
||||
user_base = config["BACKENDS"]["LDAP"]["USER_BASE"].replace(
|
||||
f',{config["BACKENDS"]["LDAP"]["ROOT_DN"]}', ""
|
||||
)
|
||||
User.base = user_base
|
||||
User.rdn_attribute = config["BACKENDS"]["LDAP"].get(
|
||||
"USER_ID_ATTRIBUTE", User.DEFAULT_ID_ATTRIBUTE
|
||||
models.User.base = user_base
|
||||
models.User.rdn_attribute = config["BACKENDS"]["LDAP"].get(
|
||||
"USER_ID_ATTRIBUTE", models.User.DEFAULT_ID_ATTRIBUTE
|
||||
)
|
||||
object_class = config["BACKENDS"]["LDAP"].get(
|
||||
"USER_CLASS", User.DEFAULT_OBJECT_CLASS
|
||||
"USER_CLASS", models.User.DEFAULT_OBJECT_CLASS
|
||||
)
|
||||
User.ldap_object_class = (
|
||||
models.User.ldap_object_class = (
|
||||
object_class if isinstance(object_class, list) else [object_class]
|
||||
)
|
||||
|
||||
|
@ -174,13 +172,13 @@ def setup_ldap_models(config):
|
|||
.get("GROUP_BASE", "")
|
||||
.replace(f',{config["BACKENDS"]["LDAP"]["ROOT_DN"]}', "")
|
||||
)
|
||||
Group.base = group_base or None
|
||||
Group.rdn_attribute = config["BACKENDS"]["LDAP"].get(
|
||||
"GROUP_ID_ATTRIBUTE", Group.DEFAULT_ID_ATTRIBUTE
|
||||
models.Group.base = group_base or None
|
||||
models.Group.rdn_attribute = config["BACKENDS"]["LDAP"].get(
|
||||
"GROUP_ID_ATTRIBUTE", models.Group.DEFAULT_ID_ATTRIBUTE
|
||||
)
|
||||
object_class = config["BACKENDS"]["LDAP"].get(
|
||||
"GROUP_CLASS", Group.DEFAULT_OBJECT_CLASS
|
||||
"GROUP_CLASS", models.Group.DEFAULT_OBJECT_CLASS
|
||||
)
|
||||
Group.ldap_object_class = (
|
||||
models.Group.ldap_object_class = (
|
||||
object_class if isinstance(object_class, list) else [object_class]
|
||||
)
|
||||
|
|
|
@ -9,6 +9,7 @@ import wtforms
|
|||
from canaille.app import b64_to_obj
|
||||
from canaille.app import default_fields
|
||||
from canaille.app import login_placeholder
|
||||
from canaille.app import models
|
||||
from canaille.app import obj_to_b64
|
||||
from canaille.app import profile_hash
|
||||
from canaille.app.flask import current_user
|
||||
|
@ -43,8 +44,6 @@ from .forms import profile_form
|
|||
from .mails import send_invitation_mail
|
||||
from .mails import send_password_initialization_mail
|
||||
from .mails import send_password_reset_mail
|
||||
from .models import Group
|
||||
from .models import User
|
||||
|
||||
|
||||
bp = Blueprint("account", __name__)
|
||||
|
@ -88,12 +87,12 @@ def login():
|
|||
form["login"].render_kw["placeholder"] = login_placeholder()
|
||||
|
||||
if request.form:
|
||||
user = User.get_from_login(form.login.data)
|
||||
user = models.User.get_from_login(form.login.data)
|
||||
if user and not user.has_password():
|
||||
return redirect(url_for("account.firstlogin", user_name=user.user_name[0]))
|
||||
|
||||
if not form.validate():
|
||||
User.logout()
|
||||
models.User.logout()
|
||||
flash(_("Login failed, please check your information"), "error")
|
||||
return render_template("login.html", form=form)
|
||||
|
||||
|
@ -111,7 +110,7 @@ def password():
|
|||
form = PasswordForm(request.form or None)
|
||||
|
||||
if request.form:
|
||||
user = User.get_from_login(session["attempt_login"])
|
||||
user = models.User.get_from_login(session["attempt_login"])
|
||||
if user and not user.has_password():
|
||||
return redirect(url_for("account.firstlogin", user_name=user.user_name[0]))
|
||||
|
||||
|
@ -120,7 +119,7 @@ def password():
|
|||
or not user
|
||||
or not user.check_password(form.password.data)
|
||||
):
|
||||
User.logout()
|
||||
models.User.logout()
|
||||
flash(_("Login failed, please check your information"), "error")
|
||||
return render_template(
|
||||
"password.html", form=form, username=session["attempt_login"]
|
||||
|
@ -156,7 +155,7 @@ def logout():
|
|||
|
||||
@bp.route("/firstlogin/<user_name>", methods=("GET", "POST"))
|
||||
def firstlogin(user_name):
|
||||
user = User.get_from_login(user_name)
|
||||
user = models.User.get_from_login(user_name)
|
||||
if not user or user.has_password():
|
||||
abort(404)
|
||||
|
||||
|
@ -182,7 +181,9 @@ def firstlogin(user_name):
|
|||
@bp.route("/users", methods=["GET", "POST"])
|
||||
@permissions_needed("manage_users")
|
||||
def users(user):
|
||||
table_form = TableForm(User, fields=user.read | user.write, formdata=request.form)
|
||||
table_form = TableForm(
|
||||
models.User, fields=user.read | user.write, formdata=request.form
|
||||
)
|
||||
if request.form and not table_form.validate():
|
||||
abort(404)
|
||||
|
||||
|
@ -278,7 +279,7 @@ def registration(data, hash):
|
|||
)
|
||||
return redirect(url_for("account.index"))
|
||||
|
||||
if User.get_from_login(invitation.user_name):
|
||||
if models.User.get_from_login(invitation.user_name):
|
||||
flash(
|
||||
_("Your account has already been created."),
|
||||
"error",
|
||||
|
@ -311,7 +312,7 @@ def registration(data, hash):
|
|||
if "groups" not in form and invitation.groups:
|
||||
form["groups"] = wtforms.SelectMultipleField(
|
||||
_("Groups"),
|
||||
choices=[(group.id, group.display_name) for group in Group.query()],
|
||||
choices=[(group.id, group.display_name) for group in models.Group.query()],
|
||||
render_kw={"readonly": "true"},
|
||||
)
|
||||
form.process(CombinedMultiDict((request.files, request.form)) or None, data=data)
|
||||
|
@ -381,7 +382,7 @@ def profile_creation(user):
|
|||
|
||||
|
||||
def profile_create(current_app, form):
|
||||
user = User()
|
||||
user = models.User()
|
||||
for attribute in form:
|
||||
if attribute.name in user.attributes:
|
||||
if isinstance(attribute.data, FileStorage):
|
||||
|
@ -420,7 +421,7 @@ def profile_edition(user, username):
|
|||
menuitem = "profile" if username == editor.user_name[0] else "users"
|
||||
fields = editor.read | editor.write
|
||||
if username != editor.user_name[0]:
|
||||
user = User.get_from_login(username)
|
||||
user = models.User.get_from_login(username)
|
||||
else:
|
||||
user = editor
|
||||
|
||||
|
@ -505,7 +506,7 @@ def profile_settings(user, username):
|
|||
):
|
||||
abort(403)
|
||||
|
||||
edited_user = User.get_from_login(username)
|
||||
edited_user = models.User.get_from_login(username)
|
||||
if not edited_user:
|
||||
abort(404)
|
||||
|
||||
|
@ -622,7 +623,7 @@ def profile_delete(user, edited_user):
|
|||
@bp.route("/impersonate/<username>")
|
||||
@permissions_needed("impersonate_users")
|
||||
def impersonate(user, username):
|
||||
puppet = User.get_from_login(username)
|
||||
puppet = models.User.get_from_login(username)
|
||||
if not puppet:
|
||||
abort(404)
|
||||
|
||||
|
@ -649,7 +650,7 @@ def forgotten():
|
|||
flash(_("Could not send the password reset link."), "error")
|
||||
return render_template("forgotten-password.html", form=form)
|
||||
|
||||
user = User.get_from_login(form.login.data)
|
||||
user = models.User.get_from_login(form.login.data)
|
||||
success_message = _(
|
||||
"A password reset link has been sent at your email address. You should receive it within a few minutes."
|
||||
)
|
||||
|
@ -689,7 +690,7 @@ def reset(user_name, hash):
|
|||
abort(404)
|
||||
|
||||
form = PasswordResetForm(request.form)
|
||||
user = User.get_from_login(user_name)
|
||||
user = models.User.get_from_login(user_name)
|
||||
|
||||
if not user or hash != profile_hash(
|
||||
user.user_name[0],
|
||||
|
@ -719,7 +720,7 @@ def photo(user_name, field):
|
|||
if field.lower() != "photo":
|
||||
abort(404)
|
||||
|
||||
user = User.get_from_login(user_name)
|
||||
user = models.User.get_from_login(user_name)
|
||||
if not user:
|
||||
abort(404)
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import wtforms.form
|
||||
from canaille.app import models
|
||||
from canaille.app.forms import HTMXBaseForm
|
||||
from canaille.app.forms import HTMXForm
|
||||
from canaille.app.forms import is_uri
|
||||
|
@ -9,12 +10,9 @@ from flask_babel import lazy_gettext as _
|
|||
from flask_wtf.file import FileAllowed
|
||||
from flask_wtf.file import FileField
|
||||
|
||||
from .models import Group
|
||||
from .models import User
|
||||
|
||||
|
||||
def unique_login(form, field):
|
||||
if User.get_from_login(field.data) and (
|
||||
if models.User.get_from_login(field.data) and (
|
||||
not getattr(form, "user", None) or form.user.user_name[0] != field.data
|
||||
):
|
||||
raise wtforms.ValidationError(
|
||||
|
@ -23,7 +21,7 @@ def unique_login(form, field):
|
|||
|
||||
|
||||
def unique_email(form, field):
|
||||
if User.get(email=field.data) and (
|
||||
if models.User.get(email=field.data) and (
|
||||
not getattr(form, "user", None) or form.user.email[0] != field.data
|
||||
):
|
||||
raise wtforms.ValidationError(
|
||||
|
@ -32,7 +30,7 @@ def unique_email(form, field):
|
|||
|
||||
|
||||
def unique_group(form, field):
|
||||
if Group.get(display_name=field.data):
|
||||
if models.Group.get(display_name=field.data):
|
||||
raise wtforms.ValidationError(
|
||||
_("The group '{group}' already exists").format(group=field.data)
|
||||
)
|
||||
|
@ -41,7 +39,7 @@ def unique_group(form, field):
|
|||
def existing_login(form, field):
|
||||
if not current_app.config.get(
|
||||
"HIDE_INVALID_LOGINS", True
|
||||
) and not User.get_from_login(field.data):
|
||||
) and not models.User.get_from_login(field.data):
|
||||
raise wtforms.ValidationError(
|
||||
_("The login '{login}' does not exist").format(login=field.data)
|
||||
)
|
||||
|
@ -257,7 +255,9 @@ PROFILE_FORM_FIELDS = dict(
|
|||
),
|
||||
groups=wtforms.SelectMultipleField(
|
||||
_("Groups"),
|
||||
choices=lambda: [(group.id, group.display_name) for group in Group.query()],
|
||||
choices=lambda: [
|
||||
(group.id, group.display_name) for group in models.Group.query()
|
||||
],
|
||||
render_kw={"placeholder": _("users, admins …")},
|
||||
),
|
||||
)
|
||||
|
@ -276,7 +276,7 @@ def profile_form(write_field_names, readonly_field_names, user=None):
|
|||
if PROFILE_FORM_FIELDS.get(name)
|
||||
}
|
||||
|
||||
if "groups" in fields and not Group.query():
|
||||
if "groups" in fields and not models.Group.query():
|
||||
del fields["groups"]
|
||||
|
||||
form = HTMXBaseForm(fields)
|
||||
|
@ -338,6 +338,8 @@ class InvitationForm(HTMXForm):
|
|||
)
|
||||
groups = wtforms.SelectMultipleField(
|
||||
_("Groups"),
|
||||
choices=lambda: [(group.id, group.display_name) for group in Group.query()],
|
||||
choices=lambda: [
|
||||
(group.id, group.display_name) for group in models.Group.query()
|
||||
],
|
||||
render_kw={},
|
||||
)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from canaille.app import models
|
||||
from canaille.app.flask import permissions_needed
|
||||
from canaille.app.flask import render_htmx_template
|
||||
from canaille.app.forms import TableForm
|
||||
|
@ -12,8 +13,6 @@ from flask_themer import render_template
|
|||
|
||||
from .forms import CreateGroupForm
|
||||
from .forms import EditGroupForm
|
||||
from .models import Group
|
||||
from .models import User
|
||||
|
||||
bp = Blueprint("groups", __name__, url_prefix="/groups")
|
||||
|
||||
|
@ -21,7 +20,7 @@ bp = Blueprint("groups", __name__, url_prefix="/groups")
|
|||
@bp.route("/", methods=["GET", "POST"])
|
||||
@permissions_needed("manage_groups")
|
||||
def groups(user):
|
||||
table_form = TableForm(Group, formdata=request.form)
|
||||
table_form = TableForm(models.Group, formdata=request.form)
|
||||
if request.form and request.form.get("page") and not table_form.validate():
|
||||
abort(404)
|
||||
|
||||
|
@ -37,7 +36,7 @@ def create_group(user):
|
|||
if not form.validate():
|
||||
flash(_("Group creation failed."), "error")
|
||||
else:
|
||||
group = Group()
|
||||
group = models.Group()
|
||||
group.members = [user]
|
||||
group.display_name = [form.display_name.data]
|
||||
group.description = [form.description.data]
|
||||
|
@ -59,7 +58,7 @@ def create_group(user):
|
|||
@bp.route("/<groupname>", methods=("GET", "POST"))
|
||||
@permissions_needed("manage_groups")
|
||||
def group(user, groupname):
|
||||
group = Group.get(display_name=groupname)
|
||||
group = models.Group.get(display_name=groupname)
|
||||
|
||||
if not group:
|
||||
abort(404)
|
||||
|
@ -78,7 +77,7 @@ def group(user, groupname):
|
|||
|
||||
|
||||
def edit_group(group):
|
||||
table_form = TableForm(User, filter={"groups": group}, formdata=request.form)
|
||||
table_form = TableForm(models.User, filter={"groups": group}, formdata=request.form)
|
||||
if request.form and request.form.get("page") and not table_form.validate():
|
||||
abort(404)
|
||||
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
import random
|
||||
|
||||
import faker
|
||||
from canaille.app import models
|
||||
from canaille.app.i18n import available_language_codes
|
||||
from canaille.core.models import Group
|
||||
from canaille.core.models import User
|
||||
from faker.config import AVAILABLE_LOCALES
|
||||
|
||||
|
||||
|
@ -19,7 +18,7 @@ def fake_users(nb=1):
|
|||
try:
|
||||
fake = random.choice(fakes)
|
||||
name = fake.unique.name()
|
||||
user = User(
|
||||
user = models.User(
|
||||
formatted_name=name,
|
||||
given_name=name.split(" ")[0],
|
||||
family_name=name.split(" ")[1],
|
||||
|
@ -47,11 +46,11 @@ def fake_users(nb=1):
|
|||
|
||||
def fake_groups(nb=1, nb_users_max=1):
|
||||
fake = faker_generator(["en_US"])[0]
|
||||
users = User.query()
|
||||
users = models.User.query()
|
||||
groups = list()
|
||||
for _ in range(nb):
|
||||
try:
|
||||
group = Group(
|
||||
group = models.Group(
|
||||
display_name=fake.unique.word(),
|
||||
description=fake.sentence(),
|
||||
)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from canaille.app import models
|
||||
from canaille.app.flask import permissions_needed
|
||||
from canaille.app.flask import render_htmx_template
|
||||
from canaille.app.forms import TableForm
|
||||
from canaille.oidc.models import AuthorizationCode
|
||||
from flask import abort
|
||||
from flask import Blueprint
|
||||
from flask import request
|
||||
|
@ -14,7 +14,7 @@ bp = Blueprint("authorizations", __name__, url_prefix="/admin/authorization")
|
|||
@bp.route("/", methods=["GET", "POST"])
|
||||
@permissions_needed("manage_oidc")
|
||||
def index(user):
|
||||
table_form = TableForm(AuthorizationCode, formdata=request.form)
|
||||
table_form = TableForm(models.AuthorizationCode, formdata=request.form)
|
||||
if request.form and request.form.get("page") and not table_form.validate():
|
||||
abort(404)
|
||||
|
||||
|
@ -28,7 +28,7 @@ def index(user):
|
|||
@bp.route("/<authorization_id>", methods=["GET", "POST"])
|
||||
@permissions_needed("manage_oidc")
|
||||
def view(user, authorization_id):
|
||||
authorization = AuthorizationCode.get(authorization_code_id=authorization_id)
|
||||
authorization = models.AuthorizationCode.get(authorization_code_id=authorization_id)
|
||||
return render_template(
|
||||
"oidc/admin/authorization_view.html",
|
||||
authorization=authorization,
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import datetime
|
||||
|
||||
from canaille.app import models
|
||||
from canaille.app.flask import permissions_needed
|
||||
from canaille.app.flask import render_htmx_template
|
||||
from canaille.app.flask import request_is_htmx
|
||||
from canaille.app.forms import TableForm
|
||||
from canaille.oidc.forms import ClientAddForm
|
||||
from canaille.oidc.models import Client
|
||||
from flask import abort
|
||||
from flask import Blueprint
|
||||
from flask import flash
|
||||
|
@ -23,7 +23,7 @@ bp = Blueprint("clients", __name__, url_prefix="/admin/client")
|
|||
@bp.route("/", methods=["GET", "POST"])
|
||||
@permissions_needed("manage_oidc")
|
||||
def index(user):
|
||||
table_form = TableForm(Client, formdata=request.form)
|
||||
table_form = TableForm(models.Client, formdata=request.form)
|
||||
if request.form and request.form.get("page") and not table_form.validate():
|
||||
abort(404)
|
||||
|
||||
|
@ -53,7 +53,7 @@ def add(user):
|
|||
|
||||
client_id = gen_salt(24)
|
||||
client_id_issued_at = datetime.datetime.now(datetime.timezone.utc)
|
||||
client = Client(
|
||||
client = models.Client(
|
||||
client_id=client_id,
|
||||
client_id_issued_at=client_id_issued_at,
|
||||
client_name=form["client_name"].data,
|
||||
|
@ -104,7 +104,7 @@ def edit(user, client_id):
|
|||
|
||||
|
||||
def client_edit(client_id):
|
||||
client = Client.get(client_id=client_id)
|
||||
client = models.Client.get(client_id=client_id)
|
||||
|
||||
if not client:
|
||||
abort(404)
|
||||
|
@ -152,7 +152,7 @@ def client_edit(client_id):
|
|||
software_version=form["software_version"].data,
|
||||
jwk=form["jwk"].data,
|
||||
jwks_uri=form["jwks_uri"].data,
|
||||
audience=[Client.get(id=id) for id in form["audience"].data],
|
||||
audience=[models.Client.get(id=id) for id in form["audience"].data],
|
||||
preconsent=form["preconsent"].data,
|
||||
)
|
||||
client.save()
|
||||
|
@ -164,7 +164,7 @@ def client_edit(client_id):
|
|||
|
||||
|
||||
def client_delete(client_id):
|
||||
client = Client.get(client_id=client_id)
|
||||
client = models.Client.get(client_id=client_id)
|
||||
|
||||
if not client:
|
||||
abort(404)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import click
|
||||
from canaille.app import models
|
||||
from canaille.app.commands import with_backendcontext
|
||||
from canaille.oidc.models import AuthorizationCode
|
||||
from canaille.oidc.models import Token
|
||||
from flask.cli import with_appcontext
|
||||
|
||||
|
||||
|
@ -12,11 +11,11 @@ def clean():
|
|||
"""
|
||||
Remove expired tokens and authorization codes.
|
||||
"""
|
||||
for t in Token.query():
|
||||
for t in models.Token.query():
|
||||
if t.is_expired():
|
||||
t.delete()
|
||||
|
||||
for a in AuthorizationCode.query():
|
||||
for a in models.AuthorizationCode.query():
|
||||
if a.is_expired():
|
||||
a.delete()
|
||||
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
import datetime
|
||||
import uuid
|
||||
|
||||
from canaille.app import models
|
||||
from canaille.app.flask import user_needed
|
||||
from canaille.oidc.models import Client
|
||||
from canaille.oidc.models import Consent
|
||||
from flask import Blueprint
|
||||
from flask import flash
|
||||
from flask import redirect
|
||||
|
@ -20,12 +19,14 @@ bp = Blueprint("consents", __name__, url_prefix="/consent")
|
|||
@bp.route("/")
|
||||
@user_needed()
|
||||
def consents(user):
|
||||
consents = Consent.query(subject=user)
|
||||
consents = models.Consent.query(subject=user)
|
||||
clients = {t.client for t in consents}
|
||||
|
||||
nb_consents = len(consents)
|
||||
nb_preconsents = sum(
|
||||
1 for client in Client.query() if client.preconsent and client not in clients
|
||||
1
|
||||
for client in models.Client.query()
|
||||
if client.preconsent and client not in clients
|
||||
)
|
||||
|
||||
return render_template(
|
||||
|
@ -42,11 +43,11 @@ def consents(user):
|
|||
@bp.route("/pre-consents")
|
||||
@user_needed()
|
||||
def pre_consents(user):
|
||||
consents = Consent.query(subject=user)
|
||||
consents = models.Consent.query(subject=user)
|
||||
clients = {t.client for t in consents}
|
||||
preconsented = [
|
||||
client
|
||||
for client in Client.query()
|
||||
for client in models.Client.query()
|
||||
if client.preconsent and client not in clients
|
||||
]
|
||||
|
||||
|
@ -67,7 +68,7 @@ def pre_consents(user):
|
|||
@bp.route("/revoke/<consent_id>")
|
||||
@user_needed()
|
||||
def revoke(user, consent_id):
|
||||
consent = Consent.get(consent_id=consent_id)
|
||||
consent = models.Consent.get(consent_id=consent_id)
|
||||
|
||||
if not consent or consent.subject != user:
|
||||
flash(_("Could not revoke this access"), "error")
|
||||
|
@ -85,7 +86,7 @@ def revoke(user, consent_id):
|
|||
@bp.route("/restore/<consent_id>")
|
||||
@user_needed()
|
||||
def restore(user, consent_id):
|
||||
consent = Consent.get(consent_id=consent_id)
|
||||
consent = models.Consent.get(consent_id=consent_id)
|
||||
|
||||
if not consent or consent.subject != user:
|
||||
flash(_("Could not restore this access"), "error")
|
||||
|
@ -106,19 +107,19 @@ def restore(user, consent_id):
|
|||
@bp.route("/revoke-preconsent/<client_id>")
|
||||
@user_needed()
|
||||
def revoke_preconsent(user, client_id):
|
||||
client = Client.get(client_id=client_id)
|
||||
client = models.Client.get(client_id=client_id)
|
||||
|
||||
if not client or not client.preconsent:
|
||||
flash(_("Could not revoke this access"), "error")
|
||||
return redirect(url_for("oidc.consents.consents"))
|
||||
|
||||
consent = Consent.get(client=client, subject=user)
|
||||
consent = models.Consent.get(client=client, subject=user)
|
||||
if consent:
|
||||
return redirect(
|
||||
url_for("oidc.consents.revoke", consent_id=consent.consent_id[0])
|
||||
)
|
||||
|
||||
consent = Consent(
|
||||
consent = models.Consent(
|
||||
consent_id=str(uuid.uuid4()),
|
||||
client=client,
|
||||
subject=user,
|
||||
|
|
|
@ -6,10 +6,10 @@ from authlib.jose import JsonWebKey
|
|||
from authlib.jose import jwt
|
||||
from authlib.oauth2 import OAuth2Error
|
||||
from canaille import csrf
|
||||
from canaille.app import models
|
||||
from canaille.app.flask import current_user
|
||||
from canaille.app.flask import set_parameter_in_url_query
|
||||
from canaille.core.forms import FullLoginForm
|
||||
from canaille.core.models import User
|
||||
from flask import abort
|
||||
from flask import Blueprint
|
||||
from flask import current_app
|
||||
|
@ -25,8 +25,6 @@ from werkzeug.datastructures import CombinedMultiDict
|
|||
|
||||
from .forms import AuthorizeForm
|
||||
from .forms import LogoutForm
|
||||
from .models import Client
|
||||
from .models import Consent
|
||||
from .oauth import authorization
|
||||
from .oauth import ClientConfigurationEndpoint
|
||||
from .oauth import ClientRegistrationEndpoint
|
||||
|
@ -59,7 +57,7 @@ def authorize():
|
|||
if "client_id" not in request.args:
|
||||
abort(400)
|
||||
|
||||
client = Client.get(request.args["client_id"])
|
||||
client = models.Client.get(client_id=request.args["client_id"])
|
||||
if not client:
|
||||
abort(400)
|
||||
|
||||
|
@ -78,7 +76,7 @@ def authorize():
|
|||
if request.method == "GET":
|
||||
return render_template("login.html", form=form, menu=False)
|
||||
|
||||
user = User.get_from_login(form.login.data)
|
||||
user = models.User.get_from_login(form.login.data)
|
||||
if (
|
||||
not form.validate()
|
||||
or not user
|
||||
|
@ -96,7 +94,7 @@ def authorize():
|
|||
|
||||
# CONSENT
|
||||
|
||||
consents = Consent.query(
|
||||
consents = models.Consent.query(
|
||||
client=client,
|
||||
subject=user,
|
||||
)
|
||||
|
@ -153,7 +151,7 @@ def authorize():
|
|||
list(set(scopes + consents[0].scope))
|
||||
).split(" ")
|
||||
else:
|
||||
consent = Consent(
|
||||
consent = models.Consent(
|
||||
consent_id=str(uuid.uuid4()),
|
||||
client=client,
|
||||
subject=user,
|
||||
|
@ -275,7 +273,7 @@ def end_session():
|
|||
valid_uris = []
|
||||
|
||||
if "client_id" in data:
|
||||
client = Client.get(data["client_id"])
|
||||
client = models.Client.get(client_id=data["client_id"])
|
||||
if client:
|
||||
valid_uris = client.post_logout_redirect_uris
|
||||
|
||||
|
@ -317,7 +315,7 @@ def end_session():
|
|||
else [id_token["aud"]]
|
||||
)
|
||||
for client_id in client_ids:
|
||||
client = Client.get(client_id)
|
||||
client = models.Client.get(client_id=client_id)
|
||||
if client:
|
||||
valid_uris.extend(client.post_logout_redirect_uris or [])
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import wtforms
|
||||
from canaille.app import models
|
||||
from canaille.app.forms import HTMXForm
|
||||
from canaille.app.forms import is_uri
|
||||
from canaille.oidc.models import Client
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
||||
|
||||
|
@ -14,7 +14,7 @@ class LogoutForm(HTMXForm):
|
|||
|
||||
|
||||
def client_audiences():
|
||||
return [(client.id, client.client_name) for client in Client.query()]
|
||||
return [(client.id, client.client_name) for client in models.Client.query()]
|
||||
|
||||
|
||||
class ClientAddForm(HTMXForm):
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
import os
|
||||
|
||||
from canaille.app import models
|
||||
from canaille.backends.ldap.installation import install_schema
|
||||
from canaille.backends.ldap.installation import ldap_connection
|
||||
from canaille.oidc.models import AuthorizationCode
|
||||
from canaille.oidc.models import Client
|
||||
from canaille.oidc.models import Consent
|
||||
from canaille.oidc.models import Token
|
||||
from cryptography.hazmat.backends import default_backend as crypto_default_backend
|
||||
from cryptography.hazmat.primitives import serialization as crypto_serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
|
@ -19,10 +16,10 @@ def install(config):
|
|||
|
||||
def setup_ldap_tree(config):
|
||||
with ldap_connection(config) as conn:
|
||||
Token.install(conn)
|
||||
AuthorizationCode.install(conn)
|
||||
Client.install(conn)
|
||||
Consent.install(conn)
|
||||
models.Token.install(conn)
|
||||
models.AuthorizationCode.install(conn)
|
||||
models.Client.install(conn)
|
||||
models.Consent.install(conn)
|
||||
|
||||
|
||||
def setup_keypair(config):
|
||||
|
|
|
@ -4,6 +4,7 @@ from authlib.oauth2.rfc6749 import AuthorizationCodeMixin
|
|||
from authlib.oauth2.rfc6749 import ClientMixin
|
||||
from authlib.oauth2.rfc6749 import TokenMixin
|
||||
from authlib.oauth2.rfc6749 import util
|
||||
from canaille.app import models
|
||||
from canaille.backends.ldap.ldapobject import LDAPObject
|
||||
|
||||
|
||||
|
@ -97,13 +98,13 @@ class Client(LDAPObject, ClientMixin):
|
|||
return metadata
|
||||
|
||||
def delete(self):
|
||||
for consent in Consent.query(client=self):
|
||||
for consent in models.Consent.query(client=self):
|
||||
consent.delete()
|
||||
|
||||
for code in AuthorizationCode.query(client=self):
|
||||
for code in models.AuthorizationCode.query(client=self):
|
||||
code.delete()
|
||||
|
||||
for token in Token.query(client=self):
|
||||
for token in models.Token.query(client=self):
|
||||
token.delete()
|
||||
|
||||
super().delete()
|
||||
|
@ -243,7 +244,7 @@ class Consent(LDAPObject):
|
|||
self.revokation_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
self.save()
|
||||
|
||||
tokens = Token.query(
|
||||
tokens = models.Token.query(
|
||||
client=self.client,
|
||||
subject=self.subject,
|
||||
)
|
||||
|
|
|
@ -26,15 +26,11 @@ from authlib.oidc.core.grants import OpenIDCode as _OpenIDCode
|
|||
from authlib.oidc.core.grants import OpenIDHybridGrant as _OpenIDHybridGrant
|
||||
from authlib.oidc.core.grants import OpenIDImplicitGrant as _OpenIDImplicitGrant
|
||||
from authlib.oidc.core.grants.util import generate_id_token
|
||||
from canaille.core.models import User
|
||||
from canaille.app import models
|
||||
from flask import current_app
|
||||
from flask import request
|
||||
from werkzeug.security import gen_salt
|
||||
|
||||
from .models import AuthorizationCode
|
||||
from .models import Client
|
||||
from .models import Token
|
||||
|
||||
DEFAULT_JWT_KTY = "RSA"
|
||||
DEFAULT_JWT_ALG = "RS256"
|
||||
DEFAULT_JWT_EXP = 3600
|
||||
|
@ -55,8 +51,8 @@ DEFAULT_JWT_MAPPING = {
|
|||
|
||||
|
||||
def exists_nonce(nonce, req):
|
||||
client = Client.get(id=req.client_id)
|
||||
exists = AuthorizationCode.query(client=client, nonce=nonce)
|
||||
client = models.Client.get(id=req.client_id)
|
||||
exists = models.AuthorizationCode.query(client=client, nonce=nonce)
|
||||
return bool(exists)
|
||||
|
||||
|
||||
|
@ -142,7 +138,7 @@ def save_authorization_code(code, request):
|
|||
nonce = request.data.get("nonce")
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
scope = request.client.get_allowed_scope(request.scope)
|
||||
code = AuthorizationCode(
|
||||
code = models.AuthorizationCode(
|
||||
authorization_code_id=gen_salt(48),
|
||||
code=code,
|
||||
subject=request.user,
|
||||
|
@ -166,7 +162,7 @@ class AuthorizationCodeGrant(_AuthorizationCodeGrant):
|
|||
return save_authorization_code(code, request)
|
||||
|
||||
def query_authorization_code(self, code, client):
|
||||
item = AuthorizationCode.query(code=code, client=client)
|
||||
item = models.AuthorizationCode.query(code=code, client=client)
|
||||
if item and not item[0].is_expired():
|
||||
return item[0]
|
||||
|
||||
|
@ -196,7 +192,7 @@ class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant):
|
|||
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
|
||||
|
||||
def authenticate_user(self, username, password):
|
||||
user = User.get_from_login(username)
|
||||
user = models.User.get_from_login(username)
|
||||
|
||||
if not user or not user.check_password(password):
|
||||
return None
|
||||
|
@ -206,7 +202,7 @@ class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant):
|
|||
|
||||
class RefreshTokenGrant(_RefreshTokenGrant):
|
||||
def authenticate_refresh_token(self, refresh_token):
|
||||
token = Token.query(refresh_token=refresh_token)
|
||||
token = models.Token.query(refresh_token=refresh_token)
|
||||
if token and token[0].is_refresh_token_active():
|
||||
return token[0]
|
||||
|
||||
|
@ -252,12 +248,12 @@ class OpenIDHybridGrant(_OpenIDHybridGrant):
|
|||
|
||||
|
||||
def query_client(client_id):
|
||||
return Client.get(client_id=client_id)
|
||||
return models.Client.get(client_id=client_id)
|
||||
|
||||
|
||||
def save_token(token, request):
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
t = Token(
|
||||
t = models.Token(
|
||||
token_id=gen_salt(48),
|
||||
type=token["token_type"],
|
||||
access_token=token["access_token"],
|
||||
|
@ -274,20 +270,20 @@ def save_token(token, request):
|
|||
|
||||
class BearerTokenValidator(_BearerTokenValidator):
|
||||
def authenticate_token(self, token_string):
|
||||
return Token.get(access_token=token_string)
|
||||
return models.Token.get(access_token=token_string)
|
||||
|
||||
|
||||
def query_token(token, token_type_hint):
|
||||
if token_type_hint == "access_token":
|
||||
return Token.get(access_token=token)
|
||||
return models.Token.get(access_token=token)
|
||||
elif token_type_hint == "refresh_token":
|
||||
return Token.get(refresh_token=token)
|
||||
return models.Token.get(refresh_token=token)
|
||||
|
||||
item = Token.get(access_token=token)
|
||||
item = models.Token.get(access_token=token)
|
||||
if item:
|
||||
return item
|
||||
|
||||
item = Token.get(refresh_token=token)
|
||||
item = models.Token.get(refresh_token=token)
|
||||
if item:
|
||||
return item
|
||||
|
||||
|
@ -369,7 +365,7 @@ class ClientRegistrationEndpoint(ClientManagementMixin, _ClientRegistrationEndpo
|
|||
client_metadata["scope"], list
|
||||
):
|
||||
client_metadata["scope"] = client_metadata["scope"].split(" ")
|
||||
client = Client(**client_info, **client_metadata)
|
||||
client = models.Client(**client_info, **client_metadata)
|
||||
client.save()
|
||||
return client
|
||||
|
||||
|
@ -377,7 +373,7 @@ class ClientRegistrationEndpoint(ClientManagementMixin, _ClientRegistrationEndpo
|
|||
class ClientConfigurationEndpoint(ClientManagementMixin, _ClientConfigurationEndpoint):
|
||||
def authenticate_client(self, request):
|
||||
client_id = request.uri.split("/")[-1]
|
||||
return Client.get(client_id=client_id)
|
||||
return models.Client.get(client_id=client_id)
|
||||
|
||||
def revoke_access_token(self, request, token):
|
||||
pass
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import datetime
|
||||
|
||||
from canaille.app import models
|
||||
from canaille.app.flask import permissions_needed
|
||||
from canaille.app.flask import render_htmx_template
|
||||
from canaille.app.forms import TableForm
|
||||
from canaille.oidc.models import Token
|
||||
from flask import abort
|
||||
from flask import Blueprint
|
||||
from flask import flash
|
||||
|
@ -19,7 +19,7 @@ bp = Blueprint("tokens", __name__, url_prefix="/admin/token")
|
|||
@bp.route("/", methods=["GET", "POST"])
|
||||
@permissions_needed("manage_oidc")
|
||||
def index(user):
|
||||
table_form = TableForm(Token, formdata=request.form)
|
||||
table_form = TableForm(models.Token, formdata=request.form)
|
||||
if request.form and request.form.get("page") and not table_form.validate():
|
||||
abort(404)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def index(user):
|
|||
@bp.route("/<token_id>", methods=["GET", "POST"])
|
||||
@permissions_needed("manage_oidc")
|
||||
def view(user, token_id):
|
||||
token = Token.get(token_id=token_id)
|
||||
token = models.Token.get(token_id=token_id)
|
||||
|
||||
if not token:
|
||||
abort(404)
|
||||
|
@ -46,7 +46,7 @@ def view(user, token_id):
|
|||
@bp.route("/<token_id>/revoke", methods=["GET", "POST"])
|
||||
@permissions_needed("manage_oidc")
|
||||
def revoke(user, token_id):
|
||||
token = Token.get(token_id=token_id)
|
||||
token = models.Token.get(token_id=token_id)
|
||||
|
||||
if not token:
|
||||
abort(404)
|
||||
|
|
|
@ -10,15 +10,13 @@ from canaille import create_app as canaille_app
|
|||
|
||||
|
||||
def populate(app):
|
||||
from canaille.core.models import Group
|
||||
from canaille.core.models import User
|
||||
from canaille.app import models
|
||||
from canaille.core.populate import fake_groups
|
||||
from canaille.core.populate import fake_users
|
||||
from canaille.oidc.models import Client
|
||||
|
||||
with app.app_context():
|
||||
with app.backend.session():
|
||||
jane = User(
|
||||
jane = models.User(
|
||||
formatted_name="Jane Doe",
|
||||
given_name="Jane",
|
||||
family_name="Doe",
|
||||
|
@ -38,7 +36,7 @@ def populate(app):
|
|||
)
|
||||
jane.save()
|
||||
|
||||
jack = User(
|
||||
jack = models.User(
|
||||
formatted_name="Jack Doe",
|
||||
given_name="Jack",
|
||||
family_name="Doe",
|
||||
|
@ -53,7 +51,7 @@ def populate(app):
|
|||
)
|
||||
jack.save()
|
||||
|
||||
john = User(
|
||||
john = models.User(
|
||||
formatted_name="John Doe",
|
||||
given_name="John",
|
||||
family_name="Doe",
|
||||
|
@ -68,7 +66,7 @@ def populate(app):
|
|||
)
|
||||
john.save()
|
||||
|
||||
james = User(
|
||||
james = models.User(
|
||||
formatted_name="James Doe",
|
||||
given_name="James",
|
||||
family_name="Doe",
|
||||
|
@ -77,28 +75,28 @@ def populate(app):
|
|||
)
|
||||
james.save()
|
||||
|
||||
users = Group(
|
||||
users = models.Group(
|
||||
display_name="users",
|
||||
members=[jane, jack, john, james],
|
||||
description="The regular users.",
|
||||
)
|
||||
users.save()
|
||||
|
||||
users = Group(
|
||||
users = models.Group(
|
||||
display_name="admins",
|
||||
members=[jane],
|
||||
description="The administrators.",
|
||||
)
|
||||
users.save()
|
||||
|
||||
users = Group(
|
||||
users = models.Group(
|
||||
display_name="moderators",
|
||||
members=[james],
|
||||
description="People who can manage users.",
|
||||
)
|
||||
users.save()
|
||||
|
||||
client1 = Client(
|
||||
client1 = models.Client(
|
||||
client_id="1JGkkzCbeHpGtlqgI5EENByf",
|
||||
client_secret="2xYPSReTQRmGG1yppMVZQ0ASXwFejPyirvuPbKhNa6TmKC5x",
|
||||
client_name="Client1",
|
||||
|
@ -115,7 +113,7 @@ def populate(app):
|
|||
)
|
||||
client1.save()
|
||||
|
||||
client2 = Client(
|
||||
client2 = models.Client(
|
||||
client_id="gn4yFN7GDykL7QP8v8gS9YfV",
|
||||
client_secret="ouFJE5WpICt6hxTyf8icXPeeklMektMY4gV0Rmf3aY60VElA",
|
||||
client_name="Client2",
|
||||
|
|
|
@ -3,6 +3,7 @@ from unittest import mock
|
|||
|
||||
import ldap.dn
|
||||
import pytest
|
||||
from canaille.app import models
|
||||
from canaille.app.configuration import ConfigurationException
|
||||
from canaille.app.configuration import validate
|
||||
from canaille.backends.ldap.backend import setup_ldap_models
|
||||
|
@ -11,12 +12,10 @@ from canaille.backends.ldap.ldapobject import python_attrs_to_ldap
|
|||
from canaille.backends.ldap.utils import ldap_to_python
|
||||
from canaille.backends.ldap.utils import python_to_ldap
|
||||
from canaille.backends.ldap.utils import Syntax
|
||||
from canaille.core.models import Group
|
||||
from canaille.core.models import User
|
||||
|
||||
|
||||
def test_object_creation(app, backend):
|
||||
user = User(
|
||||
user = models.User(
|
||||
formatted_name="Doe", # leading space
|
||||
family_name="Doe",
|
||||
user_name="user",
|
||||
|
@ -26,7 +25,7 @@ def test_object_creation(app, backend):
|
|||
user.save()
|
||||
assert user.exists
|
||||
|
||||
user = User.get(id=user.id)
|
||||
user = models.User.get(id=user.id)
|
||||
assert user.exists
|
||||
|
||||
user.delete()
|
||||
|
@ -38,7 +37,7 @@ def test_repr(backend, foo_group, user):
|
|||
|
||||
|
||||
def test_dn_when_leading_space_in_id_attribute(backend):
|
||||
user = User(
|
||||
user = models.User(
|
||||
formatted_name=" Doe", # leading space
|
||||
family_name="Doe",
|
||||
user_name="user",
|
||||
|
@ -54,7 +53,7 @@ def test_dn_when_leading_space_in_id_attribute(backend):
|
|||
|
||||
|
||||
def test_dn_when_ldap_special_char_in_id_attribute(backend):
|
||||
user = User(
|
||||
user = models.User(
|
||||
formatted_name="#Doe", # special char
|
||||
family_name="Doe",
|
||||
user_name="user",
|
||||
|
@ -70,27 +69,32 @@ def test_dn_when_ldap_special_char_in_id_attribute(backend):
|
|||
|
||||
|
||||
def test_filter(backend, foo_group, bar_group):
|
||||
assert Group.query(display_name="foo") == [foo_group]
|
||||
assert Group.query(display_name="bar") == [bar_group]
|
||||
assert models.Group.query(display_name="foo") == [foo_group]
|
||||
assert models.Group.query(display_name="bar") == [bar_group]
|
||||
|
||||
assert Group.query(display_name="foo") != 3
|
||||
assert models.Group.query(display_name="foo") != 3
|
||||
|
||||
assert Group.query(display_name=["foo"]) == [foo_group]
|
||||
assert Group.query(display_name=["bar"]) == [bar_group]
|
||||
assert models.Group.query(display_name=["foo"]) == [foo_group]
|
||||
assert models.Group.query(display_name=["bar"]) == [bar_group]
|
||||
|
||||
assert set(Group.query(display_name=["foo", "bar"])) == {foo_group, bar_group}
|
||||
assert set(models.Group.query(display_name=["foo", "bar"])) == {
|
||||
foo_group,
|
||||
bar_group,
|
||||
}
|
||||
|
||||
|
||||
def test_fuzzy(backend, user, moderator, admin):
|
||||
assert set(User.query()) == {user, moderator, admin}
|
||||
assert set(User.fuzzy("Jack")) == {moderator}
|
||||
assert set(User.fuzzy("Jack", ["formatted_name"])) == {moderator}
|
||||
assert set(User.fuzzy("Jack", ["user_name"])) == set()
|
||||
assert set(User.fuzzy("Jack", ["user_name", "formatted_name"])) == {moderator}
|
||||
assert set(User.fuzzy("moderator")) == {moderator}
|
||||
assert set(User.fuzzy("oderat")) == {moderator}
|
||||
assert set(User.fuzzy("oDeRat")) == {moderator}
|
||||
assert set(User.fuzzy("ack")) == {moderator}
|
||||
assert set(models.User.query()) == {user, moderator, admin}
|
||||
assert set(models.User.fuzzy("Jack")) == {moderator}
|
||||
assert set(models.User.fuzzy("Jack", ["formatted_name"])) == {moderator}
|
||||
assert set(models.User.fuzzy("Jack", ["user_name"])) == set()
|
||||
assert set(models.User.fuzzy("Jack", ["user_name", "formatted_name"])) == {
|
||||
moderator
|
||||
}
|
||||
assert set(models.User.fuzzy("moderator")) == {moderator}
|
||||
assert set(models.User.fuzzy("oderat")) == {moderator}
|
||||
assert set(models.User.fuzzy("oDeRat")) == {moderator}
|
||||
assert set(models.User.fuzzy("ack")) == {moderator}
|
||||
|
||||
|
||||
def test_ldap_to_python():
|
||||
|
@ -180,11 +184,11 @@ def test_guess_object_from_dn(backend, testclient, foo_group):
|
|||
foo_group.members = [foo_group]
|
||||
foo_group.save()
|
||||
g = LDAPObject.get(id=foo_group.dn)
|
||||
assert isinstance(g, Group)
|
||||
assert isinstance(g, models.Group)
|
||||
assert g == foo_group
|
||||
assert g.cn == foo_group.cn
|
||||
|
||||
ou = LDAPObject.get(id=f"{Group.base},{Group.root_dn}")
|
||||
ou = LDAPObject.get(id=f"{models.Group.base},{models.Group.root_dn}")
|
||||
assert isinstance(ou, LDAPObject)
|
||||
|
||||
|
||||
|
@ -192,11 +196,11 @@ def test_object_class_update(backend, testclient):
|
|||
testclient.app.config["BACKENDS"]["LDAP"]["USER_CLASS"] = ["inetOrgPerson"]
|
||||
setup_ldap_models(testclient.app.config)
|
||||
|
||||
user1 = User(cn="foo1", sn="bar1")
|
||||
user1 = models.User(cn="foo1", sn="bar1")
|
||||
user1.save()
|
||||
|
||||
assert user1.objectClass == ["inetOrgPerson"]
|
||||
assert User.get(id=user1.id).objectClass == ["inetOrgPerson"]
|
||||
assert models.User.get(id=user1.id).objectClass == ["inetOrgPerson"]
|
||||
|
||||
testclient.app.config["BACKENDS"]["LDAP"]["USER_CLASS"] = [
|
||||
"inetOrgPerson",
|
||||
|
@ -204,18 +208,24 @@ def test_object_class_update(backend, testclient):
|
|||
]
|
||||
setup_ldap_models(testclient.app.config)
|
||||
|
||||
user2 = User(cn="foo2", sn="bar2")
|
||||
user2 = models.User(cn="foo2", sn="bar2")
|
||||
user2.save()
|
||||
|
||||
assert user2.objectClass == ["inetOrgPerson", "extensibleObject"]
|
||||
assert User.get(id=user2.id).objectClass == ["inetOrgPerson", "extensibleObject"]
|
||||
assert models.User.get(id=user2.id).objectClass == [
|
||||
"inetOrgPerson",
|
||||
"extensibleObject",
|
||||
]
|
||||
|
||||
user1 = User.get(id=user1.id)
|
||||
user1 = models.User.get(id=user1.id)
|
||||
assert user1.objectClass == ["inetOrgPerson"]
|
||||
|
||||
user1.save()
|
||||
assert user1.objectClass == ["inetOrgPerson", "extensibleObject"]
|
||||
assert User.get(id=user1.id).objectClass == ["inetOrgPerson", "extensibleObject"]
|
||||
assert models.User.get(id=user1.id).objectClass == [
|
||||
"inetOrgPerson",
|
||||
"extensibleObject",
|
||||
]
|
||||
|
||||
user1.delete()
|
||||
user2.delete()
|
||||
|
|
|
@ -1,21 +1,20 @@
|
|||
from canaille.core.models import Group
|
||||
from canaille.core.models import User
|
||||
from canaille.app import models
|
||||
|
||||
|
||||
def test_model_comparison(testclient, backend):
|
||||
foo1 = User(
|
||||
foo1 = models.User(
|
||||
user_name="foo",
|
||||
family_name="foo",
|
||||
formatted_name="foo",
|
||||
)
|
||||
foo1.save()
|
||||
bar = User(
|
||||
bar = models.User(
|
||||
user_name="bar",
|
||||
family_name="bar",
|
||||
formatted_name="bar",
|
||||
)
|
||||
bar.save()
|
||||
foo2 = User.get(id=foo1.id)
|
||||
foo2 = models.User.get(id=foo1.id)
|
||||
|
||||
assert foo1 == foo2
|
||||
assert foo1 != bar
|
||||
|
@ -25,23 +24,23 @@ def test_model_comparison(testclient, backend):
|
|||
|
||||
|
||||
def test_model_lifecycle(testclient, backend):
|
||||
user = User(
|
||||
user = models.User(
|
||||
user_name="user_name",
|
||||
family_name="family_name",
|
||||
formatted_name="formatted_name",
|
||||
)
|
||||
|
||||
assert not User.query()
|
||||
assert not User.query(id=user.id)
|
||||
assert not User.query(id="invalid")
|
||||
assert not User.get(id=user.id)
|
||||
assert not models.User.query()
|
||||
assert not models.User.query(id=user.id)
|
||||
assert not models.User.query(id="invalid")
|
||||
assert not models.User.get(id=user.id)
|
||||
|
||||
user.save()
|
||||
|
||||
assert User.query() == [user]
|
||||
assert User.query(id=user.id) == [user]
|
||||
assert not User.query(id="invalid")
|
||||
assert User.get(id=user.id) == user
|
||||
assert models.User.query() == [user]
|
||||
assert models.User.query(id=user.id) == [user]
|
||||
assert not models.User.query(id="invalid")
|
||||
assert models.User.get(id=user.id) == user
|
||||
|
||||
user.family_name = "new_family_name"
|
||||
|
||||
|
@ -53,12 +52,12 @@ def test_model_lifecycle(testclient, backend):
|
|||
|
||||
user.delete()
|
||||
|
||||
assert not User.query(id=user.id)
|
||||
assert not User.get(id=user.id)
|
||||
assert not models.User.query(id=user.id)
|
||||
assert not models.User.get(id=user.id)
|
||||
|
||||
|
||||
def test_model_attribute_edition(testclient, backend):
|
||||
user = User(
|
||||
user = models.User(
|
||||
user_name="user_name",
|
||||
family_name="family_name",
|
||||
formatted_name="formatted_name",
|
||||
|
@ -71,7 +70,7 @@ def test_model_attribute_edition(testclient, backend):
|
|||
assert user.family_name == ["family_name"]
|
||||
assert user.email == ["email1@user.com", "email2@user.com"]
|
||||
|
||||
user = User.get(id=user.id)
|
||||
user = models.User.get(id=user.id)
|
||||
assert user.user_name == ["user_name"]
|
||||
assert user.family_name == ["family_name"]
|
||||
assert user.email == ["email1@user.com", "email2@user.com"]
|
||||
|
@ -83,7 +82,7 @@ def test_model_attribute_edition(testclient, backend):
|
|||
assert user.family_name == ["new_family_name"]
|
||||
assert user.email == ["email1@user.com"]
|
||||
|
||||
user = User.get(id=user.id)
|
||||
user = models.User.get(id=user.id)
|
||||
assert user.family_name == ["new_family_name"]
|
||||
assert user.email == ["email1@user.com"]
|
||||
|
||||
|
@ -96,7 +95,7 @@ def test_model_attribute_edition(testclient, backend):
|
|||
|
||||
|
||||
def test_model_indexation(testclient, backend):
|
||||
user = User(
|
||||
user = models.User(
|
||||
user_name="user_name",
|
||||
family_name="family_name",
|
||||
formatted_name="formatted_name",
|
||||
|
@ -104,56 +103,58 @@ def test_model_indexation(testclient, backend):
|
|||
)
|
||||
user.save()
|
||||
|
||||
assert User.get(family_name="family_name") == user
|
||||
assert not User.get(family_name="new_family_name")
|
||||
assert User.get(email="email1@user.com") == user
|
||||
assert User.get(email="email2@user.com") == user
|
||||
assert not User.get(email="email3@user.com")
|
||||
assert models.User.get(family_name="family_name") == user
|
||||
assert not models.User.get(family_name="new_family_name")
|
||||
assert models.User.get(email="email1@user.com") == user
|
||||
assert models.User.get(email="email2@user.com") == user
|
||||
assert not models.User.get(email="email3@user.com")
|
||||
|
||||
user.family_name = "new_family_name"
|
||||
user.email = ["email2@user.com"]
|
||||
|
||||
assert User.get(family_name="family_name") != user
|
||||
assert not User.get(family_name="new_family_name")
|
||||
assert User.get(email="email1@user.com") != user
|
||||
assert User.get(email="email2@user.com") != user
|
||||
assert not User.get(email="email3@user.com")
|
||||
assert models.User.get(family_name="family_name") != user
|
||||
assert not models.User.get(family_name="new_family_name")
|
||||
assert models.User.get(email="email1@user.com") != user
|
||||
assert models.User.get(email="email2@user.com") != user
|
||||
assert not models.User.get(email="email3@user.com")
|
||||
|
||||
user.save()
|
||||
|
||||
assert not User.get(family_name="family_name")
|
||||
assert User.get(family_name="new_family_name") == user
|
||||
assert not User.get(email="email1@user.com")
|
||||
assert User.get(email="email2@user.com") == user
|
||||
assert not User.get(email="email3@user.com")
|
||||
assert not models.User.get(family_name="family_name")
|
||||
assert models.User.get(family_name="new_family_name") == user
|
||||
assert not models.User.get(email="email1@user.com")
|
||||
assert models.User.get(email="email2@user.com") == user
|
||||
assert not models.User.get(email="email3@user.com")
|
||||
|
||||
user.delete()
|
||||
|
||||
assert not User.get(family_name="family_name")
|
||||
assert not User.get(family_name="new_family_name")
|
||||
assert not User.get(email="email1@user.com")
|
||||
assert not User.get(email="email2@user.com")
|
||||
assert not User.get(email="email3@user.com")
|
||||
assert not models.User.get(family_name="family_name")
|
||||
assert not models.User.get(family_name="new_family_name")
|
||||
assert not models.User.get(email="email1@user.com")
|
||||
assert not models.User.get(email="email2@user.com")
|
||||
assert not models.User.get(email="email3@user.com")
|
||||
|
||||
|
||||
def test_fuzzy(user, moderator, admin, backend):
|
||||
assert set(User.query()) == {user, moderator, admin}
|
||||
assert set(User.fuzzy("Jack")) == {moderator}
|
||||
assert set(User.fuzzy("Jack", ["formatted_name"])) == {moderator}
|
||||
assert set(User.fuzzy("Jack", ["user_name"])) == set()
|
||||
assert set(User.fuzzy("Jack", ["user_name", "formatted_name"])) == {moderator}
|
||||
assert set(User.fuzzy("moderator")) == {moderator}
|
||||
assert set(User.fuzzy("oderat")) == {moderator}
|
||||
assert set(User.fuzzy("oDeRat")) == {moderator}
|
||||
assert set(User.fuzzy("ack")) == {moderator}
|
||||
assert set(models.User.query()) == {user, moderator, admin}
|
||||
assert set(models.User.fuzzy("Jack")) == {moderator}
|
||||
assert set(models.User.fuzzy("Jack", ["formatted_name"])) == {moderator}
|
||||
assert set(models.User.fuzzy("Jack", ["user_name"])) == set()
|
||||
assert set(models.User.fuzzy("Jack", ["user_name", "formatted_name"])) == {
|
||||
moderator
|
||||
}
|
||||
assert set(models.User.fuzzy("moderator")) == {moderator}
|
||||
assert set(models.User.fuzzy("oderat")) == {moderator}
|
||||
assert set(models.User.fuzzy("oDeRat")) == {moderator}
|
||||
assert set(models.User.fuzzy("ack")) == {moderator}
|
||||
|
||||
|
||||
# def test_model_references(user, admin, foo_group, bar_group):
|
||||
def test_model_references(testclient, user, foo_group, admin, bar_group, backend):
|
||||
assert foo_group.members == [user]
|
||||
assert user.groups == [foo_group]
|
||||
assert foo_group in Group.query(members=user)
|
||||
assert user in User.query(groups=foo_group)
|
||||
assert foo_group in models.Group.query(members=user)
|
||||
assert user in models.User.query(groups=foo_group)
|
||||
|
||||
assert user not in bar_group.members
|
||||
assert bar_group not in user.groups
|
||||
|
@ -175,11 +176,11 @@ def test_model_references(testclient, user, foo_group, admin, bar_group, backend
|
|||
def test_model_references_set_unsaved_object(
|
||||
testclient, logged_moderator, user, backend
|
||||
):
|
||||
group = Group(members=[user], display_name="foo")
|
||||
group = models.Group(members=[user], display_name="foo")
|
||||
group.save()
|
||||
user.reload() # an LDAP group can be inconsistent by containing members which doesn't exist
|
||||
|
||||
non_existent_user = User(formatted_name="foo", family_name="bar")
|
||||
non_existent_user = models.User(formatted_name="foo", family_name="bar")
|
||||
group.members = group.members + [non_existent_user]
|
||||
assert group.members == [user, non_existent_user]
|
||||
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
import pytest
|
||||
import slapd
|
||||
from canaille import create_app
|
||||
from canaille.app import models
|
||||
from canaille.backends.ldap.backend import LDAPBackend
|
||||
from canaille.core.models import Group
|
||||
from canaille.core.models import User
|
||||
from flask_webtest import TestApp
|
||||
from werkzeug.security import gen_salt
|
||||
|
||||
|
@ -156,7 +155,7 @@ def testclient(app):
|
|||
|
||||
@pytest.fixture
|
||||
def user(app, backend):
|
||||
u = User(
|
||||
u = models.User(
|
||||
formatted_name="John (johnny) Doe",
|
||||
given_name="John",
|
||||
family_name="Doe",
|
||||
|
@ -176,7 +175,7 @@ def user(app, backend):
|
|||
|
||||
@pytest.fixture
|
||||
def admin(app, backend):
|
||||
u = User(
|
||||
u = models.User(
|
||||
formatted_name="Jane Doe",
|
||||
family_name="Doe",
|
||||
user_name="admin",
|
||||
|
@ -190,7 +189,7 @@ def admin(app, backend):
|
|||
|
||||
@pytest.fixture
|
||||
def moderator(app, backend):
|
||||
u = User(
|
||||
u = models.User(
|
||||
formatted_name="Jack Doe",
|
||||
family_name="Doe",
|
||||
user_name="moderator",
|
||||
|
@ -225,7 +224,7 @@ def logged_moderator(moderator, testclient):
|
|||
|
||||
@pytest.fixture
|
||||
def foo_group(app, user, backend):
|
||||
group = Group(
|
||||
group = models.Group(
|
||||
members=[user],
|
||||
display_name="foo",
|
||||
)
|
||||
|
@ -237,7 +236,7 @@ def foo_group(app, user, backend):
|
|||
|
||||
@pytest.fixture
|
||||
def bar_group(app, admin, backend):
|
||||
group = Group(
|
||||
group = models.Group(
|
||||
members=[admin],
|
||||
display_name="bar",
|
||||
)
|
||||
|
|
|
@ -1,17 +1,16 @@
|
|||
from canaille.app import models
|
||||
from canaille.commands import cli
|
||||
from canaille.core.models import Group
|
||||
from canaille.core.models import User
|
||||
from canaille.core.populate import fake_users
|
||||
|
||||
|
||||
def test_populate_users(testclient, backend):
|
||||
runner = testclient.app.test_cli_runner()
|
||||
|
||||
assert len(User.query()) == 0
|
||||
assert len(models.User.query()) == 0
|
||||
res = runner.invoke(cli, ["populate", "--nb", "10", "users"])
|
||||
assert res.exit_code == 0, res.stdout
|
||||
assert len(User.query()) == 10
|
||||
for user in User.query():
|
||||
assert len(models.User.query()) == 10
|
||||
for user in models.User.query():
|
||||
user.delete()
|
||||
|
||||
|
||||
|
@ -19,13 +18,13 @@ def test_populate_groups(testclient, backend):
|
|||
fake_users(10)
|
||||
runner = testclient.app.test_cli_runner()
|
||||
|
||||
assert len(Group.query()) == 0
|
||||
assert len(models.Group.query()) == 0
|
||||
res = runner.invoke(cli, ["populate", "--nb", "10", "groups"])
|
||||
assert res.exit_code == 0, res.stdout
|
||||
assert len(Group.query()) == 10
|
||||
assert len(models.Group.query()) == 10
|
||||
|
||||
for group in Group.query():
|
||||
for group in models.Group.query():
|
||||
group.delete()
|
||||
|
||||
for user in User.query():
|
||||
for user in models.User.query():
|
||||
user.delete()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from unittest import mock
|
||||
|
||||
from canaille.core.models import User
|
||||
from canaille.app import models
|
||||
|
||||
|
||||
def test_index(testclient, user):
|
||||
|
@ -112,7 +112,7 @@ def test_password_page_without_signin_in_redirects_to_login_page(testclient, use
|
|||
|
||||
def test_user_without_password_first_login(testclient, backend, smtpd):
|
||||
assert len(smtpd.messages) == 0
|
||||
u = User(
|
||||
u = models.User(
|
||||
formatted_name="Temp User",
|
||||
family_name="Temp",
|
||||
user_name="temp",
|
||||
|
@ -146,7 +146,7 @@ def test_first_login_account_initialization_mail_sending_failed(
|
|||
SMTP.side_effect = mock.Mock(side_effect=OSError("unit test mail error"))
|
||||
assert len(smtpd.messages) == 0
|
||||
|
||||
u = User(
|
||||
u = models.User(
|
||||
formatted_name="Temp User",
|
||||
family_name="Temp",
|
||||
user_name="temp",
|
||||
|
@ -168,7 +168,7 @@ def test_first_login_account_initialization_mail_sending_failed(
|
|||
|
||||
def test_first_login_form_error(testclient, backend, smtpd):
|
||||
assert len(smtpd.messages) == 0
|
||||
u = User(
|
||||
u = models.User(
|
||||
formatted_name="Temp User",
|
||||
family_name="Temp",
|
||||
user_name="temp",
|
||||
|
@ -190,7 +190,7 @@ def test_first_login_page_unavailable_for_users_with_password(
|
|||
|
||||
|
||||
def test_user_password_deleted_during_login(testclient, backend):
|
||||
u = User(
|
||||
u = models.User(
|
||||
formatted_name="Temp User",
|
||||
family_name="Temp",
|
||||
user_name="temp",
|
||||
|
@ -214,7 +214,7 @@ def test_user_password_deleted_during_login(testclient, backend):
|
|||
|
||||
|
||||
def test_user_deleted_in_session(testclient, backend):
|
||||
u = User(
|
||||
u = models.User(
|
||||
formatted_name="Jake Doe",
|
||||
family_name="Jake",
|
||||
user_name="jake",
|
||||
|
@ -278,7 +278,7 @@ def test_wrong_login(testclient, user):
|
|||
|
||||
|
||||
def test_admin_self_deletion(testclient, backend):
|
||||
admin = User(
|
||||
admin = models.User(
|
||||
formatted_name="Temp admin",
|
||||
family_name="admin",
|
||||
user_name="temp",
|
||||
|
@ -296,14 +296,14 @@ def test_admin_self_deletion(testclient, backend):
|
|||
.follow(status=200)
|
||||
)
|
||||
|
||||
assert User.get_from_login("temp") is None
|
||||
assert models.User.get_from_login("temp") is None
|
||||
|
||||
with testclient.session_transaction() as sess:
|
||||
assert not sess.get("user_id")
|
||||
|
||||
|
||||
def test_user_self_deletion(testclient, backend):
|
||||
user = User(
|
||||
user = models.User(
|
||||
formatted_name="Temp user",
|
||||
family_name="user",
|
||||
user_name="temp",
|
||||
|
@ -330,7 +330,7 @@ def test_user_self_deletion(testclient, backend):
|
|||
.follow(status=200)
|
||||
)
|
||||
|
||||
assert User.get_from_login("temp") is None
|
||||
assert models.User.get_from_login("temp") is None
|
||||
|
||||
with testclient.session_transaction() as sess:
|
||||
assert not sess.get("user_id")
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
from canaille.core.models import Group
|
||||
from canaille.core.models import User
|
||||
from canaille.app import models
|
||||
from canaille.core.populate import fake_groups
|
||||
from canaille.core.populate import fake_users
|
||||
|
||||
|
||||
def test_no_group(app, backend):
|
||||
assert Group.query() == []
|
||||
assert models.Group.query() == []
|
||||
|
||||
|
||||
def test_group_list_pagination(testclient, logged_admin, foo_group):
|
||||
|
@ -49,7 +48,7 @@ def test_group_list_bad_pages(testclient, logged_admin):
|
|||
|
||||
|
||||
def test_group_deletion(testclient, backend):
|
||||
user = User(
|
||||
user = models.User(
|
||||
formatted_name="foobar",
|
||||
family_name="foobar",
|
||||
user_name="foobar",
|
||||
|
@ -57,7 +56,7 @@ def test_group_deletion(testclient, backend):
|
|||
)
|
||||
user.save()
|
||||
|
||||
group = Group(
|
||||
group = models.Group(
|
||||
members=[user],
|
||||
display_name="foobar",
|
||||
)
|
||||
|
@ -109,7 +108,7 @@ def test_set_groups(app, user, foo_group, bar_group):
|
|||
|
||||
|
||||
def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group):
|
||||
user = User(
|
||||
user = models.User(
|
||||
formatted_name=" Doe", # leading space in id attribute
|
||||
family_name="Doe",
|
||||
user_name="user2",
|
||||
|
@ -137,8 +136,8 @@ def test_moderator_can_create_edit_and_delete_group(
|
|||
):
|
||||
# The group does not exist
|
||||
res = testclient.get("/groups", status=200)
|
||||
assert Group.get(display_name="bar") is None
|
||||
assert Group.get(display_name="foo") == foo_group
|
||||
assert models.Group.get(display_name="bar") is None
|
||||
assert models.Group.get(display_name="foo") == foo_group
|
||||
res.mustcontain(no="bar")
|
||||
res.mustcontain("foo")
|
||||
|
||||
|
@ -152,7 +151,7 @@ def test_moderator_can_create_edit_and_delete_group(
|
|||
res = form.submit(status=302).follow(status=200)
|
||||
|
||||
logged_moderator.reload()
|
||||
bar_group = Group.get(display_name="bar")
|
||||
bar_group = models.Group.get(display_name="bar")
|
||||
assert bar_group.display_name == "bar"
|
||||
assert bar_group.description == ["yolo"]
|
||||
assert bar_group.members == [
|
||||
|
@ -168,17 +167,17 @@ def test_moderator_can_create_edit_and_delete_group(
|
|||
|
||||
res = form.submit(name="action", value="edit").follow()
|
||||
|
||||
bar_group = Group.get(display_name="bar")
|
||||
bar_group = models.Group.get(display_name="bar")
|
||||
assert bar_group.display_name == "bar"
|
||||
assert bar_group.description == ["yolo2"]
|
||||
assert Group.get(display_name="bar2") is None
|
||||
assert models.Group.get(display_name="bar2") is None
|
||||
members = bar_group.members
|
||||
for member in members:
|
||||
res.mustcontain(member.formatted_name[0])
|
||||
|
||||
# Group is deleted
|
||||
res = form.submit(name="action", value="delete", status=302)
|
||||
assert Group.get(display_name="bar") is None
|
||||
assert models.Group.get(display_name="bar") is None
|
||||
assert ("success", "The group bar has been sucessfully deleted") in res.flashes
|
||||
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import datetime
|
||||
|
||||
from canaille.app import models
|
||||
from canaille.core.account import Invitation
|
||||
from canaille.core.models import User
|
||||
|
||||
|
||||
def test_invitation(testclient, logged_admin, foo_group, smtpd):
|
||||
assert User.get_from_login("someone") is None
|
||||
assert models.User.get_from_login("someone") is None
|
||||
|
||||
res = testclient.get("/invite", status=200)
|
||||
|
||||
|
@ -39,7 +39,7 @@ def test_invitation(testclient, logged_admin, foo_group, smtpd):
|
|||
assert ("success", "Your account has been created successfuly.") in res.flashes
|
||||
res = res.follow(status=200)
|
||||
|
||||
user = User.get_from_login("someone")
|
||||
user = models.User.get_from_login("someone")
|
||||
foo_group.reload()
|
||||
assert user.check_password("whatever")
|
||||
assert user.groups == [foo_group]
|
||||
|
@ -53,8 +53,8 @@ def test_invitation(testclient, logged_admin, foo_group, smtpd):
|
|||
|
||||
|
||||
def test_invitation_editable_user_name(testclient, logged_admin, foo_group, smtpd):
|
||||
assert User.get_from_login("jackyjack") is None
|
||||
assert User.get_from_login("djorje") is None
|
||||
assert models.User.get_from_login("jackyjack") is None
|
||||
assert models.User.get_from_login("djorje") is None
|
||||
|
||||
res = testclient.get("/invite", status=200)
|
||||
|
||||
|
@ -89,7 +89,7 @@ def test_invitation_editable_user_name(testclient, logged_admin, foo_group, smtp
|
|||
assert ("success", "Your account has been created successfuly.") in res.flashes
|
||||
res = res.follow(status=200)
|
||||
|
||||
user = User.get_from_login("djorje")
|
||||
user = models.User.get_from_login("djorje")
|
||||
foo_group.reload()
|
||||
assert user.check_password("whatever")
|
||||
assert user.groups == [foo_group]
|
||||
|
@ -101,7 +101,7 @@ def test_invitation_editable_user_name(testclient, logged_admin, foo_group, smtp
|
|||
|
||||
|
||||
def test_generate_link(testclient, logged_admin, foo_group, smtpd):
|
||||
assert User.get_from_login("sometwo") is None
|
||||
assert models.User.get_from_login("sometwo") is None
|
||||
|
||||
res = testclient.get("/invite", status=200)
|
||||
|
||||
|
@ -131,7 +131,7 @@ def test_generate_link(testclient, logged_admin, foo_group, smtpd):
|
|||
res = res.form.submit(status=302)
|
||||
res = res.follow(status=200)
|
||||
|
||||
user = User.get_from_login("sometwo")
|
||||
user = models.User.get_from_login("sometwo")
|
||||
foo_group.reload()
|
||||
assert user.check_password("whatever")
|
||||
assert user.groups == [foo_group]
|
||||
|
@ -228,7 +228,7 @@ def test_registration_no_password(testclient, foo_group):
|
|||
res = res.form.submit(status=200)
|
||||
res.mustcontain("This field is required.")
|
||||
|
||||
assert not User.get_from_login("someoneelse")
|
||||
assert not models.User.get_from_login("someoneelse")
|
||||
|
||||
with testclient.session_transaction() as sess:
|
||||
assert "user_id" not in sess
|
||||
|
@ -295,7 +295,7 @@ def test_groups_are_saved_even_when_user_does_not_have_read_permission(
|
|||
res = res.form.submit(status=302)
|
||||
res = res.follow(status=200)
|
||||
|
||||
user = User.get_from_login("someoneelse")
|
||||
user = models.User.get_from_login("someoneelse")
|
||||
foo_group.reload()
|
||||
assert user.groups == [foo_group]
|
||||
user.delete()
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
from canaille.core.models import User
|
||||
from canaille.app import models
|
||||
|
||||
|
||||
def test_user_get_from_login(testclient, user, backend):
|
||||
assert User.get_from_login(login="invalid") is None
|
||||
assert User.get_from_login(login="user") == user
|
||||
assert models.User.get_from_login(login="invalid") is None
|
||||
assert models.User.get_from_login(login="user") == user
|
||||
|
||||
|
||||
def test_user_has_password(testclient, backend):
|
||||
u = User(
|
||||
u = models.User(
|
||||
formatted_name="Temp User",
|
||||
family_name="Temp",
|
||||
user_name="temp",
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from canaille.core.models import User
|
||||
from canaille.app import models
|
||||
|
||||
|
||||
def test_user_creation_edition_and_deletion(
|
||||
|
@ -6,7 +6,7 @@ def test_user_creation_edition_and_deletion(
|
|||
):
|
||||
# The user does not exist.
|
||||
res = testclient.get("/users", status=200)
|
||||
assert User.get_from_login("george") is None
|
||||
assert models.User.get_from_login("george") is None
|
||||
res.mustcontain(no="george")
|
||||
|
||||
# Fill the profile for a new user.
|
||||
|
@ -24,7 +24,7 @@ def test_user_creation_edition_and_deletion(
|
|||
res = res.form.submit(name="action", value="edit", status=302)
|
||||
assert ("success", "User account creation succeed.") in res.flashes
|
||||
res = res.follow(status=200)
|
||||
george = User.get_from_login("george")
|
||||
george = models.User.get_from_login("george")
|
||||
foo_group.reload()
|
||||
assert "George" == george.given_name[0]
|
||||
assert george.groups == [foo_group]
|
||||
|
@ -45,7 +45,7 @@ def test_user_creation_edition_and_deletion(
|
|||
res.form["groups"] = [foo_group.id, bar_group.id]
|
||||
res = res.form.submit(name="action", value="edit").follow()
|
||||
|
||||
george = User.get_from_login("george")
|
||||
george = models.User.get_from_login("george")
|
||||
assert "Georgio" == george.given_name[0]
|
||||
assert george.check_password("totoyolo")
|
||||
|
||||
|
@ -60,7 +60,7 @@ def test_user_creation_edition_and_deletion(
|
|||
# User have been deleted.
|
||||
res = testclient.get("/profile/george/settings", status=200)
|
||||
res = res.form.submit(name="action", value="delete", status=302).follow(status=200)
|
||||
assert User.get_from_login("george") is None
|
||||
assert models.User.get_from_login("george") is None
|
||||
res.mustcontain(no="george")
|
||||
|
||||
|
||||
|
@ -89,7 +89,7 @@ def test_user_creation_without_password(testclient, logged_moderator):
|
|||
res = res.form.submit(name="action", value="edit", status=302)
|
||||
assert ("success", "User account creation succeed.") in res.flashes
|
||||
res = res.follow(status=200)
|
||||
george = User.get_from_login("george")
|
||||
george = models.User.get_from_login("george")
|
||||
assert george.user_name[0] == "george"
|
||||
assert not george.has_password()
|
||||
|
||||
|
@ -100,13 +100,13 @@ def test_user_creation_form_validation_failed(
|
|||
testclient, logged_moderator, foo_group, bar_group
|
||||
):
|
||||
res = testclient.get("/users", status=200)
|
||||
assert User.get_from_login("george") is None
|
||||
assert models.User.get_from_login("george") is None
|
||||
res.mustcontain(no="george")
|
||||
|
||||
res = testclient.get("/profile", status=200)
|
||||
res = res.form.submit(name="action", value="edit")
|
||||
assert ("error", "User account creation failed.") in res.flashes
|
||||
assert User.get_from_login("george") is None
|
||||
assert models.User.get_from_login("george") is None
|
||||
|
||||
|
||||
def test_username_already_taken(
|
||||
|
@ -140,7 +140,7 @@ def test_cn_setting_with_given_name_and_surname(testclient, logged_moderator):
|
|||
|
||||
res = res.form.submit(name="action", value="edit", status=302).follow(status=200)
|
||||
|
||||
george = User.get_from_login("george")
|
||||
george = models.User.get_from_login("george")
|
||||
assert george.formatted_name[0] == "George Abitbol"
|
||||
george.delete()
|
||||
|
||||
|
@ -153,6 +153,6 @@ def test_cn_setting_with_surname_only(testclient, logged_moderator):
|
|||
|
||||
res = res.form.submit(name="action", value="edit", status=302).follow(status=200)
|
||||
|
||||
george = User.get_from_login("george")
|
||||
george = models.User.get_from_login("george")
|
||||
assert george.formatted_name[0] == "Abitbol"
|
||||
george.delete()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import datetime
|
||||
|
||||
from canaille.core.models import User
|
||||
from canaille.app import models
|
||||
from webtest import Upload
|
||||
|
||||
|
||||
|
@ -101,7 +101,7 @@ def test_photo_on_profile_edition(
|
|||
|
||||
def test_photo_on_profile_creation(testclient, jpeg_photo, logged_admin):
|
||||
res = testclient.get("/users", status=200)
|
||||
assert User.get_from_login("foobar") is None
|
||||
assert models.User.get_from_login("foobar") is None
|
||||
res.mustcontain(no="foobar")
|
||||
|
||||
res = testclient.get("/profile", status=200)
|
||||
|
@ -111,14 +111,14 @@ def test_photo_on_profile_creation(testclient, jpeg_photo, logged_admin):
|
|||
res.form["email"] = "george@abitbol.com"
|
||||
res = res.form.submit(name="action", value="edit", status=302).follow(status=200)
|
||||
|
||||
user = User.get_from_login("foobar")
|
||||
user = models.User.get_from_login("foobar")
|
||||
assert user.photo == [jpeg_photo]
|
||||
user.delete()
|
||||
|
||||
|
||||
def test_photo_deleted_on_profile_creation(testclient, jpeg_photo, logged_admin):
|
||||
res = testclient.get("/users", status=200)
|
||||
assert User.get_from_login("foobar") is None
|
||||
assert models.User.get_from_login("foobar") is None
|
||||
res.mustcontain(no="foobar")
|
||||
|
||||
res = testclient.get("/profile", status=200)
|
||||
|
@ -129,6 +129,6 @@ def test_photo_deleted_on_profile_creation(testclient, jpeg_photo, logged_admin)
|
|||
res.form["email"] = "george@abitbol.com"
|
||||
res = res.form.submit(name="action", value="edit", status=302).follow(status=200)
|
||||
|
||||
user = User.get_from_login("foobar")
|
||||
user = models.User.get_from_login("foobar")
|
||||
assert user.photo == []
|
||||
user.delete()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from unittest import mock
|
||||
|
||||
from canaille.core.models import User
|
||||
from canaille.app import models
|
||||
|
||||
|
||||
def test_edition(
|
||||
|
@ -126,7 +126,7 @@ def test_password_change_fail(testclient, logged_user):
|
|||
|
||||
|
||||
def test_password_initialization_mail(smtpd, testclient, backend, logged_admin):
|
||||
u = User(
|
||||
u = models.User(
|
||||
formatted_name="Temp User",
|
||||
family_name="Temp",
|
||||
user_name="temp",
|
||||
|
@ -162,7 +162,7 @@ def test_password_initialization_mail_send_fail(
|
|||
SMTP, smtpd, testclient, backend, logged_admin
|
||||
):
|
||||
SMTP.side_effect = mock.Mock(side_effect=OSError("unit test mail error"))
|
||||
u = User(
|
||||
u = models.User(
|
||||
formatted_name="Temp User",
|
||||
family_name="Temp",
|
||||
user_name="temp",
|
||||
|
@ -232,7 +232,7 @@ def test_invalid_form_request(testclient, logged_admin):
|
|||
|
||||
|
||||
def test_password_reset_email(smtpd, testclient, backend, logged_admin):
|
||||
u = User(
|
||||
u = models.User(
|
||||
formatted_name="Temp User",
|
||||
family_name="Temp",
|
||||
user_name="temp",
|
||||
|
@ -260,7 +260,7 @@ def test_password_reset_email(smtpd, testclient, backend, logged_admin):
|
|||
@mock.patch("smtplib.SMTP")
|
||||
def test_password_reset_email_failed(SMTP, smtpd, testclient, backend, logged_admin):
|
||||
SMTP.side_effect = mock.Mock(side_effect=OSError("unit test mail error"))
|
||||
u = User(
|
||||
u = models.User(
|
||||
formatted_name="Temp User",
|
||||
family_name="Temp",
|
||||
user_name="temp",
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
import datetime
|
||||
|
||||
from canaille.app import models
|
||||
from canaille.commands import cli
|
||||
from canaille.oidc.models import AuthorizationCode
|
||||
from canaille.oidc.models import Token
|
||||
from werkzeug.security import gen_salt
|
||||
|
||||
|
||||
def test_clean_command(testclient, backend, client, user):
|
||||
valid_code = AuthorizationCode(
|
||||
valid_code = models.AuthorizationCode(
|
||||
authorization_code_id=gen_salt(48),
|
||||
code="my-valid-code",
|
||||
client=client,
|
||||
|
@ -23,7 +22,7 @@ def test_clean_command(testclient, backend, client, user):
|
|||
revokation="",
|
||||
)
|
||||
valid_code.save()
|
||||
expired_code = AuthorizationCode(
|
||||
expired_code = models.AuthorizationCode(
|
||||
authorization_code_id=gen_salt(48),
|
||||
code="my-expired-code",
|
||||
client=client,
|
||||
|
@ -43,7 +42,7 @@ def test_clean_command(testclient, backend, client, user):
|
|||
)
|
||||
expired_code.save()
|
||||
|
||||
valid_token = Token(
|
||||
valid_token = models.Token(
|
||||
token_id=gen_salt(48),
|
||||
access_token="my-valid-token",
|
||||
client=client,
|
||||
|
@ -57,7 +56,7 @@ def test_clean_command(testclient, backend, client, user):
|
|||
lifetime=3600,
|
||||
)
|
||||
valid_token.save()
|
||||
expired_token = Token(
|
||||
expired_token = models.Token(
|
||||
token_id=gen_salt(48),
|
||||
access_token="my-expired-token",
|
||||
client=client,
|
||||
|
@ -73,8 +72,8 @@ def test_clean_command(testclient, backend, client, user):
|
|||
)
|
||||
expired_token.save()
|
||||
|
||||
assert AuthorizationCode.get(code="my-expired-code")
|
||||
assert Token.get(access_token="my-expired-token")
|
||||
assert models.AuthorizationCode.get(code="my-expired-code")
|
||||
assert models.Token.get(access_token="my-expired-token")
|
||||
assert expired_code.is_expired()
|
||||
assert expired_token.is_expired()
|
||||
|
||||
|
@ -82,5 +81,5 @@ def test_clean_command(testclient, backend, client, user):
|
|||
res = runner.invoke(cli, ["clean"])
|
||||
assert res.exit_code == 0, res.stdout
|
||||
|
||||
assert AuthorizationCode.query() == [valid_code]
|
||||
assert Token.query() == [valid_token]
|
||||
assert models.AuthorizationCode.query() == [valid_code]
|
||||
assert models.Token.query() == [valid_token]
|
||||
|
|
|
@ -4,10 +4,7 @@ import uuid
|
|||
|
||||
import pytest
|
||||
from authlib.oidc.core.grants.util import generate_id_token
|
||||
from canaille.oidc.models import AuthorizationCode
|
||||
from canaille.oidc.models import Client
|
||||
from canaille.oidc.models import Consent
|
||||
from canaille.oidc.models import Token
|
||||
from canaille.app import models
|
||||
from canaille.oidc.oauth import generate_user_info
|
||||
from canaille.oidc.oauth import get_jwt_config
|
||||
from cryptography.hazmat.backends import default_backend as crypto_default_backend
|
||||
|
@ -71,7 +68,7 @@ def configuration(configuration, keypair_path):
|
|||
|
||||
@pytest.fixture
|
||||
def client(testclient, other_client, backend):
|
||||
c = Client(
|
||||
c = models.Client(
|
||||
client_id=gen_salt(24),
|
||||
client_name="Some client",
|
||||
contacts="contact@mydomain.tld",
|
||||
|
@ -107,7 +104,7 @@ def client(testclient, other_client, backend):
|
|||
|
||||
@pytest.fixture
|
||||
def other_client(testclient, backend):
|
||||
c = Client(
|
||||
c = models.Client(
|
||||
client_id=gen_salt(24),
|
||||
client_name="Some other client",
|
||||
contacts="contact@myotherdomain.tld",
|
||||
|
@ -143,7 +140,7 @@ def other_client(testclient, backend):
|
|||
|
||||
@pytest.fixture
|
||||
def authorization(testclient, user, client, backend):
|
||||
a = AuthorizationCode(
|
||||
a = models.AuthorizationCode(
|
||||
authorization_code_id=gen_salt(48),
|
||||
code="my-code",
|
||||
client=client,
|
||||
|
@ -165,7 +162,7 @@ def authorization(testclient, user, client, backend):
|
|||
|
||||
@pytest.fixture
|
||||
def token(testclient, client, user, backend):
|
||||
t = Token(
|
||||
t = models.Token(
|
||||
token_id=gen_salt(48),
|
||||
access_token=gen_salt(48),
|
||||
audience=[client],
|
||||
|
@ -194,7 +191,7 @@ def id_token(testclient, client, user, backend):
|
|||
|
||||
@pytest.fixture
|
||||
def consent(testclient, client, user, backend):
|
||||
t = Consent(
|
||||
t = models.Consent(
|
||||
consent_id=str(uuid.uuid4()),
|
||||
client=client,
|
||||
subject=user,
|
||||
|
|
|
@ -5,10 +5,7 @@ from urllib.parse import urlsplit
|
|||
import freezegun
|
||||
from authlib.jose import jwt
|
||||
from authlib.oauth2.rfc7636 import create_s256_code_challenge
|
||||
from canaille.core.models import User
|
||||
from canaille.oidc.models import AuthorizationCode
|
||||
from canaille.oidc.models import Consent
|
||||
from canaille.oidc.models import Token
|
||||
from canaille.app import models
|
||||
from canaille.oidc.oauth import setup_oauth
|
||||
from werkzeug.security import gen_salt
|
||||
|
||||
|
@ -18,7 +15,7 @@ from . import client_credentials
|
|||
def test_authorization_code_flow(
|
||||
testclient, logged_user, client, keypair, other_client
|
||||
):
|
||||
assert not Consent.query()
|
||||
assert not models.Consent.query()
|
||||
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
|
@ -36,7 +33,7 @@ def test_authorization_code_flow(
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code=code)
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
assert set(authcode.scope[0].split(" ")) == {
|
||||
"openid",
|
||||
|
@ -47,7 +44,7 @@ def test_authorization_code_flow(
|
|||
"phone",
|
||||
}
|
||||
|
||||
consents = Consent.query(client=client, subject=logged_user)
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
assert set(consents[0].scope) == {
|
||||
"openid",
|
||||
"profile",
|
||||
|
@ -70,7 +67,7 @@ def test_authorization_code_flow(
|
|||
)
|
||||
|
||||
access_token = res.json["access_token"]
|
||||
token = Token.get(access_token=access_token)
|
||||
token = models.Token.get(access_token=access_token)
|
||||
assert token.client == client
|
||||
assert token.subject == logged_user
|
||||
assert set(token.scope[0].split(" ")) == {
|
||||
|
@ -119,7 +116,7 @@ def test_invalid_client(testclient, logged_user, keypair):
|
|||
def test_authorization_code_flow_with_redirect_uri(
|
||||
testclient, logged_user, client, keypair, other_client
|
||||
):
|
||||
assert not Consent.query()
|
||||
assert not models.Consent.query()
|
||||
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
|
@ -138,9 +135,9 @@ def test_authorization_code_flow_with_redirect_uri(
|
|||
assert res.location.startswith(client.redirect_uris[1])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code=code)
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
consents = Consent.query(client=client, subject=logged_user)
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
|
||||
res = testclient.post(
|
||||
"/oauth/token",
|
||||
|
@ -155,7 +152,7 @@ def test_authorization_code_flow_with_redirect_uri(
|
|||
)
|
||||
|
||||
access_token = res.json["access_token"]
|
||||
token = Token.get(access_token=access_token)
|
||||
token = models.Token.get(access_token=access_token)
|
||||
assert token.client == client
|
||||
assert token.subject == logged_user
|
||||
|
||||
|
@ -166,7 +163,7 @@ def test_authorization_code_flow_with_redirect_uri(
|
|||
def test_authorization_code_flow_preconsented(
|
||||
testclient, logged_user, client, keypair, other_client
|
||||
):
|
||||
assert not Consent.query()
|
||||
assert not models.Consent.query()
|
||||
|
||||
client.preconsent = True
|
||||
client.save()
|
||||
|
@ -185,10 +182,10 @@ def test_authorization_code_flow_preconsented(
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code=code)
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = Consent.query(client=client, subject=logged_user)
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
assert not consents
|
||||
|
||||
res = testclient.post(
|
||||
|
@ -204,7 +201,7 @@ def test_authorization_code_flow_preconsented(
|
|||
)
|
||||
|
||||
access_token = res.json["access_token"]
|
||||
token = Token.get(access_token=access_token)
|
||||
token = models.Token.get(access_token=access_token)
|
||||
assert token.client == client
|
||||
assert token.subject == logged_user
|
||||
|
||||
|
@ -223,7 +220,7 @@ def test_authorization_code_flow_preconsented(
|
|||
|
||||
|
||||
def test_logout_login(testclient, logged_user, client):
|
||||
assert not Consent.query()
|
||||
assert not models.Consent.query()
|
||||
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
|
@ -254,10 +251,10 @@ def test_logout_login(testclient, logged_user, client):
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code=code)
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = Consent.query(client=client, subject=logged_user)
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
assert "profile" in consents[0].scope
|
||||
|
||||
res = testclient.post(
|
||||
|
@ -273,7 +270,7 @@ def test_logout_login(testclient, logged_user, client):
|
|||
)
|
||||
|
||||
access_token = res.json["access_token"]
|
||||
token = Token.get(access_token=access_token)
|
||||
token = models.Token.get(access_token=access_token)
|
||||
assert token.client == client
|
||||
assert token.subject == logged_user
|
||||
|
||||
|
@ -289,7 +286,7 @@ def test_logout_login(testclient, logged_user, client):
|
|||
|
||||
|
||||
def test_deny(testclient, logged_user, client):
|
||||
assert not Consent.query()
|
||||
assert not models.Consent.query()
|
||||
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
|
@ -308,11 +305,11 @@ def test_deny(testclient, logged_user, client):
|
|||
error = params["error"][0]
|
||||
assert error == "access_denied"
|
||||
|
||||
assert not Consent.query()
|
||||
assert not models.Consent.query()
|
||||
|
||||
|
||||
def test_refresh_token(testclient, user, client):
|
||||
assert not Consent.query()
|
||||
assert not models.Consent.query()
|
||||
|
||||
with freezegun.freeze_time("2020-01-01 01:00:00"):
|
||||
res = testclient.get(
|
||||
|
@ -335,10 +332,10 @@ def test_refresh_token(testclient, user, client):
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code=code)
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = Consent.query(client=client, subject=user)
|
||||
consents = models.Consent.query(client=client, subject=user)
|
||||
assert "profile" in consents[0].scope
|
||||
|
||||
with freezegun.freeze_time("2020-01-01 00:01:00"):
|
||||
|
@ -354,7 +351,7 @@ def test_refresh_token(testclient, user, client):
|
|||
status=200,
|
||||
)
|
||||
access_token = res.json["access_token"]
|
||||
old_token = Token.get(access_token=access_token)
|
||||
old_token = models.Token.get(access_token=access_token)
|
||||
assert old_token is not None
|
||||
assert not old_token.revokation_date
|
||||
|
||||
|
@ -369,7 +366,7 @@ def test_refresh_token(testclient, user, client):
|
|||
status=200,
|
||||
)
|
||||
access_token = res.json["access_token"]
|
||||
new_token = Token.get(access_token=access_token)
|
||||
new_token = models.Token.get(access_token=access_token)
|
||||
assert new_token is not None
|
||||
assert old_token.access_token != new_token.access_token
|
||||
|
||||
|
@ -389,7 +386,7 @@ def test_refresh_token(testclient, user, client):
|
|||
|
||||
|
||||
def test_code_challenge(testclient, logged_user, client):
|
||||
assert not Consent.query()
|
||||
assert not models.Consent.query()
|
||||
|
||||
client.token_endpoint_auth_method = "none"
|
||||
client.save()
|
||||
|
@ -415,10 +412,10 @@ def test_code_challenge(testclient, logged_user, client):
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code=code)
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = Consent.query(client=client, subject=logged_user)
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
assert "profile" in consents[0].scope
|
||||
|
||||
res = testclient.post(
|
||||
|
@ -435,7 +432,7 @@ def test_code_challenge(testclient, logged_user, client):
|
|||
)
|
||||
access_token = res.json["access_token"]
|
||||
|
||||
token = Token.get(access_token=access_token)
|
||||
token = models.Token.get(access_token=access_token)
|
||||
assert token.client == client
|
||||
assert token.subject == logged_user
|
||||
|
||||
|
@ -456,7 +453,7 @@ def test_code_challenge(testclient, logged_user, client):
|
|||
def test_authorization_code_flow_when_consent_already_given(
|
||||
testclient, logged_user, client
|
||||
):
|
||||
assert not Consent.query()
|
||||
assert not models.Consent.query()
|
||||
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
|
@ -474,10 +471,10 @@ def test_authorization_code_flow_when_consent_already_given(
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code=code)
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = Consent.query(client=client, subject=logged_user)
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
assert "profile" in consents[0].scope
|
||||
|
||||
res = testclient.post(
|
||||
|
@ -514,7 +511,7 @@ def test_authorization_code_flow_when_consent_already_given(
|
|||
def test_authorization_code_flow_when_consent_already_given_but_for_a_smaller_scope(
|
||||
testclient, logged_user, client
|
||||
):
|
||||
assert not Consent.query()
|
||||
assert not models.Consent.query()
|
||||
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
|
@ -532,10 +529,10 @@ def test_authorization_code_flow_when_consent_already_given_but_for_a_smaller_sc
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code=code)
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = Consent.query(client=client, subject=logged_user)
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
assert "profile" in consents[0].scope
|
||||
assert "groups" not in consents[0].scope
|
||||
|
||||
|
@ -568,10 +565,10 @@ def test_authorization_code_flow_when_consent_already_given_but_for_a_smaller_sc
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code=code)
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
|
||||
consents = Consent.query(client=client, subject=logged_user)
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
assert "profile" in consents[0].scope
|
||||
assert "groups" in consents[0].scope
|
||||
|
||||
|
@ -604,7 +601,7 @@ def test_authorization_code_flow_but_user_cannot_use_oidc(
|
|||
|
||||
|
||||
def test_prompt_none(testclient, logged_user, client):
|
||||
consent = Consent(
|
||||
consent = models.Consent(
|
||||
consent_id=str(uuid.uuid4()),
|
||||
client=client,
|
||||
subject=logged_user,
|
||||
|
@ -631,7 +628,7 @@ def test_prompt_none(testclient, logged_user, client):
|
|||
|
||||
|
||||
def test_prompt_not_logged(testclient, user, client):
|
||||
consent = Consent(
|
||||
consent = models.Consent(
|
||||
consent_id=str(uuid.uuid4()),
|
||||
client=client,
|
||||
subject=user,
|
||||
|
@ -685,7 +682,7 @@ def test_nonce_required_in_oidc_requests(testclient, logged_user, client):
|
|||
|
||||
|
||||
def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client):
|
||||
assert not Consent.query()
|
||||
assert not models.Consent.query()
|
||||
testclient.app.config["REQUIRE_NONCE"] = False
|
||||
|
||||
res = testclient.get(
|
||||
|
@ -701,14 +698,14 @@ def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client):
|
|||
res = res.form.submit(name="answer", value="accept", status=302)
|
||||
|
||||
assert res.location.startswith(client.redirect_uris[0])
|
||||
for consent in Consent.query():
|
||||
for consent in models.Consent.query():
|
||||
consent.delete()
|
||||
|
||||
|
||||
def test_authorization_code_request_scope_too_large(
|
||||
testclient, logged_user, keypair, other_client
|
||||
):
|
||||
assert not Consent.query()
|
||||
assert not models.Consent.query()
|
||||
assert "email" not in other_client.scope
|
||||
|
||||
res = testclient.get(
|
||||
|
@ -726,13 +723,13 @@ def test_authorization_code_request_scope_too_large(
|
|||
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code=code)
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert set(authcode.scope[0].split(" ")) == {
|
||||
"openid",
|
||||
"profile",
|
||||
}
|
||||
|
||||
consents = Consent.query(client=other_client, subject=logged_user)
|
||||
consents = models.Consent.query(client=other_client, subject=logged_user)
|
||||
assert set(consents[0].scope) == {
|
||||
"openid",
|
||||
"profile",
|
||||
|
@ -751,7 +748,7 @@ def test_authorization_code_request_scope_too_large(
|
|||
)
|
||||
|
||||
access_token = res.json["access_token"]
|
||||
token = Token.get(access_token=access_token)
|
||||
token = models.Token.get(access_token=access_token)
|
||||
assert token.client == other_client
|
||||
assert token.subject == logged_user
|
||||
assert set(token.scope[0].split(" ")) == {
|
||||
|
@ -813,7 +810,7 @@ def test_authorization_code_expired(testclient, user, client):
|
|||
|
||||
|
||||
def test_code_with_invalid_user(testclient, admin, client):
|
||||
user = User(
|
||||
user = models.User(
|
||||
formatted_name="John Doe",
|
||||
family_name="Doe",
|
||||
user_name="temp",
|
||||
|
@ -838,7 +835,7 @@ def test_code_with_invalid_user(testclient, admin, client):
|
|||
res = res.form.submit(name="answer", value="accept", status=302)
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code=code)
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
|
||||
user.delete()
|
||||
|
||||
|
@ -861,7 +858,7 @@ def test_code_with_invalid_user(testclient, admin, client):
|
|||
|
||||
|
||||
def test_refresh_token_with_invalid_user(testclient, client):
|
||||
user = User(
|
||||
user = models.User(
|
||||
formatted_name="John Doe",
|
||||
family_name="Doe",
|
||||
user_name="temp",
|
||||
|
@ -888,7 +885,7 @@ def test_refresh_token_with_invalid_user(testclient, client):
|
|||
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code=code)
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
|
||||
res = testclient.post(
|
||||
"/oauth/token",
|
||||
|
@ -919,7 +916,7 @@ def test_refresh_token_with_invalid_user(testclient, client):
|
|||
"error": "invalid_request",
|
||||
"error_description": 'There is no "user" for this token.',
|
||||
}
|
||||
Token.get(access_token=access_token).delete()
|
||||
models.Token.get(access_token=access_token).delete()
|
||||
|
||||
|
||||
def test_token_default_expiration_date(testclient, logged_user, client, keypair):
|
||||
|
@ -937,7 +934,7 @@ def test_token_default_expiration_date(testclient, logged_user, client, keypair)
|
|||
res = res.form.submit(name="answer", value="accept", status=302)
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code=code)
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode.lifetime == 84400
|
||||
|
||||
res = testclient.post(
|
||||
|
@ -955,7 +952,7 @@ def test_token_default_expiration_date(testclient, logged_user, client, keypair)
|
|||
assert res.json["expires_in"] == 864000
|
||||
|
||||
access_token = res.json["access_token"]
|
||||
token = Token.get(access_token=access_token)
|
||||
token = models.Token.get(access_token=access_token)
|
||||
assert token.lifetime == 864000
|
||||
|
||||
claims = jwt.decode(access_token, keypair[1])
|
||||
|
@ -965,7 +962,7 @@ def test_token_default_expiration_date(testclient, logged_user, client, keypair)
|
|||
claims = jwt.decode(id_token, keypair[1])
|
||||
assert claims["exp"] - claims["iat"] == 3600
|
||||
|
||||
consents = Consent.query(client=client, subject=logged_user)
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
for consent in consents:
|
||||
consent.delete()
|
||||
|
||||
|
@ -995,7 +992,7 @@ def test_token_custom_expiration_date(testclient, logged_user, client, keypair):
|
|||
res = res.form.submit(name="answer", value="accept", status=302)
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code=code)
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode.lifetime == 84400
|
||||
|
||||
res = testclient.post(
|
||||
|
@ -1013,7 +1010,7 @@ def test_token_custom_expiration_date(testclient, logged_user, client, keypair):
|
|||
assert res.json["expires_in"] == 1000
|
||||
|
||||
access_token = res.json["access_token"]
|
||||
token = Token.get(access_token=access_token)
|
||||
token = models.Token.get(access_token=access_token)
|
||||
assert token.lifetime == 1000
|
||||
|
||||
claims = jwt.decode(access_token, keypair[1])
|
||||
|
@ -1023,6 +1020,6 @@ def test_token_custom_expiration_date(testclient, logged_user, client, keypair):
|
|||
claims = jwt.decode(id_token, keypair[1])
|
||||
assert claims["exp"] - claims["iat"] == 6000
|
||||
|
||||
consents = Consent.query(client=client, subject=logged_user)
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
for consent in consents:
|
||||
consent.delete()
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
import datetime
|
||||
|
||||
from canaille.oidc.models import AuthorizationCode
|
||||
from canaille.oidc.models import Client
|
||||
from canaille.oidc.models import Consent
|
||||
from canaille.oidc.models import Token
|
||||
from canaille.app import models
|
||||
from werkzeug.security import gen_salt
|
||||
|
||||
|
||||
|
@ -29,7 +26,7 @@ def test_client_list_pagination(testclient, logged_admin, client, other_client):
|
|||
res.mustcontain("2 items")
|
||||
clients = []
|
||||
for _ in range(25):
|
||||
client = Client(client_id=gen_salt(48), client_name=gen_salt(48))
|
||||
client = models.Client(client_id=gen_salt(48), client_name=gen_salt(48))
|
||||
client.save()
|
||||
clients.append(client)
|
||||
|
||||
|
@ -115,7 +112,7 @@ def test_client_add(testclient, logged_admin):
|
|||
res = res.follow(status=200)
|
||||
|
||||
client_id = res.forms["readonly"]["client_id"].value
|
||||
client = Client.get(client_id=client_id)
|
||||
client = models.Client.get(client_id=client_id)
|
||||
data["audience"] = [client]
|
||||
for k, v in data.items():
|
||||
client_value = getattr(client, k)
|
||||
|
@ -198,27 +195,29 @@ def test_client_edit_missing_fields(testclient, client, logged_admin, other_clie
|
|||
|
||||
|
||||
def test_client_delete(testclient, logged_admin):
|
||||
client = Client(client_id="client_id")
|
||||
client = models.Client(client_id="client_id")
|
||||
client.save()
|
||||
token = Token(
|
||||
token = models.Token(
|
||||
token_id="id",
|
||||
client=client,
|
||||
issue_datetime=datetime.datetime.now(datetime.timezone.utc),
|
||||
)
|
||||
token.save()
|
||||
consent = Consent(
|
||||
consent = models.Consent(
|
||||
consent_id="consent_id", subject=logged_admin, client=client, scope="openid"
|
||||
)
|
||||
consent.save()
|
||||
code = AuthorizationCode(authorization_code_id="id", client=client, subject=client)
|
||||
code = models.AuthorizationCode(
|
||||
authorization_code_id="id", client=client, subject=client
|
||||
)
|
||||
|
||||
res = testclient.get("/admin/client/edit/" + client.client_id)
|
||||
res = res.forms["clientaddform"].submit(name="action", value="delete").follow()
|
||||
|
||||
assert not Client.get()
|
||||
assert not Token.get()
|
||||
assert not AuthorizationCode.get()
|
||||
assert not Consent.get()
|
||||
assert not models.Client.get()
|
||||
assert not models.Token.get()
|
||||
assert not models.AuthorizationCode.get()
|
||||
assert not models.Consent.get()
|
||||
|
||||
|
||||
def test_client_delete_invalid_client(testclient, logged_admin, client):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from canaille.oidc.models import AuthorizationCode
|
||||
from canaille.app import models
|
||||
from werkzeug.security import gen_salt
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ def test_authorization_list_pagination(testclient, logged_admin, client):
|
|||
res.mustcontain("0 items")
|
||||
authorizations = []
|
||||
for _ in range(26):
|
||||
code = AuthorizationCode(
|
||||
code = models.AuthorizationCode(
|
||||
authorization_code_id=gen_salt(48), client=client, subject=logged_admin
|
||||
)
|
||||
code.save()
|
||||
|
@ -66,13 +66,13 @@ def test_authorization_list_bad_pages(testclient, logged_admin):
|
|||
|
||||
def test_authorization_list_search(testclient, logged_admin, client):
|
||||
id1 = gen_salt(48)
|
||||
auth1 = AuthorizationCode(
|
||||
auth1 = models.AuthorizationCode(
|
||||
authorization_code_id=id1, client=client, subject=logged_admin
|
||||
)
|
||||
auth1.save()
|
||||
|
||||
id2 = gen_salt(48)
|
||||
auth2 = AuthorizationCode(
|
||||
auth2 = models.AuthorizationCode(
|
||||
authorization_code_id=id2, client=client, subject=logged_admin
|
||||
)
|
||||
auth2.save()
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from urllib.parse import parse_qs
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from canaille.oidc.models import Consent
|
||||
from canaille.oidc.models import Token
|
||||
from canaille.app import models
|
||||
|
||||
from . import client_credentials
|
||||
|
||||
|
@ -118,7 +117,7 @@ def test_oidc_authorization_after_revokation(
|
|||
|
||||
res = res.form.submit(name="answer", value="accept", status=302)
|
||||
|
||||
consents = Consent.query(client=client, subject=logged_user)
|
||||
consents = models.Consent.query(client=client, subject=logged_user)
|
||||
consent.reload()
|
||||
assert consents[0] == consent
|
||||
assert not consent.revoked
|
||||
|
@ -138,7 +137,7 @@ def test_oidc_authorization_after_revokation(
|
|||
)
|
||||
|
||||
access_token = res.json["access_token"]
|
||||
token = Token.get(access_token=access_token)
|
||||
token = models.Token.get(access_token=access_token)
|
||||
assert token.client == client
|
||||
assert token.subject == logged_user
|
||||
|
||||
|
@ -158,13 +157,13 @@ def test_preconsented_client_appears_in_consent_list(testclient, client, logged_
|
|||
def test_revoke_preconsented_client(testclient, client, logged_user, token):
|
||||
client.preconsent = True
|
||||
client.save()
|
||||
assert not Consent.get()
|
||||
assert not models.Consent.get()
|
||||
assert not token.revoked
|
||||
|
||||
res = testclient.get(f"/consent/revoke-preconsent/{client.client_id}", status=302)
|
||||
assert ("success", "The access has been revoked") in res.flashes
|
||||
|
||||
consent = Consent.get()
|
||||
consent = models.Consent.get()
|
||||
assert consent.client == client
|
||||
assert consent.subject == logged_user
|
||||
assert consent.scope == ["openid", "email", "profile", "groups", "address", "phone"]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from unittest import mock
|
||||
|
||||
from authlib.jose import jwt
|
||||
from canaille.oidc.models import Client
|
||||
from canaille.app import models
|
||||
|
||||
|
||||
def test_client_registration_with_authentication_static_token(
|
||||
|
@ -29,7 +29,7 @@ def test_client_registration_with_authentication_static_token(
|
|||
headers = {"Authorization": "Bearer static-token"}
|
||||
|
||||
res = testclient.post_json("/oauth/register", payload, headers=headers, status=201)
|
||||
client = Client.get(client_id=res.json["client_id"])
|
||||
client = models.Client.get(client_id=res.json["client_id"])
|
||||
|
||||
assert res.json == {
|
||||
"client_id": client.client_id,
|
||||
|
@ -148,7 +148,7 @@ def test_client_registration_with_software_statement(testclient, backend, keypai
|
|||
}
|
||||
res = testclient.post_json("/oauth/register", payload, status=201)
|
||||
|
||||
client = Client.get(client_id=res.json["client_id"])
|
||||
client = models.Client.get(client_id=res.json["client_id"])
|
||||
assert res.json == {
|
||||
"client_id": client.client_id,
|
||||
"client_secret": client.client_secret,
|
||||
|
@ -199,7 +199,7 @@ def test_client_registration_without_authentication_ok(testclient, backend):
|
|||
|
||||
res = testclient.post_json("/oauth/register", payload, status=201)
|
||||
|
||||
client = Client.get(client_id=res.json["client_id"])
|
||||
client = models.Client.get(client_id=res.json["client_id"])
|
||||
assert res.json == {
|
||||
"client_id": mock.ANY,
|
||||
"client_secret": mock.ANY,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import warnings
|
||||
from datetime import datetime
|
||||
|
||||
from canaille.oidc.models import Client
|
||||
from canaille.app import models
|
||||
|
||||
|
||||
def test_get(testclient, backend, client, user):
|
||||
|
@ -95,7 +95,7 @@ def test_update(testclient, backend, client, user):
|
|||
res = testclient.put_json(
|
||||
f"/oauth/register/{client.client_id}", payload, headers=headers, status=200
|
||||
)
|
||||
client = Client.get(client_id=res.json["client_id"])
|
||||
client = models.Client.get(client_id=res.json["client_id"])
|
||||
|
||||
assert res.json == {
|
||||
"client_id": client.client_id,
|
||||
|
@ -145,7 +145,7 @@ def test_delete(testclient, backend, user):
|
|||
"static-token"
|
||||
]
|
||||
|
||||
client = Client(client_id="foobar", client_name="Some client")
|
||||
client = models.Client(client_id="foobar", client_name="Some client")
|
||||
client.save()
|
||||
|
||||
headers = {"Authorization": "Bearer static-token"}
|
||||
|
@ -153,7 +153,7 @@ def test_delete(testclient, backend, user):
|
|||
res = testclient.delete(
|
||||
f"/oauth/register/{client.client_id}", headers=headers, status=204
|
||||
)
|
||||
assert not Client.get(client_id=client.client_id)
|
||||
assert not models.Client.get(client_id=client.client_id)
|
||||
|
||||
|
||||
def test_invalid_client(testclient, backend, user):
|
||||
|
|
|
@ -2,8 +2,7 @@ from urllib.parse import parse_qs
|
|||
from urllib.parse import urlsplit
|
||||
|
||||
from authlib.jose import jwt
|
||||
from canaille.oidc.models import AuthorizationCode
|
||||
from canaille.oidc.models import Token
|
||||
from canaille.app import models
|
||||
|
||||
|
||||
def test_oauth_hybrid(testclient, backend, user, client):
|
||||
|
@ -32,11 +31,11 @@ def test_oauth_hybrid(testclient, backend, user, client):
|
|||
params = parse_qs(urlsplit(res.location).fragment)
|
||||
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code=code)
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
|
||||
access_token = params["access_token"][0]
|
||||
token = Token.get(access_token=access_token)
|
||||
token = models.Token.get(access_token=access_token)
|
||||
assert token is not None
|
||||
|
||||
res = testclient.get(
|
||||
|
@ -65,11 +64,11 @@ def test_oidc_hybrid(testclient, backend, logged_user, client, keypair, other_cl
|
|||
params = parse_qs(urlsplit(res.location).fragment)
|
||||
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code=code)
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
|
||||
access_token = params["access_token"][0]
|
||||
token = Token.get(access_token=access_token)
|
||||
token = models.Token.get(access_token=access_token)
|
||||
assert token is not None
|
||||
|
||||
id_token = params["id_token"][0]
|
||||
|
|
|
@ -2,7 +2,7 @@ from urllib.parse import parse_qs
|
|||
from urllib.parse import urlsplit
|
||||
|
||||
from authlib.jose import jwt
|
||||
from canaille.oidc.models import Token
|
||||
from canaille.app import models
|
||||
|
||||
|
||||
def test_oauth_implicit(testclient, user, client):
|
||||
|
@ -35,7 +35,7 @@ def test_oauth_implicit(testclient, user, client):
|
|||
params = parse_qs(urlsplit(res.location).fragment)
|
||||
|
||||
access_token = params["access_token"][0]
|
||||
token = Token.get(access_token=access_token)
|
||||
token = models.Token.get(access_token=access_token)
|
||||
assert token is not None
|
||||
|
||||
res = testclient.get(
|
||||
|
@ -79,7 +79,7 @@ def test_oidc_implicit(testclient, keypair, user, client, other_client):
|
|||
params = parse_qs(urlsplit(res.location).fragment)
|
||||
|
||||
access_token = params["access_token"][0]
|
||||
token = Token.get(access_token=access_token)
|
||||
token = models.Token.get(access_token=access_token)
|
||||
assert token is not None
|
||||
|
||||
id_token = params["id_token"][0]
|
||||
|
@ -133,7 +133,7 @@ def test_oidc_implicit_with_group(
|
|||
params = parse_qs(urlsplit(res.location).fragment)
|
||||
|
||||
access_token = params["access_token"][0]
|
||||
token = Token.get(access_token=access_token)
|
||||
token = models.Token.get(access_token=access_token)
|
||||
assert token is not None
|
||||
|
||||
id_token = params["id_token"][0]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from canaille.oidc.models import Token
|
||||
from canaille.app import models
|
||||
|
||||
from . import client_credentials
|
||||
|
||||
|
@ -20,7 +20,7 @@ def test_password_flow_basic(testclient, user, client):
|
|||
assert res.json["token_type"] == "Bearer"
|
||||
access_token = res.json["access_token"]
|
||||
|
||||
token = Token.get(access_token=access_token)
|
||||
token = models.Token.get(access_token=access_token)
|
||||
assert token is not None
|
||||
|
||||
res = testclient.get(
|
||||
|
@ -52,7 +52,7 @@ def test_password_flow_post(testclient, user, client):
|
|||
assert res.json["token_type"] == "Bearer"
|
||||
access_token = res.json["access_token"]
|
||||
|
||||
token = Token.get(access_token=access_token)
|
||||
token = models.Token.get(access_token=access_token)
|
||||
assert token is not None
|
||||
|
||||
res = testclient.get(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import datetime
|
||||
|
||||
from canaille.oidc.models import Token
|
||||
from canaille.app import models
|
||||
from werkzeug.security import gen_salt
|
||||
|
||||
|
||||
|
@ -22,7 +22,7 @@ def test_token_list_pagination(testclient, logged_admin, client):
|
|||
res.mustcontain("0 items")
|
||||
tokens = []
|
||||
for _ in range(26):
|
||||
token = Token(
|
||||
token = models.Token(
|
||||
token_id=gen_salt(48),
|
||||
access_token="my-valid-token",
|
||||
client=client,
|
||||
|
@ -75,7 +75,7 @@ def test_token_list_bad_pages(testclient, logged_admin):
|
|||
|
||||
|
||||
def test_token_list_search(testclient, logged_admin, client):
|
||||
token1 = Token(
|
||||
token1 = models.Token(
|
||||
token_id=gen_salt(48),
|
||||
access_token="this-token-is-ok",
|
||||
client=client,
|
||||
|
@ -89,7 +89,7 @@ def test_token_list_search(testclient, logged_admin, client):
|
|||
lifetime=3600,
|
||||
)
|
||||
token1.save()
|
||||
token2 = Token(
|
||||
token2 = models.Token(
|
||||
token_id=gen_salt(48),
|
||||
access_token="this-token-is-valid",
|
||||
client=client,
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from urllib.parse import parse_qs
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from canaille.oidc.models import AuthorizationCode
|
||||
from canaille.oidc.models import Token
|
||||
from canaille.app import models
|
||||
|
||||
from . import client_credentials
|
||||
|
||||
|
@ -76,7 +75,7 @@ def test_full_flow(testclient, logged_user, client, user, other_client):
|
|||
assert res.location.startswith(client.redirect_uris[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code=code)
|
||||
authcode = models.AuthorizationCode.get(code=code)
|
||||
assert authcode is not None
|
||||
|
||||
res = testclient.post(
|
||||
|
@ -92,7 +91,7 @@ def test_full_flow(testclient, logged_user, client, user, other_client):
|
|||
)
|
||||
access_token = res.json["access_token"]
|
||||
|
||||
token = Token.get(access_token=access_token)
|
||||
token = models.Token.get(access_token=access_token)
|
||||
assert token.client == client
|
||||
assert token.subject == logged_user
|
||||
|
||||
|
|
Loading…
Reference in a new issue