import datetime from authlib.integrations.flask_oauth2 import AuthorizationServer, ResourceProtector from authlib.oauth2.rfc6749.grants import ( AuthorizationCodeGrant as _AuthorizationCodeGrant, ResourceOwnerPasswordCredentialsGrant as _ResourceOwnerPasswordCredentialsGrant, RefreshTokenGrant as _RefreshTokenGrant, ImplicitGrant, ClientCredentialsGrant, ) from authlib.oauth2.rfc6750 import BearerTokenValidator as _BearerTokenValidator from authlib.oauth2.rfc7009 import RevocationEndpoint as _RevocationEndpoint from authlib.oauth2.rfc7636 import CodeChallenge from authlib.oauth2.rfc7662 import IntrospectionEndpoint as _IntrospectionEndpoint from authlib.oidc.core.grants import ( OpenIDCode as _OpenIDCode, OpenIDImplicitGrant as _OpenIDImplicitGrant, OpenIDHybridGrant as _OpenIDHybridGrant, ) from authlib.oidc.core import UserInfo from flask import current_app from werkzeug.security import gen_salt from .models import Client, AuthorizationCode, Token, User def exists_nonce(nonce, req): exists = AuthorizationCode.filter(oauthClientID=req.client_id, oauthNonce=nonce) return bool(exists) def get_jwt_config(grant): return { "key": current_app.config["JWT"]["KEY"], "alg": current_app.config["JWT"]["ALG"], "iss": current_app.config["JWT"]["ISS"], "exp": current_app.config["JWT"]["EXP"], } def generate_user_info(user, scope): fields = ["sub"] if "profile" in scope: fields += [ "name", "family_name", "given_name", "nickname", "preferred_username", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale", "updated_at", ] if "email" in scope: fields += ["email", "email_verified"] if "address" in scope: fields += ["address"] if "phone" in scope: fields += ["phone_number", "phone_number_verified"] data = {} for field in fields: ldap_field_match = current_app.config["JWT"]["MAPPING"].get(field.upper()) if ldap_field_match and getattr(user, ldap_field_match, None): data[field] = getattr(user, ldap_field_match) if isinstance(data[field], list): data[field] = data[field][0] return UserInfo(**data) def save_authorization_code(code, request): nonce = request.data.get("nonce") now = datetime.datetime.now() code = AuthorizationCode( oauthCode=code, oauthSubject=request.user, oauthClientID=request.client.oauthClientID, oauthRedirectURI=request.redirect_uri or request.client.oauthRedirectURIs[0], oauthScope=request.scope, oauthNonce=nonce, oauthAuthorizationDate=now.strftime("%Y%m%d%H%M%SZ"), oauthAuthorizationLifetime=str(84000), oauthCodeChallenge=request.data.get("code_challenge"), oauthCodeChallengeMethod=request.data.get("code_challenge_method"), ) code.save() return code.oauthCode class AuthorizationCodeGrant(_AuthorizationCodeGrant): def save_authorization_code(self, code, request): return save_authorization_code(code, request) def query_authorization_code(self, code, client): item = AuthorizationCode.filter( oauthCode=code, oauthClientID=client.oauthClientID ) if item and not item[0].is_expired(): return item[0] def delete_authorization_code(self, authorization_code): authorization_code.delete() def authenticate_user(self, authorization_code): return User.get(authorization_code.oauthSubject) class OpenIDCode(_OpenIDCode): def exists_nonce(self, nonce, request): return exists_nonce(nonce, request) def get_jwt_config(self, grant): return get_jwt_config(grant) def generate_user_info(self, user, scope): return generate_user_info(user, scope) class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant): def authenticate_user(self, username, password): return User.authenticate(username, password) class RefreshTokenGrant(_RefreshTokenGrant): def authenticate_refresh_token(self, refresh_token): token = Token.filter(oauthRefreshToken=refresh_token) if token and token[0].is_refresh_token_active(): return token[0] def authenticate_user(self, credential): return User.get(credential.oauthSubject) def revoke_old_credential(self, credential): credential.revoke = True credential.save() class OpenIDImplicitGrant(_OpenIDImplicitGrant): def exists_nonce(self, nonce, request): return exists_nonce(nonce, request) def get_jwt_config(self, grant=None): return get_jwt_config(grant) def generate_user_info(self, user, scope): user = User.get(user) return generate_user_info(user, scope) class OpenIDHybridGrant(_OpenIDHybridGrant): def create_authorization_code(self, client, grant_user, request): code = gen_salt(48) return self.save_authorization_code(code, request) def save_authorization_code(self, code, request): return save_authorization_code(code, request) def exists_nonce(self, nonce, request): return exists_nonce(nonce, request) def get_jwt_config(self, grant=None): return get_jwt_config(grant) def generate_user_info(self, user, scope): user = User.get(user) return generate_user_info(user, scope) def query_client(client_id): return Client.get(client_id) def save_token(token, request): now = datetime.datetime.now() t = Token( oauthTokenType=token["token_type"], oauthAccessToken=token["access_token"], oauthIssueDate=now.strftime("%Y%m%d%H%M%SZ"), oauthTokenLifetime=str(token["expires_in"]), oauthScope=token["scope"], oauthClientID=request.client.oauthClientID, oauthRefreshToken=token.get("refresh_token"), ) t.save() class BearerTokenValidator(_BearerTokenValidator): def authenticate_token(self, token_string): return Token.get(token_string) def request_invalid(self, request): return False def token_revoked(self, token): return token.revoked class RevocationEndpoint(_RevocationEndpoint): def query_token(self, token, token_type_hint, client): if token_type_hint == "access_token": return Token.filter( oauthClientID=client.oauthClientID, oauthAccessToken=token ) elif token_type_hint == "refresh_token": return Token.filter( oauthClientID=client.oauthClientID, oauthRefreshToken=token ) item = Token.filter(oauthClientID=client.oauthClientID, oauthAccessToken=token) if item: return item[0] item = Token.filter(oauthClientID=client.oauthClientID, oauthRefreshToken=token) if item: return item[0] return None def revoke_token(self, token): token.revoked = True token.save() class IntrospectionEndpoint(_IntrospectionEndpoint): def query_token(self, token, token_type_hint, client): if token_type_hint == "access_token": tok = Token.filter(oauthAccessToken=token) elif token_type_hint == "refresh_token": tok = Token.filter(oauthRefreshToken=token) else: tok = Token.filter(oauthAccessToken=token) if not tok: tok = Token.filter(oauthRefreshToken=token) if tok: tok = tok[0] if tok.oauthClientID == client.oauthClientID: return tok # if has_introspect_permission(client): # return tok def introspect_token(self, token): return { "active": True, "client_id": token.oauthClientID, "token_type": token.oauthTokenType, "username": User.get(token.oauthSubject).name, "scope": token.get_scope(), "sub": token.oauthSubject, "aud": token.oauthClientID, "iss": current_app.config["JWT"]["ISS"], "exp": token.get_expires_at(), "iat": token.get_issued_at(), } authorization = AuthorizationServer() require_oauth = ResourceProtector() def config_oauth(app): authorization.init_app(app, query_client=query_client, save_token=save_token) authorization.register_grant(PasswordGrant) authorization.register_grant(ImplicitGrant) authorization.register_grant(RefreshTokenGrant) authorization.register_grant(ClientCredentialsGrant) authorization.register_grant( AuthorizationCodeGrant, [OpenIDCode(require_nonce=True), CodeChallenge(required=False)], ) authorization.register_grant(OpenIDImplicitGrant) authorization.register_grant(OpenIDHybridGrant) require_oauth.register_token_validator(BearerTokenValidator()) authorization.register_endpoint(IntrospectionEndpoint) authorization.register_endpoint(RevocationEndpoint)