forked from Github-Mirrors/canaille
wip
This commit is contained in:
parent
4954cfd096
commit
150c9a3b3e
7 changed files with 223 additions and 350 deletions
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
|
@ -0,0 +1,6 @@
|
|||
*.sqlite
|
||||
*.pyc
|
||||
venv/*
|
||||
.*@neomake*
|
||||
.ash_history
|
||||
.python_history
|
14
app.py
14
app.py
|
@ -1,9 +1,11 @@
|
|||
from website.app import create_app
|
||||
|
||||
|
||||
app = create_app({
|
||||
'SECRET_KEY': 'secret',
|
||||
'OAUTH2_REFRESH_TOKEN_GENERATOR': True,
|
||||
'SQLALCHEMY_TRACK_MODIFICATIONS': False,
|
||||
'SQLALCHEMY_DATABASE_URI': 'sqlite:///db.sqlite',
|
||||
})
|
||||
app = create_app(
|
||||
{
|
||||
"SECRET_KEY": "secret",
|
||||
"OAUTH2_REFRESH_TOKEN_GENERATOR": True,
|
||||
"SQLALCHEMY_TRACK_MODIFICATIONS": False,
|
||||
"SQLALCHEMY_DATABASE_URI": "sqlite:///db.sqlite",
|
||||
}
|
||||
)
|
||||
|
|
|
@ -238,6 +238,15 @@ olcAttributeTypes: ( 1.3.6.1.4.1.40805.1.1.28 NAME 'oauthRefreshToken'
|
|||
SINGLE-VALUE
|
||||
USAGE userApplications
|
||||
X-ORIGIN 'OAuth 2.0' )
|
||||
olcAttributeTypes: ( 1.3.6.1.4.1.40805.1.1.29 NAME 'oauthTokenEndpointAuthMethod'
|
||||
DESC 'OAuth 2.0 Token endpoint authentication method'
|
||||
EQUALITY caseExactMatch
|
||||
ORDERING caseExactOrderingMatch
|
||||
SUBSTR caseExactSubstringsMatch
|
||||
SYNTAX 1.3.6.1.4.1.1466.115.121.1.15
|
||||
SINGLE-VALUE
|
||||
USAGE userApplications
|
||||
X-ORIGIN 'OAuth 2.0 Dynamic Client Registration Protocol' )
|
||||
olcObjectClasses: ( 1.3.6.1.4.1.40805.1.2.1 NAME 'oauthClient'
|
||||
DESC 'OAuth 2.0 Authorization Code'
|
||||
SUP top
|
||||
|
@ -247,6 +256,7 @@ olcObjectClasses: ( 1.3.6.1.4.1.40805.1.2.1 NAME 'oauthClient'
|
|||
oauthClientName $
|
||||
oauthClientContact $
|
||||
oauthClientURI $
|
||||
oauthRedirectURI $
|
||||
oauthLogoURI $
|
||||
oauthIssueDate $
|
||||
oauthClientSecret $
|
||||
|
@ -258,6 +268,7 @@ olcObjectClasses: ( 1.3.6.1.4.1.40805.1.2.1 NAME 'oauthClient'
|
|||
oauthPolicyURI $
|
||||
oauthJWKURI $
|
||||
oauthJWK $
|
||||
oauthTokenEndpointAuthMethod $
|
||||
oauthSoftwareID $
|
||||
oauthSoftwareVersion )
|
||||
)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import ldap
|
||||
import os
|
||||
from flask import Flask, g
|
||||
#from .models import db
|
||||
|
||||
# from .models import db
|
||||
from .oauth2 import config_oauth
|
||||
from .routes import bp
|
||||
|
||||
|
@ -39,11 +40,11 @@ def setup_app(app):
|
|||
g.ldap.unbind_s()
|
||||
return response
|
||||
|
||||
# # Create tables if they do not exist already
|
||||
# @app.before_first_request
|
||||
# def create_tables():
|
||||
# db.create_all()
|
||||
#
|
||||
# db.init_app(app)
|
||||
# # Create tables if they do not exist already
|
||||
# @app.before_first_request
|
||||
# def create_tables():
|
||||
# db.create_all()
|
||||
#
|
||||
# db.init_app(app)
|
||||
config_oauth(app)
|
||||
app.register_blueprint(bp, url_prefix="")
|
||||
|
|
|
@ -1,223 +1,14 @@
|
|||
import ldap
|
||||
import time
|
||||
import datetime
|
||||
from flask import g
|
||||
from authlib.common.encoding import json_loads, json_dumps
|
||||
from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope
|
||||
from authlib.oauth2.rfc6749 import (
|
||||
ClientMixin,
|
||||
TokenMixin,
|
||||
AuthorizationCodeMixin,
|
||||
)
|
||||
#class OAuth2Client(db.Model, ClientMixin):
|
||||
# __tablename__ = 'oauth2_client'
|
||||
#
|
||||
# id = db.Column(db.Integer, primary_key=True)
|
||||
# user_id = db.Column(
|
||||
# db.Integer, db.ForeignKey('user.id', ondelete='CASCADE'))
|
||||
# user = db.relationship('User')
|
||||
#
|
||||
# client_id = db.Column(db.String(48), index=True)
|
||||
# client_secret = db.Column(db.String(120))
|
||||
# client_id_issued_at = db.Column(db.Integer, nullable=False, default=0)
|
||||
# client_secret_expires_at = db.Column(db.Integer, nullable=False, default=0)
|
||||
# _client_metadata = db.Column('client_metadata', db.Text)
|
||||
#
|
||||
# @property
|
||||
# def client_info(self):
|
||||
# """Implementation for Client Info in OAuth 2.0 Dynamic Client
|
||||
# Registration Protocol via `Section 3.2.1`_.
|
||||
# .. _`Section 3.2.1`: https://tools.ietf.org/html/rfc7591#section-3.2.1
|
||||
# """
|
||||
# return dict(
|
||||
# client_id=self.client_id,
|
||||
# client_secret=self.client_secret,
|
||||
# client_id_issued_at=self.client_id_issued_at,
|
||||
# client_secret_expires_at=self.client_secret_expires_at,
|
||||
# )
|
||||
#
|
||||
# @property
|
||||
# def client_metadata(self):
|
||||
# if 'client_metadata' in self.__dict__:
|
||||
# return self.__dict__['client_metadata']
|
||||
# if self._client_metadata:
|
||||
# data = json_loads(self._client_metadata)
|
||||
# self.__dict__['client_metadata'] = data
|
||||
# return data
|
||||
# return {}
|
||||
#
|
||||
# def set_client_metadata(self, value):
|
||||
# self._client_metadata = json_dumps(value)
|
||||
#
|
||||
# @property
|
||||
# def redirect_uris(self):
|
||||
# return self.client_metadata.get('redirect_uris', [])
|
||||
#
|
||||
# @property
|
||||
# def token_endpoint_auth_method(self):
|
||||
# return self.client_metadata.get(
|
||||
# 'token_endpoint_auth_method',
|
||||
# 'client_secret_basic'
|
||||
# )
|
||||
#
|
||||
# @property
|
||||
# def grant_types(self):
|
||||
# return self.client_metadata.get('grant_types', [])
|
||||
#
|
||||
# @property
|
||||
# def response_types(self):
|
||||
# return self.client_metadata.get('response_types', [])
|
||||
#
|
||||
# @property
|
||||
# def client_name(self):
|
||||
# return self.client_metadata.get('client_name')
|
||||
#
|
||||
# @property
|
||||
# def client_uri(self):
|
||||
# return self.client_metadata.get('client_uri')
|
||||
#
|
||||
# @property
|
||||
# def logo_uri(self):
|
||||
# return self.client_metadata.get('logo_uri')
|
||||
#
|
||||
# @property
|
||||
# def scope(self):
|
||||
# return self.client_metadata.get('scope', '')
|
||||
#
|
||||
# @property
|
||||
# def contacts(self):
|
||||
# return self.client_metadata.get('contacts', [])
|
||||
#
|
||||
# @property
|
||||
# def tos_uri(self):
|
||||
# return self.client_metadata.get('tos_uri')
|
||||
#
|
||||
# @property
|
||||
# def policy_uri(self):
|
||||
# return self.client_metadata.get('policy_uri')
|
||||
#
|
||||
# @property
|
||||
# def jwks_uri(self):
|
||||
# return self.client_metadata.get('jwks_uri')
|
||||
#
|
||||
# @property
|
||||
# def jwks(self):
|
||||
# return self.client_metadata.get('jwks', [])
|
||||
#
|
||||
# @property
|
||||
# def software_id(self):
|
||||
# return self.client_metadata.get('software_id')
|
||||
#
|
||||
# @property
|
||||
# def software_version(self):
|
||||
# return self.client_metadata.get('software_version')
|
||||
#
|
||||
# def get_client_id(self):
|
||||
# return self.client_id
|
||||
#
|
||||
# def get_default_redirect_uri(self):
|
||||
# if self.redirect_uris:
|
||||
# return self.redirect_uris[0]
|
||||
#
|
||||
# def get_allowed_scope(self, scope):
|
||||
# if not scope:
|
||||
# return ''
|
||||
# allowed = set(self.scope.split())
|
||||
# scopes = scope_to_list(scope)
|
||||
# return list_to_scope([s for s in scopes if s in allowed])
|
||||
#
|
||||
# def check_redirect_uri(self, redirect_uri):
|
||||
# return redirect_uri in self.redirect_uris
|
||||
#
|
||||
# def has_client_secret(self):
|
||||
# return bool(self.client_secret)
|
||||
#
|
||||
# def check_client_secret(self, client_secret):
|
||||
# return self.client_secret == client_secret
|
||||
#
|
||||
# def check_token_endpoint_auth_method(self, method):
|
||||
# return self.token_endpoint_auth_method == method
|
||||
#
|
||||
# def check_response_type(self, response_type):
|
||||
# return response_type in self.response_types
|
||||
#
|
||||
# def check_grant_type(self, grant_type):
|
||||
# return grant_type in self.grant_types
|
||||
#
|
||||
#
|
||||
#class OAuth2AuthorizationCode(db.Model, AuthorizationCodeMixin):
|
||||
# __tablename__ = 'oauth2_code'
|
||||
#
|
||||
# id = db.Column(db.Integer, primary_key=True)
|
||||
# user_id = db.Column(
|
||||
# db.Integer, db.ForeignKey('user.id', ondelete='CASCADE'))
|
||||
# user = db.relationship('User')
|
||||
#
|
||||
# code = db.Column(db.String(120), unique=True, nullable=False)
|
||||
# client_id = db.Column(db.String(48))
|
||||
# redirect_uri = db.Column(db.Text, default='')
|
||||
# response_type = db.Column(db.Text, default='')
|
||||
# scope = db.Column(db.Text, default='')
|
||||
# nonce = db.Column(db.Text)
|
||||
# auth_time = db.Column(
|
||||
# db.Integer, nullable=False,
|
||||
# default=lambda: int(time.time())
|
||||
# )
|
||||
#
|
||||
# code_challenge = db.Column(db.Text)
|
||||
# code_challenge_method = db.Column(db.String(48))
|
||||
#
|
||||
# def is_expired(self):
|
||||
# return self.auth_time + 300 < time.time()
|
||||
#
|
||||
# def get_redirect_uri(self):
|
||||
# return self.redirect_uri
|
||||
#
|
||||
# def get_scope(self):
|
||||
# return self.scope
|
||||
#
|
||||
# def get_auth_time(self):
|
||||
# return self.auth_time
|
||||
#
|
||||
# def get_nonce(self):
|
||||
# return self.nonce
|
||||
#
|
||||
#
|
||||
#class OAuth2Token(db.Model, TokenMixin):
|
||||
# __tablename__ = 'oauth2_token'
|
||||
#
|
||||
# id = db.Column(db.Integer, primary_key=True)
|
||||
# user_id = db.Column(
|
||||
# db.Integer, db.ForeignKey('user.id', ondelete='CASCADE'))
|
||||
# user = db.relationship('User')
|
||||
#
|
||||
# client_id = db.Column(db.String(48))
|
||||
# token_type = db.Column(db.String(40))
|
||||
# access_token = db.Column(db.String(255), unique=True, nullable=False)
|
||||
# refresh_token = db.Column(db.String(255), index=True)
|
||||
# scope = db.Column(db.Text, default='')
|
||||
# revoked = db.Column(db.Boolean, default=False)
|
||||
# issued_at = db.Column(
|
||||
# db.Integer, nullable=False, default=lambda: int(time.time())
|
||||
# )
|
||||
# expires_in = db.Column(db.Integer, nullable=False, default=0)
|
||||
#
|
||||
# def is_refresh_token_active(self):
|
||||
# if self.revoked:
|
||||
# return False
|
||||
# expires_at = self.issued_at + self.expires_in * 2
|
||||
# return expires_at >= time.time()
|
||||
#
|
||||
# def get_client_id(self):
|
||||
# return self.client_id
|
||||
#
|
||||
# def get_scope(self):
|
||||
# return self.scope
|
||||
#
|
||||
# def get_expires_in(self):
|
||||
# return self.expires_in
|
||||
#
|
||||
# def get_expires_at(self):
|
||||
# return self.issued_at + self.expires_in
|
||||
|
||||
|
||||
class LDAPObjectHelper:
|
||||
_object_class_by_name = None
|
||||
|
@ -226,7 +17,7 @@ class LDAPObjectHelper:
|
|||
base = None
|
||||
id = None
|
||||
|
||||
#TODO If ldap attribute is SINGLE-VALUE, do not bother with lists
|
||||
# TODO If ldap attribute is SINGLE-VALUE, do not bother with lists
|
||||
|
||||
def __init__(self, dn=None, **kwargs):
|
||||
self.attrs = {}
|
||||
|
@ -253,7 +44,9 @@ class LDAPObjectHelper:
|
|||
if cls._object_class_by_name:
|
||||
return cls._object_class_by_name
|
||||
|
||||
res = g.ldap.search_s("cn=subschema", ldap.SCOPE_BASE, "(objectclass=*)", ["*", "+"])
|
||||
res = g.ldap.search_s(
|
||||
"cn=subschema", ldap.SCOPE_BASE, "(objectclass=*)", ["*", "+"]
|
||||
)
|
||||
subschema_entry = res[0]
|
||||
subschema_subentry = ldap.cidict.cidict(subschema_entry[1])
|
||||
subschema = ldap.schema.SubSchema(subschema_subentry)
|
||||
|
@ -308,9 +101,7 @@ class LDAPObjectHelper:
|
|||
result = g.ldap.search_s(base or cls.base, ldap.SCOPE_SUBTREE, ldapfilter)
|
||||
|
||||
return [
|
||||
cls(
|
||||
**{k: [elt.decode("utf-8") for elt in v] for k, v in args.items()},
|
||||
)
|
||||
cls(**{k: [elt.decode("utf-8") for elt in v] for k, v in args.items()},)
|
||||
for _, args in result
|
||||
]
|
||||
|
||||
|
@ -340,7 +131,7 @@ class User(LDAPObjectHelper):
|
|||
|
||||
|
||||
class Client(LDAPObjectHelper, ClientMixin):
|
||||
objectClass = ["oauthClientIdentity"]
|
||||
objectClass = ["oauthClient"]
|
||||
base = "ou=clients,dc=mydomain,dc=tld"
|
||||
id = "oauthClientID"
|
||||
|
||||
|
@ -351,7 +142,7 @@ class Client(LDAPObjectHelper, ClientMixin):
|
|||
return self.oauthRedirectURI[0]
|
||||
|
||||
def get_allowed_scope(self, scope):
|
||||
return self.oauthScopeValue[0]
|
||||
return self.oauthScope[0]
|
||||
|
||||
def check_redirect_uri(self, redirect_uri):
|
||||
return redirect_uri in self.oauthRedirectURI
|
||||
|
@ -371,10 +162,43 @@ class Client(LDAPObjectHelper, ClientMixin):
|
|||
def check_grant_type(self, grant_type):
|
||||
return grant_type in self.oauthGrantType
|
||||
|
||||
@property
|
||||
def client_info(self):
|
||||
return dict(
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
client_id_issued_at=self.client_id_issued_at,
|
||||
client_secret_expires_at=self.client_secret_expires_at,
|
||||
)
|
||||
|
||||
@property
|
||||
def client_metadata(self):
|
||||
if "client_metadata" in self.__dict__:
|
||||
return self.__dict__["client_metadata"]
|
||||
if self._client_metadata:
|
||||
data = json_loads(self._client_metadata)
|
||||
self.__dict__["client_metadata"] = data
|
||||
return data
|
||||
return {}
|
||||
|
||||
def set_client_metadata(self, value):
|
||||
self._client_metadata = json_dumps(value)
|
||||
|
||||
@property
|
||||
def redirect_uris(self):
|
||||
return self.client_metadata.get("redirect_uris", [])
|
||||
|
||||
@property
|
||||
def token_endpoint_auth_method(self):
|
||||
return self.client_metadata.get(
|
||||
"token_endpoint_auth_method", "client_secret_basic"
|
||||
)
|
||||
|
||||
|
||||
class AuthorizationCode(LDAPObjectHelper, AuthorizationCodeMixin):
|
||||
objectClass = ["oauth2Authz"]
|
||||
objectClass = ["oauth2AuthorizationCode"]
|
||||
base = "ou=authorizations,dc=mydomain,dc=tld"
|
||||
id = "oauthCode"
|
||||
|
||||
def get_redirect_uri(self):
|
||||
return Client.get(self.authzClientID[0]).oauthRedirectURI[0]
|
||||
|
@ -382,11 +206,26 @@ class AuthorizationCode(LDAPObjectHelper, AuthorizationCodeMixin):
|
|||
def get_scope(self):
|
||||
return self.oauth2ScopeValue[0]
|
||||
|
||||
def is_refresh_token_active(self):
|
||||
if self.revoked:
|
||||
return False
|
||||
expires_at = self.issued_at + self.expires_in * 2
|
||||
return expires_at >= time.time()
|
||||
|
||||
def get_client_id(self):
|
||||
return self.client_id
|
||||
|
||||
def get_expires_in(self):
|
||||
return self.expires_in
|
||||
|
||||
def get_expires_at(self):
|
||||
return self.issued_at + self.expires_in
|
||||
|
||||
|
||||
class Token(LDAPObjectHelper, TokenMixin):
|
||||
objectClass = ["oauth2IdAccessToken", "oauth2AuthzAux"]
|
||||
objectClass = ["oauth2Token"]
|
||||
base = "ou=tokens,dc=mydomain,dc=tld"
|
||||
id = "authzAccessToken"
|
||||
id = "oauthToken"
|
||||
|
||||
def get_client_id(self):
|
||||
return self.authzClientID[0]
|
||||
|
@ -398,6 +237,14 @@ class Token(LDAPObjectHelper, TokenMixin):
|
|||
return int(self.authzAccessTokenLifetime[0])
|
||||
|
||||
def get_expires_at(self):
|
||||
issue_date = datetime.datetime.strptime(self.authzAccessTokenIssueDate[0], "%Y%m%d%H%M%SZ")
|
||||
issue_date = datetime.datetime.strptime(
|
||||
self.authzAccessTokenIssueDate[0], "%Y%m%d%H%M%SZ"
|
||||
)
|
||||
issue_timestamp = (issue_date - datetime.datetime(1970, 1, 1)).total_seconds()
|
||||
return issue_timestamp + int(self.authzAccessTokenLifetime[0])
|
||||
|
||||
def is_refresh_token_active(self):
|
||||
if self.revoked:
|
||||
return False
|
||||
expires_at = self.issued_at + self.expires_in * 2
|
||||
return expires_at >= time.time()
|
||||
|
|
|
@ -1,66 +1,90 @@
|
|||
import datetime
|
||||
from authlib.integrations.flask_oauth2 import (
|
||||
AuthorizationServer,
|
||||
ResourceProtector,
|
||||
from authlib.integrations.flask_oauth2 import AuthorizationServer, ResourceProtector
|
||||
from authlib.oauth2.rfc6749.grants import (
|
||||
AuthorizationCodeGrant as _AuthorizationCodeGrant,
|
||||
ResourceOwnerPasswordCredentialsGrant as _ResourceOwnerPasswordCredentialsGrant,
|
||||
RefreshTokenGrant as _RefreshTokenGrant,
|
||||
)
|
||||
from authlib.oauth2.rfc6749 import grants, util
|
||||
from authlib.oauth2.rfc6750 import BearerTokenValidator
|
||||
from authlib.oauth2.rfc7009 import RevocationEndpoint
|
||||
from authlib.oauth2.rfc7636 import CodeChallenge
|
||||
from .models import User, Client, Authorization, Token
|
||||
from authlib.oauth2.rfc6750 import BearerTokenValidator as _BearerTokenValidator
|
||||
from authlib.oidc.core.grants import (
|
||||
OpenIDCode as _OpenIDCode,
|
||||
OpenIDImplicitGrant as _OpenIDImplicitGrant,
|
||||
OpenIDHybridGrant as _OpenIDHybridGrant,
|
||||
)
|
||||
from authlib.oidc.core import UserInfo
|
||||
from werkzeug.security import gen_salt
|
||||
from .models import Client, AuthorizationCode, Token, User
|
||||
|
||||
DUMMY_JWT_CONFIG = {
|
||||
"key": "secret-key",
|
||||
"alg": "HS256",
|
||||
"iss": "https://authlib.org",
|
||||
"exp": 3600,
|
||||
}
|
||||
|
||||
|
||||
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
||||
TOKEN_ENDPOINT_AUTH_METHODS = [
|
||||
"client_secret_basic",
|
||||
"client_secret_post",
|
||||
"none",
|
||||
]
|
||||
def exists_nonce(nonce, req):
|
||||
exists = AuthorizationCode.query.filter_by(
|
||||
client_id=req.client_id, nonce=nonce
|
||||
).first()
|
||||
return bool(exists)
|
||||
|
||||
def save_authorization_code(self, code, request):
|
||||
raise NotImplementedError()
|
||||
code_challenge = request.data.get("code_challenge")
|
||||
code_challenge_method = request.data.get("code_challenge_method")
|
||||
auth_code = Authorization(
|
||||
code=code,
|
||||
client_id=request.client.client_id,
|
||||
redirect_uri=request.redirect_uri,
|
||||
scope=request.scope,
|
||||
user_id=request.user.id,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
)
|
||||
#db.session.add(auth_code)
|
||||
#db.session.commit()
|
||||
return auth_code
|
||||
|
||||
def query_authorization_code(self, code, client):
|
||||
raise NotImplementedError()
|
||||
auth_code = Authorization.query.filter_by(
|
||||
def generate_user_info(user, scope):
|
||||
return UserInfo(sub=str(user.id), name=user.username)
|
||||
|
||||
|
||||
def create_authorization_code(client, grant_user, request):
|
||||
raise NotImplementedError()
|
||||
code = gen_salt(48)
|
||||
nonce = request.data.get("nonce")
|
||||
item = AuthorizationCode(
|
||||
code=code,
|
||||
client_id=client.client_id,
|
||||
redirect_uri=request.redirect_uri,
|
||||
scope=request.scope,
|
||||
user_id=grant_user.id,
|
||||
nonce=nonce,
|
||||
)
|
||||
return code
|
||||
|
||||
|
||||
class AuthorizationCodeGrant(_AuthorizationCodeGrant):
|
||||
def create_authorization_code(self, client, grant_user, request):
|
||||
return create_authorization_code(client, grant_user, request)
|
||||
|
||||
def parse_authorization_code(self, code, client):
|
||||
item = AuthorizationCode.query.filter_by(
|
||||
code=code, client_id=client.client_id
|
||||
).first()
|
||||
if auth_code and not auth_code.is_expired():
|
||||
return auth_code
|
||||
if item and not item.is_expired():
|
||||
return item
|
||||
|
||||
def delete_authorization_code(self, authorization_code):
|
||||
raise NotImplementedError()
|
||||
pass
|
||||
#db.session.delete(authorization_code)
|
||||
#db.session.commit()
|
||||
|
||||
def authenticate_user(self, authorization_code):
|
||||
raise NotImplementedError()
|
||||
return User.query.get(authorization_code.user_id)
|
||||
|
||||
|
||||
class PasswordGrant(grants.ResourceOwnerPasswordCredentialsGrant):
|
||||
class OpenIDCode(_OpenIDCode):
|
||||
def exists_nonce(self, nonce, request):
|
||||
return exists_nonce(nonce, request)
|
||||
|
||||
def get_jwt_config(self, grant):
|
||||
return DUMMY_JWT_CONFIG
|
||||
|
||||
def generate_user_info(self, user, scope):
|
||||
return generate_user_info(user, scope)
|
||||
|
||||
|
||||
class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant):
|
||||
def authenticate_user(self, username, password):
|
||||
user = User.get(username)
|
||||
if user is not None and user.check_password(password):
|
||||
return user
|
||||
|
||||
|
||||
class RefreshTokenGrant(grants.RefreshTokenGrant):
|
||||
class RefreshTokenGrant(_RefreshTokenGrant):
|
||||
def authenticate_refresh_token(self, refresh_token):
|
||||
raise NotImplementedError()
|
||||
token = Token.query.filter_by(refresh_token=refresh_token).first()
|
||||
|
@ -74,8 +98,30 @@ class RefreshTokenGrant(grants.RefreshTokenGrant):
|
|||
def revoke_old_credential(self, credential):
|
||||
raise NotImplementedError()
|
||||
credential.revoked = True
|
||||
#db.session.add(credential)
|
||||
#db.session.commit()
|
||||
|
||||
class ImplicitGrant(_OpenIDImplicitGrant):
|
||||
def exists_nonce(self, nonce, request):
|
||||
return exists_nonce(nonce, request)
|
||||
|
||||
def get_jwt_config(self, grant):
|
||||
return DUMMY_JWT_CONFIG
|
||||
|
||||
def generate_user_info(self, user, scope):
|
||||
return generate_user_info(user, scope)
|
||||
|
||||
|
||||
class HybridGrant(_OpenIDHybridGrant):
|
||||
def create_authorization_code(self, client, grant_user, request):
|
||||
return create_authorization_code(client, grant_user, request)
|
||||
|
||||
def exists_nonce(self, nonce, request):
|
||||
return exists_nonce(nonce, request)
|
||||
|
||||
def get_jwt_config(self):
|
||||
return DUMMY_JWT_CONFIG
|
||||
|
||||
def generate_user_info(self, user, scope):
|
||||
return generate_user_info(user, scope)
|
||||
|
||||
|
||||
def query_client(client_id):
|
||||
|
@ -83,28 +129,10 @@ def query_client(client_id):
|
|||
|
||||
|
||||
def save_token(token, request):
|
||||
client_id, client_secret = util.extract_basic_authorization(request.headers)
|
||||
t = Token(
|
||||
authzAccessToken=token['access_token'],
|
||||
authzScopeValue=token['scope'],
|
||||
authzAccessTokenIssueDate=datetime.datetime.now().strftime("%Y%m%d%H%M%SZ"),
|
||||
authzSubject=request.user.dn,
|
||||
authzClientID=client_id,
|
||||
authzRefreshTokenSecret=token['refresh_token'],
|
||||
authzAccessTokenLifetime=str(token['expires_in']),
|
||||
# ??? = token['type']
|
||||
)
|
||||
t.save()
|
||||
return t
|
||||
raise NotImplementedError()
|
||||
|
||||
class RevocationEndpoint(RevocationEndpoint):
|
||||
def query_token(self, token, token_type_hint, client):
|
||||
raise NotImplementedError()
|
||||
|
||||
def revoke_token(self, token):
|
||||
raise NotImplementedError()
|
||||
|
||||
class BearerTokenValidator(BearerTokenValidator):
|
||||
class BearerTokenValidator(_BearerTokenValidator):
|
||||
def authenticate_token(self, token_string):
|
||||
return Token.get(token_string)
|
||||
|
||||
|
@ -114,20 +142,19 @@ class BearerTokenValidator(BearerTokenValidator):
|
|||
def token_revoked(self, token):
|
||||
return False
|
||||
|
||||
authorization = AuthorizationServer(query_client=query_client, save_token=save_token)
|
||||
|
||||
authorization = AuthorizationServer()
|
||||
require_oauth = ResourceProtector()
|
||||
|
||||
|
||||
def config_oauth(app):
|
||||
authorization.init_app(app)
|
||||
authorization.init_app(app, query_client=query_client, save_token=save_token)
|
||||
|
||||
# support all grants
|
||||
authorization.register_grant(grants.ImplicitGrant)
|
||||
authorization.register_grant(grants.ClientCredentialsGrant)
|
||||
authorization.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=True)])
|
||||
authorization.register_grant(
|
||||
AuthorizationCodeGrant, [OpenIDCode(require_nonce=True)]
|
||||
)
|
||||
authorization.register_grant(ImplicitGrant)
|
||||
authorization.register_grant(HybridGrant)
|
||||
authorization.register_grant(PasswordGrant)
|
||||
authorization.register_grant(RefreshTokenGrant)
|
||||
|
||||
authorization.register_endpoint(RevocationEndpoint)
|
||||
|
||||
require_oauth.register_token_validator(BearerTokenValidator())
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import datetime
|
||||
from flask import Blueprint, request, session, url_for
|
||||
from flask import Blueprint, request, session
|
||||
from flask import render_template, redirect, jsonify
|
||||
from werkzeug.security import gen_salt
|
||||
from authlib.integrations.flask_oauth2 import current_token
|
||||
from authlib.oauth2 import OAuth2Error
|
||||
from .models import User, Client
|
||||
from .oauth2 import authorization, require_oauth
|
||||
|
@ -13,36 +12,31 @@ bp = Blueprint(__name__, 'home')
|
|||
|
||||
def current_user():
|
||||
if 'user_dn' in session:
|
||||
return User.get(session['user_dn'])
|
||||
return User.get(session["user_dn"])
|
||||
return None
|
||||
|
||||
|
||||
def split_by_crlf(s):
|
||||
return [v for v in s.splitlines() if v]
|
||||
|
||||
|
||||
@bp.route('/', methods=('GET', 'POST'))
|
||||
def home():
|
||||
if request.method == 'POST':
|
||||
username = request.form.get('username')
|
||||
user = User.filter(cn=username)
|
||||
user = User.get(username)
|
||||
|
||||
if not user:
|
||||
user = User(cn=username, sn=username)
|
||||
user.save()
|
||||
else:
|
||||
user = user[0]
|
||||
session["user_dn"] = user.dn
|
||||
return redirect('/')
|
||||
|
||||
user = current_user()
|
||||
clients = Client.filter()
|
||||
if user:
|
||||
clients = Client.filter()
|
||||
else:
|
||||
clients = []
|
||||
return render_template('home.html', user=user, clients=clients)
|
||||
|
||||
|
||||
@bp.route('/logout')
|
||||
def logout():
|
||||
del session['id']
|
||||
return redirect('/')
|
||||
def split_by_crlf(s):
|
||||
return [v for v in s.splitlines() if v]
|
||||
|
||||
|
||||
@bp.route('/create_client', methods=('GET', 'POST'))
|
||||
|
@ -52,69 +46,54 @@ def create_client():
|
|||
return redirect('/')
|
||||
if request.method == 'GET':
|
||||
return render_template('create_client.html')
|
||||
|
||||
form = request.form
|
||||
client_id = gen_salt(24)
|
||||
client_id_issued_at = datetime.datetime.now().strftime("%Y%m%d%H%M%SZ")
|
||||
client = Client(
|
||||
oauthClientID=client_id,
|
||||
oauthClientIDIssueTime=client_id_issued_at,
|
||||
oauthIssueDate=client_id_issued_at,
|
||||
oauthClientName=form["client_name"],
|
||||
oauthClientURI=form["client_uri"],
|
||||
oauthGrantType=split_by_crlf(form["grant_type"]),
|
||||
oauthRedirectURI=split_by_crlf(form["redirect_uri"]),
|
||||
oauthResponseType=split_by_crlf(form["response_type"]),
|
||||
oauthScopeValue=form["scope"],
|
||||
oauthTokenEndpointAuthMethod=form["token_endpoint_auth_method"]
|
||||
oauthScope=form["scope"],
|
||||
oauthTokenEndpointAuthMethod=form["token_endpoint_auth_method"],
|
||||
oauthClientSecret='' if form['token_endpoint_auth_method'] == 'none' else gen_salt(48),
|
||||
)
|
||||
|
||||
if form['token_endpoint_auth_method'] == 'none':
|
||||
client.oauthClientSecret = ''
|
||||
else:
|
||||
client.oauthClientSecret = gen_salt(48)
|
||||
|
||||
client.save()
|
||||
|
||||
#db.session.add(client)
|
||||
#db.session.commit()
|
||||
return redirect('/')
|
||||
|
||||
|
||||
@bp.route('/oauth/authorize', methods=['GET', 'POST'])
|
||||
def authorize():
|
||||
user = current_user()
|
||||
# if user log status is not true (Auth server), then to log it in
|
||||
if not user:
|
||||
return redirect(url_for('website.routes.home', next=request.url))
|
||||
if request.method == 'GET':
|
||||
try:
|
||||
grant = authorization.validate_consent_request(end_user=user)
|
||||
except OAuth2Error as error:
|
||||
return error.error
|
||||
return jsonify(dict(error.get_body()))
|
||||
return render_template('authorize.html', user=user, grant=grant)
|
||||
if not user and 'username' in request.form:
|
||||
username = request.form.get('username')
|
||||
user = User.query.filter_by(username=username).first()
|
||||
user = User.get(username)
|
||||
if request.form['confirm']:
|
||||
grant_user = user
|
||||
else:
|
||||
grant_user = None
|
||||
return authorization.create_authorization_response(grant_user=grant_user)
|
||||
|
||||
@bp.route('/logout')
|
||||
def logout():
|
||||
del session['user_dn']
|
||||
return redirect('/')
|
||||
|
||||
@bp.route('/oauth/token', methods=['POST'])
|
||||
def issue_token():
|
||||
return authorization.create_token_response()
|
||||
|
||||
|
||||
@bp.route('/oauth/revoke', methods=['POST'])
|
||||
def revoke_token():
|
||||
return authorization.create_endpoint_response('revocation')
|
||||
|
||||
|
||||
@bp.route('/api/me')
|
||||
@require_oauth('profile')
|
||||
def api_me():
|
||||
user_dn = current_token.authzSubject[0]
|
||||
user = User.get(user_dn)
|
||||
return jsonify(id=user.cn, name=user.sn)
|
||||
return jsonify(foo="bar")
|
||||
|
|
Loading…
Reference in a new issue