forked from Github-Mirrors/canaille
Moved every model import to canaille.models
This commit is contained in:
parent
e110c4851b
commit
c1d1706007
43 changed files with 421 additions and 428 deletions
|
@ -3,7 +3,7 @@ import hashlib
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from canaille.core.models import User
|
from canaille.app import models
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_babel import gettext as _
|
from flask_babel import gettext as _
|
||||||
|
@ -26,7 +26,7 @@ def profile_hash(*args):
|
||||||
|
|
||||||
def login_placeholder():
|
def login_placeholder():
|
||||||
user_filter = current_app.config["BACKENDS"]["LDAP"].get(
|
user_filter = current_app.config["BACKENDS"]["LDAP"].get(
|
||||||
"USER_FILTER", User.DEFAULT_FILTER
|
"USER_FILTER", models.User.DEFAULT_FILTER
|
||||||
)
|
)
|
||||||
placeholders = []
|
placeholders = []
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from functools import wraps
|
||||||
from urllib.parse import urlsplit
|
from urllib.parse import urlsplit
|
||||||
from urllib.parse import urlunsplit
|
from urllib.parse import urlunsplit
|
||||||
|
|
||||||
from canaille.core.models import User
|
from canaille.app import models
|
||||||
from flask import abort
|
from flask import abort
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from flask import render_template
|
from flask import render_template
|
||||||
|
@ -14,7 +14,7 @@ from flask_babel import gettext as _
|
||||||
|
|
||||||
def current_user():
|
def current_user():
|
||||||
for user_id in session.get("user_id", [])[::-1]:
|
for user_id in session.get("user_id", [])[::-1]:
|
||||||
user = User.get(id=user_id)
|
user = models.User.get(id=user_id)
|
||||||
if user:
|
if user:
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
7
canaille/app/models.py
Normal file
7
canaille/app/models.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# nopycln: file
|
||||||
|
from canaille.core.models import Group
|
||||||
|
from canaille.core.models import User
|
||||||
|
from canaille.oidc.models import AuthorizationCode
|
||||||
|
from canaille.oidc.models import Client
|
||||||
|
from canaille.oidc.models import Consent
|
||||||
|
from canaille.oidc.models import Token
|
|
@ -76,8 +76,7 @@ class LDAPBackend(Backend):
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, config):
|
def validate(cls, config):
|
||||||
from canaille.app.configuration import ConfigurationException
|
from canaille.app.configuration import ConfigurationException
|
||||||
from canaille.core.models import Group
|
from canaille.app import models
|
||||||
from canaille.core.models import User
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn = ldap.initialize(config["BACKENDS"]["LDAP"]["URI"])
|
conn = ldap.initialize(config["BACKENDS"]["LDAP"]["URI"])
|
||||||
|
@ -100,8 +99,8 @@ class LDAPBackend(Backend):
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
try:
|
try:
|
||||||
User.ldap_object_classes(conn)
|
models.User.ldap_object_classes(conn)
|
||||||
user = User(
|
user = models.User(
|
||||||
formatted_name=f"canaille_{uuid.uuid4()}",
|
formatted_name=f"canaille_{uuid.uuid4()}",
|
||||||
family_name=f"canaille_{uuid.uuid4()}",
|
family_name=f"canaille_{uuid.uuid4()}",
|
||||||
user_name=f"canaille_{uuid.uuid4()}",
|
user_name=f"canaille_{uuid.uuid4()}",
|
||||||
|
@ -118,9 +117,9 @@ class LDAPBackend(Backend):
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
try:
|
try:
|
||||||
Group.ldap_object_classes(conn)
|
models.Group.ldap_object_classes(conn)
|
||||||
|
|
||||||
user = User(
|
user = models.User(
|
||||||
cn=f"canaille_{uuid.uuid4()}",
|
cn=f"canaille_{uuid.uuid4()}",
|
||||||
family_name=f"canaille_{uuid.uuid4()}",
|
family_name=f"canaille_{uuid.uuid4()}",
|
||||||
user_name=f"canaille_{uuid.uuid4()}",
|
user_name=f"canaille_{uuid.uuid4()}",
|
||||||
|
@ -129,7 +128,7 @@ class LDAPBackend(Backend):
|
||||||
)
|
)
|
||||||
user.save(conn)
|
user.save(conn)
|
||||||
|
|
||||||
group = Group(
|
group = models.Group(
|
||||||
display_name=f"canaille_{uuid.uuid4()}",
|
display_name=f"canaille_{uuid.uuid4()}",
|
||||||
members=[user],
|
members=[user],
|
||||||
)
|
)
|
||||||
|
@ -150,22 +149,21 @@ class LDAPBackend(Backend):
|
||||||
|
|
||||||
def setup_ldap_models(config):
|
def setup_ldap_models(config):
|
||||||
from .ldapobject import LDAPObject
|
from .ldapobject import LDAPObject
|
||||||
from canaille.core.models import Group
|
from canaille.app import models
|
||||||
from canaille.core.models import User
|
|
||||||
|
|
||||||
LDAPObject.root_dn = config["BACKENDS"]["LDAP"]["ROOT_DN"]
|
LDAPObject.root_dn = config["BACKENDS"]["LDAP"]["ROOT_DN"]
|
||||||
|
|
||||||
user_base = config["BACKENDS"]["LDAP"]["USER_BASE"].replace(
|
user_base = config["BACKENDS"]["LDAP"]["USER_BASE"].replace(
|
||||||
f',{config["BACKENDS"]["LDAP"]["ROOT_DN"]}', ""
|
f',{config["BACKENDS"]["LDAP"]["ROOT_DN"]}', ""
|
||||||
)
|
)
|
||||||
User.base = user_base
|
models.User.base = user_base
|
||||||
User.rdn_attribute = config["BACKENDS"]["LDAP"].get(
|
models.User.rdn_attribute = config["BACKENDS"]["LDAP"].get(
|
||||||
"USER_ID_ATTRIBUTE", User.DEFAULT_ID_ATTRIBUTE
|
"USER_ID_ATTRIBUTE", models.User.DEFAULT_ID_ATTRIBUTE
|
||||||
)
|
)
|
||||||
object_class = config["BACKENDS"]["LDAP"].get(
|
object_class = config["BACKENDS"]["LDAP"].get(
|
||||||
"USER_CLASS", User.DEFAULT_OBJECT_CLASS
|
"USER_CLASS", models.User.DEFAULT_OBJECT_CLASS
|
||||||
)
|
)
|
||||||
User.ldap_object_class = (
|
models.User.ldap_object_class = (
|
||||||
object_class if isinstance(object_class, list) else [object_class]
|
object_class if isinstance(object_class, list) else [object_class]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -174,13 +172,13 @@ def setup_ldap_models(config):
|
||||||
.get("GROUP_BASE", "")
|
.get("GROUP_BASE", "")
|
||||||
.replace(f',{config["BACKENDS"]["LDAP"]["ROOT_DN"]}', "")
|
.replace(f',{config["BACKENDS"]["LDAP"]["ROOT_DN"]}', "")
|
||||||
)
|
)
|
||||||
Group.base = group_base or None
|
models.Group.base = group_base or None
|
||||||
Group.rdn_attribute = config["BACKENDS"]["LDAP"].get(
|
models.Group.rdn_attribute = config["BACKENDS"]["LDAP"].get(
|
||||||
"GROUP_ID_ATTRIBUTE", Group.DEFAULT_ID_ATTRIBUTE
|
"GROUP_ID_ATTRIBUTE", models.Group.DEFAULT_ID_ATTRIBUTE
|
||||||
)
|
)
|
||||||
object_class = config["BACKENDS"]["LDAP"].get(
|
object_class = config["BACKENDS"]["LDAP"].get(
|
||||||
"GROUP_CLASS", Group.DEFAULT_OBJECT_CLASS
|
"GROUP_CLASS", models.Group.DEFAULT_OBJECT_CLASS
|
||||||
)
|
)
|
||||||
Group.ldap_object_class = (
|
models.Group.ldap_object_class = (
|
||||||
object_class if isinstance(object_class, list) else [object_class]
|
object_class if isinstance(object_class, list) else [object_class]
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,6 +9,7 @@ import wtforms
|
||||||
from canaille.app import b64_to_obj
|
from canaille.app import b64_to_obj
|
||||||
from canaille.app import default_fields
|
from canaille.app import default_fields
|
||||||
from canaille.app import login_placeholder
|
from canaille.app import login_placeholder
|
||||||
|
from canaille.app import models
|
||||||
from canaille.app import obj_to_b64
|
from canaille.app import obj_to_b64
|
||||||
from canaille.app import profile_hash
|
from canaille.app import profile_hash
|
||||||
from canaille.app.flask import current_user
|
from canaille.app.flask import current_user
|
||||||
|
@ -43,8 +44,6 @@ from .forms import profile_form
|
||||||
from .mails import send_invitation_mail
|
from .mails import send_invitation_mail
|
||||||
from .mails import send_password_initialization_mail
|
from .mails import send_password_initialization_mail
|
||||||
from .mails import send_password_reset_mail
|
from .mails import send_password_reset_mail
|
||||||
from .models import Group
|
|
||||||
from .models import User
|
|
||||||
|
|
||||||
|
|
||||||
bp = Blueprint("account", __name__)
|
bp = Blueprint("account", __name__)
|
||||||
|
@ -88,12 +87,12 @@ def login():
|
||||||
form["login"].render_kw["placeholder"] = login_placeholder()
|
form["login"].render_kw["placeholder"] = login_placeholder()
|
||||||
|
|
||||||
if request.form:
|
if request.form:
|
||||||
user = User.get_from_login(form.login.data)
|
user = models.User.get_from_login(form.login.data)
|
||||||
if user and not user.has_password():
|
if user and not user.has_password():
|
||||||
return redirect(url_for("account.firstlogin", user_name=user.user_name[0]))
|
return redirect(url_for("account.firstlogin", user_name=user.user_name[0]))
|
||||||
|
|
||||||
if not form.validate():
|
if not form.validate():
|
||||||
User.logout()
|
models.User.logout()
|
||||||
flash(_("Login failed, please check your information"), "error")
|
flash(_("Login failed, please check your information"), "error")
|
||||||
return render_template("login.html", form=form)
|
return render_template("login.html", form=form)
|
||||||
|
|
||||||
|
@ -111,7 +110,7 @@ def password():
|
||||||
form = PasswordForm(request.form or None)
|
form = PasswordForm(request.form or None)
|
||||||
|
|
||||||
if request.form:
|
if request.form:
|
||||||
user = User.get_from_login(session["attempt_login"])
|
user = models.User.get_from_login(session["attempt_login"])
|
||||||
if user and not user.has_password():
|
if user and not user.has_password():
|
||||||
return redirect(url_for("account.firstlogin", user_name=user.user_name[0]))
|
return redirect(url_for("account.firstlogin", user_name=user.user_name[0]))
|
||||||
|
|
||||||
|
@ -120,7 +119,7 @@ def password():
|
||||||
or not user
|
or not user
|
||||||
or not user.check_password(form.password.data)
|
or not user.check_password(form.password.data)
|
||||||
):
|
):
|
||||||
User.logout()
|
models.User.logout()
|
||||||
flash(_("Login failed, please check your information"), "error")
|
flash(_("Login failed, please check your information"), "error")
|
||||||
return render_template(
|
return render_template(
|
||||||
"password.html", form=form, username=session["attempt_login"]
|
"password.html", form=form, username=session["attempt_login"]
|
||||||
|
@ -156,7 +155,7 @@ def logout():
|
||||||
|
|
||||||
@bp.route("/firstlogin/<user_name>", methods=("GET", "POST"))
|
@bp.route("/firstlogin/<user_name>", methods=("GET", "POST"))
|
||||||
def firstlogin(user_name):
|
def firstlogin(user_name):
|
||||||
user = User.get_from_login(user_name)
|
user = models.User.get_from_login(user_name)
|
||||||
if not user or user.has_password():
|
if not user or user.has_password():
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
|
@ -182,7 +181,9 @@ def firstlogin(user_name):
|
||||||
@bp.route("/users", methods=["GET", "POST"])
|
@bp.route("/users", methods=["GET", "POST"])
|
||||||
@permissions_needed("manage_users")
|
@permissions_needed("manage_users")
|
||||||
def users(user):
|
def users(user):
|
||||||
table_form = TableForm(User, fields=user.read | user.write, formdata=request.form)
|
table_form = TableForm(
|
||||||
|
models.User, fields=user.read | user.write, formdata=request.form
|
||||||
|
)
|
||||||
if request.form and not table_form.validate():
|
if request.form and not table_form.validate():
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
|
@ -278,7 +279,7 @@ def registration(data, hash):
|
||||||
)
|
)
|
||||||
return redirect(url_for("account.index"))
|
return redirect(url_for("account.index"))
|
||||||
|
|
||||||
if User.get_from_login(invitation.user_name):
|
if models.User.get_from_login(invitation.user_name):
|
||||||
flash(
|
flash(
|
||||||
_("Your account has already been created."),
|
_("Your account has already been created."),
|
||||||
"error",
|
"error",
|
||||||
|
@ -311,7 +312,7 @@ def registration(data, hash):
|
||||||
if "groups" not in form and invitation.groups:
|
if "groups" not in form and invitation.groups:
|
||||||
form["groups"] = wtforms.SelectMultipleField(
|
form["groups"] = wtforms.SelectMultipleField(
|
||||||
_("Groups"),
|
_("Groups"),
|
||||||
choices=[(group.id, group.display_name) for group in Group.query()],
|
choices=[(group.id, group.display_name) for group in models.Group.query()],
|
||||||
render_kw={"readonly": "true"},
|
render_kw={"readonly": "true"},
|
||||||
)
|
)
|
||||||
form.process(CombinedMultiDict((request.files, request.form)) or None, data=data)
|
form.process(CombinedMultiDict((request.files, request.form)) or None, data=data)
|
||||||
|
@ -381,7 +382,7 @@ def profile_creation(user):
|
||||||
|
|
||||||
|
|
||||||
def profile_create(current_app, form):
|
def profile_create(current_app, form):
|
||||||
user = User()
|
user = models.User()
|
||||||
for attribute in form:
|
for attribute in form:
|
||||||
if attribute.name in user.attributes:
|
if attribute.name in user.attributes:
|
||||||
if isinstance(attribute.data, FileStorage):
|
if isinstance(attribute.data, FileStorage):
|
||||||
|
@ -420,7 +421,7 @@ def profile_edition(user, username):
|
||||||
menuitem = "profile" if username == editor.user_name[0] else "users"
|
menuitem = "profile" if username == editor.user_name[0] else "users"
|
||||||
fields = editor.read | editor.write
|
fields = editor.read | editor.write
|
||||||
if username != editor.user_name[0]:
|
if username != editor.user_name[0]:
|
||||||
user = User.get_from_login(username)
|
user = models.User.get_from_login(username)
|
||||||
else:
|
else:
|
||||||
user = editor
|
user = editor
|
||||||
|
|
||||||
|
@ -505,7 +506,7 @@ def profile_settings(user, username):
|
||||||
):
|
):
|
||||||
abort(403)
|
abort(403)
|
||||||
|
|
||||||
edited_user = User.get_from_login(username)
|
edited_user = models.User.get_from_login(username)
|
||||||
if not edited_user:
|
if not edited_user:
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
|
@ -622,7 +623,7 @@ def profile_delete(user, edited_user):
|
||||||
@bp.route("/impersonate/<username>")
|
@bp.route("/impersonate/<username>")
|
||||||
@permissions_needed("impersonate_users")
|
@permissions_needed("impersonate_users")
|
||||||
def impersonate(user, username):
|
def impersonate(user, username):
|
||||||
puppet = User.get_from_login(username)
|
puppet = models.User.get_from_login(username)
|
||||||
if not puppet:
|
if not puppet:
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
|
@ -649,7 +650,7 @@ def forgotten():
|
||||||
flash(_("Could not send the password reset link."), "error")
|
flash(_("Could not send the password reset link."), "error")
|
||||||
return render_template("forgotten-password.html", form=form)
|
return render_template("forgotten-password.html", form=form)
|
||||||
|
|
||||||
user = User.get_from_login(form.login.data)
|
user = models.User.get_from_login(form.login.data)
|
||||||
success_message = _(
|
success_message = _(
|
||||||
"A password reset link has been sent at your email address. You should receive it within a few minutes."
|
"A password reset link has been sent at your email address. You should receive it within a few minutes."
|
||||||
)
|
)
|
||||||
|
@ -689,7 +690,7 @@ def reset(user_name, hash):
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
form = PasswordResetForm(request.form)
|
form = PasswordResetForm(request.form)
|
||||||
user = User.get_from_login(user_name)
|
user = models.User.get_from_login(user_name)
|
||||||
|
|
||||||
if not user or hash != profile_hash(
|
if not user or hash != profile_hash(
|
||||||
user.user_name[0],
|
user.user_name[0],
|
||||||
|
@ -719,7 +720,7 @@ def photo(user_name, field):
|
||||||
if field.lower() != "photo":
|
if field.lower() != "photo":
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
user = User.get_from_login(user_name)
|
user = models.User.get_from_login(user_name)
|
||||||
if not user:
|
if not user:
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import wtforms.form
|
import wtforms.form
|
||||||
|
from canaille.app import models
|
||||||
from canaille.app.forms import HTMXBaseForm
|
from canaille.app.forms import HTMXBaseForm
|
||||||
from canaille.app.forms import HTMXForm
|
from canaille.app.forms import HTMXForm
|
||||||
from canaille.app.forms import is_uri
|
from canaille.app.forms import is_uri
|
||||||
|
@ -9,12 +10,9 @@ from flask_babel import lazy_gettext as _
|
||||||
from flask_wtf.file import FileAllowed
|
from flask_wtf.file import FileAllowed
|
||||||
from flask_wtf.file import FileField
|
from flask_wtf.file import FileField
|
||||||
|
|
||||||
from .models import Group
|
|
||||||
from .models import User
|
|
||||||
|
|
||||||
|
|
||||||
def unique_login(form, field):
|
def unique_login(form, field):
|
||||||
if User.get_from_login(field.data) and (
|
if models.User.get_from_login(field.data) and (
|
||||||
not getattr(form, "user", None) or form.user.user_name[0] != field.data
|
not getattr(form, "user", None) or form.user.user_name[0] != field.data
|
||||||
):
|
):
|
||||||
raise wtforms.ValidationError(
|
raise wtforms.ValidationError(
|
||||||
|
@ -23,7 +21,7 @@ def unique_login(form, field):
|
||||||
|
|
||||||
|
|
||||||
def unique_email(form, field):
|
def unique_email(form, field):
|
||||||
if User.get(email=field.data) and (
|
if models.User.get(email=field.data) and (
|
||||||
not getattr(form, "user", None) or form.user.email[0] != field.data
|
not getattr(form, "user", None) or form.user.email[0] != field.data
|
||||||
):
|
):
|
||||||
raise wtforms.ValidationError(
|
raise wtforms.ValidationError(
|
||||||
|
@ -32,7 +30,7 @@ def unique_email(form, field):
|
||||||
|
|
||||||
|
|
||||||
def unique_group(form, field):
|
def unique_group(form, field):
|
||||||
if Group.get(display_name=field.data):
|
if models.Group.get(display_name=field.data):
|
||||||
raise wtforms.ValidationError(
|
raise wtforms.ValidationError(
|
||||||
_("The group '{group}' already exists").format(group=field.data)
|
_("The group '{group}' already exists").format(group=field.data)
|
||||||
)
|
)
|
||||||
|
@ -41,7 +39,7 @@ def unique_group(form, field):
|
||||||
def existing_login(form, field):
|
def existing_login(form, field):
|
||||||
if not current_app.config.get(
|
if not current_app.config.get(
|
||||||
"HIDE_INVALID_LOGINS", True
|
"HIDE_INVALID_LOGINS", True
|
||||||
) and not User.get_from_login(field.data):
|
) and not models.User.get_from_login(field.data):
|
||||||
raise wtforms.ValidationError(
|
raise wtforms.ValidationError(
|
||||||
_("The login '{login}' does not exist").format(login=field.data)
|
_("The login '{login}' does not exist").format(login=field.data)
|
||||||
)
|
)
|
||||||
|
@ -257,7 +255,9 @@ PROFILE_FORM_FIELDS = dict(
|
||||||
),
|
),
|
||||||
groups=wtforms.SelectMultipleField(
|
groups=wtforms.SelectMultipleField(
|
||||||
_("Groups"),
|
_("Groups"),
|
||||||
choices=lambda: [(group.id, group.display_name) for group in Group.query()],
|
choices=lambda: [
|
||||||
|
(group.id, group.display_name) for group in models.Group.query()
|
||||||
|
],
|
||||||
render_kw={"placeholder": _("users, admins …")},
|
render_kw={"placeholder": _("users, admins …")},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -276,7 +276,7 @@ def profile_form(write_field_names, readonly_field_names, user=None):
|
||||||
if PROFILE_FORM_FIELDS.get(name)
|
if PROFILE_FORM_FIELDS.get(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
if "groups" in fields and not Group.query():
|
if "groups" in fields and not models.Group.query():
|
||||||
del fields["groups"]
|
del fields["groups"]
|
||||||
|
|
||||||
form = HTMXBaseForm(fields)
|
form = HTMXBaseForm(fields)
|
||||||
|
@ -338,6 +338,8 @@ class InvitationForm(HTMXForm):
|
||||||
)
|
)
|
||||||
groups = wtforms.SelectMultipleField(
|
groups = wtforms.SelectMultipleField(
|
||||||
_("Groups"),
|
_("Groups"),
|
||||||
choices=lambda: [(group.id, group.display_name) for group in Group.query()],
|
choices=lambda: [
|
||||||
|
(group.id, group.display_name) for group in models.Group.query()
|
||||||
|
],
|
||||||
render_kw={},
|
render_kw={},
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from canaille.app import models
|
||||||
from canaille.app.flask import permissions_needed
|
from canaille.app.flask import permissions_needed
|
||||||
from canaille.app.flask import render_htmx_template
|
from canaille.app.flask import render_htmx_template
|
||||||
from canaille.app.forms import TableForm
|
from canaille.app.forms import TableForm
|
||||||
|
@ -12,8 +13,6 @@ from flask_themer import render_template
|
||||||
|
|
||||||
from .forms import CreateGroupForm
|
from .forms import CreateGroupForm
|
||||||
from .forms import EditGroupForm
|
from .forms import EditGroupForm
|
||||||
from .models import Group
|
|
||||||
from .models import User
|
|
||||||
|
|
||||||
bp = Blueprint("groups", __name__, url_prefix="/groups")
|
bp = Blueprint("groups", __name__, url_prefix="/groups")
|
||||||
|
|
||||||
|
@ -21,7 +20,7 @@ bp = Blueprint("groups", __name__, url_prefix="/groups")
|
||||||
@bp.route("/", methods=["GET", "POST"])
|
@bp.route("/", methods=["GET", "POST"])
|
||||||
@permissions_needed("manage_groups")
|
@permissions_needed("manage_groups")
|
||||||
def groups(user):
|
def groups(user):
|
||||||
table_form = TableForm(Group, formdata=request.form)
|
table_form = TableForm(models.Group, formdata=request.form)
|
||||||
if request.form and request.form.get("page") and not table_form.validate():
|
if request.form and request.form.get("page") and not table_form.validate():
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
|
@ -37,7 +36,7 @@ def create_group(user):
|
||||||
if not form.validate():
|
if not form.validate():
|
||||||
flash(_("Group creation failed."), "error")
|
flash(_("Group creation failed."), "error")
|
||||||
else:
|
else:
|
||||||
group = Group()
|
group = models.Group()
|
||||||
group.members = [user]
|
group.members = [user]
|
||||||
group.display_name = [form.display_name.data]
|
group.display_name = [form.display_name.data]
|
||||||
group.description = [form.description.data]
|
group.description = [form.description.data]
|
||||||
|
@ -59,7 +58,7 @@ def create_group(user):
|
||||||
@bp.route("/<groupname>", methods=("GET", "POST"))
|
@bp.route("/<groupname>", methods=("GET", "POST"))
|
||||||
@permissions_needed("manage_groups")
|
@permissions_needed("manage_groups")
|
||||||
def group(user, groupname):
|
def group(user, groupname):
|
||||||
group = Group.get(display_name=groupname)
|
group = models.Group.get(display_name=groupname)
|
||||||
|
|
||||||
if not group:
|
if not group:
|
||||||
abort(404)
|
abort(404)
|
||||||
|
@ -78,7 +77,7 @@ def group(user, groupname):
|
||||||
|
|
||||||
|
|
||||||
def edit_group(group):
|
def edit_group(group):
|
||||||
table_form = TableForm(User, filter={"groups": group}, formdata=request.form)
|
table_form = TableForm(models.User, filter={"groups": group}, formdata=request.form)
|
||||||
if request.form and request.form.get("page") and not table_form.validate():
|
if request.form and request.form.get("page") and not table_form.validate():
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import faker
|
import faker
|
||||||
|
from canaille.app import models
|
||||||
from canaille.app.i18n import available_language_codes
|
from canaille.app.i18n import available_language_codes
|
||||||
from canaille.core.models import Group
|
|
||||||
from canaille.core.models import User
|
|
||||||
from faker.config import AVAILABLE_LOCALES
|
from faker.config import AVAILABLE_LOCALES
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,7 +18,7 @@ def fake_users(nb=1):
|
||||||
try:
|
try:
|
||||||
fake = random.choice(fakes)
|
fake = random.choice(fakes)
|
||||||
name = fake.unique.name()
|
name = fake.unique.name()
|
||||||
user = User(
|
user = models.User(
|
||||||
formatted_name=name,
|
formatted_name=name,
|
||||||
given_name=name.split(" ")[0],
|
given_name=name.split(" ")[0],
|
||||||
family_name=name.split(" ")[1],
|
family_name=name.split(" ")[1],
|
||||||
|
@ -47,11 +46,11 @@ def fake_users(nb=1):
|
||||||
|
|
||||||
def fake_groups(nb=1, nb_users_max=1):
|
def fake_groups(nb=1, nb_users_max=1):
|
||||||
fake = faker_generator(["en_US"])[0]
|
fake = faker_generator(["en_US"])[0]
|
||||||
users = User.query()
|
users = models.User.query()
|
||||||
groups = list()
|
groups = list()
|
||||||
for _ in range(nb):
|
for _ in range(nb):
|
||||||
try:
|
try:
|
||||||
group = Group(
|
group = models.Group(
|
||||||
display_name=fake.unique.word(),
|
display_name=fake.unique.word(),
|
||||||
description=fake.sentence(),
|
description=fake.sentence(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
|
from canaille.app import models
|
||||||
from canaille.app.flask import permissions_needed
|
from canaille.app.flask import permissions_needed
|
||||||
from canaille.app.flask import render_htmx_template
|
from canaille.app.flask import render_htmx_template
|
||||||
from canaille.app.forms import TableForm
|
from canaille.app.forms import TableForm
|
||||||
from canaille.oidc.models import AuthorizationCode
|
|
||||||
from flask import abort
|
from flask import abort
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
from flask import request
|
from flask import request
|
||||||
|
@ -14,7 +14,7 @@ bp = Blueprint("authorizations", __name__, url_prefix="/admin/authorization")
|
||||||
@bp.route("/", methods=["GET", "POST"])
|
@bp.route("/", methods=["GET", "POST"])
|
||||||
@permissions_needed("manage_oidc")
|
@permissions_needed("manage_oidc")
|
||||||
def index(user):
|
def index(user):
|
||||||
table_form = TableForm(AuthorizationCode, formdata=request.form)
|
table_form = TableForm(models.AuthorizationCode, formdata=request.form)
|
||||||
if request.form and request.form.get("page") and not table_form.validate():
|
if request.form and request.form.get("page") and not table_form.validate():
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ def index(user):
|
||||||
@bp.route("/<authorization_id>", methods=["GET", "POST"])
|
@bp.route("/<authorization_id>", methods=["GET", "POST"])
|
||||||
@permissions_needed("manage_oidc")
|
@permissions_needed("manage_oidc")
|
||||||
def view(user, authorization_id):
|
def view(user, authorization_id):
|
||||||
authorization = AuthorizationCode.get(authorization_code_id=authorization_id)
|
authorization = models.AuthorizationCode.get(authorization_code_id=authorization_id)
|
||||||
return render_template(
|
return render_template(
|
||||||
"oidc/admin/authorization_view.html",
|
"oidc/admin/authorization_view.html",
|
||||||
authorization=authorization,
|
authorization=authorization,
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
from canaille.app import models
|
||||||
from canaille.app.flask import permissions_needed
|
from canaille.app.flask import permissions_needed
|
||||||
from canaille.app.flask import render_htmx_template
|
from canaille.app.flask import render_htmx_template
|
||||||
from canaille.app.flask import request_is_htmx
|
from canaille.app.flask import request_is_htmx
|
||||||
from canaille.app.forms import TableForm
|
from canaille.app.forms import TableForm
|
||||||
from canaille.oidc.forms import ClientAddForm
|
from canaille.oidc.forms import ClientAddForm
|
||||||
from canaille.oidc.models import Client
|
|
||||||
from flask import abort
|
from flask import abort
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
from flask import flash
|
from flask import flash
|
||||||
|
@ -23,7 +23,7 @@ bp = Blueprint("clients", __name__, url_prefix="/admin/client")
|
||||||
@bp.route("/", methods=["GET", "POST"])
|
@bp.route("/", methods=["GET", "POST"])
|
||||||
@permissions_needed("manage_oidc")
|
@permissions_needed("manage_oidc")
|
||||||
def index(user):
|
def index(user):
|
||||||
table_form = TableForm(Client, formdata=request.form)
|
table_form = TableForm(models.Client, formdata=request.form)
|
||||||
if request.form and request.form.get("page") and not table_form.validate():
|
if request.form and request.form.get("page") and not table_form.validate():
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ def add(user):
|
||||||
|
|
||||||
client_id = gen_salt(24)
|
client_id = gen_salt(24)
|
||||||
client_id_issued_at = datetime.datetime.now(datetime.timezone.utc)
|
client_id_issued_at = datetime.datetime.now(datetime.timezone.utc)
|
||||||
client = Client(
|
client = models.Client(
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
client_id_issued_at=client_id_issued_at,
|
client_id_issued_at=client_id_issued_at,
|
||||||
client_name=form["client_name"].data,
|
client_name=form["client_name"].data,
|
||||||
|
@ -104,7 +104,7 @@ def edit(user, client_id):
|
||||||
|
|
||||||
|
|
||||||
def client_edit(client_id):
|
def client_edit(client_id):
|
||||||
client = Client.get(client_id=client_id)
|
client = models.Client.get(client_id=client_id)
|
||||||
|
|
||||||
if not client:
|
if not client:
|
||||||
abort(404)
|
abort(404)
|
||||||
|
@ -152,7 +152,7 @@ def client_edit(client_id):
|
||||||
software_version=form["software_version"].data,
|
software_version=form["software_version"].data,
|
||||||
jwk=form["jwk"].data,
|
jwk=form["jwk"].data,
|
||||||
jwks_uri=form["jwks_uri"].data,
|
jwks_uri=form["jwks_uri"].data,
|
||||||
audience=[Client.get(id=id) for id in form["audience"].data],
|
audience=[models.Client.get(id=id) for id in form["audience"].data],
|
||||||
preconsent=form["preconsent"].data,
|
preconsent=form["preconsent"].data,
|
||||||
)
|
)
|
||||||
client.save()
|
client.save()
|
||||||
|
@ -164,7 +164,7 @@ def client_edit(client_id):
|
||||||
|
|
||||||
|
|
||||||
def client_delete(client_id):
|
def client_delete(client_id):
|
||||||
client = Client.get(client_id=client_id)
|
client = models.Client.get(client_id=client_id)
|
||||||
|
|
||||||
if not client:
|
if not client:
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import click
|
import click
|
||||||
|
from canaille.app import models
|
||||||
from canaille.app.commands import with_backendcontext
|
from canaille.app.commands import with_backendcontext
|
||||||
from canaille.oidc.models import AuthorizationCode
|
|
||||||
from canaille.oidc.models import Token
|
|
||||||
from flask.cli import with_appcontext
|
from flask.cli import with_appcontext
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,11 +11,11 @@ def clean():
|
||||||
"""
|
"""
|
||||||
Remove expired tokens and authorization codes.
|
Remove expired tokens and authorization codes.
|
||||||
"""
|
"""
|
||||||
for t in Token.query():
|
for t in models.Token.query():
|
||||||
if t.is_expired():
|
if t.is_expired():
|
||||||
t.delete()
|
t.delete()
|
||||||
|
|
||||||
for a in AuthorizationCode.query():
|
for a in models.AuthorizationCode.query():
|
||||||
if a.is_expired():
|
if a.is_expired():
|
||||||
a.delete()
|
a.delete()
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
import datetime
|
import datetime
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
from canaille.app import models
|
||||||
from canaille.app.flask import user_needed
|
from canaille.app.flask import user_needed
|
||||||
from canaille.oidc.models import Client
|
|
||||||
from canaille.oidc.models import Consent
|
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
from flask import flash
|
from flask import flash
|
||||||
from flask import redirect
|
from flask import redirect
|
||||||
|
@ -20,12 +19,14 @@ bp = Blueprint("consents", __name__, url_prefix="/consent")
|
||||||
@bp.route("/")
|
@bp.route("/")
|
||||||
@user_needed()
|
@user_needed()
|
||||||
def consents(user):
|
def consents(user):
|
||||||
consents = Consent.query(subject=user)
|
consents = models.Consent.query(subject=user)
|
||||||
clients = {t.client for t in consents}
|
clients = {t.client for t in consents}
|
||||||
|
|
||||||
nb_consents = len(consents)
|
nb_consents = len(consents)
|
||||||
nb_preconsents = sum(
|
nb_preconsents = sum(
|
||||||
1 for client in Client.query() if client.preconsent and client not in clients
|
1
|
||||||
|
for client in models.Client.query()
|
||||||
|
if client.preconsent and client not in clients
|
||||||
)
|
)
|
||||||
|
|
||||||
return render_template(
|
return render_template(
|
||||||
|
@ -42,11 +43,11 @@ def consents(user):
|
||||||
@bp.route("/pre-consents")
|
@bp.route("/pre-consents")
|
||||||
@user_needed()
|
@user_needed()
|
||||||
def pre_consents(user):
|
def pre_consents(user):
|
||||||
consents = Consent.query(subject=user)
|
consents = models.Consent.query(subject=user)
|
||||||
clients = {t.client for t in consents}
|
clients = {t.client for t in consents}
|
||||||
preconsented = [
|
preconsented = [
|
||||||
client
|
client
|
||||||
for client in Client.query()
|
for client in models.Client.query()
|
||||||
if client.preconsent and client not in clients
|
if client.preconsent and client not in clients
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -67,7 +68,7 @@ def pre_consents(user):
|
||||||
@bp.route("/revoke/<consent_id>")
|
@bp.route("/revoke/<consent_id>")
|
||||||
@user_needed()
|
@user_needed()
|
||||||
def revoke(user, consent_id):
|
def revoke(user, consent_id):
|
||||||
consent = Consent.get(consent_id=consent_id)
|
consent = models.Consent.get(consent_id=consent_id)
|
||||||
|
|
||||||
if not consent or consent.subject != user:
|
if not consent or consent.subject != user:
|
||||||
flash(_("Could not revoke this access"), "error")
|
flash(_("Could not revoke this access"), "error")
|
||||||
|
@ -85,7 +86,7 @@ def revoke(user, consent_id):
|
||||||
@bp.route("/restore/<consent_id>")
|
@bp.route("/restore/<consent_id>")
|
||||||
@user_needed()
|
@user_needed()
|
||||||
def restore(user, consent_id):
|
def restore(user, consent_id):
|
||||||
consent = Consent.get(consent_id=consent_id)
|
consent = models.Consent.get(consent_id=consent_id)
|
||||||
|
|
||||||
if not consent or consent.subject != user:
|
if not consent or consent.subject != user:
|
||||||
flash(_("Could not restore this access"), "error")
|
flash(_("Could not restore this access"), "error")
|
||||||
|
@ -106,19 +107,19 @@ def restore(user, consent_id):
|
||||||
@bp.route("/revoke-preconsent/<client_id>")
|
@bp.route("/revoke-preconsent/<client_id>")
|
||||||
@user_needed()
|
@user_needed()
|
||||||
def revoke_preconsent(user, client_id):
|
def revoke_preconsent(user, client_id):
|
||||||
client = Client.get(client_id=client_id)
|
client = models.Client.get(client_id=client_id)
|
||||||
|
|
||||||
if not client or not client.preconsent:
|
if not client or not client.preconsent:
|
||||||
flash(_("Could not revoke this access"), "error")
|
flash(_("Could not revoke this access"), "error")
|
||||||
return redirect(url_for("oidc.consents.consents"))
|
return redirect(url_for("oidc.consents.consents"))
|
||||||
|
|
||||||
consent = Consent.get(client=client, subject=user)
|
consent = models.Consent.get(client=client, subject=user)
|
||||||
if consent:
|
if consent:
|
||||||
return redirect(
|
return redirect(
|
||||||
url_for("oidc.consents.revoke", consent_id=consent.consent_id[0])
|
url_for("oidc.consents.revoke", consent_id=consent.consent_id[0])
|
||||||
)
|
)
|
||||||
|
|
||||||
consent = Consent(
|
consent = models.Consent(
|
||||||
consent_id=str(uuid.uuid4()),
|
consent_id=str(uuid.uuid4()),
|
||||||
client=client,
|
client=client,
|
||||||
subject=user,
|
subject=user,
|
||||||
|
|
|
@ -6,10 +6,10 @@ from authlib.jose import JsonWebKey
|
||||||
from authlib.jose import jwt
|
from authlib.jose import jwt
|
||||||
from authlib.oauth2 import OAuth2Error
|
from authlib.oauth2 import OAuth2Error
|
||||||
from canaille import csrf
|
from canaille import csrf
|
||||||
|
from canaille.app import models
|
||||||
from canaille.app.flask import current_user
|
from canaille.app.flask import current_user
|
||||||
from canaille.app.flask import set_parameter_in_url_query
|
from canaille.app.flask import set_parameter_in_url_query
|
||||||
from canaille.core.forms import FullLoginForm
|
from canaille.core.forms import FullLoginForm
|
||||||
from canaille.core.models import User
|
|
||||||
from flask import abort
|
from flask import abort
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
@ -25,8 +25,6 @@ from werkzeug.datastructures import CombinedMultiDict
|
||||||
|
|
||||||
from .forms import AuthorizeForm
|
from .forms import AuthorizeForm
|
||||||
from .forms import LogoutForm
|
from .forms import LogoutForm
|
||||||
from .models import Client
|
|
||||||
from .models import Consent
|
|
||||||
from .oauth import authorization
|
from .oauth import authorization
|
||||||
from .oauth import ClientConfigurationEndpoint
|
from .oauth import ClientConfigurationEndpoint
|
||||||
from .oauth import ClientRegistrationEndpoint
|
from .oauth import ClientRegistrationEndpoint
|
||||||
|
@ -59,7 +57,7 @@ def authorize():
|
||||||
if "client_id" not in request.args:
|
if "client_id" not in request.args:
|
||||||
abort(400)
|
abort(400)
|
||||||
|
|
||||||
client = Client.get(request.args["client_id"])
|
client = models.Client.get(client_id=request.args["client_id"])
|
||||||
if not client:
|
if not client:
|
||||||
abort(400)
|
abort(400)
|
||||||
|
|
||||||
|
@ -78,7 +76,7 @@ def authorize():
|
||||||
if request.method == "GET":
|
if request.method == "GET":
|
||||||
return render_template("login.html", form=form, menu=False)
|
return render_template("login.html", form=form, menu=False)
|
||||||
|
|
||||||
user = User.get_from_login(form.login.data)
|
user = models.User.get_from_login(form.login.data)
|
||||||
if (
|
if (
|
||||||
not form.validate()
|
not form.validate()
|
||||||
or not user
|
or not user
|
||||||
|
@ -96,7 +94,7 @@ def authorize():
|
||||||
|
|
||||||
# CONSENT
|
# CONSENT
|
||||||
|
|
||||||
consents = Consent.query(
|
consents = models.Consent.query(
|
||||||
client=client,
|
client=client,
|
||||||
subject=user,
|
subject=user,
|
||||||
)
|
)
|
||||||
|
@ -153,7 +151,7 @@ def authorize():
|
||||||
list(set(scopes + consents[0].scope))
|
list(set(scopes + consents[0].scope))
|
||||||
).split(" ")
|
).split(" ")
|
||||||
else:
|
else:
|
||||||
consent = Consent(
|
consent = models.Consent(
|
||||||
consent_id=str(uuid.uuid4()),
|
consent_id=str(uuid.uuid4()),
|
||||||
client=client,
|
client=client,
|
||||||
subject=user,
|
subject=user,
|
||||||
|
@ -275,7 +273,7 @@ def end_session():
|
||||||
valid_uris = []
|
valid_uris = []
|
||||||
|
|
||||||
if "client_id" in data:
|
if "client_id" in data:
|
||||||
client = Client.get(data["client_id"])
|
client = models.Client.get(client_id=data["client_id"])
|
||||||
if client:
|
if client:
|
||||||
valid_uris = client.post_logout_redirect_uris
|
valid_uris = client.post_logout_redirect_uris
|
||||||
|
|
||||||
|
@ -317,7 +315,7 @@ def end_session():
|
||||||
else [id_token["aud"]]
|
else [id_token["aud"]]
|
||||||
)
|
)
|
||||||
for client_id in client_ids:
|
for client_id in client_ids:
|
||||||
client = Client.get(client_id)
|
client = models.Client.get(client_id=client_id)
|
||||||
if client:
|
if client:
|
||||||
valid_uris.extend(client.post_logout_redirect_uris or [])
|
valid_uris.extend(client.post_logout_redirect_uris or [])
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import wtforms
|
import wtforms
|
||||||
|
from canaille.app import models
|
||||||
from canaille.app.forms import HTMXForm
|
from canaille.app.forms import HTMXForm
|
||||||
from canaille.app.forms import is_uri
|
from canaille.app.forms import is_uri
|
||||||
from canaille.oidc.models import Client
|
|
||||||
from flask_babel import lazy_gettext as _
|
from flask_babel import lazy_gettext as _
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ class LogoutForm(HTMXForm):
|
||||||
|
|
||||||
|
|
||||||
def client_audiences():
|
def client_audiences():
|
||||||
return [(client.id, client.client_name) for client in Client.query()]
|
return [(client.id, client.client_name) for client in models.Client.query()]
|
||||||
|
|
||||||
|
|
||||||
class ClientAddForm(HTMXForm):
|
class ClientAddForm(HTMXForm):
|
||||||
|
|
|
@ -1,11 +1,8 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from canaille.app import models
|
||||||
from canaille.backends.ldap.installation import install_schema
|
from canaille.backends.ldap.installation import install_schema
|
||||||
from canaille.backends.ldap.installation import ldap_connection
|
from canaille.backends.ldap.installation import ldap_connection
|
||||||
from canaille.oidc.models import AuthorizationCode
|
|
||||||
from canaille.oidc.models import Client
|
|
||||||
from canaille.oidc.models import Consent
|
|
||||||
from canaille.oidc.models import Token
|
|
||||||
from cryptography.hazmat.backends import default_backend as crypto_default_backend
|
from cryptography.hazmat.backends import default_backend as crypto_default_backend
|
||||||
from cryptography.hazmat.primitives import serialization as crypto_serialization
|
from cryptography.hazmat.primitives import serialization as crypto_serialization
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
@ -19,10 +16,10 @@ def install(config):
|
||||||
|
|
||||||
def setup_ldap_tree(config):
|
def setup_ldap_tree(config):
|
||||||
with ldap_connection(config) as conn:
|
with ldap_connection(config) as conn:
|
||||||
Token.install(conn)
|
models.Token.install(conn)
|
||||||
AuthorizationCode.install(conn)
|
models.AuthorizationCode.install(conn)
|
||||||
Client.install(conn)
|
models.Client.install(conn)
|
||||||
Consent.install(conn)
|
models.Consent.install(conn)
|
||||||
|
|
||||||
|
|
||||||
def setup_keypair(config):
|
def setup_keypair(config):
|
||||||
|
|
|
@ -4,6 +4,7 @@ from authlib.oauth2.rfc6749 import AuthorizationCodeMixin
|
||||||
from authlib.oauth2.rfc6749 import ClientMixin
|
from authlib.oauth2.rfc6749 import ClientMixin
|
||||||
from authlib.oauth2.rfc6749 import TokenMixin
|
from authlib.oauth2.rfc6749 import TokenMixin
|
||||||
from authlib.oauth2.rfc6749 import util
|
from authlib.oauth2.rfc6749 import util
|
||||||
|
from canaille.app import models
|
||||||
from canaille.backends.ldap.ldapobject import LDAPObject
|
from canaille.backends.ldap.ldapobject import LDAPObject
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,13 +98,13 @@ class Client(LDAPObject, ClientMixin):
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def delete(self):
|
def delete(self):
|
||||||
for consent in Consent.query(client=self):
|
for consent in models.Consent.query(client=self):
|
||||||
consent.delete()
|
consent.delete()
|
||||||
|
|
||||||
for code in AuthorizationCode.query(client=self):
|
for code in models.AuthorizationCode.query(client=self):
|
||||||
code.delete()
|
code.delete()
|
||||||
|
|
||||||
for token in Token.query(client=self):
|
for token in models.Token.query(client=self):
|
||||||
token.delete()
|
token.delete()
|
||||||
|
|
||||||
super().delete()
|
super().delete()
|
||||||
|
@ -243,7 +244,7 @@ class Consent(LDAPObject):
|
||||||
self.revokation_date = datetime.datetime.now(datetime.timezone.utc)
|
self.revokation_date = datetime.datetime.now(datetime.timezone.utc)
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
tokens = Token.query(
|
tokens = models.Token.query(
|
||||||
client=self.client,
|
client=self.client,
|
||||||
subject=self.subject,
|
subject=self.subject,
|
||||||
)
|
)
|
||||||
|
|
|
@ -26,15 +26,11 @@ from authlib.oidc.core.grants import OpenIDCode as _OpenIDCode
|
||||||
from authlib.oidc.core.grants import OpenIDHybridGrant as _OpenIDHybridGrant
|
from authlib.oidc.core.grants import OpenIDHybridGrant as _OpenIDHybridGrant
|
||||||
from authlib.oidc.core.grants import OpenIDImplicitGrant as _OpenIDImplicitGrant
|
from authlib.oidc.core.grants import OpenIDImplicitGrant as _OpenIDImplicitGrant
|
||||||
from authlib.oidc.core.grants.util import generate_id_token
|
from authlib.oidc.core.grants.util import generate_id_token
|
||||||
from canaille.core.models import User
|
from canaille.app import models
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from flask import request
|
from flask import request
|
||||||
from werkzeug.security import gen_salt
|
from werkzeug.security import gen_salt
|
||||||
|
|
||||||
from .models import AuthorizationCode
|
|
||||||
from .models import Client
|
|
||||||
from .models import Token
|
|
||||||
|
|
||||||
DEFAULT_JWT_KTY = "RSA"
|
DEFAULT_JWT_KTY = "RSA"
|
||||||
DEFAULT_JWT_ALG = "RS256"
|
DEFAULT_JWT_ALG = "RS256"
|
||||||
DEFAULT_JWT_EXP = 3600
|
DEFAULT_JWT_EXP = 3600
|
||||||
|
@ -55,8 +51,8 @@ DEFAULT_JWT_MAPPING = {
|
||||||
|
|
||||||
|
|
||||||
def exists_nonce(nonce, req):
|
def exists_nonce(nonce, req):
|
||||||
client = Client.get(id=req.client_id)
|
client = models.Client.get(id=req.client_id)
|
||||||
exists = AuthorizationCode.query(client=client, nonce=nonce)
|
exists = models.AuthorizationCode.query(client=client, nonce=nonce)
|
||||||
return bool(exists)
|
return bool(exists)
|
||||||
|
|
||||||
|
|
||||||
|
@ -142,7 +138,7 @@ def save_authorization_code(code, request):
|
||||||
nonce = request.data.get("nonce")
|
nonce = request.data.get("nonce")
|
||||||
now = datetime.datetime.now(datetime.timezone.utc)
|
now = datetime.datetime.now(datetime.timezone.utc)
|
||||||
scope = request.client.get_allowed_scope(request.scope)
|
scope = request.client.get_allowed_scope(request.scope)
|
||||||
code = AuthorizationCode(
|
code = models.AuthorizationCode(
|
||||||
authorization_code_id=gen_salt(48),
|
authorization_code_id=gen_salt(48),
|
||||||
code=code,
|
code=code,
|
||||||
subject=request.user,
|
subject=request.user,
|
||||||
|
@ -166,7 +162,7 @@ class AuthorizationCodeGrant(_AuthorizationCodeGrant):
|
||||||
return save_authorization_code(code, request)
|
return save_authorization_code(code, request)
|
||||||
|
|
||||||
def query_authorization_code(self, code, client):
|
def query_authorization_code(self, code, client):
|
||||||
item = AuthorizationCode.query(code=code, client=client)
|
item = models.AuthorizationCode.query(code=code, client=client)
|
||||||
if item and not item[0].is_expired():
|
if item and not item[0].is_expired():
|
||||||
return item[0]
|
return item[0]
|
||||||
|
|
||||||
|
@ -196,7 +192,7 @@ class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant):
|
||||||
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
|
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
|
||||||
|
|
||||||
def authenticate_user(self, username, password):
|
def authenticate_user(self, username, password):
|
||||||
user = User.get_from_login(username)
|
user = models.User.get_from_login(username)
|
||||||
|
|
||||||
if not user or not user.check_password(password):
|
if not user or not user.check_password(password):
|
||||||
return None
|
return None
|
||||||
|
@ -206,7 +202,7 @@ class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant):
|
||||||
|
|
||||||
class RefreshTokenGrant(_RefreshTokenGrant):
|
class RefreshTokenGrant(_RefreshTokenGrant):
|
||||||
def authenticate_refresh_token(self, refresh_token):
|
def authenticate_refresh_token(self, refresh_token):
|
||||||
token = Token.query(refresh_token=refresh_token)
|
token = models.Token.query(refresh_token=refresh_token)
|
||||||
if token and token[0].is_refresh_token_active():
|
if token and token[0].is_refresh_token_active():
|
||||||
return token[0]
|
return token[0]
|
||||||
|
|
||||||
|
@ -252,12 +248,12 @@ class OpenIDHybridGrant(_OpenIDHybridGrant):
|
||||||
|
|
||||||
|
|
||||||
def query_client(client_id):
|
def query_client(client_id):
|
||||||
return Client.get(client_id=client_id)
|
return models.Client.get(client_id=client_id)
|
||||||
|
|
||||||
|
|
||||||
def save_token(token, request):
|
def save_token(token, request):
|
||||||
now = datetime.datetime.now(datetime.timezone.utc)
|
now = datetime.datetime.now(datetime.timezone.utc)
|
||||||
t = Token(
|
t = models.Token(
|
||||||
token_id=gen_salt(48),
|
token_id=gen_salt(48),
|
||||||
type=token["token_type"],
|
type=token["token_type"],
|
||||||
access_token=token["access_token"],
|
access_token=token["access_token"],
|
||||||
|
@ -274,20 +270,20 @@ def save_token(token, request):
|
||||||
|
|
||||||
class BearerTokenValidator(_BearerTokenValidator):
|
class BearerTokenValidator(_BearerTokenValidator):
|
||||||
def authenticate_token(self, token_string):
|
def authenticate_token(self, token_string):
|
||||||
return Token.get(access_token=token_string)
|
return models.Token.get(access_token=token_string)
|
||||||
|
|
||||||
|
|
||||||
def query_token(token, token_type_hint):
|
def query_token(token, token_type_hint):
|
||||||
if token_type_hint == "access_token":
|
if token_type_hint == "access_token":
|
||||||
return Token.get(access_token=token)
|
return models.Token.get(access_token=token)
|
||||||
elif token_type_hint == "refresh_token":
|
elif token_type_hint == "refresh_token":
|
||||||
return Token.get(refresh_token=token)
|
return models.Token.get(refresh_token=token)
|
||||||
|
|
||||||
item = Token.get(access_token=token)
|
item = models.Token.get(access_token=token)
|
||||||
if item:
|
if item:
|
||||||
return item
|
return item
|
||||||
|
|
||||||
item = Token.get(refresh_token=token)
|
item = models.Token.get(refresh_token=token)
|
||||||
if item:
|
if item:
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
@ -369,7 +365,7 @@ class ClientRegistrationEndpoint(ClientManagementMixin, _ClientRegistrationEndpo
|
||||||
client_metadata["scope"], list
|
client_metadata["scope"], list
|
||||||
):
|
):
|
||||||
client_metadata["scope"] = client_metadata["scope"].split(" ")
|
client_metadata["scope"] = client_metadata["scope"].split(" ")
|
||||||
client = Client(**client_info, **client_metadata)
|
client = models.Client(**client_info, **client_metadata)
|
||||||
client.save()
|
client.save()
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
@ -377,7 +373,7 @@ class ClientRegistrationEndpoint(ClientManagementMixin, _ClientRegistrationEndpo
|
||||||
class ClientConfigurationEndpoint(ClientManagementMixin, _ClientConfigurationEndpoint):
|
class ClientConfigurationEndpoint(ClientManagementMixin, _ClientConfigurationEndpoint):
|
||||||
def authenticate_client(self, request):
|
def authenticate_client(self, request):
|
||||||
client_id = request.uri.split("/")[-1]
|
client_id = request.uri.split("/")[-1]
|
||||||
return Client.get(client_id=client_id)
|
return models.Client.get(client_id=client_id)
|
||||||
|
|
||||||
def revoke_access_token(self, request, token):
|
def revoke_access_token(self, request, token):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
from canaille.app import models
|
||||||
from canaille.app.flask import permissions_needed
|
from canaille.app.flask import permissions_needed
|
||||||
from canaille.app.flask import render_htmx_template
|
from canaille.app.flask import render_htmx_template
|
||||||
from canaille.app.forms import TableForm
|
from canaille.app.forms import TableForm
|
||||||
from canaille.oidc.models import Token
|
|
||||||
from flask import abort
|
from flask import abort
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
from flask import flash
|
from flask import flash
|
||||||
|
@ -19,7 +19,7 @@ bp = Blueprint("tokens", __name__, url_prefix="/admin/token")
|
||||||
@bp.route("/", methods=["GET", "POST"])
|
@bp.route("/", methods=["GET", "POST"])
|
||||||
@permissions_needed("manage_oidc")
|
@permissions_needed("manage_oidc")
|
||||||
def index(user):
|
def index(user):
|
||||||
table_form = TableForm(Token, formdata=request.form)
|
table_form = TableForm(models.Token, formdata=request.form)
|
||||||
if request.form and request.form.get("page") and not table_form.validate():
|
if request.form and request.form.get("page") and not table_form.validate():
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ def index(user):
|
||||||
@bp.route("/<token_id>", methods=["GET", "POST"])
|
@bp.route("/<token_id>", methods=["GET", "POST"])
|
||||||
@permissions_needed("manage_oidc")
|
@permissions_needed("manage_oidc")
|
||||||
def view(user, token_id):
|
def view(user, token_id):
|
||||||
token = Token.get(token_id=token_id)
|
token = models.Token.get(token_id=token_id)
|
||||||
|
|
||||||
if not token:
|
if not token:
|
||||||
abort(404)
|
abort(404)
|
||||||
|
@ -46,7 +46,7 @@ def view(user, token_id):
|
||||||
@bp.route("/<token_id>/revoke", methods=["GET", "POST"])
|
@bp.route("/<token_id>/revoke", methods=["GET", "POST"])
|
||||||
@permissions_needed("manage_oidc")
|
@permissions_needed("manage_oidc")
|
||||||
def revoke(user, token_id):
|
def revoke(user, token_id):
|
||||||
token = Token.get(token_id=token_id)
|
token = models.Token.get(token_id=token_id)
|
||||||
|
|
||||||
if not token:
|
if not token:
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
|
@ -10,15 +10,13 @@ from canaille import create_app as canaille_app
|
||||||
|
|
||||||
|
|
||||||
def populate(app):
|
def populate(app):
|
||||||
from canaille.core.models import Group
|
from canaille.app import models
|
||||||
from canaille.core.models import User
|
|
||||||
from canaille.core.populate import fake_groups
|
from canaille.core.populate import fake_groups
|
||||||
from canaille.core.populate import fake_users
|
from canaille.core.populate import fake_users
|
||||||
from canaille.oidc.models import Client
|
|
||||||
|
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
with app.backend.session():
|
with app.backend.session():
|
||||||
jane = User(
|
jane = models.User(
|
||||||
formatted_name="Jane Doe",
|
formatted_name="Jane Doe",
|
||||||
given_name="Jane",
|
given_name="Jane",
|
||||||
family_name="Doe",
|
family_name="Doe",
|
||||||
|
@ -38,7 +36,7 @@ def populate(app):
|
||||||
)
|
)
|
||||||
jane.save()
|
jane.save()
|
||||||
|
|
||||||
jack = User(
|
jack = models.User(
|
||||||
formatted_name="Jack Doe",
|
formatted_name="Jack Doe",
|
||||||
given_name="Jack",
|
given_name="Jack",
|
||||||
family_name="Doe",
|
family_name="Doe",
|
||||||
|
@ -53,7 +51,7 @@ def populate(app):
|
||||||
)
|
)
|
||||||
jack.save()
|
jack.save()
|
||||||
|
|
||||||
john = User(
|
john = models.User(
|
||||||
formatted_name="John Doe",
|
formatted_name="John Doe",
|
||||||
given_name="John",
|
given_name="John",
|
||||||
family_name="Doe",
|
family_name="Doe",
|
||||||
|
@ -68,7 +66,7 @@ def populate(app):
|
||||||
)
|
)
|
||||||
john.save()
|
john.save()
|
||||||
|
|
||||||
james = User(
|
james = models.User(
|
||||||
formatted_name="James Doe",
|
formatted_name="James Doe",
|
||||||
given_name="James",
|
given_name="James",
|
||||||
family_name="Doe",
|
family_name="Doe",
|
||||||
|
@ -77,28 +75,28 @@ def populate(app):
|
||||||
)
|
)
|
||||||
james.save()
|
james.save()
|
||||||
|
|
||||||
users = Group(
|
users = models.Group(
|
||||||
display_name="users",
|
display_name="users",
|
||||||
members=[jane, jack, john, james],
|
members=[jane, jack, john, james],
|
||||||
description="The regular users.",
|
description="The regular users.",
|
||||||
)
|
)
|
||||||
users.save()
|
users.save()
|
||||||
|
|
||||||
users = Group(
|
users = models.Group(
|
||||||
display_name="admins",
|
display_name="admins",
|
||||||
members=[jane],
|
members=[jane],
|
||||||
description="The administrators.",
|
description="The administrators.",
|
||||||
)
|
)
|
||||||
users.save()
|
users.save()
|
||||||
|
|
||||||
users = Group(
|
users = models.Group(
|
||||||
display_name="moderators",
|
display_name="moderators",
|
||||||
members=[james],
|
members=[james],
|
||||||
description="People who can manage users.",
|
description="People who can manage users.",
|
||||||
)
|
)
|
||||||
users.save()
|
users.save()
|
||||||
|
|
||||||
client1 = Client(
|
client1 = models.Client(
|
||||||
client_id="1JGkkzCbeHpGtlqgI5EENByf",
|
client_id="1JGkkzCbeHpGtlqgI5EENByf",
|
||||||
client_secret="2xYPSReTQRmGG1yppMVZQ0ASXwFejPyirvuPbKhNa6TmKC5x",
|
client_secret="2xYPSReTQRmGG1yppMVZQ0ASXwFejPyirvuPbKhNa6TmKC5x",
|
||||||
client_name="Client1",
|
client_name="Client1",
|
||||||
|
@ -115,7 +113,7 @@ def populate(app):
|
||||||
)
|
)
|
||||||
client1.save()
|
client1.save()
|
||||||
|
|
||||||
client2 = Client(
|
client2 = models.Client(
|
||||||
client_id="gn4yFN7GDykL7QP8v8gS9YfV",
|
client_id="gn4yFN7GDykL7QP8v8gS9YfV",
|
||||||
client_secret="ouFJE5WpICt6hxTyf8icXPeeklMektMY4gV0Rmf3aY60VElA",
|
client_secret="ouFJE5WpICt6hxTyf8icXPeeklMektMY4gV0Rmf3aY60VElA",
|
||||||
client_name="Client2",
|
client_name="Client2",
|
||||||
|
|
|
@ -3,6 +3,7 @@ from unittest import mock
|
||||||
|
|
||||||
import ldap.dn
|
import ldap.dn
|
||||||
import pytest
|
import pytest
|
||||||
|
from canaille.app import models
|
||||||
from canaille.app.configuration import ConfigurationException
|
from canaille.app.configuration import ConfigurationException
|
||||||
from canaille.app.configuration import validate
|
from canaille.app.configuration import validate
|
||||||
from canaille.backends.ldap.backend import setup_ldap_models
|
from canaille.backends.ldap.backend import setup_ldap_models
|
||||||
|
@ -11,12 +12,10 @@ from canaille.backends.ldap.ldapobject import python_attrs_to_ldap
|
||||||
from canaille.backends.ldap.utils import ldap_to_python
|
from canaille.backends.ldap.utils import ldap_to_python
|
||||||
from canaille.backends.ldap.utils import python_to_ldap
|
from canaille.backends.ldap.utils import python_to_ldap
|
||||||
from canaille.backends.ldap.utils import Syntax
|
from canaille.backends.ldap.utils import Syntax
|
||||||
from canaille.core.models import Group
|
|
||||||
from canaille.core.models import User
|
|
||||||
|
|
||||||
|
|
||||||
def test_object_creation(app, backend):
|
def test_object_creation(app, backend):
|
||||||
user = User(
|
user = models.User(
|
||||||
formatted_name="Doe", # leading space
|
formatted_name="Doe", # leading space
|
||||||
family_name="Doe",
|
family_name="Doe",
|
||||||
user_name="user",
|
user_name="user",
|
||||||
|
@ -26,7 +25,7 @@ def test_object_creation(app, backend):
|
||||||
user.save()
|
user.save()
|
||||||
assert user.exists
|
assert user.exists
|
||||||
|
|
||||||
user = User.get(id=user.id)
|
user = models.User.get(id=user.id)
|
||||||
assert user.exists
|
assert user.exists
|
||||||
|
|
||||||
user.delete()
|
user.delete()
|
||||||
|
@ -38,7 +37,7 @@ def test_repr(backend, foo_group, user):
|
||||||
|
|
||||||
|
|
||||||
def test_dn_when_leading_space_in_id_attribute(backend):
|
def test_dn_when_leading_space_in_id_attribute(backend):
|
||||||
user = User(
|
user = models.User(
|
||||||
formatted_name=" Doe", # leading space
|
formatted_name=" Doe", # leading space
|
||||||
family_name="Doe",
|
family_name="Doe",
|
||||||
user_name="user",
|
user_name="user",
|
||||||
|
@ -54,7 +53,7 @@ def test_dn_when_leading_space_in_id_attribute(backend):
|
||||||
|
|
||||||
|
|
||||||
def test_dn_when_ldap_special_char_in_id_attribute(backend):
|
def test_dn_when_ldap_special_char_in_id_attribute(backend):
|
||||||
user = User(
|
user = models.User(
|
||||||
formatted_name="#Doe", # special char
|
formatted_name="#Doe", # special char
|
||||||
family_name="Doe",
|
family_name="Doe",
|
||||||
user_name="user",
|
user_name="user",
|
||||||
|
@ -70,27 +69,32 @@ def test_dn_when_ldap_special_char_in_id_attribute(backend):
|
||||||
|
|
||||||
|
|
||||||
def test_filter(backend, foo_group, bar_group):
|
def test_filter(backend, foo_group, bar_group):
|
||||||
assert Group.query(display_name="foo") == [foo_group]
|
assert models.Group.query(display_name="foo") == [foo_group]
|
||||||
assert Group.query(display_name="bar") == [bar_group]
|
assert models.Group.query(display_name="bar") == [bar_group]
|
||||||
|
|
||||||
assert Group.query(display_name="foo") != 3
|
assert models.Group.query(display_name="foo") != 3
|
||||||
|
|
||||||
assert Group.query(display_name=["foo"]) == [foo_group]
|
assert models.Group.query(display_name=["foo"]) == [foo_group]
|
||||||
assert Group.query(display_name=["bar"]) == [bar_group]
|
assert models.Group.query(display_name=["bar"]) == [bar_group]
|
||||||
|
|
||||||
assert set(Group.query(display_name=["foo", "bar"])) == {foo_group, bar_group}
|
assert set(models.Group.query(display_name=["foo", "bar"])) == {
|
||||||
|
foo_group,
|
||||||
|
bar_group,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_fuzzy(backend, user, moderator, admin):
|
def test_fuzzy(backend, user, moderator, admin):
|
||||||
assert set(User.query()) == {user, moderator, admin}
|
assert set(models.User.query()) == {user, moderator, admin}
|
||||||
assert set(User.fuzzy("Jack")) == {moderator}
|
assert set(models.User.fuzzy("Jack")) == {moderator}
|
||||||
assert set(User.fuzzy("Jack", ["formatted_name"])) == {moderator}
|
assert set(models.User.fuzzy("Jack", ["formatted_name"])) == {moderator}
|
||||||
assert set(User.fuzzy("Jack", ["user_name"])) == set()
|
assert set(models.User.fuzzy("Jack", ["user_name"])) == set()
|
||||||
assert set(User.fuzzy("Jack", ["user_name", "formatted_name"])) == {moderator}
|
assert set(models.User.fuzzy("Jack", ["user_name", "formatted_name"])) == {
|
||||||
assert set(User.fuzzy("moderator")) == {moderator}
|
moderator
|
||||||
assert set(User.fuzzy("oderat")) == {moderator}
|
}
|
||||||
assert set(User.fuzzy("oDeRat")) == {moderator}
|
assert set(models.User.fuzzy("moderator")) == {moderator}
|
||||||
assert set(User.fuzzy("ack")) == {moderator}
|
assert set(models.User.fuzzy("oderat")) == {moderator}
|
||||||
|
assert set(models.User.fuzzy("oDeRat")) == {moderator}
|
||||||
|
assert set(models.User.fuzzy("ack")) == {moderator}
|
||||||
|
|
||||||
|
|
||||||
def test_ldap_to_python():
|
def test_ldap_to_python():
|
||||||
|
@ -180,11 +184,11 @@ def test_guess_object_from_dn(backend, testclient, foo_group):
|
||||||
foo_group.members = [foo_group]
|
foo_group.members = [foo_group]
|
||||||
foo_group.save()
|
foo_group.save()
|
||||||
g = LDAPObject.get(id=foo_group.dn)
|
g = LDAPObject.get(id=foo_group.dn)
|
||||||
assert isinstance(g, Group)
|
assert isinstance(g, models.Group)
|
||||||
assert g == foo_group
|
assert g == foo_group
|
||||||
assert g.cn == foo_group.cn
|
assert g.cn == foo_group.cn
|
||||||
|
|
||||||
ou = LDAPObject.get(id=f"{Group.base},{Group.root_dn}")
|
ou = LDAPObject.get(id=f"{models.Group.base},{models.Group.root_dn}")
|
||||||
assert isinstance(ou, LDAPObject)
|
assert isinstance(ou, LDAPObject)
|
||||||
|
|
||||||
|
|
||||||
|
@ -192,11 +196,11 @@ def test_object_class_update(backend, testclient):
|
||||||
testclient.app.config["BACKENDS"]["LDAP"]["USER_CLASS"] = ["inetOrgPerson"]
|
testclient.app.config["BACKENDS"]["LDAP"]["USER_CLASS"] = ["inetOrgPerson"]
|
||||||
setup_ldap_models(testclient.app.config)
|
setup_ldap_models(testclient.app.config)
|
||||||
|
|
||||||
user1 = User(cn="foo1", sn="bar1")
|
user1 = models.User(cn="foo1", sn="bar1")
|
||||||
user1.save()
|
user1.save()
|
||||||
|
|
||||||
assert user1.objectClass == ["inetOrgPerson"]
|
assert user1.objectClass == ["inetOrgPerson"]
|
||||||
assert User.get(id=user1.id).objectClass == ["inetOrgPerson"]
|
assert models.User.get(id=user1.id).objectClass == ["inetOrgPerson"]
|
||||||
|
|
||||||
testclient.app.config["BACKENDS"]["LDAP"]["USER_CLASS"] = [
|
testclient.app.config["BACKENDS"]["LDAP"]["USER_CLASS"] = [
|
||||||
"inetOrgPerson",
|
"inetOrgPerson",
|
||||||
|
@ -204,18 +208,24 @@ def test_object_class_update(backend, testclient):
|
||||||
]
|
]
|
||||||
setup_ldap_models(testclient.app.config)
|
setup_ldap_models(testclient.app.config)
|
||||||
|
|
||||||
user2 = User(cn="foo2", sn="bar2")
|
user2 = models.User(cn="foo2", sn="bar2")
|
||||||
user2.save()
|
user2.save()
|
||||||
|
|
||||||
assert user2.objectClass == ["inetOrgPerson", "extensibleObject"]
|
assert user2.objectClass == ["inetOrgPerson", "extensibleObject"]
|
||||||
assert User.get(id=user2.id).objectClass == ["inetOrgPerson", "extensibleObject"]
|
assert models.User.get(id=user2.id).objectClass == [
|
||||||
|
"inetOrgPerson",
|
||||||
|
"extensibleObject",
|
||||||
|
]
|
||||||
|
|
||||||
user1 = User.get(id=user1.id)
|
user1 = models.User.get(id=user1.id)
|
||||||
assert user1.objectClass == ["inetOrgPerson"]
|
assert user1.objectClass == ["inetOrgPerson"]
|
||||||
|
|
||||||
user1.save()
|
user1.save()
|
||||||
assert user1.objectClass == ["inetOrgPerson", "extensibleObject"]
|
assert user1.objectClass == ["inetOrgPerson", "extensibleObject"]
|
||||||
assert User.get(id=user1.id).objectClass == ["inetOrgPerson", "extensibleObject"]
|
assert models.User.get(id=user1.id).objectClass == [
|
||||||
|
"inetOrgPerson",
|
||||||
|
"extensibleObject",
|
||||||
|
]
|
||||||
|
|
||||||
user1.delete()
|
user1.delete()
|
||||||
user2.delete()
|
user2.delete()
|
||||||
|
|
|
@ -1,21 +1,20 @@
|
||||||
from canaille.core.models import Group
|
from canaille.app import models
|
||||||
from canaille.core.models import User
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_comparison(testclient, backend):
|
def test_model_comparison(testclient, backend):
|
||||||
foo1 = User(
|
foo1 = models.User(
|
||||||
user_name="foo",
|
user_name="foo",
|
||||||
family_name="foo",
|
family_name="foo",
|
||||||
formatted_name="foo",
|
formatted_name="foo",
|
||||||
)
|
)
|
||||||
foo1.save()
|
foo1.save()
|
||||||
bar = User(
|
bar = models.User(
|
||||||
user_name="bar",
|
user_name="bar",
|
||||||
family_name="bar",
|
family_name="bar",
|
||||||
formatted_name="bar",
|
formatted_name="bar",
|
||||||
)
|
)
|
||||||
bar.save()
|
bar.save()
|
||||||
foo2 = User.get(id=foo1.id)
|
foo2 = models.User.get(id=foo1.id)
|
||||||
|
|
||||||
assert foo1 == foo2
|
assert foo1 == foo2
|
||||||
assert foo1 != bar
|
assert foo1 != bar
|
||||||
|
@ -25,23 +24,23 @@ def test_model_comparison(testclient, backend):
|
||||||
|
|
||||||
|
|
||||||
def test_model_lifecycle(testclient, backend):
|
def test_model_lifecycle(testclient, backend):
|
||||||
user = User(
|
user = models.User(
|
||||||
user_name="user_name",
|
user_name="user_name",
|
||||||
family_name="family_name",
|
family_name="family_name",
|
||||||
formatted_name="formatted_name",
|
formatted_name="formatted_name",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert not User.query()
|
assert not models.User.query()
|
||||||
assert not User.query(id=user.id)
|
assert not models.User.query(id=user.id)
|
||||||
assert not User.query(id="invalid")
|
assert not models.User.query(id="invalid")
|
||||||
assert not User.get(id=user.id)
|
assert not models.User.get(id=user.id)
|
||||||
|
|
||||||
user.save()
|
user.save()
|
||||||
|
|
||||||
assert User.query() == [user]
|
assert models.User.query() == [user]
|
||||||
assert User.query(id=user.id) == [user]
|
assert models.User.query(id=user.id) == [user]
|
||||||
assert not User.query(id="invalid")
|
assert not models.User.query(id="invalid")
|
||||||
assert User.get(id=user.id) == user
|
assert models.User.get(id=user.id) == user
|
||||||
|
|
||||||
user.family_name = "new_family_name"
|
user.family_name = "new_family_name"
|
||||||
|
|
||||||
|
@ -53,12 +52,12 @@ def test_model_lifecycle(testclient, backend):
|
||||||
|
|
||||||
user.delete()
|
user.delete()
|
||||||
|
|
||||||
assert not User.query(id=user.id)
|
assert not models.User.query(id=user.id)
|
||||||
assert not User.get(id=user.id)
|
assert not models.User.get(id=user.id)
|
||||||
|
|
||||||
|
|
||||||
def test_model_attribute_edition(testclient, backend):
|
def test_model_attribute_edition(testclient, backend):
|
||||||
user = User(
|
user = models.User(
|
||||||
user_name="user_name",
|
user_name="user_name",
|
||||||
family_name="family_name",
|
family_name="family_name",
|
||||||
formatted_name="formatted_name",
|
formatted_name="formatted_name",
|
||||||
|
@ -71,7 +70,7 @@ def test_model_attribute_edition(testclient, backend):
|
||||||
assert user.family_name == ["family_name"]
|
assert user.family_name == ["family_name"]
|
||||||
assert user.email == ["email1@user.com", "email2@user.com"]
|
assert user.email == ["email1@user.com", "email2@user.com"]
|
||||||
|
|
||||||
user = User.get(id=user.id)
|
user = models.User.get(id=user.id)
|
||||||
assert user.user_name == ["user_name"]
|
assert user.user_name == ["user_name"]
|
||||||
assert user.family_name == ["family_name"]
|
assert user.family_name == ["family_name"]
|
||||||
assert user.email == ["email1@user.com", "email2@user.com"]
|
assert user.email == ["email1@user.com", "email2@user.com"]
|
||||||
|
@ -83,7 +82,7 @@ def test_model_attribute_edition(testclient, backend):
|
||||||
assert user.family_name == ["new_family_name"]
|
assert user.family_name == ["new_family_name"]
|
||||||
assert user.email == ["email1@user.com"]
|
assert user.email == ["email1@user.com"]
|
||||||
|
|
||||||
user = User.get(id=user.id)
|
user = models.User.get(id=user.id)
|
||||||
assert user.family_name == ["new_family_name"]
|
assert user.family_name == ["new_family_name"]
|
||||||
assert user.email == ["email1@user.com"]
|
assert user.email == ["email1@user.com"]
|
||||||
|
|
||||||
|
@ -96,7 +95,7 @@ def test_model_attribute_edition(testclient, backend):
|
||||||
|
|
||||||
|
|
||||||
def test_model_indexation(testclient, backend):
|
def test_model_indexation(testclient, backend):
|
||||||
user = User(
|
user = models.User(
|
||||||
user_name="user_name",
|
user_name="user_name",
|
||||||
family_name="family_name",
|
family_name="family_name",
|
||||||
formatted_name="formatted_name",
|
formatted_name="formatted_name",
|
||||||
|
@ -104,56 +103,58 @@ def test_model_indexation(testclient, backend):
|
||||||
)
|
)
|
||||||
user.save()
|
user.save()
|
||||||
|
|
||||||
assert User.get(family_name="family_name") == user
|
assert models.User.get(family_name="family_name") == user
|
||||||
assert not User.get(family_name="new_family_name")
|
assert not models.User.get(family_name="new_family_name")
|
||||||
assert User.get(email="email1@user.com") == user
|
assert models.User.get(email="email1@user.com") == user
|
||||||
assert User.get(email="email2@user.com") == user
|
assert models.User.get(email="email2@user.com") == user
|
||||||
assert not User.get(email="email3@user.com")
|
assert not models.User.get(email="email3@user.com")
|
||||||
|
|
||||||
user.family_name = "new_family_name"
|
user.family_name = "new_family_name"
|
||||||
user.email = ["email2@user.com"]
|
user.email = ["email2@user.com"]
|
||||||
|
|
||||||
assert User.get(family_name="family_name") != user
|
assert models.User.get(family_name="family_name") != user
|
||||||
assert not User.get(family_name="new_family_name")
|
assert not models.User.get(family_name="new_family_name")
|
||||||
assert User.get(email="email1@user.com") != user
|
assert models.User.get(email="email1@user.com") != user
|
||||||
assert User.get(email="email2@user.com") != user
|
assert models.User.get(email="email2@user.com") != user
|
||||||
assert not User.get(email="email3@user.com")
|
assert not models.User.get(email="email3@user.com")
|
||||||
|
|
||||||
user.save()
|
user.save()
|
||||||
|
|
||||||
assert not User.get(family_name="family_name")
|
assert not models.User.get(family_name="family_name")
|
||||||
assert User.get(family_name="new_family_name") == user
|
assert models.User.get(family_name="new_family_name") == user
|
||||||
assert not User.get(email="email1@user.com")
|
assert not models.User.get(email="email1@user.com")
|
||||||
assert User.get(email="email2@user.com") == user
|
assert models.User.get(email="email2@user.com") == user
|
||||||
assert not User.get(email="email3@user.com")
|
assert not models.User.get(email="email3@user.com")
|
||||||
|
|
||||||
user.delete()
|
user.delete()
|
||||||
|
|
||||||
assert not User.get(family_name="family_name")
|
assert not models.User.get(family_name="family_name")
|
||||||
assert not User.get(family_name="new_family_name")
|
assert not models.User.get(family_name="new_family_name")
|
||||||
assert not User.get(email="email1@user.com")
|
assert not models.User.get(email="email1@user.com")
|
||||||
assert not User.get(email="email2@user.com")
|
assert not models.User.get(email="email2@user.com")
|
||||||
assert not User.get(email="email3@user.com")
|
assert not models.User.get(email="email3@user.com")
|
||||||
|
|
||||||
|
|
||||||
def test_fuzzy(user, moderator, admin, backend):
|
def test_fuzzy(user, moderator, admin, backend):
|
||||||
assert set(User.query()) == {user, moderator, admin}
|
assert set(models.User.query()) == {user, moderator, admin}
|
||||||
assert set(User.fuzzy("Jack")) == {moderator}
|
assert set(models.User.fuzzy("Jack")) == {moderator}
|
||||||
assert set(User.fuzzy("Jack", ["formatted_name"])) == {moderator}
|
assert set(models.User.fuzzy("Jack", ["formatted_name"])) == {moderator}
|
||||||
assert set(User.fuzzy("Jack", ["user_name"])) == set()
|
assert set(models.User.fuzzy("Jack", ["user_name"])) == set()
|
||||||
assert set(User.fuzzy("Jack", ["user_name", "formatted_name"])) == {moderator}
|
assert set(models.User.fuzzy("Jack", ["user_name", "formatted_name"])) == {
|
||||||
assert set(User.fuzzy("moderator")) == {moderator}
|
moderator
|
||||||
assert set(User.fuzzy("oderat")) == {moderator}
|
}
|
||||||
assert set(User.fuzzy("oDeRat")) == {moderator}
|
assert set(models.User.fuzzy("moderator")) == {moderator}
|
||||||
assert set(User.fuzzy("ack")) == {moderator}
|
assert set(models.User.fuzzy("oderat")) == {moderator}
|
||||||
|
assert set(models.User.fuzzy("oDeRat")) == {moderator}
|
||||||
|
assert set(models.User.fuzzy("ack")) == {moderator}
|
||||||
|
|
||||||
|
|
||||||
# def test_model_references(user, admin, foo_group, bar_group):
|
# def test_model_references(user, admin, foo_group, bar_group):
|
||||||
def test_model_references(testclient, user, foo_group, admin, bar_group, backend):
|
def test_model_references(testclient, user, foo_group, admin, bar_group, backend):
|
||||||
assert foo_group.members == [user]
|
assert foo_group.members == [user]
|
||||||
assert user.groups == [foo_group]
|
assert user.groups == [foo_group]
|
||||||
assert foo_group in Group.query(members=user)
|
assert foo_group in models.Group.query(members=user)
|
||||||
assert user in User.query(groups=foo_group)
|
assert user in models.User.query(groups=foo_group)
|
||||||
|
|
||||||
assert user not in bar_group.members
|
assert user not in bar_group.members
|
||||||
assert bar_group not in user.groups
|
assert bar_group not in user.groups
|
||||||
|
@ -175,11 +176,11 @@ def test_model_references(testclient, user, foo_group, admin, bar_group, backend
|
||||||
def test_model_references_set_unsaved_object(
|
def test_model_references_set_unsaved_object(
|
||||||
testclient, logged_moderator, user, backend
|
testclient, logged_moderator, user, backend
|
||||||
):
|
):
|
||||||
group = Group(members=[user], display_name="foo")
|
group = models.Group(members=[user], display_name="foo")
|
||||||
group.save()
|
group.save()
|
||||||
user.reload() # an LDAP group can be inconsistent by containing members which doesn't exist
|
user.reload() # an LDAP group can be inconsistent by containing members which doesn't exist
|
||||||
|
|
||||||
non_existent_user = User(formatted_name="foo", family_name="bar")
|
non_existent_user = models.User(formatted_name="foo", family_name="bar")
|
||||||
group.members = group.members + [non_existent_user]
|
group.members = group.members + [non_existent_user]
|
||||||
assert group.members == [user, non_existent_user]
|
assert group.members == [user, non_existent_user]
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
import slapd
|
import slapd
|
||||||
from canaille import create_app
|
from canaille import create_app
|
||||||
|
from canaille.app import models
|
||||||
from canaille.backends.ldap.backend import LDAPBackend
|
from canaille.backends.ldap.backend import LDAPBackend
|
||||||
from canaille.core.models import Group
|
|
||||||
from canaille.core.models import User
|
|
||||||
from flask_webtest import TestApp
|
from flask_webtest import TestApp
|
||||||
from werkzeug.security import gen_salt
|
from werkzeug.security import gen_salt
|
||||||
|
|
||||||
|
@ -156,7 +155,7 @@ def testclient(app):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def user(app, backend):
|
def user(app, backend):
|
||||||
u = User(
|
u = models.User(
|
||||||
formatted_name="John (johnny) Doe",
|
formatted_name="John (johnny) Doe",
|
||||||
given_name="John",
|
given_name="John",
|
||||||
family_name="Doe",
|
family_name="Doe",
|
||||||
|
@ -176,7 +175,7 @@ def user(app, backend):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def admin(app, backend):
|
def admin(app, backend):
|
||||||
u = User(
|
u = models.User(
|
||||||
formatted_name="Jane Doe",
|
formatted_name="Jane Doe",
|
||||||
family_name="Doe",
|
family_name="Doe",
|
||||||
user_name="admin",
|
user_name="admin",
|
||||||
|
@ -190,7 +189,7 @@ def admin(app, backend):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def moderator(app, backend):
|
def moderator(app, backend):
|
||||||
u = User(
|
u = models.User(
|
||||||
formatted_name="Jack Doe",
|
formatted_name="Jack Doe",
|
||||||
family_name="Doe",
|
family_name="Doe",
|
||||||
user_name="moderator",
|
user_name="moderator",
|
||||||
|
@ -225,7 +224,7 @@ def logged_moderator(moderator, testclient):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def foo_group(app, user, backend):
|
def foo_group(app, user, backend):
|
||||||
group = Group(
|
group = models.Group(
|
||||||
members=[user],
|
members=[user],
|
||||||
display_name="foo",
|
display_name="foo",
|
||||||
)
|
)
|
||||||
|
@ -237,7 +236,7 @@ def foo_group(app, user, backend):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def bar_group(app, admin, backend):
|
def bar_group(app, admin, backend):
|
||||||
group = Group(
|
group = models.Group(
|
||||||
members=[admin],
|
members=[admin],
|
||||||
display_name="bar",
|
display_name="bar",
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,17 +1,16 @@
|
||||||
|
from canaille.app import models
|
||||||
from canaille.commands import cli
|
from canaille.commands import cli
|
||||||
from canaille.core.models import Group
|
|
||||||
from canaille.core.models import User
|
|
||||||
from canaille.core.populate import fake_users
|
from canaille.core.populate import fake_users
|
||||||
|
|
||||||
|
|
||||||
def test_populate_users(testclient, backend):
|
def test_populate_users(testclient, backend):
|
||||||
runner = testclient.app.test_cli_runner()
|
runner = testclient.app.test_cli_runner()
|
||||||
|
|
||||||
assert len(User.query()) == 0
|
assert len(models.User.query()) == 0
|
||||||
res = runner.invoke(cli, ["populate", "--nb", "10", "users"])
|
res = runner.invoke(cli, ["populate", "--nb", "10", "users"])
|
||||||
assert res.exit_code == 0, res.stdout
|
assert res.exit_code == 0, res.stdout
|
||||||
assert len(User.query()) == 10
|
assert len(models.User.query()) == 10
|
||||||
for user in User.query():
|
for user in models.User.query():
|
||||||
user.delete()
|
user.delete()
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,13 +18,13 @@ def test_populate_groups(testclient, backend):
|
||||||
fake_users(10)
|
fake_users(10)
|
||||||
runner = testclient.app.test_cli_runner()
|
runner = testclient.app.test_cli_runner()
|
||||||
|
|
||||||
assert len(Group.query()) == 0
|
assert len(models.Group.query()) == 0
|
||||||
res = runner.invoke(cli, ["populate", "--nb", "10", "groups"])
|
res = runner.invoke(cli, ["populate", "--nb", "10", "groups"])
|
||||||
assert res.exit_code == 0, res.stdout
|
assert res.exit_code == 0, res.stdout
|
||||||
assert len(Group.query()) == 10
|
assert len(models.Group.query()) == 10
|
||||||
|
|
||||||
for group in Group.query():
|
for group in models.Group.query():
|
||||||
group.delete()
|
group.delete()
|
||||||
|
|
||||||
for user in User.query():
|
for user in models.User.query():
|
||||||
user.delete()
|
user.delete()
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from canaille.core.models import User
|
from canaille.app import models
|
||||||
|
|
||||||
|
|
||||||
def test_index(testclient, user):
|
def test_index(testclient, user):
|
||||||
|
@ -112,7 +112,7 @@ def test_password_page_without_signin_in_redirects_to_login_page(testclient, use
|
||||||
|
|
||||||
def test_user_without_password_first_login(testclient, backend, smtpd):
|
def test_user_without_password_first_login(testclient, backend, smtpd):
|
||||||
assert len(smtpd.messages) == 0
|
assert len(smtpd.messages) == 0
|
||||||
u = User(
|
u = models.User(
|
||||||
formatted_name="Temp User",
|
formatted_name="Temp User",
|
||||||
family_name="Temp",
|
family_name="Temp",
|
||||||
user_name="temp",
|
user_name="temp",
|
||||||
|
@ -146,7 +146,7 @@ def test_first_login_account_initialization_mail_sending_failed(
|
||||||
SMTP.side_effect = mock.Mock(side_effect=OSError("unit test mail error"))
|
SMTP.side_effect = mock.Mock(side_effect=OSError("unit test mail error"))
|
||||||
assert len(smtpd.messages) == 0
|
assert len(smtpd.messages) == 0
|
||||||
|
|
||||||
u = User(
|
u = models.User(
|
||||||
formatted_name="Temp User",
|
formatted_name="Temp User",
|
||||||
family_name="Temp",
|
family_name="Temp",
|
||||||
user_name="temp",
|
user_name="temp",
|
||||||
|
@ -168,7 +168,7 @@ def test_first_login_account_initialization_mail_sending_failed(
|
||||||
|
|
||||||
def test_first_login_form_error(testclient, backend, smtpd):
|
def test_first_login_form_error(testclient, backend, smtpd):
|
||||||
assert len(smtpd.messages) == 0
|
assert len(smtpd.messages) == 0
|
||||||
u = User(
|
u = models.User(
|
||||||
formatted_name="Temp User",
|
formatted_name="Temp User",
|
||||||
family_name="Temp",
|
family_name="Temp",
|
||||||
user_name="temp",
|
user_name="temp",
|
||||||
|
@ -190,7 +190,7 @@ def test_first_login_page_unavailable_for_users_with_password(
|
||||||
|
|
||||||
|
|
||||||
def test_user_password_deleted_during_login(testclient, backend):
|
def test_user_password_deleted_during_login(testclient, backend):
|
||||||
u = User(
|
u = models.User(
|
||||||
formatted_name="Temp User",
|
formatted_name="Temp User",
|
||||||
family_name="Temp",
|
family_name="Temp",
|
||||||
user_name="temp",
|
user_name="temp",
|
||||||
|
@ -214,7 +214,7 @@ def test_user_password_deleted_during_login(testclient, backend):
|
||||||
|
|
||||||
|
|
||||||
def test_user_deleted_in_session(testclient, backend):
|
def test_user_deleted_in_session(testclient, backend):
|
||||||
u = User(
|
u = models.User(
|
||||||
formatted_name="Jake Doe",
|
formatted_name="Jake Doe",
|
||||||
family_name="Jake",
|
family_name="Jake",
|
||||||
user_name="jake",
|
user_name="jake",
|
||||||
|
@ -278,7 +278,7 @@ def test_wrong_login(testclient, user):
|
||||||
|
|
||||||
|
|
||||||
def test_admin_self_deletion(testclient, backend):
|
def test_admin_self_deletion(testclient, backend):
|
||||||
admin = User(
|
admin = models.User(
|
||||||
formatted_name="Temp admin",
|
formatted_name="Temp admin",
|
||||||
family_name="admin",
|
family_name="admin",
|
||||||
user_name="temp",
|
user_name="temp",
|
||||||
|
@ -296,14 +296,14 @@ def test_admin_self_deletion(testclient, backend):
|
||||||
.follow(status=200)
|
.follow(status=200)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert User.get_from_login("temp") is None
|
assert models.User.get_from_login("temp") is None
|
||||||
|
|
||||||
with testclient.session_transaction() as sess:
|
with testclient.session_transaction() as sess:
|
||||||
assert not sess.get("user_id")
|
assert not sess.get("user_id")
|
||||||
|
|
||||||
|
|
||||||
def test_user_self_deletion(testclient, backend):
|
def test_user_self_deletion(testclient, backend):
|
||||||
user = User(
|
user = models.User(
|
||||||
formatted_name="Temp user",
|
formatted_name="Temp user",
|
||||||
family_name="user",
|
family_name="user",
|
||||||
user_name="temp",
|
user_name="temp",
|
||||||
|
@ -330,7 +330,7 @@ def test_user_self_deletion(testclient, backend):
|
||||||
.follow(status=200)
|
.follow(status=200)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert User.get_from_login("temp") is None
|
assert models.User.get_from_login("temp") is None
|
||||||
|
|
||||||
with testclient.session_transaction() as sess:
|
with testclient.session_transaction() as sess:
|
||||||
assert not sess.get("user_id")
|
assert not sess.get("user_id")
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
from canaille.core.models import Group
|
from canaille.app import models
|
||||||
from canaille.core.models import User
|
|
||||||
from canaille.core.populate import fake_groups
|
from canaille.core.populate import fake_groups
|
||||||
from canaille.core.populate import fake_users
|
from canaille.core.populate import fake_users
|
||||||
|
|
||||||
|
|
||||||
def test_no_group(app, backend):
|
def test_no_group(app, backend):
|
||||||
assert Group.query() == []
|
assert models.Group.query() == []
|
||||||
|
|
||||||
|
|
||||||
def test_group_list_pagination(testclient, logged_admin, foo_group):
|
def test_group_list_pagination(testclient, logged_admin, foo_group):
|
||||||
|
@ -49,7 +48,7 @@ def test_group_list_bad_pages(testclient, logged_admin):
|
||||||
|
|
||||||
|
|
||||||
def test_group_deletion(testclient, backend):
|
def test_group_deletion(testclient, backend):
|
||||||
user = User(
|
user = models.User(
|
||||||
formatted_name="foobar",
|
formatted_name="foobar",
|
||||||
family_name="foobar",
|
family_name="foobar",
|
||||||
user_name="foobar",
|
user_name="foobar",
|
||||||
|
@ -57,7 +56,7 @@ def test_group_deletion(testclient, backend):
|
||||||
)
|
)
|
||||||
user.save()
|
user.save()
|
||||||
|
|
||||||
group = Group(
|
group = models.Group(
|
||||||
members=[user],
|
members=[user],
|
||||||
display_name="foobar",
|
display_name="foobar",
|
||||||
)
|
)
|
||||||
|
@ -109,7 +108,7 @@ def test_set_groups(app, user, foo_group, bar_group):
|
||||||
|
|
||||||
|
|
||||||
def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group):
|
def test_set_groups_with_leading_space_in_user_id_attribute(app, foo_group):
|
||||||
user = User(
|
user = models.User(
|
||||||
formatted_name=" Doe", # leading space in id attribute
|
formatted_name=" Doe", # leading space in id attribute
|
||||||
family_name="Doe",
|
family_name="Doe",
|
||||||
user_name="user2",
|
user_name="user2",
|
||||||
|
@ -137,8 +136,8 @@ def test_moderator_can_create_edit_and_delete_group(
|
||||||
):
|
):
|
||||||
# The group does not exist
|
# The group does not exist
|
||||||
res = testclient.get("/groups", status=200)
|
res = testclient.get("/groups", status=200)
|
||||||
assert Group.get(display_name="bar") is None
|
assert models.Group.get(display_name="bar") is None
|
||||||
assert Group.get(display_name="foo") == foo_group
|
assert models.Group.get(display_name="foo") == foo_group
|
||||||
res.mustcontain(no="bar")
|
res.mustcontain(no="bar")
|
||||||
res.mustcontain("foo")
|
res.mustcontain("foo")
|
||||||
|
|
||||||
|
@ -152,7 +151,7 @@ def test_moderator_can_create_edit_and_delete_group(
|
||||||
res = form.submit(status=302).follow(status=200)
|
res = form.submit(status=302).follow(status=200)
|
||||||
|
|
||||||
logged_moderator.reload()
|
logged_moderator.reload()
|
||||||
bar_group = Group.get(display_name="bar")
|
bar_group = models.Group.get(display_name="bar")
|
||||||
assert bar_group.display_name == "bar"
|
assert bar_group.display_name == "bar"
|
||||||
assert bar_group.description == ["yolo"]
|
assert bar_group.description == ["yolo"]
|
||||||
assert bar_group.members == [
|
assert bar_group.members == [
|
||||||
|
@ -168,17 +167,17 @@ def test_moderator_can_create_edit_and_delete_group(
|
||||||
|
|
||||||
res = form.submit(name="action", value="edit").follow()
|
res = form.submit(name="action", value="edit").follow()
|
||||||
|
|
||||||
bar_group = Group.get(display_name="bar")
|
bar_group = models.Group.get(display_name="bar")
|
||||||
assert bar_group.display_name == "bar"
|
assert bar_group.display_name == "bar"
|
||||||
assert bar_group.description == ["yolo2"]
|
assert bar_group.description == ["yolo2"]
|
||||||
assert Group.get(display_name="bar2") is None
|
assert models.Group.get(display_name="bar2") is None
|
||||||
members = bar_group.members
|
members = bar_group.members
|
||||||
for member in members:
|
for member in members:
|
||||||
res.mustcontain(member.formatted_name[0])
|
res.mustcontain(member.formatted_name[0])
|
||||||
|
|
||||||
# Group is deleted
|
# Group is deleted
|
||||||
res = form.submit(name="action", value="delete", status=302)
|
res = form.submit(name="action", value="delete", status=302)
|
||||||
assert Group.get(display_name="bar") is None
|
assert models.Group.get(display_name="bar") is None
|
||||||
assert ("success", "The group bar has been sucessfully deleted") in res.flashes
|
assert ("success", "The group bar has been sucessfully deleted") in res.flashes
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
from canaille.app import models
|
||||||
from canaille.core.account import Invitation
|
from canaille.core.account import Invitation
|
||||||
from canaille.core.models import User
|
|
||||||
|
|
||||||
|
|
||||||
def test_invitation(testclient, logged_admin, foo_group, smtpd):
|
def test_invitation(testclient, logged_admin, foo_group, smtpd):
|
||||||
assert User.get_from_login("someone") is None
|
assert models.User.get_from_login("someone") is None
|
||||||
|
|
||||||
res = testclient.get("/invite", status=200)
|
res = testclient.get("/invite", status=200)
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ def test_invitation(testclient, logged_admin, foo_group, smtpd):
|
||||||
assert ("success", "Your account has been created successfuly.") in res.flashes
|
assert ("success", "Your account has been created successfuly.") in res.flashes
|
||||||
res = res.follow(status=200)
|
res = res.follow(status=200)
|
||||||
|
|
||||||
user = User.get_from_login("someone")
|
user = models.User.get_from_login("someone")
|
||||||
foo_group.reload()
|
foo_group.reload()
|
||||||
assert user.check_password("whatever")
|
assert user.check_password("whatever")
|
||||||
assert user.groups == [foo_group]
|
assert user.groups == [foo_group]
|
||||||
|
@ -53,8 +53,8 @@ def test_invitation(testclient, logged_admin, foo_group, smtpd):
|
||||||
|
|
||||||
|
|
||||||
def test_invitation_editable_user_name(testclient, logged_admin, foo_group, smtpd):
|
def test_invitation_editable_user_name(testclient, logged_admin, foo_group, smtpd):
|
||||||
assert User.get_from_login("jackyjack") is None
|
assert models.User.get_from_login("jackyjack") is None
|
||||||
assert User.get_from_login("djorje") is None
|
assert models.User.get_from_login("djorje") is None
|
||||||
|
|
||||||
res = testclient.get("/invite", status=200)
|
res = testclient.get("/invite", status=200)
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ def test_invitation_editable_user_name(testclient, logged_admin, foo_group, smtp
|
||||||
assert ("success", "Your account has been created successfuly.") in res.flashes
|
assert ("success", "Your account has been created successfuly.") in res.flashes
|
||||||
res = res.follow(status=200)
|
res = res.follow(status=200)
|
||||||
|
|
||||||
user = User.get_from_login("djorje")
|
user = models.User.get_from_login("djorje")
|
||||||
foo_group.reload()
|
foo_group.reload()
|
||||||
assert user.check_password("whatever")
|
assert user.check_password("whatever")
|
||||||
assert user.groups == [foo_group]
|
assert user.groups == [foo_group]
|
||||||
|
@ -101,7 +101,7 @@ def test_invitation_editable_user_name(testclient, logged_admin, foo_group, smtp
|
||||||
|
|
||||||
|
|
||||||
def test_generate_link(testclient, logged_admin, foo_group, smtpd):
|
def test_generate_link(testclient, logged_admin, foo_group, smtpd):
|
||||||
assert User.get_from_login("sometwo") is None
|
assert models.User.get_from_login("sometwo") is None
|
||||||
|
|
||||||
res = testclient.get("/invite", status=200)
|
res = testclient.get("/invite", status=200)
|
||||||
|
|
||||||
|
@ -131,7 +131,7 @@ def test_generate_link(testclient, logged_admin, foo_group, smtpd):
|
||||||
res = res.form.submit(status=302)
|
res = res.form.submit(status=302)
|
||||||
res = res.follow(status=200)
|
res = res.follow(status=200)
|
||||||
|
|
||||||
user = User.get_from_login("sometwo")
|
user = models.User.get_from_login("sometwo")
|
||||||
foo_group.reload()
|
foo_group.reload()
|
||||||
assert user.check_password("whatever")
|
assert user.check_password("whatever")
|
||||||
assert user.groups == [foo_group]
|
assert user.groups == [foo_group]
|
||||||
|
@ -228,7 +228,7 @@ def test_registration_no_password(testclient, foo_group):
|
||||||
res = res.form.submit(status=200)
|
res = res.form.submit(status=200)
|
||||||
res.mustcontain("This field is required.")
|
res.mustcontain("This field is required.")
|
||||||
|
|
||||||
assert not User.get_from_login("someoneelse")
|
assert not models.User.get_from_login("someoneelse")
|
||||||
|
|
||||||
with testclient.session_transaction() as sess:
|
with testclient.session_transaction() as sess:
|
||||||
assert "user_id" not in sess
|
assert "user_id" not in sess
|
||||||
|
@ -295,7 +295,7 @@ def test_groups_are_saved_even_when_user_does_not_have_read_permission(
|
||||||
res = res.form.submit(status=302)
|
res = res.form.submit(status=302)
|
||||||
res = res.follow(status=200)
|
res = res.follow(status=200)
|
||||||
|
|
||||||
user = User.get_from_login("someoneelse")
|
user = models.User.get_from_login("someoneelse")
|
||||||
foo_group.reload()
|
foo_group.reload()
|
||||||
assert user.groups == [foo_group]
|
assert user.groups == [foo_group]
|
||||||
user.delete()
|
user.delete()
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
from canaille.core.models import User
|
from canaille.app import models
|
||||||
|
|
||||||
|
|
||||||
def test_user_get_from_login(testclient, user, backend):
|
def test_user_get_from_login(testclient, user, backend):
|
||||||
assert User.get_from_login(login="invalid") is None
|
assert models.User.get_from_login(login="invalid") is None
|
||||||
assert User.get_from_login(login="user") == user
|
assert models.User.get_from_login(login="user") == user
|
||||||
|
|
||||||
|
|
||||||
def test_user_has_password(testclient, backend):
|
def test_user_has_password(testclient, backend):
|
||||||
u = User(
|
u = models.User(
|
||||||
formatted_name="Temp User",
|
formatted_name="Temp User",
|
||||||
family_name="Temp",
|
family_name="Temp",
|
||||||
user_name="temp",
|
user_name="temp",
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from canaille.core.models import User
|
from canaille.app import models
|
||||||
|
|
||||||
|
|
||||||
def test_user_creation_edition_and_deletion(
|
def test_user_creation_edition_and_deletion(
|
||||||
|
@ -6,7 +6,7 @@ def test_user_creation_edition_and_deletion(
|
||||||
):
|
):
|
||||||
# The user does not exist.
|
# The user does not exist.
|
||||||
res = testclient.get("/users", status=200)
|
res = testclient.get("/users", status=200)
|
||||||
assert User.get_from_login("george") is None
|
assert models.User.get_from_login("george") is None
|
||||||
res.mustcontain(no="george")
|
res.mustcontain(no="george")
|
||||||
|
|
||||||
# Fill the profile for a new user.
|
# Fill the profile for a new user.
|
||||||
|
@ -24,7 +24,7 @@ def test_user_creation_edition_and_deletion(
|
||||||
res = res.form.submit(name="action", value="edit", status=302)
|
res = res.form.submit(name="action", value="edit", status=302)
|
||||||
assert ("success", "User account creation succeed.") in res.flashes
|
assert ("success", "User account creation succeed.") in res.flashes
|
||||||
res = res.follow(status=200)
|
res = res.follow(status=200)
|
||||||
george = User.get_from_login("george")
|
george = models.User.get_from_login("george")
|
||||||
foo_group.reload()
|
foo_group.reload()
|
||||||
assert "George" == george.given_name[0]
|
assert "George" == george.given_name[0]
|
||||||
assert george.groups == [foo_group]
|
assert george.groups == [foo_group]
|
||||||
|
@ -45,7 +45,7 @@ def test_user_creation_edition_and_deletion(
|
||||||
res.form["groups"] = [foo_group.id, bar_group.id]
|
res.form["groups"] = [foo_group.id, bar_group.id]
|
||||||
res = res.form.submit(name="action", value="edit").follow()
|
res = res.form.submit(name="action", value="edit").follow()
|
||||||
|
|
||||||
george = User.get_from_login("george")
|
george = models.User.get_from_login("george")
|
||||||
assert "Georgio" == george.given_name[0]
|
assert "Georgio" == george.given_name[0]
|
||||||
assert george.check_password("totoyolo")
|
assert george.check_password("totoyolo")
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ def test_user_creation_edition_and_deletion(
|
||||||
# User have been deleted.
|
# User have been deleted.
|
||||||
res = testclient.get("/profile/george/settings", status=200)
|
res = testclient.get("/profile/george/settings", status=200)
|
||||||
res = res.form.submit(name="action", value="delete", status=302).follow(status=200)
|
res = res.form.submit(name="action", value="delete", status=302).follow(status=200)
|
||||||
assert User.get_from_login("george") is None
|
assert models.User.get_from_login("george") is None
|
||||||
res.mustcontain(no="george")
|
res.mustcontain(no="george")
|
||||||
|
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ def test_user_creation_without_password(testclient, logged_moderator):
|
||||||
res = res.form.submit(name="action", value="edit", status=302)
|
res = res.form.submit(name="action", value="edit", status=302)
|
||||||
assert ("success", "User account creation succeed.") in res.flashes
|
assert ("success", "User account creation succeed.") in res.flashes
|
||||||
res = res.follow(status=200)
|
res = res.follow(status=200)
|
||||||
george = User.get_from_login("george")
|
george = models.User.get_from_login("george")
|
||||||
assert george.user_name[0] == "george"
|
assert george.user_name[0] == "george"
|
||||||
assert not george.has_password()
|
assert not george.has_password()
|
||||||
|
|
||||||
|
@ -100,13 +100,13 @@ def test_user_creation_form_validation_failed(
|
||||||
testclient, logged_moderator, foo_group, bar_group
|
testclient, logged_moderator, foo_group, bar_group
|
||||||
):
|
):
|
||||||
res = testclient.get("/users", status=200)
|
res = testclient.get("/users", status=200)
|
||||||
assert User.get_from_login("george") is None
|
assert models.User.get_from_login("george") is None
|
||||||
res.mustcontain(no="george")
|
res.mustcontain(no="george")
|
||||||
|
|
||||||
res = testclient.get("/profile", status=200)
|
res = testclient.get("/profile", status=200)
|
||||||
res = res.form.submit(name="action", value="edit")
|
res = res.form.submit(name="action", value="edit")
|
||||||
assert ("error", "User account creation failed.") in res.flashes
|
assert ("error", "User account creation failed.") in res.flashes
|
||||||
assert User.get_from_login("george") is None
|
assert models.User.get_from_login("george") is None
|
||||||
|
|
||||||
|
|
||||||
def test_username_already_taken(
|
def test_username_already_taken(
|
||||||
|
@ -140,7 +140,7 @@ def test_cn_setting_with_given_name_and_surname(testclient, logged_moderator):
|
||||||
|
|
||||||
res = res.form.submit(name="action", value="edit", status=302).follow(status=200)
|
res = res.form.submit(name="action", value="edit", status=302).follow(status=200)
|
||||||
|
|
||||||
george = User.get_from_login("george")
|
george = models.User.get_from_login("george")
|
||||||
assert george.formatted_name[0] == "George Abitbol"
|
assert george.formatted_name[0] == "George Abitbol"
|
||||||
george.delete()
|
george.delete()
|
||||||
|
|
||||||
|
@ -153,6 +153,6 @@ def test_cn_setting_with_surname_only(testclient, logged_moderator):
|
||||||
|
|
||||||
res = res.form.submit(name="action", value="edit", status=302).follow(status=200)
|
res = res.form.submit(name="action", value="edit", status=302).follow(status=200)
|
||||||
|
|
||||||
george = User.get_from_login("george")
|
george = models.User.get_from_login("george")
|
||||||
assert george.formatted_name[0] == "Abitbol"
|
assert george.formatted_name[0] == "Abitbol"
|
||||||
george.delete()
|
george.delete()
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from canaille.core.models import User
|
from canaille.app import models
|
||||||
from webtest import Upload
|
from webtest import Upload
|
||||||
|
|
||||||
|
|
||||||
|
@ -101,7 +101,7 @@ def test_photo_on_profile_edition(
|
||||||
|
|
||||||
def test_photo_on_profile_creation(testclient, jpeg_photo, logged_admin):
|
def test_photo_on_profile_creation(testclient, jpeg_photo, logged_admin):
|
||||||
res = testclient.get("/users", status=200)
|
res = testclient.get("/users", status=200)
|
||||||
assert User.get_from_login("foobar") is None
|
assert models.User.get_from_login("foobar") is None
|
||||||
res.mustcontain(no="foobar")
|
res.mustcontain(no="foobar")
|
||||||
|
|
||||||
res = testclient.get("/profile", status=200)
|
res = testclient.get("/profile", status=200)
|
||||||
|
@ -111,14 +111,14 @@ def test_photo_on_profile_creation(testclient, jpeg_photo, logged_admin):
|
||||||
res.form["email"] = "george@abitbol.com"
|
res.form["email"] = "george@abitbol.com"
|
||||||
res = res.form.submit(name="action", value="edit", status=302).follow(status=200)
|
res = res.form.submit(name="action", value="edit", status=302).follow(status=200)
|
||||||
|
|
||||||
user = User.get_from_login("foobar")
|
user = models.User.get_from_login("foobar")
|
||||||
assert user.photo == [jpeg_photo]
|
assert user.photo == [jpeg_photo]
|
||||||
user.delete()
|
user.delete()
|
||||||
|
|
||||||
|
|
||||||
def test_photo_deleted_on_profile_creation(testclient, jpeg_photo, logged_admin):
|
def test_photo_deleted_on_profile_creation(testclient, jpeg_photo, logged_admin):
|
||||||
res = testclient.get("/users", status=200)
|
res = testclient.get("/users", status=200)
|
||||||
assert User.get_from_login("foobar") is None
|
assert models.User.get_from_login("foobar") is None
|
||||||
res.mustcontain(no="foobar")
|
res.mustcontain(no="foobar")
|
||||||
|
|
||||||
res = testclient.get("/profile", status=200)
|
res = testclient.get("/profile", status=200)
|
||||||
|
@ -129,6 +129,6 @@ def test_photo_deleted_on_profile_creation(testclient, jpeg_photo, logged_admin)
|
||||||
res.form["email"] = "george@abitbol.com"
|
res.form["email"] = "george@abitbol.com"
|
||||||
res = res.form.submit(name="action", value="edit", status=302).follow(status=200)
|
res = res.form.submit(name="action", value="edit", status=302).follow(status=200)
|
||||||
|
|
||||||
user = User.get_from_login("foobar")
|
user = models.User.get_from_login("foobar")
|
||||||
assert user.photo == []
|
assert user.photo == []
|
||||||
user.delete()
|
user.delete()
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from canaille.core.models import User
|
from canaille.app import models
|
||||||
|
|
||||||
|
|
||||||
def test_edition(
|
def test_edition(
|
||||||
|
@ -126,7 +126,7 @@ def test_password_change_fail(testclient, logged_user):
|
||||||
|
|
||||||
|
|
||||||
def test_password_initialization_mail(smtpd, testclient, backend, logged_admin):
|
def test_password_initialization_mail(smtpd, testclient, backend, logged_admin):
|
||||||
u = User(
|
u = models.User(
|
||||||
formatted_name="Temp User",
|
formatted_name="Temp User",
|
||||||
family_name="Temp",
|
family_name="Temp",
|
||||||
user_name="temp",
|
user_name="temp",
|
||||||
|
@ -162,7 +162,7 @@ def test_password_initialization_mail_send_fail(
|
||||||
SMTP, smtpd, testclient, backend, logged_admin
|
SMTP, smtpd, testclient, backend, logged_admin
|
||||||
):
|
):
|
||||||
SMTP.side_effect = mock.Mock(side_effect=OSError("unit test mail error"))
|
SMTP.side_effect = mock.Mock(side_effect=OSError("unit test mail error"))
|
||||||
u = User(
|
u = models.User(
|
||||||
formatted_name="Temp User",
|
formatted_name="Temp User",
|
||||||
family_name="Temp",
|
family_name="Temp",
|
||||||
user_name="temp",
|
user_name="temp",
|
||||||
|
@ -232,7 +232,7 @@ def test_invalid_form_request(testclient, logged_admin):
|
||||||
|
|
||||||
|
|
||||||
def test_password_reset_email(smtpd, testclient, backend, logged_admin):
|
def test_password_reset_email(smtpd, testclient, backend, logged_admin):
|
||||||
u = User(
|
u = models.User(
|
||||||
formatted_name="Temp User",
|
formatted_name="Temp User",
|
||||||
family_name="Temp",
|
family_name="Temp",
|
||||||
user_name="temp",
|
user_name="temp",
|
||||||
|
@ -260,7 +260,7 @@ def test_password_reset_email(smtpd, testclient, backend, logged_admin):
|
||||||
@mock.patch("smtplib.SMTP")
|
@mock.patch("smtplib.SMTP")
|
||||||
def test_password_reset_email_failed(SMTP, smtpd, testclient, backend, logged_admin):
|
def test_password_reset_email_failed(SMTP, smtpd, testclient, backend, logged_admin):
|
||||||
SMTP.side_effect = mock.Mock(side_effect=OSError("unit test mail error"))
|
SMTP.side_effect = mock.Mock(side_effect=OSError("unit test mail error"))
|
||||||
u = User(
|
u = models.User(
|
||||||
formatted_name="Temp User",
|
formatted_name="Temp User",
|
||||||
family_name="Temp",
|
family_name="Temp",
|
||||||
user_name="temp",
|
user_name="temp",
|
||||||
|
|
|
@ -1,13 +1,12 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
from canaille.app import models
|
||||||
from canaille.commands import cli
|
from canaille.commands import cli
|
||||||
from canaille.oidc.models import AuthorizationCode
|
|
||||||
from canaille.oidc.models import Token
|
|
||||||
from werkzeug.security import gen_salt
|
from werkzeug.security import gen_salt
|
||||||
|
|
||||||
|
|
||||||
def test_clean_command(testclient, backend, client, user):
|
def test_clean_command(testclient, backend, client, user):
|
||||||
valid_code = AuthorizationCode(
|
valid_code = models.AuthorizationCode(
|
||||||
authorization_code_id=gen_salt(48),
|
authorization_code_id=gen_salt(48),
|
||||||
code="my-valid-code",
|
code="my-valid-code",
|
||||||
client=client,
|
client=client,
|
||||||
|
@ -23,7 +22,7 @@ def test_clean_command(testclient, backend, client, user):
|
||||||
revokation="",
|
revokation="",
|
||||||
)
|
)
|
||||||
valid_code.save()
|
valid_code.save()
|
||||||
expired_code = AuthorizationCode(
|
expired_code = models.AuthorizationCode(
|
||||||
authorization_code_id=gen_salt(48),
|
authorization_code_id=gen_salt(48),
|
||||||
code="my-expired-code",
|
code="my-expired-code",
|
||||||
client=client,
|
client=client,
|
||||||
|
@ -43,7 +42,7 @@ def test_clean_command(testclient, backend, client, user):
|
||||||
)
|
)
|
||||||
expired_code.save()
|
expired_code.save()
|
||||||
|
|
||||||
valid_token = Token(
|
valid_token = models.Token(
|
||||||
token_id=gen_salt(48),
|
token_id=gen_salt(48),
|
||||||
access_token="my-valid-token",
|
access_token="my-valid-token",
|
||||||
client=client,
|
client=client,
|
||||||
|
@ -57,7 +56,7 @@ def test_clean_command(testclient, backend, client, user):
|
||||||
lifetime=3600,
|
lifetime=3600,
|
||||||
)
|
)
|
||||||
valid_token.save()
|
valid_token.save()
|
||||||
expired_token = Token(
|
expired_token = models.Token(
|
||||||
token_id=gen_salt(48),
|
token_id=gen_salt(48),
|
||||||
access_token="my-expired-token",
|
access_token="my-expired-token",
|
||||||
client=client,
|
client=client,
|
||||||
|
@ -73,8 +72,8 @@ def test_clean_command(testclient, backend, client, user):
|
||||||
)
|
)
|
||||||
expired_token.save()
|
expired_token.save()
|
||||||
|
|
||||||
assert AuthorizationCode.get(code="my-expired-code")
|
assert models.AuthorizationCode.get(code="my-expired-code")
|
||||||
assert Token.get(access_token="my-expired-token")
|
assert models.Token.get(access_token="my-expired-token")
|
||||||
assert expired_code.is_expired()
|
assert expired_code.is_expired()
|
||||||
assert expired_token.is_expired()
|
assert expired_token.is_expired()
|
||||||
|
|
||||||
|
@ -82,5 +81,5 @@ def test_clean_command(testclient, backend, client, user):
|
||||||
res = runner.invoke(cli, ["clean"])
|
res = runner.invoke(cli, ["clean"])
|
||||||
assert res.exit_code == 0, res.stdout
|
assert res.exit_code == 0, res.stdout
|
||||||
|
|
||||||
assert AuthorizationCode.query() == [valid_code]
|
assert models.AuthorizationCode.query() == [valid_code]
|
||||||
assert Token.query() == [valid_token]
|
assert models.Token.query() == [valid_token]
|
||||||
|
|
|
@ -4,10 +4,7 @@ import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from authlib.oidc.core.grants.util import generate_id_token
|
from authlib.oidc.core.grants.util import generate_id_token
|
||||||
from canaille.oidc.models import AuthorizationCode
|
from canaille.app import models
|
||||||
from canaille.oidc.models import Client
|
|
||||||
from canaille.oidc.models import Consent
|
|
||||||
from canaille.oidc.models import Token
|
|
||||||
from canaille.oidc.oauth import generate_user_info
|
from canaille.oidc.oauth import generate_user_info
|
||||||
from canaille.oidc.oauth import get_jwt_config
|
from canaille.oidc.oauth import get_jwt_config
|
||||||
from cryptography.hazmat.backends import default_backend as crypto_default_backend
|
from cryptography.hazmat.backends import default_backend as crypto_default_backend
|
||||||
|
@ -71,7 +68,7 @@ def configuration(configuration, keypair_path):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client(testclient, other_client, backend):
|
def client(testclient, other_client, backend):
|
||||||
c = Client(
|
c = models.Client(
|
||||||
client_id=gen_salt(24),
|
client_id=gen_salt(24),
|
||||||
client_name="Some client",
|
client_name="Some client",
|
||||||
contacts="contact@mydomain.tld",
|
contacts="contact@mydomain.tld",
|
||||||
|
@ -107,7 +104,7 @@ def client(testclient, other_client, backend):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def other_client(testclient, backend):
|
def other_client(testclient, backend):
|
||||||
c = Client(
|
c = models.Client(
|
||||||
client_id=gen_salt(24),
|
client_id=gen_salt(24),
|
||||||
client_name="Some other client",
|
client_name="Some other client",
|
||||||
contacts="contact@myotherdomain.tld",
|
contacts="contact@myotherdomain.tld",
|
||||||
|
@ -143,7 +140,7 @@ def other_client(testclient, backend):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def authorization(testclient, user, client, backend):
|
def authorization(testclient, user, client, backend):
|
||||||
a = AuthorizationCode(
|
a = models.AuthorizationCode(
|
||||||
authorization_code_id=gen_salt(48),
|
authorization_code_id=gen_salt(48),
|
||||||
code="my-code",
|
code="my-code",
|
||||||
client=client,
|
client=client,
|
||||||
|
@ -165,7 +162,7 @@ def authorization(testclient, user, client, backend):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def token(testclient, client, user, backend):
|
def token(testclient, client, user, backend):
|
||||||
t = Token(
|
t = models.Token(
|
||||||
token_id=gen_salt(48),
|
token_id=gen_salt(48),
|
||||||
access_token=gen_salt(48),
|
access_token=gen_salt(48),
|
||||||
audience=[client],
|
audience=[client],
|
||||||
|
@ -194,7 +191,7 @@ def id_token(testclient, client, user, backend):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def consent(testclient, client, user, backend):
|
def consent(testclient, client, user, backend):
|
||||||
t = Consent(
|
t = models.Consent(
|
||||||
consent_id=str(uuid.uuid4()),
|
consent_id=str(uuid.uuid4()),
|
||||||
client=client,
|
client=client,
|
||||||
subject=user,
|
subject=user,
|
||||||
|
|
|
@ -5,10 +5,7 @@ from urllib.parse import urlsplit
|
||||||
import freezegun
|
import freezegun
|
||||||
from authlib.jose import jwt
|
from authlib.jose import jwt
|
||||||
from authlib.oauth2.rfc7636 import create_s256_code_challenge
|
from authlib.oauth2.rfc7636 import create_s256_code_challenge
|
||||||
from canaille.core.models import User
|
from canaille.app import models
|
||||||
from canaille.oidc.models import AuthorizationCode
|
|
||||||
from canaille.oidc.models import Consent
|
|
||||||
from canaille.oidc.models import Token
|
|
||||||
from canaille.oidc.oauth import setup_oauth
|
from canaille.oidc.oauth import setup_oauth
|
||||||
from werkzeug.security import gen_salt
|
from werkzeug.security import gen_salt
|
||||||
|
|
||||||
|
@ -18,7 +15,7 @@ from . import client_credentials
|
||||||
def test_authorization_code_flow(
|
def test_authorization_code_flow(
|
||||||
testclient, logged_user, client, keypair, other_client
|
testclient, logged_user, client, keypair, other_client
|
||||||
):
|
):
|
||||||
assert not Consent.query()
|
assert not models.Consent.query()
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
"/oauth/authorize",
|
"/oauth/authorize",
|
||||||
|
@ -36,7 +33,7 @@ def test_authorization_code_flow(
|
||||||
assert res.location.startswith(client.redirect_uris[0])
|
assert res.location.startswith(client.redirect_uris[0])
|
||||||
params = parse_qs(urlsplit(res.location).query)
|
params = parse_qs(urlsplit(res.location).query)
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
assert set(authcode.scope[0].split(" ")) == {
|
assert set(authcode.scope[0].split(" ")) == {
|
||||||
"openid",
|
"openid",
|
||||||
|
@ -47,7 +44,7 @@ def test_authorization_code_flow(
|
||||||
"phone",
|
"phone",
|
||||||
}
|
}
|
||||||
|
|
||||||
consents = Consent.query(client=client, subject=logged_user)
|
consents = models.Consent.query(client=client, subject=logged_user)
|
||||||
assert set(consents[0].scope) == {
|
assert set(consents[0].scope) == {
|
||||||
"openid",
|
"openid",
|
||||||
"profile",
|
"profile",
|
||||||
|
@ -70,7 +67,7 @@ def test_authorization_code_flow(
|
||||||
)
|
)
|
||||||
|
|
||||||
access_token = res.json["access_token"]
|
access_token = res.json["access_token"]
|
||||||
token = Token.get(access_token=access_token)
|
token = models.Token.get(access_token=access_token)
|
||||||
assert token.client == client
|
assert token.client == client
|
||||||
assert token.subject == logged_user
|
assert token.subject == logged_user
|
||||||
assert set(token.scope[0].split(" ")) == {
|
assert set(token.scope[0].split(" ")) == {
|
||||||
|
@ -119,7 +116,7 @@ def test_invalid_client(testclient, logged_user, keypair):
|
||||||
def test_authorization_code_flow_with_redirect_uri(
|
def test_authorization_code_flow_with_redirect_uri(
|
||||||
testclient, logged_user, client, keypair, other_client
|
testclient, logged_user, client, keypair, other_client
|
||||||
):
|
):
|
||||||
assert not Consent.query()
|
assert not models.Consent.query()
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
"/oauth/authorize",
|
"/oauth/authorize",
|
||||||
|
@ -138,9 +135,9 @@ def test_authorization_code_flow_with_redirect_uri(
|
||||||
assert res.location.startswith(client.redirect_uris[1])
|
assert res.location.startswith(client.redirect_uris[1])
|
||||||
params = parse_qs(urlsplit(res.location).query)
|
params = parse_qs(urlsplit(res.location).query)
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
consents = Consent.query(client=client, subject=logged_user)
|
consents = models.Consent.query(client=client, subject=logged_user)
|
||||||
|
|
||||||
res = testclient.post(
|
res = testclient.post(
|
||||||
"/oauth/token",
|
"/oauth/token",
|
||||||
|
@ -155,7 +152,7 @@ def test_authorization_code_flow_with_redirect_uri(
|
||||||
)
|
)
|
||||||
|
|
||||||
access_token = res.json["access_token"]
|
access_token = res.json["access_token"]
|
||||||
token = Token.get(access_token=access_token)
|
token = models.Token.get(access_token=access_token)
|
||||||
assert token.client == client
|
assert token.client == client
|
||||||
assert token.subject == logged_user
|
assert token.subject == logged_user
|
||||||
|
|
||||||
|
@ -166,7 +163,7 @@ def test_authorization_code_flow_with_redirect_uri(
|
||||||
def test_authorization_code_flow_preconsented(
|
def test_authorization_code_flow_preconsented(
|
||||||
testclient, logged_user, client, keypair, other_client
|
testclient, logged_user, client, keypair, other_client
|
||||||
):
|
):
|
||||||
assert not Consent.query()
|
assert not models.Consent.query()
|
||||||
|
|
||||||
client.preconsent = True
|
client.preconsent = True
|
||||||
client.save()
|
client.save()
|
||||||
|
@ -185,10 +182,10 @@ def test_authorization_code_flow_preconsented(
|
||||||
assert res.location.startswith(client.redirect_uris[0])
|
assert res.location.startswith(client.redirect_uris[0])
|
||||||
params = parse_qs(urlsplit(res.location).query)
|
params = parse_qs(urlsplit(res.location).query)
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
|
|
||||||
consents = Consent.query(client=client, subject=logged_user)
|
consents = models.Consent.query(client=client, subject=logged_user)
|
||||||
assert not consents
|
assert not consents
|
||||||
|
|
||||||
res = testclient.post(
|
res = testclient.post(
|
||||||
|
@ -204,7 +201,7 @@ def test_authorization_code_flow_preconsented(
|
||||||
)
|
)
|
||||||
|
|
||||||
access_token = res.json["access_token"]
|
access_token = res.json["access_token"]
|
||||||
token = Token.get(access_token=access_token)
|
token = models.Token.get(access_token=access_token)
|
||||||
assert token.client == client
|
assert token.client == client
|
||||||
assert token.subject == logged_user
|
assert token.subject == logged_user
|
||||||
|
|
||||||
|
@ -223,7 +220,7 @@ def test_authorization_code_flow_preconsented(
|
||||||
|
|
||||||
|
|
||||||
def test_logout_login(testclient, logged_user, client):
|
def test_logout_login(testclient, logged_user, client):
|
||||||
assert not Consent.query()
|
assert not models.Consent.query()
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
"/oauth/authorize",
|
"/oauth/authorize",
|
||||||
|
@ -254,10 +251,10 @@ def test_logout_login(testclient, logged_user, client):
|
||||||
assert res.location.startswith(client.redirect_uris[0])
|
assert res.location.startswith(client.redirect_uris[0])
|
||||||
params = parse_qs(urlsplit(res.location).query)
|
params = parse_qs(urlsplit(res.location).query)
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
|
|
||||||
consents = Consent.query(client=client, subject=logged_user)
|
consents = models.Consent.query(client=client, subject=logged_user)
|
||||||
assert "profile" in consents[0].scope
|
assert "profile" in consents[0].scope
|
||||||
|
|
||||||
res = testclient.post(
|
res = testclient.post(
|
||||||
|
@ -273,7 +270,7 @@ def test_logout_login(testclient, logged_user, client):
|
||||||
)
|
)
|
||||||
|
|
||||||
access_token = res.json["access_token"]
|
access_token = res.json["access_token"]
|
||||||
token = Token.get(access_token=access_token)
|
token = models.Token.get(access_token=access_token)
|
||||||
assert token.client == client
|
assert token.client == client
|
||||||
assert token.subject == logged_user
|
assert token.subject == logged_user
|
||||||
|
|
||||||
|
@ -289,7 +286,7 @@ def test_logout_login(testclient, logged_user, client):
|
||||||
|
|
||||||
|
|
||||||
def test_deny(testclient, logged_user, client):
|
def test_deny(testclient, logged_user, client):
|
||||||
assert not Consent.query()
|
assert not models.Consent.query()
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
"/oauth/authorize",
|
"/oauth/authorize",
|
||||||
|
@ -308,11 +305,11 @@ def test_deny(testclient, logged_user, client):
|
||||||
error = params["error"][0]
|
error = params["error"][0]
|
||||||
assert error == "access_denied"
|
assert error == "access_denied"
|
||||||
|
|
||||||
assert not Consent.query()
|
assert not models.Consent.query()
|
||||||
|
|
||||||
|
|
||||||
def test_refresh_token(testclient, user, client):
|
def test_refresh_token(testclient, user, client):
|
||||||
assert not Consent.query()
|
assert not models.Consent.query()
|
||||||
|
|
||||||
with freezegun.freeze_time("2020-01-01 01:00:00"):
|
with freezegun.freeze_time("2020-01-01 01:00:00"):
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
|
@ -335,10 +332,10 @@ def test_refresh_token(testclient, user, client):
|
||||||
assert res.location.startswith(client.redirect_uris[0])
|
assert res.location.startswith(client.redirect_uris[0])
|
||||||
params = parse_qs(urlsplit(res.location).query)
|
params = parse_qs(urlsplit(res.location).query)
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
|
|
||||||
consents = Consent.query(client=client, subject=user)
|
consents = models.Consent.query(client=client, subject=user)
|
||||||
assert "profile" in consents[0].scope
|
assert "profile" in consents[0].scope
|
||||||
|
|
||||||
with freezegun.freeze_time("2020-01-01 00:01:00"):
|
with freezegun.freeze_time("2020-01-01 00:01:00"):
|
||||||
|
@ -354,7 +351,7 @@ def test_refresh_token(testclient, user, client):
|
||||||
status=200,
|
status=200,
|
||||||
)
|
)
|
||||||
access_token = res.json["access_token"]
|
access_token = res.json["access_token"]
|
||||||
old_token = Token.get(access_token=access_token)
|
old_token = models.Token.get(access_token=access_token)
|
||||||
assert old_token is not None
|
assert old_token is not None
|
||||||
assert not old_token.revokation_date
|
assert not old_token.revokation_date
|
||||||
|
|
||||||
|
@ -369,7 +366,7 @@ def test_refresh_token(testclient, user, client):
|
||||||
status=200,
|
status=200,
|
||||||
)
|
)
|
||||||
access_token = res.json["access_token"]
|
access_token = res.json["access_token"]
|
||||||
new_token = Token.get(access_token=access_token)
|
new_token = models.Token.get(access_token=access_token)
|
||||||
assert new_token is not None
|
assert new_token is not None
|
||||||
assert old_token.access_token != new_token.access_token
|
assert old_token.access_token != new_token.access_token
|
||||||
|
|
||||||
|
@ -389,7 +386,7 @@ def test_refresh_token(testclient, user, client):
|
||||||
|
|
||||||
|
|
||||||
def test_code_challenge(testclient, logged_user, client):
|
def test_code_challenge(testclient, logged_user, client):
|
||||||
assert not Consent.query()
|
assert not models.Consent.query()
|
||||||
|
|
||||||
client.token_endpoint_auth_method = "none"
|
client.token_endpoint_auth_method = "none"
|
||||||
client.save()
|
client.save()
|
||||||
|
@ -415,10 +412,10 @@ def test_code_challenge(testclient, logged_user, client):
|
||||||
assert res.location.startswith(client.redirect_uris[0])
|
assert res.location.startswith(client.redirect_uris[0])
|
||||||
params = parse_qs(urlsplit(res.location).query)
|
params = parse_qs(urlsplit(res.location).query)
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
|
|
||||||
consents = Consent.query(client=client, subject=logged_user)
|
consents = models.Consent.query(client=client, subject=logged_user)
|
||||||
assert "profile" in consents[0].scope
|
assert "profile" in consents[0].scope
|
||||||
|
|
||||||
res = testclient.post(
|
res = testclient.post(
|
||||||
|
@ -435,7 +432,7 @@ def test_code_challenge(testclient, logged_user, client):
|
||||||
)
|
)
|
||||||
access_token = res.json["access_token"]
|
access_token = res.json["access_token"]
|
||||||
|
|
||||||
token = Token.get(access_token=access_token)
|
token = models.Token.get(access_token=access_token)
|
||||||
assert token.client == client
|
assert token.client == client
|
||||||
assert token.subject == logged_user
|
assert token.subject == logged_user
|
||||||
|
|
||||||
|
@ -456,7 +453,7 @@ def test_code_challenge(testclient, logged_user, client):
|
||||||
def test_authorization_code_flow_when_consent_already_given(
|
def test_authorization_code_flow_when_consent_already_given(
|
||||||
testclient, logged_user, client
|
testclient, logged_user, client
|
||||||
):
|
):
|
||||||
assert not Consent.query()
|
assert not models.Consent.query()
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
"/oauth/authorize",
|
"/oauth/authorize",
|
||||||
|
@ -474,10 +471,10 @@ def test_authorization_code_flow_when_consent_already_given(
|
||||||
assert res.location.startswith(client.redirect_uris[0])
|
assert res.location.startswith(client.redirect_uris[0])
|
||||||
params = parse_qs(urlsplit(res.location).query)
|
params = parse_qs(urlsplit(res.location).query)
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
|
|
||||||
consents = Consent.query(client=client, subject=logged_user)
|
consents = models.Consent.query(client=client, subject=logged_user)
|
||||||
assert "profile" in consents[0].scope
|
assert "profile" in consents[0].scope
|
||||||
|
|
||||||
res = testclient.post(
|
res = testclient.post(
|
||||||
|
@ -514,7 +511,7 @@ def test_authorization_code_flow_when_consent_already_given(
|
||||||
def test_authorization_code_flow_when_consent_already_given_but_for_a_smaller_scope(
|
def test_authorization_code_flow_when_consent_already_given_but_for_a_smaller_scope(
|
||||||
testclient, logged_user, client
|
testclient, logged_user, client
|
||||||
):
|
):
|
||||||
assert not Consent.query()
|
assert not models.Consent.query()
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
"/oauth/authorize",
|
"/oauth/authorize",
|
||||||
|
@ -532,10 +529,10 @@ def test_authorization_code_flow_when_consent_already_given_but_for_a_smaller_sc
|
||||||
assert res.location.startswith(client.redirect_uris[0])
|
assert res.location.startswith(client.redirect_uris[0])
|
||||||
params = parse_qs(urlsplit(res.location).query)
|
params = parse_qs(urlsplit(res.location).query)
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
|
|
||||||
consents = Consent.query(client=client, subject=logged_user)
|
consents = models.Consent.query(client=client, subject=logged_user)
|
||||||
assert "profile" in consents[0].scope
|
assert "profile" in consents[0].scope
|
||||||
assert "groups" not in consents[0].scope
|
assert "groups" not in consents[0].scope
|
||||||
|
|
||||||
|
@ -568,10 +565,10 @@ def test_authorization_code_flow_when_consent_already_given_but_for_a_smaller_sc
|
||||||
assert res.location.startswith(client.redirect_uris[0])
|
assert res.location.startswith(client.redirect_uris[0])
|
||||||
params = parse_qs(urlsplit(res.location).query)
|
params = parse_qs(urlsplit(res.location).query)
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
|
|
||||||
consents = Consent.query(client=client, subject=logged_user)
|
consents = models.Consent.query(client=client, subject=logged_user)
|
||||||
assert "profile" in consents[0].scope
|
assert "profile" in consents[0].scope
|
||||||
assert "groups" in consents[0].scope
|
assert "groups" in consents[0].scope
|
||||||
|
|
||||||
|
@ -604,7 +601,7 @@ def test_authorization_code_flow_but_user_cannot_use_oidc(
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_none(testclient, logged_user, client):
|
def test_prompt_none(testclient, logged_user, client):
|
||||||
consent = Consent(
|
consent = models.Consent(
|
||||||
consent_id=str(uuid.uuid4()),
|
consent_id=str(uuid.uuid4()),
|
||||||
client=client,
|
client=client,
|
||||||
subject=logged_user,
|
subject=logged_user,
|
||||||
|
@ -631,7 +628,7 @@ def test_prompt_none(testclient, logged_user, client):
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_not_logged(testclient, user, client):
|
def test_prompt_not_logged(testclient, user, client):
|
||||||
consent = Consent(
|
consent = models.Consent(
|
||||||
consent_id=str(uuid.uuid4()),
|
consent_id=str(uuid.uuid4()),
|
||||||
client=client,
|
client=client,
|
||||||
subject=user,
|
subject=user,
|
||||||
|
@ -685,7 +682,7 @@ def test_nonce_required_in_oidc_requests(testclient, logged_user, client):
|
||||||
|
|
||||||
|
|
||||||
def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client):
|
def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client):
|
||||||
assert not Consent.query()
|
assert not models.Consent.query()
|
||||||
testclient.app.config["REQUIRE_NONCE"] = False
|
testclient.app.config["REQUIRE_NONCE"] = False
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
|
@ -701,14 +698,14 @@ def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client):
|
||||||
res = res.form.submit(name="answer", value="accept", status=302)
|
res = res.form.submit(name="answer", value="accept", status=302)
|
||||||
|
|
||||||
assert res.location.startswith(client.redirect_uris[0])
|
assert res.location.startswith(client.redirect_uris[0])
|
||||||
for consent in Consent.query():
|
for consent in models.Consent.query():
|
||||||
consent.delete()
|
consent.delete()
|
||||||
|
|
||||||
|
|
||||||
def test_authorization_code_request_scope_too_large(
|
def test_authorization_code_request_scope_too_large(
|
||||||
testclient, logged_user, keypair, other_client
|
testclient, logged_user, keypair, other_client
|
||||||
):
|
):
|
||||||
assert not Consent.query()
|
assert not models.Consent.query()
|
||||||
assert "email" not in other_client.scope
|
assert "email" not in other_client.scope
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
|
@ -726,13 +723,13 @@ def test_authorization_code_request_scope_too_large(
|
||||||
|
|
||||||
params = parse_qs(urlsplit(res.location).query)
|
params = parse_qs(urlsplit(res.location).query)
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert set(authcode.scope[0].split(" ")) == {
|
assert set(authcode.scope[0].split(" ")) == {
|
||||||
"openid",
|
"openid",
|
||||||
"profile",
|
"profile",
|
||||||
}
|
}
|
||||||
|
|
||||||
consents = Consent.query(client=other_client, subject=logged_user)
|
consents = models.Consent.query(client=other_client, subject=logged_user)
|
||||||
assert set(consents[0].scope) == {
|
assert set(consents[0].scope) == {
|
||||||
"openid",
|
"openid",
|
||||||
"profile",
|
"profile",
|
||||||
|
@ -751,7 +748,7 @@ def test_authorization_code_request_scope_too_large(
|
||||||
)
|
)
|
||||||
|
|
||||||
access_token = res.json["access_token"]
|
access_token = res.json["access_token"]
|
||||||
token = Token.get(access_token=access_token)
|
token = models.Token.get(access_token=access_token)
|
||||||
assert token.client == other_client
|
assert token.client == other_client
|
||||||
assert token.subject == logged_user
|
assert token.subject == logged_user
|
||||||
assert set(token.scope[0].split(" ")) == {
|
assert set(token.scope[0].split(" ")) == {
|
||||||
|
@ -813,7 +810,7 @@ def test_authorization_code_expired(testclient, user, client):
|
||||||
|
|
||||||
|
|
||||||
def test_code_with_invalid_user(testclient, admin, client):
|
def test_code_with_invalid_user(testclient, admin, client):
|
||||||
user = User(
|
user = models.User(
|
||||||
formatted_name="John Doe",
|
formatted_name="John Doe",
|
||||||
family_name="Doe",
|
family_name="Doe",
|
||||||
user_name="temp",
|
user_name="temp",
|
||||||
|
@ -838,7 +835,7 @@ def test_code_with_invalid_user(testclient, admin, client):
|
||||||
res = res.form.submit(name="answer", value="accept", status=302)
|
res = res.form.submit(name="answer", value="accept", status=302)
|
||||||
params = parse_qs(urlsplit(res.location).query)
|
params = parse_qs(urlsplit(res.location).query)
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
|
|
||||||
user.delete()
|
user.delete()
|
||||||
|
|
||||||
|
@ -861,7 +858,7 @@ def test_code_with_invalid_user(testclient, admin, client):
|
||||||
|
|
||||||
|
|
||||||
def test_refresh_token_with_invalid_user(testclient, client):
|
def test_refresh_token_with_invalid_user(testclient, client):
|
||||||
user = User(
|
user = models.User(
|
||||||
formatted_name="John Doe",
|
formatted_name="John Doe",
|
||||||
family_name="Doe",
|
family_name="Doe",
|
||||||
user_name="temp",
|
user_name="temp",
|
||||||
|
@ -888,7 +885,7 @@ def test_refresh_token_with_invalid_user(testclient, client):
|
||||||
|
|
||||||
params = parse_qs(urlsplit(res.location).query)
|
params = parse_qs(urlsplit(res.location).query)
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
|
|
||||||
res = testclient.post(
|
res = testclient.post(
|
||||||
"/oauth/token",
|
"/oauth/token",
|
||||||
|
@ -919,7 +916,7 @@ def test_refresh_token_with_invalid_user(testclient, client):
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
"error_description": 'There is no "user" for this token.',
|
"error_description": 'There is no "user" for this token.',
|
||||||
}
|
}
|
||||||
Token.get(access_token=access_token).delete()
|
models.Token.get(access_token=access_token).delete()
|
||||||
|
|
||||||
|
|
||||||
def test_token_default_expiration_date(testclient, logged_user, client, keypair):
|
def test_token_default_expiration_date(testclient, logged_user, client, keypair):
|
||||||
|
@ -937,7 +934,7 @@ def test_token_default_expiration_date(testclient, logged_user, client, keypair)
|
||||||
res = res.form.submit(name="answer", value="accept", status=302)
|
res = res.form.submit(name="answer", value="accept", status=302)
|
||||||
params = parse_qs(urlsplit(res.location).query)
|
params = parse_qs(urlsplit(res.location).query)
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode.lifetime == 84400
|
assert authcode.lifetime == 84400
|
||||||
|
|
||||||
res = testclient.post(
|
res = testclient.post(
|
||||||
|
@ -955,7 +952,7 @@ def test_token_default_expiration_date(testclient, logged_user, client, keypair)
|
||||||
assert res.json["expires_in"] == 864000
|
assert res.json["expires_in"] == 864000
|
||||||
|
|
||||||
access_token = res.json["access_token"]
|
access_token = res.json["access_token"]
|
||||||
token = Token.get(access_token=access_token)
|
token = models.Token.get(access_token=access_token)
|
||||||
assert token.lifetime == 864000
|
assert token.lifetime == 864000
|
||||||
|
|
||||||
claims = jwt.decode(access_token, keypair[1])
|
claims = jwt.decode(access_token, keypair[1])
|
||||||
|
@ -965,7 +962,7 @@ def test_token_default_expiration_date(testclient, logged_user, client, keypair)
|
||||||
claims = jwt.decode(id_token, keypair[1])
|
claims = jwt.decode(id_token, keypair[1])
|
||||||
assert claims["exp"] - claims["iat"] == 3600
|
assert claims["exp"] - claims["iat"] == 3600
|
||||||
|
|
||||||
consents = Consent.query(client=client, subject=logged_user)
|
consents = models.Consent.query(client=client, subject=logged_user)
|
||||||
for consent in consents:
|
for consent in consents:
|
||||||
consent.delete()
|
consent.delete()
|
||||||
|
|
||||||
|
@ -995,7 +992,7 @@ def test_token_custom_expiration_date(testclient, logged_user, client, keypair):
|
||||||
res = res.form.submit(name="answer", value="accept", status=302)
|
res = res.form.submit(name="answer", value="accept", status=302)
|
||||||
params = parse_qs(urlsplit(res.location).query)
|
params = parse_qs(urlsplit(res.location).query)
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode.lifetime == 84400
|
assert authcode.lifetime == 84400
|
||||||
|
|
||||||
res = testclient.post(
|
res = testclient.post(
|
||||||
|
@ -1013,7 +1010,7 @@ def test_token_custom_expiration_date(testclient, logged_user, client, keypair):
|
||||||
assert res.json["expires_in"] == 1000
|
assert res.json["expires_in"] == 1000
|
||||||
|
|
||||||
access_token = res.json["access_token"]
|
access_token = res.json["access_token"]
|
||||||
token = Token.get(access_token=access_token)
|
token = models.Token.get(access_token=access_token)
|
||||||
assert token.lifetime == 1000
|
assert token.lifetime == 1000
|
||||||
|
|
||||||
claims = jwt.decode(access_token, keypair[1])
|
claims = jwt.decode(access_token, keypair[1])
|
||||||
|
@ -1023,6 +1020,6 @@ def test_token_custom_expiration_date(testclient, logged_user, client, keypair):
|
||||||
claims = jwt.decode(id_token, keypair[1])
|
claims = jwt.decode(id_token, keypair[1])
|
||||||
assert claims["exp"] - claims["iat"] == 6000
|
assert claims["exp"] - claims["iat"] == 6000
|
||||||
|
|
||||||
consents = Consent.query(client=client, subject=logged_user)
|
consents = models.Consent.query(client=client, subject=logged_user)
|
||||||
for consent in consents:
|
for consent in consents:
|
||||||
consent.delete()
|
consent.delete()
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from canaille.oidc.models import AuthorizationCode
|
from canaille.app import models
|
||||||
from canaille.oidc.models import Client
|
|
||||||
from canaille.oidc.models import Consent
|
|
||||||
from canaille.oidc.models import Token
|
|
||||||
from werkzeug.security import gen_salt
|
from werkzeug.security import gen_salt
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,7 +26,7 @@ def test_client_list_pagination(testclient, logged_admin, client, other_client):
|
||||||
res.mustcontain("2 items")
|
res.mustcontain("2 items")
|
||||||
clients = []
|
clients = []
|
||||||
for _ in range(25):
|
for _ in range(25):
|
||||||
client = Client(client_id=gen_salt(48), client_name=gen_salt(48))
|
client = models.Client(client_id=gen_salt(48), client_name=gen_salt(48))
|
||||||
client.save()
|
client.save()
|
||||||
clients.append(client)
|
clients.append(client)
|
||||||
|
|
||||||
|
@ -115,7 +112,7 @@ def test_client_add(testclient, logged_admin):
|
||||||
res = res.follow(status=200)
|
res = res.follow(status=200)
|
||||||
|
|
||||||
client_id = res.forms["readonly"]["client_id"].value
|
client_id = res.forms["readonly"]["client_id"].value
|
||||||
client = Client.get(client_id=client_id)
|
client = models.Client.get(client_id=client_id)
|
||||||
data["audience"] = [client]
|
data["audience"] = [client]
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
client_value = getattr(client, k)
|
client_value = getattr(client, k)
|
||||||
|
@ -198,27 +195,29 @@ def test_client_edit_missing_fields(testclient, client, logged_admin, other_clie
|
||||||
|
|
||||||
|
|
||||||
def test_client_delete(testclient, logged_admin):
|
def test_client_delete(testclient, logged_admin):
|
||||||
client = Client(client_id="client_id")
|
client = models.Client(client_id="client_id")
|
||||||
client.save()
|
client.save()
|
||||||
token = Token(
|
token = models.Token(
|
||||||
token_id="id",
|
token_id="id",
|
||||||
client=client,
|
client=client,
|
||||||
issue_datetime=datetime.datetime.now(datetime.timezone.utc),
|
issue_datetime=datetime.datetime.now(datetime.timezone.utc),
|
||||||
)
|
)
|
||||||
token.save()
|
token.save()
|
||||||
consent = Consent(
|
consent = models.Consent(
|
||||||
consent_id="consent_id", subject=logged_admin, client=client, scope="openid"
|
consent_id="consent_id", subject=logged_admin, client=client, scope="openid"
|
||||||
)
|
)
|
||||||
consent.save()
|
consent.save()
|
||||||
code = AuthorizationCode(authorization_code_id="id", client=client, subject=client)
|
code = models.AuthorizationCode(
|
||||||
|
authorization_code_id="id", client=client, subject=client
|
||||||
|
)
|
||||||
|
|
||||||
res = testclient.get("/admin/client/edit/" + client.client_id)
|
res = testclient.get("/admin/client/edit/" + client.client_id)
|
||||||
res = res.forms["clientaddform"].submit(name="action", value="delete").follow()
|
res = res.forms["clientaddform"].submit(name="action", value="delete").follow()
|
||||||
|
|
||||||
assert not Client.get()
|
assert not models.Client.get()
|
||||||
assert not Token.get()
|
assert not models.Token.get()
|
||||||
assert not AuthorizationCode.get()
|
assert not models.AuthorizationCode.get()
|
||||||
assert not Consent.get()
|
assert not models.Consent.get()
|
||||||
|
|
||||||
|
|
||||||
def test_client_delete_invalid_client(testclient, logged_admin, client):
|
def test_client_delete_invalid_client(testclient, logged_admin, client):
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from canaille.oidc.models import AuthorizationCode
|
from canaille.app import models
|
||||||
from werkzeug.security import gen_salt
|
from werkzeug.security import gen_salt
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ def test_authorization_list_pagination(testclient, logged_admin, client):
|
||||||
res.mustcontain("0 items")
|
res.mustcontain("0 items")
|
||||||
authorizations = []
|
authorizations = []
|
||||||
for _ in range(26):
|
for _ in range(26):
|
||||||
code = AuthorizationCode(
|
code = models.AuthorizationCode(
|
||||||
authorization_code_id=gen_salt(48), client=client, subject=logged_admin
|
authorization_code_id=gen_salt(48), client=client, subject=logged_admin
|
||||||
)
|
)
|
||||||
code.save()
|
code.save()
|
||||||
|
@ -66,13 +66,13 @@ def test_authorization_list_bad_pages(testclient, logged_admin):
|
||||||
|
|
||||||
def test_authorization_list_search(testclient, logged_admin, client):
|
def test_authorization_list_search(testclient, logged_admin, client):
|
||||||
id1 = gen_salt(48)
|
id1 = gen_salt(48)
|
||||||
auth1 = AuthorizationCode(
|
auth1 = models.AuthorizationCode(
|
||||||
authorization_code_id=id1, client=client, subject=logged_admin
|
authorization_code_id=id1, client=client, subject=logged_admin
|
||||||
)
|
)
|
||||||
auth1.save()
|
auth1.save()
|
||||||
|
|
||||||
id2 = gen_salt(48)
|
id2 = gen_salt(48)
|
||||||
auth2 = AuthorizationCode(
|
auth2 = models.AuthorizationCode(
|
||||||
authorization_code_id=id2, client=client, subject=logged_admin
|
authorization_code_id=id2, client=client, subject=logged_admin
|
||||||
)
|
)
|
||||||
auth2.save()
|
auth2.save()
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
from urllib.parse import parse_qs
|
from urllib.parse import parse_qs
|
||||||
from urllib.parse import urlsplit
|
from urllib.parse import urlsplit
|
||||||
|
|
||||||
from canaille.oidc.models import Consent
|
from canaille.app import models
|
||||||
from canaille.oidc.models import Token
|
|
||||||
|
|
||||||
from . import client_credentials
|
from . import client_credentials
|
||||||
|
|
||||||
|
@ -118,7 +117,7 @@ def test_oidc_authorization_after_revokation(
|
||||||
|
|
||||||
res = res.form.submit(name="answer", value="accept", status=302)
|
res = res.form.submit(name="answer", value="accept", status=302)
|
||||||
|
|
||||||
consents = Consent.query(client=client, subject=logged_user)
|
consents = models.Consent.query(client=client, subject=logged_user)
|
||||||
consent.reload()
|
consent.reload()
|
||||||
assert consents[0] == consent
|
assert consents[0] == consent
|
||||||
assert not consent.revoked
|
assert not consent.revoked
|
||||||
|
@ -138,7 +137,7 @@ def test_oidc_authorization_after_revokation(
|
||||||
)
|
)
|
||||||
|
|
||||||
access_token = res.json["access_token"]
|
access_token = res.json["access_token"]
|
||||||
token = Token.get(access_token=access_token)
|
token = models.Token.get(access_token=access_token)
|
||||||
assert token.client == client
|
assert token.client == client
|
||||||
assert token.subject == logged_user
|
assert token.subject == logged_user
|
||||||
|
|
||||||
|
@ -158,13 +157,13 @@ def test_preconsented_client_appears_in_consent_list(testclient, client, logged_
|
||||||
def test_revoke_preconsented_client(testclient, client, logged_user, token):
|
def test_revoke_preconsented_client(testclient, client, logged_user, token):
|
||||||
client.preconsent = True
|
client.preconsent = True
|
||||||
client.save()
|
client.save()
|
||||||
assert not Consent.get()
|
assert not models.Consent.get()
|
||||||
assert not token.revoked
|
assert not token.revoked
|
||||||
|
|
||||||
res = testclient.get(f"/consent/revoke-preconsent/{client.client_id}", status=302)
|
res = testclient.get(f"/consent/revoke-preconsent/{client.client_id}", status=302)
|
||||||
assert ("success", "The access has been revoked") in res.flashes
|
assert ("success", "The access has been revoked") in res.flashes
|
||||||
|
|
||||||
consent = Consent.get()
|
consent = models.Consent.get()
|
||||||
assert consent.client == client
|
assert consent.client == client
|
||||||
assert consent.subject == logged_user
|
assert consent.subject == logged_user
|
||||||
assert consent.scope == ["openid", "email", "profile", "groups", "address", "phone"]
|
assert consent.scope == ["openid", "email", "profile", "groups", "address", "phone"]
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from authlib.jose import jwt
|
from authlib.jose import jwt
|
||||||
from canaille.oidc.models import Client
|
from canaille.app import models
|
||||||
|
|
||||||
|
|
||||||
def test_client_registration_with_authentication_static_token(
|
def test_client_registration_with_authentication_static_token(
|
||||||
|
@ -29,7 +29,7 @@ def test_client_registration_with_authentication_static_token(
|
||||||
headers = {"Authorization": "Bearer static-token"}
|
headers = {"Authorization": "Bearer static-token"}
|
||||||
|
|
||||||
res = testclient.post_json("/oauth/register", payload, headers=headers, status=201)
|
res = testclient.post_json("/oauth/register", payload, headers=headers, status=201)
|
||||||
client = Client.get(client_id=res.json["client_id"])
|
client = models.Client.get(client_id=res.json["client_id"])
|
||||||
|
|
||||||
assert res.json == {
|
assert res.json == {
|
||||||
"client_id": client.client_id,
|
"client_id": client.client_id,
|
||||||
|
@ -148,7 +148,7 @@ def test_client_registration_with_software_statement(testclient, backend, keypai
|
||||||
}
|
}
|
||||||
res = testclient.post_json("/oauth/register", payload, status=201)
|
res = testclient.post_json("/oauth/register", payload, status=201)
|
||||||
|
|
||||||
client = Client.get(client_id=res.json["client_id"])
|
client = models.Client.get(client_id=res.json["client_id"])
|
||||||
assert res.json == {
|
assert res.json == {
|
||||||
"client_id": client.client_id,
|
"client_id": client.client_id,
|
||||||
"client_secret": client.client_secret,
|
"client_secret": client.client_secret,
|
||||||
|
@ -199,7 +199,7 @@ def test_client_registration_without_authentication_ok(testclient, backend):
|
||||||
|
|
||||||
res = testclient.post_json("/oauth/register", payload, status=201)
|
res = testclient.post_json("/oauth/register", payload, status=201)
|
||||||
|
|
||||||
client = Client.get(client_id=res.json["client_id"])
|
client = models.Client.get(client_id=res.json["client_id"])
|
||||||
assert res.json == {
|
assert res.json == {
|
||||||
"client_id": mock.ANY,
|
"client_id": mock.ANY,
|
||||||
"client_secret": mock.ANY,
|
"client_secret": mock.ANY,
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import warnings
|
import warnings
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from canaille.oidc.models import Client
|
from canaille.app import models
|
||||||
|
|
||||||
|
|
||||||
def test_get(testclient, backend, client, user):
|
def test_get(testclient, backend, client, user):
|
||||||
|
@ -95,7 +95,7 @@ def test_update(testclient, backend, client, user):
|
||||||
res = testclient.put_json(
|
res = testclient.put_json(
|
||||||
f"/oauth/register/{client.client_id}", payload, headers=headers, status=200
|
f"/oauth/register/{client.client_id}", payload, headers=headers, status=200
|
||||||
)
|
)
|
||||||
client = Client.get(client_id=res.json["client_id"])
|
client = models.Client.get(client_id=res.json["client_id"])
|
||||||
|
|
||||||
assert res.json == {
|
assert res.json == {
|
||||||
"client_id": client.client_id,
|
"client_id": client.client_id,
|
||||||
|
@ -145,7 +145,7 @@ def test_delete(testclient, backend, user):
|
||||||
"static-token"
|
"static-token"
|
||||||
]
|
]
|
||||||
|
|
||||||
client = Client(client_id="foobar", client_name="Some client")
|
client = models.Client(client_id="foobar", client_name="Some client")
|
||||||
client.save()
|
client.save()
|
||||||
|
|
||||||
headers = {"Authorization": "Bearer static-token"}
|
headers = {"Authorization": "Bearer static-token"}
|
||||||
|
@ -153,7 +153,7 @@ def test_delete(testclient, backend, user):
|
||||||
res = testclient.delete(
|
res = testclient.delete(
|
||||||
f"/oauth/register/{client.client_id}", headers=headers, status=204
|
f"/oauth/register/{client.client_id}", headers=headers, status=204
|
||||||
)
|
)
|
||||||
assert not Client.get(client_id=client.client_id)
|
assert not models.Client.get(client_id=client.client_id)
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_client(testclient, backend, user):
|
def test_invalid_client(testclient, backend, user):
|
||||||
|
|
|
@ -2,8 +2,7 @@ from urllib.parse import parse_qs
|
||||||
from urllib.parse import urlsplit
|
from urllib.parse import urlsplit
|
||||||
|
|
||||||
from authlib.jose import jwt
|
from authlib.jose import jwt
|
||||||
from canaille.oidc.models import AuthorizationCode
|
from canaille.app import models
|
||||||
from canaille.oidc.models import Token
|
|
||||||
|
|
||||||
|
|
||||||
def test_oauth_hybrid(testclient, backend, user, client):
|
def test_oauth_hybrid(testclient, backend, user, client):
|
||||||
|
@ -32,11 +31,11 @@ def test_oauth_hybrid(testclient, backend, user, client):
|
||||||
params = parse_qs(urlsplit(res.location).fragment)
|
params = parse_qs(urlsplit(res.location).fragment)
|
||||||
|
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
|
|
||||||
access_token = params["access_token"][0]
|
access_token = params["access_token"][0]
|
||||||
token = Token.get(access_token=access_token)
|
token = models.Token.get(access_token=access_token)
|
||||||
assert token is not None
|
assert token is not None
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
|
@ -65,11 +64,11 @@ def test_oidc_hybrid(testclient, backend, logged_user, client, keypair, other_cl
|
||||||
params = parse_qs(urlsplit(res.location).fragment)
|
params = parse_qs(urlsplit(res.location).fragment)
|
||||||
|
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
|
|
||||||
access_token = params["access_token"][0]
|
access_token = params["access_token"][0]
|
||||||
token = Token.get(access_token=access_token)
|
token = models.Token.get(access_token=access_token)
|
||||||
assert token is not None
|
assert token is not None
|
||||||
|
|
||||||
id_token = params["id_token"][0]
|
id_token = params["id_token"][0]
|
||||||
|
|
|
@ -2,7 +2,7 @@ from urllib.parse import parse_qs
|
||||||
from urllib.parse import urlsplit
|
from urllib.parse import urlsplit
|
||||||
|
|
||||||
from authlib.jose import jwt
|
from authlib.jose import jwt
|
||||||
from canaille.oidc.models import Token
|
from canaille.app import models
|
||||||
|
|
||||||
|
|
||||||
def test_oauth_implicit(testclient, user, client):
|
def test_oauth_implicit(testclient, user, client):
|
||||||
|
@ -35,7 +35,7 @@ def test_oauth_implicit(testclient, user, client):
|
||||||
params = parse_qs(urlsplit(res.location).fragment)
|
params = parse_qs(urlsplit(res.location).fragment)
|
||||||
|
|
||||||
access_token = params["access_token"][0]
|
access_token = params["access_token"][0]
|
||||||
token = Token.get(access_token=access_token)
|
token = models.Token.get(access_token=access_token)
|
||||||
assert token is not None
|
assert token is not None
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
|
@ -79,7 +79,7 @@ def test_oidc_implicit(testclient, keypair, user, client, other_client):
|
||||||
params = parse_qs(urlsplit(res.location).fragment)
|
params = parse_qs(urlsplit(res.location).fragment)
|
||||||
|
|
||||||
access_token = params["access_token"][0]
|
access_token = params["access_token"][0]
|
||||||
token = Token.get(access_token=access_token)
|
token = models.Token.get(access_token=access_token)
|
||||||
assert token is not None
|
assert token is not None
|
||||||
|
|
||||||
id_token = params["id_token"][0]
|
id_token = params["id_token"][0]
|
||||||
|
@ -133,7 +133,7 @@ def test_oidc_implicit_with_group(
|
||||||
params = parse_qs(urlsplit(res.location).fragment)
|
params = parse_qs(urlsplit(res.location).fragment)
|
||||||
|
|
||||||
access_token = params["access_token"][0]
|
access_token = params["access_token"][0]
|
||||||
token = Token.get(access_token=access_token)
|
token = models.Token.get(access_token=access_token)
|
||||||
assert token is not None
|
assert token is not None
|
||||||
|
|
||||||
id_token = params["id_token"][0]
|
id_token = params["id_token"][0]
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from canaille.oidc.models import Token
|
from canaille.app import models
|
||||||
|
|
||||||
from . import client_credentials
|
from . import client_credentials
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ def test_password_flow_basic(testclient, user, client):
|
||||||
assert res.json["token_type"] == "Bearer"
|
assert res.json["token_type"] == "Bearer"
|
||||||
access_token = res.json["access_token"]
|
access_token = res.json["access_token"]
|
||||||
|
|
||||||
token = Token.get(access_token=access_token)
|
token = models.Token.get(access_token=access_token)
|
||||||
assert token is not None
|
assert token is not None
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
|
@ -52,7 +52,7 @@ def test_password_flow_post(testclient, user, client):
|
||||||
assert res.json["token_type"] == "Bearer"
|
assert res.json["token_type"] == "Bearer"
|
||||||
access_token = res.json["access_token"]
|
access_token = res.json["access_token"]
|
||||||
|
|
||||||
token = Token.get(access_token=access_token)
|
token = models.Token.get(access_token=access_token)
|
||||||
assert token is not None
|
assert token is not None
|
||||||
|
|
||||||
res = testclient.get(
|
res = testclient.get(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from canaille.oidc.models import Token
|
from canaille.app import models
|
||||||
from werkzeug.security import gen_salt
|
from werkzeug.security import gen_salt
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ def test_token_list_pagination(testclient, logged_admin, client):
|
||||||
res.mustcontain("0 items")
|
res.mustcontain("0 items")
|
||||||
tokens = []
|
tokens = []
|
||||||
for _ in range(26):
|
for _ in range(26):
|
||||||
token = Token(
|
token = models.Token(
|
||||||
token_id=gen_salt(48),
|
token_id=gen_salt(48),
|
||||||
access_token="my-valid-token",
|
access_token="my-valid-token",
|
||||||
client=client,
|
client=client,
|
||||||
|
@ -75,7 +75,7 @@ def test_token_list_bad_pages(testclient, logged_admin):
|
||||||
|
|
||||||
|
|
||||||
def test_token_list_search(testclient, logged_admin, client):
|
def test_token_list_search(testclient, logged_admin, client):
|
||||||
token1 = Token(
|
token1 = models.Token(
|
||||||
token_id=gen_salt(48),
|
token_id=gen_salt(48),
|
||||||
access_token="this-token-is-ok",
|
access_token="this-token-is-ok",
|
||||||
client=client,
|
client=client,
|
||||||
|
@ -89,7 +89,7 @@ def test_token_list_search(testclient, logged_admin, client):
|
||||||
lifetime=3600,
|
lifetime=3600,
|
||||||
)
|
)
|
||||||
token1.save()
|
token1.save()
|
||||||
token2 = Token(
|
token2 = models.Token(
|
||||||
token_id=gen_salt(48),
|
token_id=gen_salt(48),
|
||||||
access_token="this-token-is-valid",
|
access_token="this-token-is-valid",
|
||||||
client=client,
|
client=client,
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
from urllib.parse import parse_qs
|
from urllib.parse import parse_qs
|
||||||
from urllib.parse import urlsplit
|
from urllib.parse import urlsplit
|
||||||
|
|
||||||
from canaille.oidc.models import AuthorizationCode
|
from canaille.app import models
|
||||||
from canaille.oidc.models import Token
|
|
||||||
|
|
||||||
from . import client_credentials
|
from . import client_credentials
|
||||||
|
|
||||||
|
@ -76,7 +75,7 @@ def test_full_flow(testclient, logged_user, client, user, other_client):
|
||||||
assert res.location.startswith(client.redirect_uris[0])
|
assert res.location.startswith(client.redirect_uris[0])
|
||||||
params = parse_qs(urlsplit(res.location).query)
|
params = parse_qs(urlsplit(res.location).query)
|
||||||
code = params["code"][0]
|
code = params["code"][0]
|
||||||
authcode = AuthorizationCode.get(code=code)
|
authcode = models.AuthorizationCode.get(code=code)
|
||||||
assert authcode is not None
|
assert authcode is not None
|
||||||
|
|
||||||
res = testclient.post(
|
res = testclient.post(
|
||||||
|
@ -92,7 +91,7 @@ def test_full_flow(testclient, logged_user, client, user, other_client):
|
||||||
)
|
)
|
||||||
access_token = res.json["access_token"]
|
access_token = res.json["access_token"]
|
||||||
|
|
||||||
token = Token.get(access_token=access_token)
|
token = models.Token.get(access_token=access_token)
|
||||||
assert token.client == client
|
assert token.client == client
|
||||||
assert token.subject == logged_user
|
assert token.subject == logged_user
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue