canaille-globuzma/web/oauth2utils.py
2020-08-24 15:56:30 +02:00

282 lines
9 KiB
Python

import datetime
from authlib.integrations.flask_oauth2 import AuthorizationServer, ResourceProtector
from authlib.oauth2.rfc6749.grants import (
AuthorizationCodeGrant as _AuthorizationCodeGrant,
ResourceOwnerPasswordCredentialsGrant as _ResourceOwnerPasswordCredentialsGrant,
RefreshTokenGrant as _RefreshTokenGrant,
ImplicitGrant,
ClientCredentialsGrant,
)
from authlib.oauth2.rfc6750 import BearerTokenValidator as _BearerTokenValidator
from authlib.oauth2.rfc7009 import RevocationEndpoint as _RevocationEndpoint
from authlib.oauth2.rfc7636 import CodeChallenge
from authlib.oauth2.rfc7662 import IntrospectionEndpoint as _IntrospectionEndpoint
from authlib.oidc.core.grants import (
OpenIDCode as _OpenIDCode,
OpenIDImplicitGrant as _OpenIDImplicitGrant,
OpenIDHybridGrant as _OpenIDHybridGrant,
)
from authlib.oidc.core import UserInfo
from flask import current_app
from werkzeug.security import gen_salt
from .models import Client, AuthorizationCode, Token, User
def exists_nonce(nonce, req):
exists = AuthorizationCode.filter(oauthClientID=req.client_id, oauthNonce=nonce)
return bool(exists)
def get_jwt_config(grant):
return {
"key": current_app.config["JWT"]["KEY"],
"alg": current_app.config["JWT"]["ALG"],
"iss": current_app.config["JWT"]["ISS"],
"exp": current_app.config["JWT"]["EXP"],
}
def generate_user_info(user, scope):
fields = ["sub"]
if "profile" in scope:
fields += [
"name",
"family_name",
"given_name",
"nickname",
"preferred_username",
"profile",
"picture",
"website",
"gender",
"birthdate",
"zoneinfo",
"locale",
"updated_at",
]
if "email" in scope:
fields += ["email", "email_verified"]
if "address" in scope:
fields += ["address"]
if "phone" in scope:
fields += ["phone_number", "phone_number_verified"]
data = {}
for field in fields:
ldap_field_match = current_app.config["JWT"]["MAPPING"].get(field.upper())
if ldap_field_match and getattr(user, ldap_field_match, None):
data[field] = getattr(user, ldap_field_match)
if isinstance(data[field], list):
data[field] = data[field][0]
return UserInfo(**data)
def save_authorization_code(code, request):
nonce = request.data.get("nonce")
now = datetime.datetime.now()
code = AuthorizationCode(
oauthCode=code,
oauthSubject=request.user,
oauthClientID=request.client.oauthClientID,
oauthRedirectURI=request.redirect_uri or request.client.oauthRedirectURIs[0],
oauthScope=request.scope,
oauthNonce=nonce,
oauthAuthorizationDate=now.strftime("%Y%m%d%H%M%SZ"),
oauthAuthorizationLifetime=str(84000),
oauthCodeChallenge=request.data.get("code_challenge"),
oauthCodeChallengeMethod=request.data.get("code_challenge_method"),
)
code.save()
return code.oauthCode
class AuthorizationCodeGrant(_AuthorizationCodeGrant):
def save_authorization_code(self, code, request):
return save_authorization_code(code, request)
def query_authorization_code(self, code, client):
item = AuthorizationCode.filter(
oauthCode=code, oauthClientID=client.oauthClientID
)
if item and not item[0].is_expired():
return item[0]
def delete_authorization_code(self, authorization_code):
authorization_code.delete()
def authenticate_user(self, authorization_code):
return User.get(authorization_code.oauthSubject)
class OpenIDCode(_OpenIDCode):
def exists_nonce(self, nonce, request):
return exists_nonce(nonce, request)
def get_jwt_config(self, grant):
return get_jwt_config(grant)
def generate_user_info(self, user, scope):
return generate_user_info(user, scope)
class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant):
def authenticate_user(self, username, password):
return User.authenticate(username, password)
class RefreshTokenGrant(_RefreshTokenGrant):
def authenticate_refresh_token(self, refresh_token):
token = Token.filter(oauthRefreshToken=refresh_token)
if token and token[0].is_refresh_token_active():
return token[0]
def authenticate_user(self, credential):
return User.get(credential.oauthSubject)
def revoke_old_credential(self, credential):
# TODO: implement revokation
pass
class OpenIDImplicitGrant(_OpenIDImplicitGrant):
def exists_nonce(self, nonce, request):
return exists_nonce(nonce, request)
def get_jwt_config(self, grant=None):
return get_jwt_config(grant)
def generate_user_info(self, user, scope):
user = User.get(user)
return generate_user_info(user, scope)
class OpenIDHybridGrant(_OpenIDHybridGrant):
def create_authorization_code(self, client, grant_user, request):
code = gen_salt(48)
return self.save_authorization_code(code, request)
def save_authorization_code(self, code, request):
return save_authorization_code(code, request)
def exists_nonce(self, nonce, request):
return exists_nonce(nonce, request)
def get_jwt_config(self, grant=None):
return get_jwt_config(grant)
def generate_user_info(self, user, scope):
user = User.get(user)
return generate_user_info(user, scope)
def query_client(client_id):
return Client.get(client_id)
def save_token(token, request):
now = datetime.datetime.now()
t = Token(
oauthTokenType=token["token_type"],
oauthAccessToken=token["access_token"],
oauthIssueDate=now.strftime("%Y%m%d%H%M%SZ"),
oauthTokenLifetime=str(token["expires_in"]),
oauthScope=token["scope"],
oauthClientID=request.client.oauthClientID,
oauthRefreshToken=token.get("refresh_token"),
)
t.save()
class BearerTokenValidator(_BearerTokenValidator):
def authenticate_token(self, token_string):
return Token.get(token_string)
def request_invalid(self, request):
return False
def token_revoked(self, token):
return token.revoked
class RevocationEndpoint(_RevocationEndpoint):
def query_token(self, token, token_type_hint, client):
if token_type_hint == "access_token":
return Token.filter(
oauthClientID=client.oauthClientID, oauthAccessToken=token
)
elif token_type_hint == "refresh_token":
return Token.filter(
oauthClientID=client.oauthClientID, oauthRefreshToken=token
)
item = Token.filter(oauthClientID=client.oauthClientID, oauthAccessToken=token)
if item:
return item[0]
item = Token.filter(oauthClientID=client.oauthClientID, oauthRefreshToken=token)
if item:
return item[0]
return None
def revoke_token(self, token):
token.revoked = True
token.save()
class IntrospectionEndpoint(_IntrospectionEndpoint):
def query_token(self, token, token_type_hint, client):
if token_type_hint == "access_token":
tok = Token.filter(oauthAccessToken=token)
elif token_type_hint == "refresh_token":
tok = Token.filter(oauthRefreshToken=token)
else:
tok = Token.filter(oauthAccessToken=token)
if not tok:
tok = Token.filter(oauthRefreshToken=token)
if tok:
tok = tok[0]
if tok.oauthClientID == client.oauthClientID:
return tok
# if has_introspect_permission(client):
# return tok
def introspect_token(self, token):
return {
"active": True,
"client_id": token.oauthClientID,
"token_type": token.oauthTokenType,
"username": User.get(token.oauthSubject).name,
"scope": token.get_scope(),
"sub": token.oauthSubject,
"aud": token.oauthClientID,
"iss": current_app.config["JWT"]["ISS"],
"exp": token.get_expires_at(),
"iat": token.get_issued_at(),
}
authorization = AuthorizationServer()
require_oauth = ResourceProtector()
def config_oauth(app):
authorization.init_app(app, query_client=query_client, save_token=save_token)
authorization.register_grant(PasswordGrant)
authorization.register_grant(ImplicitGrant)
authorization.register_grant(RefreshTokenGrant)
authorization.register_grant(ClientCredentialsGrant)
authorization.register_grant(
AuthorizationCodeGrant,
[OpenIDCode(require_nonce=True), CodeChallenge(required=False)],
)
authorization.register_grant(OpenIDImplicitGrant)
authorization.register_grant(OpenIDHybridGrant)
require_oauth.register_token_validator(BearerTokenValidator())
authorization.register_endpoint(IntrospectionEndpoint)
authorization.register_endpoint(RevocationEndpoint)