forked from Github-Mirrors/canaille
Implemented refresh grant
This commit is contained in:
parent
60d30e258b
commit
ee23c5ec32
6 changed files with 128 additions and 19 deletions
|
@ -124,7 +124,13 @@ def client(app, slapd_connection):
|
|||
oauthLogoURI="https://mydomain.tld/logo.png",
|
||||
oauthIssueDate=datetime.datetime.now().strftime("%Y%m%d%H%S%MZ"),
|
||||
oauthClientSecret=gen_salt(48),
|
||||
oauthGrantType=["password", "authorization_code", "implicit", "hybrid"],
|
||||
oauthGrantType=[
|
||||
"password",
|
||||
"authorization_code",
|
||||
"implicit",
|
||||
"hybrid",
|
||||
"refresh_token",
|
||||
],
|
||||
oauthResponseType=["code", "token", "id_token"],
|
||||
oauthScope=["openid", "profile"],
|
||||
oauthTermsOfServiceURI="https://mydomain.tld/tos",
|
||||
|
|
|
@ -3,7 +3,7 @@ from urllib.parse import urlsplit, parse_qs
|
|||
from web.models import AuthorizationCode, Token
|
||||
|
||||
|
||||
def test_success(testclient, slapd_connection, user, client):
|
||||
def test_authorization_code_flow(testclient, slapd_connection, logged_user, client):
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
params=dict(
|
||||
|
@ -15,11 +15,57 @@ def test_success(testclient, slapd_connection, user, client):
|
|||
)
|
||||
assert 200 == res.status_code
|
||||
|
||||
res.form["login"] = user.name
|
||||
res = res.forms["accept"].submit()
|
||||
assert 302 == res.status_code
|
||||
|
||||
assert res.location.startswith(client.oauthRedirectURIs[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code, conn=slapd_connection)
|
||||
assert authcode is not None
|
||||
|
||||
res = testclient.post(
|
||||
"/oauth/token",
|
||||
params=dict(
|
||||
grant_type="authorization_code",
|
||||
code=code,
|
||||
scope="profile",
|
||||
redirect_uri=client.oauthRedirectURIs[0],
|
||||
),
|
||||
headers={"Authorization": f"Basic {client_credentials(client)}"},
|
||||
)
|
||||
assert 200 == res.status_code
|
||||
access_token = res.json["access_token"]
|
||||
|
||||
token = Token.get(access_token, conn=slapd_connection)
|
||||
assert token is not None
|
||||
|
||||
res = testclient.get("/api/me", headers={"Authorization": f"Bearer {access_token}"})
|
||||
assert 200 == res.status_code
|
||||
assert {"foo": "bar"} == res.json
|
||||
|
||||
|
||||
def test_logout_login(testclient, slapd_connection, logged_user, client):
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
params=dict(
|
||||
response_type="code",
|
||||
client_id=client.oauthClientID,
|
||||
scope="profile",
|
||||
nonce="somenonce",
|
||||
),
|
||||
)
|
||||
assert 200 == res.status_code
|
||||
|
||||
res = res.forms["logout"].submit()
|
||||
assert 302 == res.status_code
|
||||
res = res.follow()
|
||||
assert 200 == res.status_code
|
||||
|
||||
res.form["login"] = logged_user.name
|
||||
res.form["password"] = "correct horse battery staple"
|
||||
res = res.form.submit()
|
||||
assert 302 == res.status_code
|
||||
|
||||
res = res.follow()
|
||||
assert 200 == res.status_code
|
||||
|
||||
|
@ -51,3 +97,57 @@ def test_success(testclient, slapd_connection, user, client):
|
|||
res = testclient.get("/api/me", headers={"Authorization": f"Bearer {access_token}"})
|
||||
assert 200 == res.status_code
|
||||
assert {"foo": "bar"} == res.json
|
||||
|
||||
|
||||
def test_refresh_token(testclient, slapd_connection, logged_user, client):
|
||||
res = testclient.get(
|
||||
"/oauth/authorize",
|
||||
params=dict(
|
||||
response_type="code",
|
||||
client_id=client.oauthClientID,
|
||||
scope="profile",
|
||||
nonce="somenonce",
|
||||
),
|
||||
)
|
||||
assert 200 == res.status_code
|
||||
|
||||
res = res.forms["accept"].submit()
|
||||
assert 302 == res.status_code
|
||||
|
||||
assert res.location.startswith(client.oauthRedirectURIs[0])
|
||||
params = parse_qs(urlsplit(res.location).query)
|
||||
code = params["code"][0]
|
||||
authcode = AuthorizationCode.get(code, conn=slapd_connection)
|
||||
assert authcode is not None
|
||||
|
||||
res = testclient.post(
|
||||
"/oauth/token",
|
||||
params=dict(
|
||||
grant_type="authorization_code",
|
||||
code=code,
|
||||
scope="profile",
|
||||
redirect_uri=client.oauthRedirectURIs[0],
|
||||
),
|
||||
headers={"Authorization": f"Basic {client_credentials(client)}"},
|
||||
)
|
||||
assert 200 == res.status_code
|
||||
access_token = res.json["access_token"]
|
||||
token = Token.get(access_token, conn=slapd_connection)
|
||||
assert token is not None
|
||||
|
||||
print("------------------------------------")
|
||||
res = testclient.post(
|
||||
"/oauth/token",
|
||||
params=dict(
|
||||
grant_type="refresh_token", refresh_token=res.json["refresh_token"],
|
||||
),
|
||||
headers={"Authorization": f"Basic {client_credentials(client)}"},
|
||||
)
|
||||
assert 200 == res.status_code
|
||||
access_token = res.json["access_token"]
|
||||
token = Token.get(access_token, conn=slapd_connection)
|
||||
assert token is not None
|
||||
|
||||
res = testclient.get("/api/me", headers={"Authorization": f"Bearer {access_token}"})
|
||||
assert 200 == res.status_code
|
||||
assert {"foo": "bar"} == res.json
|
||||
|
|
|
@ -2,7 +2,7 @@ from . import client_credentials
|
|||
from web.models import Token
|
||||
|
||||
|
||||
def test_success(testclient, slapd_connection, user, client):
|
||||
def test_password_flow(testclient, slapd_connection, user, client):
|
||||
res = testclient.post(
|
||||
"/oauth/token",
|
||||
params=dict(
|
||||
|
|
|
@ -48,8 +48,9 @@ class ClientAdd(FlaskForm):
|
|||
("authorization_code", "authorization_code"),
|
||||
("implicit", "implicit"),
|
||||
("hybrid", "hybrid"),
|
||||
("refresh_token", "refresh_token"),
|
||||
],
|
||||
default=["authorization_code"],
|
||||
default=["authorization_code", "refresh_token"],
|
||||
)
|
||||
oauthScope = wtforms.TextField(
|
||||
gettext("Scope"),
|
||||
|
|
|
@ -172,7 +172,7 @@ class Token(LDAPObjectHelper, TokenMixin):
|
|||
id = "oauthAccessToken"
|
||||
|
||||
def get_client_id(self):
|
||||
return self.authzClientID
|
||||
return self.oauthClientID
|
||||
|
||||
def get_scope(self):
|
||||
return " ".join(self.oauthScope)
|
||||
|
@ -186,7 +186,10 @@ class Token(LDAPObjectHelper, TokenMixin):
|
|||
return issue_timestamp + int(self.oauthTokenLifetime)
|
||||
|
||||
def is_refresh_token_active(self):
|
||||
if self.revoked:
|
||||
return False
|
||||
expires_at = self.issued_at + self.expires_in * 2
|
||||
return expires_at >= time.time()
|
||||
# if self.revoked:
|
||||
# return False
|
||||
return (
|
||||
datetime.datetime.strptime(self.oauthIssueDate, "%Y%m%d%H%M%SZ")
|
||||
+ datetime.timedelta(seconds=int(self.oauthTokenLifetime))
|
||||
>= datetime.datetime.now()
|
||||
)
|
||||
|
|
|
@ -122,18 +122,16 @@ class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant):
|
|||
|
||||
class RefreshTokenGrant(_RefreshTokenGrant):
|
||||
def authenticate_refresh_token(self, refresh_token):
|
||||
raise NotImplementedError()
|
||||
token = Token.query.filter_by(refresh_token=refresh_token).first()
|
||||
if token and token.is_refresh_token_active():
|
||||
return token
|
||||
token = Token.filter(oauthRefreshToken=refresh_token)
|
||||
if token and token[0].is_refresh_token_active():
|
||||
return token[0]
|
||||
|
||||
def authenticate_user(self, credential):
|
||||
raise NotImplementedError()
|
||||
return User.query.get(credential.user_id)
|
||||
return User.get(credential.oauthSubject)
|
||||
|
||||
def revoke_old_credential(self, credential):
|
||||
raise NotImplementedError()
|
||||
credential.revoked = True
|
||||
# TODO: implement revokation
|
||||
pass
|
||||
|
||||
|
||||
class OpenIDImplicitGrant(_OpenIDImplicitGrant):
|
||||
|
@ -206,6 +204,7 @@ def config_oauth(app):
|
|||
|
||||
authorization.register_grant(PasswordGrant)
|
||||
authorization.register_grant(ImplicitGrant)
|
||||
authorization.register_grant(RefreshTokenGrant)
|
||||
authorization.register_grant(ClientCredentialsGrant)
|
||||
|
||||
authorization.register_grant(
|
||||
|
|
Loading…
Reference in a new issue