forked from Github-Mirrors/canaille
Remember consents
This commit is contained in:
parent
fd6fb648df
commit
00a0557f2e
7 changed files with 196 additions and 8 deletions
|
@ -21,7 +21,7 @@ from flask_babel import Babel
|
|||
from .flaskutils import current_user
|
||||
from .ldaputils import LDAPObjectHelper
|
||||
from .oauth2utils import config_oauth
|
||||
from .models import User, Token, AuthorizationCode, Client
|
||||
from .models import User, Token, AuthorizationCode, Client, Consent
|
||||
|
||||
try: # pragma: no cover
|
||||
import sentry_sdk
|
||||
|
@ -101,6 +101,7 @@ def setup_ldap_tree(app):
|
|||
Token.initialize(conn)
|
||||
AuthorizationCode.initialize(conn)
|
||||
Client.initialize(conn)
|
||||
Consent.initialize(conn)
|
||||
conn.unbind_s()
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import ldap
|
||||
import datetime
|
||||
import ldap
|
||||
import uuid
|
||||
from authlib.common.encoding import json_loads, json_dumps
|
||||
from authlib.oauth2.rfc6749 import (
|
||||
ClientMixin,
|
||||
|
@ -222,3 +223,15 @@ class Token(LDAPObjectHelper, TokenMixin):
|
|||
return False
|
||||
|
||||
return self.expire_date >= datetime.datetime.now()
|
||||
|
||||
|
||||
class Consent(LDAPObjectHelper):
|
||||
objectClass = ["oauthConsent"]
|
||||
base = "ou=consents,ou=oauth"
|
||||
id = "cn"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if "cn" not in kwargs:
|
||||
kwargs["cn"] = str(uuid.uuid4())
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
|
|
@ -3,7 +3,7 @@ from authlib.oauth2 import OAuth2Error
|
|||
from flask import Blueprint, request, session, redirect
|
||||
from flask import render_template, jsonify, flash, current_app
|
||||
from flask_babel import gettext
|
||||
from .models import User, Client
|
||||
from .models import User, Client, Consent
|
||||
from .oauth2utils import authorization, IntrospectionEndpoint, RevocationEndpoint
|
||||
from .forms import LoginForm
|
||||
from .flaskutils import current_user
|
||||
|
@ -16,8 +16,14 @@ bp = Blueprint(__name__, "oauth")
|
|||
def authorize():
|
||||
user = current_user()
|
||||
client = Client.get(request.values["client_id"])
|
||||
scopes = request.args.get("scope", "").split(" ")
|
||||
|
||||
# LOGIN
|
||||
|
||||
if not user:
|
||||
if request.args.get("prompt") == "none":
|
||||
return jsonify({"error": "login_required"})
|
||||
|
||||
form = LoginForm(request.form or None)
|
||||
if request.method == "GET":
|
||||
return render_template("login.html", form=form, menu=False)
|
||||
|
@ -30,12 +36,37 @@ def authorize():
|
|||
|
||||
return redirect(request.url)
|
||||
|
||||
# CONSENT
|
||||
|
||||
consents = Consent.filter(
|
||||
oauthClient=client.dn,
|
||||
oauthSubject=user.dn,
|
||||
)
|
||||
consent = consents[0] if consents else None
|
||||
|
||||
if request.method == "GET":
|
||||
if consent and all(scope in set(consent.oauthScope) for scope in scopes):
|
||||
return authorization.create_authorization_response(grant_user=user.dn)
|
||||
|
||||
elif request.args.get("prompt") == "none":
|
||||
return jsonify({"error": "consent_required"})
|
||||
|
||||
try:
|
||||
grant = authorization.validate_consent_request(end_user=user)
|
||||
except OAuth2Error as error:
|
||||
return jsonify(dict(error.get_body()))
|
||||
|
||||
if consent:
|
||||
consent.oauthScope = list(set(scopes + consents[0].oauthScope))
|
||||
else:
|
||||
consent = Consent(
|
||||
oauthClient=client.dn,
|
||||
oauthSubject=user.dn,
|
||||
oauthScope=scopes,
|
||||
)
|
||||
|
||||
consent.save()
|
||||
|
||||
return render_template(
|
||||
"authorize.html", user=user, grant=grant, client=client, menu=False
|
||||
)
|
||||
|
|
|
@ -353,3 +353,14 @@ olcObjectClasses: ( 1.3.6.1.4.1.56207.1.2.3 NAME 'oauthToken'
|
|||
oauthTokenLifetime $
|
||||
oauthRevoked )
|
||||
X-ORIGIN 'OAuth 2.0' )
|
||||
olcObjectClasses: ( 1.3.6.1.4.1.56207.1.2.4 NAME 'oauthConsent'
|
||||
DESC 'OAuth 2.0 User consents'
|
||||
SUP top
|
||||
STRUCTURAL
|
||||
MUST (
|
||||
cn $
|
||||
oauthSubject $
|
||||
oauthClient $
|
||||
oauthScope
|
||||
)
|
||||
X-ORIGIN 'OAuth 2.0' )
|
||||
|
|
|
@ -350,3 +350,14 @@ objectclass ( 1.3.6.1.4.1.56207.1.2.3 NAME 'oauthToken'
|
|||
oauthTokenLifetime $
|
||||
oauthRevoked )
|
||||
X-ORIGIN 'OAuth 2.0' )
|
||||
objectclass ( 1.3.6.1.4.1.56207.1.2.4 NAME 'oauthConsent'
|
||||
DESC 'OAuth 2.0 User consents'
|
||||
SUP top
|
||||
STRUCTURAL
|
||||
MUST (
|
||||
cn $
|
||||
oauthSubject $
|
||||
oauthClient $
|
||||
oauthScope
|
||||
)
|
||||
X-ORIGIN 'OAuth 2.0' )
|
||||
|
|
|
@ -9,7 +9,7 @@ from cryptography.hazmat.backends import default_backend as crypto_default_backe
|
|||
from flask_webtest import TestApp
|
||||
from werkzeug.security import gen_salt
|
||||
from oidc_ldap_bridge import create_app
|
||||
from oidc_ldap_bridge.models import User, Client, Token, AuthorizationCode
|
||||
from oidc_ldap_bridge.models import User, Client, Token, AuthorizationCode, Consent
|
||||
from oidc_ldap_bridge.ldaputils import LDAPObjectHelper
|
||||
|
||||
|
||||
|
@ -253,3 +253,10 @@ def logged_admin(admin, testclient):
|
|||
with testclient.session_transaction() as sess:
|
||||
sess["user_dn"] = admin.dn
|
||||
return admin
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanups(slapd_connection):
|
||||
yield
|
||||
for consent in Consent.filter(conn=slapd_connection):
|
||||
consent.delete(conn=slapd_connection)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from . import client_credentials
|
||||
from authlib.oauth2.rfc7636 import create_s256_code_challenge
|
||||
from urllib.parse import urlsplit, parse_qs
|
||||
from oidc_ldap_bridge.models import AuthorizationCode, Token
|
||||
from oidc_ldap_bridge.models import AuthorizationCode, Token, Consent
|
||||
from werkzeug.security import gen_salt
|
||||
|
||||
|
||||
|
@ -76,9 +76,6 @@ def test_logout_login(testclient, slapd_connection, logged_user, client):
|
|||
res = res.form.submit()
|
||||
assert 302 == res.status_code
|
||||
res = res.follow()
|
||||
assert 200 == res.status_code
|
||||
|
||||
res = res.forms["accept"].submit()
|
||||
assert 302 == res.status_code
|
||||
|
||||
assert res.location.startswith(client.oauthRedirectURIs[0])
|
||||
|
@ -219,3 +216,120 @@ def test_code_challenge(testclient, slapd_connection, logged_user, client):
|
|||
|
||||
client.oauthTokenEndpointAuthMethod = "client_secret_basic"
|
||||
client.save(slapd_connection)
|
||||
|
||||
|
||||
def test_authorization_code_flow_when_consent_already_given(
|
||||
testclient, slapd_connection, logged_user, client
|
||||
):
|
||||
assert not Consent.filter(conn=slapd_connection)
|
||||
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
params=dict(
|
||||
response_type="code",
|
||||
client_id=client.oauthClientID,
|
||||
scope="profile",
|
||||
nonce="somenonce",
|
||||
),
|
||||
)
|
||||
assert 200 == res.status_code
|
||||
|
||||
res = res.forms["accept"].submit()
|
||||
assert 302 == res.status_code
|
||||
|
||||
assert res.location.startswith(client.oauthRedirectURIs[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code, conn=slapd_connection)
|
||||
assert authcode is not None
|
||||
|
||||
consents = Consent.filter(
|
||||
oauthClient=client.dn, oauthSubject=logged_user.dn, conn=slapd_connection
|
||||
)
|
||||
assert "profile" in consents[0].oauthScope
|
||||
|
||||
res = testclient.post(
|
||||
"/oauth/token",
|
||||
params=dict(
|
||||
grant_type="authorization_code",
|
||||
code=code,
|
||||
scope="profile",
|
||||
redirect_uri=client.oauthRedirectURIs[0],
|
||||
),
|
||||
headers={"Authorization": f"Basic {client_credentials(client)}"},
|
||||
)
|
||||
assert 200 == res.status_code
|
||||
assert "access_token" in res.json
|
||||
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
params=dict(
|
||||
response_type="code",
|
||||
client_id=client.oauthClientID,
|
||||
scope="profile",
|
||||
nonce="somenonce",
|
||||
),
|
||||
)
|
||||
assert 302 == res.status_code
|
||||
assert res.location.startswith(client.oauthRedirectURIs[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
assert "code" in params
|
||||
|
||||
|
||||
def test_prompt_none(testclient, slapd_connection, logged_user, client):
|
||||
Consent(
|
||||
oauthClient=client.dn,
|
||||
oauthSubject=logged_user.dn,
|
||||
oauthScope=["openid", "profile"],
|
||||
).save(conn=slapd_connection)
|
||||
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
params=dict(
|
||||
response_type="code",
|
||||
client_id=client.oauthClientID,
|
||||
scope="profile",
|
||||
nonce="somenonce",
|
||||
prompt="none",
|
||||
),
|
||||
)
|
||||
assert 302 == res.status_code
|
||||
assert res.location.startswith(client.oauthRedirectURIs[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
assert "code" in params
|
||||
|
||||
|
||||
def test_prompt_not_logged(testclient, slapd_connection, user, client):
|
||||
Consent(
|
||||
oauthClient=client.dn,
|
||||
oauthSubject=user.dn,
|
||||
oauthScope=["openid", "profile"],
|
||||
).save(conn=slapd_connection)
|
||||
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
params=dict(
|
||||
response_type="code",
|
||||
client_id=client.oauthClientID,
|
||||
scope="profile",
|
||||
nonce="somenonce",
|
||||
prompt="none",
|
||||
),
|
||||
)
|
||||
assert 200 == res.status_code
|
||||
assert "login_required" == res.json.get("error")
|
||||
|
||||
|
||||
def test_prompt_no_consent(testclient, slapd_connection, logged_user, client):
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
params=dict(
|
||||
response_type="code",
|
||||
client_id=client.oauthClientID,
|
||||
scope="profile",
|
||||
nonce="somenonce",
|
||||
prompt="none",
|
||||
),
|
||||
)
|
||||
assert 200 == res.status_code
|
||||
assert "consent_required" == res.json.get("error")
|
||||
|
|
Loading…
Reference in a new issue