LDAPObject dn attributes are automatically initialized

This commit is contained in:
Éloi Rivard 2023-03-08 23:53:53 +01:00
parent d201d6f617
commit 53581404ab
26 changed files with 238 additions and 195 deletions

View file

@ -32,7 +32,7 @@ def create_group(user):
flash(_("Group creation failed."), "error")
else:
group = Group()
group.member = [user.dn]
group.member = [user]
group.cn = [form.name.data]
group.description = [form.description.data]
group.save()

View file

@ -155,7 +155,7 @@ def validate_configuration(config):
group = Group(
cn=f"canaille_{uuid.uuid4()}",
member=[user.dn],
member=[user],
)
group.save(conn)
group.delete(conn)

View file

@ -6,7 +6,24 @@ from .utils import ldap_to_python
from .utils import python_to_ldap
class LDAPObject:
class LDAPObjectMetaclass(type):
ldap_to_python_class = {}
def __new__(cls, name, bases, attrs):
klass = super().__new__(cls, name, bases, attrs)
if attrs.get("object_class"):
for oc in attrs["object_class"]:
cls.ldap_to_python_class[oc] = klass
return klass
def __setattr__(cls, name, value):
super().__setattr__(name, value)
if name == "object_class":
for oc in value:
cls.ldap_to_python_class[oc] = cls
class LDAPObject(metaclass=LDAPObjectMetaclass):
_object_class_by_name = None
_attribute_type_by_name = None
_may = None
@ -22,23 +39,44 @@ class LDAPObject:
self.changes = {}
kwargs.setdefault("objectClass", self.object_class)
for name, value in kwargs.items():
setattr(self, name, value)
def __repr__(self):
return f"<{self.__class__.__name__} {self.rdn_attribute}={self.rdn_value}>"
return (
f"<{self.__class__.__name__} {self.rdn_attribute}={self.rdn_value}>"
if self.rdn_attribute
else "<LDAPOBject>"
)
def __eq__(self, other):
return (
if not (
isinstance(other, self.__class__)
and self.may() == other.may()
and self.must() == other.must()
and all(
getattr(self, attr) == getattr(other, attr)
hasattr(self, attr) == hasattr(other, attr)
for attr in self.may() + self.must()
if hasattr(self, attr) and hasattr(other, attr)
)
):
return False
self_attributes = self.python_attrs_to_ldap(
{
attr: getattr(self, attr)
for attr in self.may() + self.must()
if hasattr(self, attr)
}
)
other_attributes = other.python_attrs_to_ldap(
{
attr: getattr(other, attr)
for attr in self.may() + self.must()
if hasattr(self, attr)
}
)
return self_attributes == other_attributes
def __hash__(self):
return hash(self.dn)
@ -59,8 +97,7 @@ class LDAPObject:
# Lazy conversion from ldap format to python format
if any(isinstance(value, bytes) for value in self.attrs[name]):
ldap_attrs = LDAPObject.ldap_object_attributes()
syntax = ldap_attrs[name].syntax if name in ldap_attrs else None
syntax = self.attribute_ldap_syntax(name)
self.attrs[name] = [
ldap_to_python(value, syntax) for value in self.attrs[name]
]
@ -102,8 +139,6 @@ class LDAPObject:
return self._may
def must(self):
if not self._must:
self.update_ldap_attributes()
return self._must
@classmethod
@ -178,7 +213,6 @@ class LDAPObject:
@classmethod
def python_attrs_to_ldap(cls, attrs, encode=True):
ldap_attrs = LDAPObject.ldap_object_attributes()
if cls.attribute_table:
attrs = {
cls.attribute_table.get(name, name): values
@ -186,16 +220,23 @@ class LDAPObject:
}
return {
name: [
python_to_ldap(
value,
ldap_attrs[name].syntax if name in ldap_attrs else None,
encode=encode,
)
python_to_ldap(value, cls.attribute_ldap_syntax(name), encode=encode)
for value in (values if isinstance(values, list) else [values])
]
for name, values in attrs.items()
}
@classmethod
def attribute_ldap_syntax(cls, attribute_name):
ldap_attrs = LDAPObject.ldap_object_attributes()
if attribute_name not in ldap_attrs:
return None
if ldap_attrs[attribute_name].syntax:
return ldap_attrs[attribute_name].syntax
return cls.attribute_ldap_syntax(ldap_attrs[attribute_name].sup[0])
@classmethod
def get(cls, dn=None, filter=None, conn=None, **kwargs):
try:
@ -241,11 +282,20 @@ class LDAPObject:
objects = []
for _, args in result:
cls = cls.guess_class(args["objectClass"])
obj = cls()
obj.attrs = args
objects.append(obj)
return objects
@classmethod
def guess_class(cls, object_classes):
if cls == LDAPObject:
for oc in object_classes:
if oc.decode() in LDAPObjectMetaclass.ldap_to_python_class:
return LDAPObjectMetaclass.ldap_to_python_class[oc.decode()]
return cls
@classmethod
def update_ldap_attributes(cls):
all_object_classes = cls.ldap_object_classes()

View file

@ -6,7 +6,9 @@ LDAP_NULL_DATE = "000001010000Z"
class Syntax(str, Enum):
# fmt: off
BINARY = "1.3.6.1.4.1.1466.115.121.1.5"
BOOLEAN = "1.3.6.1.4.1.1466.115.121.1.7"
DISTINGUISHED_NAME = "1.3.6.1.4.1.1466.115.121.1.12"
DIRECTORY_STRING = "1.3.6.1.4.1.1466.115.121.1.15"
FAX_IMAGE = "1.3.6.1.4.1.1466.115.121.1.23"
GENERALIZED_TIME = "1.3.6.1.4.1.1466.115.121.1.24"
@ -22,6 +24,8 @@ class Syntax(str, Enum):
def ldap_to_python(value, syntax):
from .ldapobject import LDAPObject
if syntax == Syntax.GENERALIZED_TIME:
value = value.decode("utf-8")
if value == LDAP_NULL_DATE:
@ -39,6 +43,9 @@ def ldap_to_python(value, syntax):
if syntax == Syntax.BOOLEAN:
return value.decode("utf-8").upper() == "TRUE"
if syntax == Syntax.DISTINGUISHED_NAME:
return LDAPObject.get(dn=value.decode("utf-8"))
return value.decode("utf-8")
@ -59,6 +66,9 @@ def python_to_ldap(value, syntax, encode=True):
if syntax == Syntax.BOOLEAN and isinstance(value, bool):
value = "TRUE" if value else "FALSE"
if syntax == Syntax.DISTINGUISHED_NAME:
value = value.dn
if not value:
return None

View file

@ -190,16 +190,12 @@ class Group(LDAPObject):
return self[attribute][0]
def get_members(self, conn=None):
return [
User.get(dn=user_id, conn=conn)
for user_id in self.member
if User.get(dn=user_id, conn=conn)
]
return [member for member in self.member if member]
def add_member(self, user):
self.member = self.member + [user.dn]
self.member = self.member + [user]
self.save()
def remove_member(self, user):
self.member = [m for m in self.member if m != user.dn]
self.member = [m for m in self.member if m != user]
self.save()

View file

@ -71,7 +71,7 @@ def add(user):
if form["token_endpoint_auth_method"].data == "none"
else gen_salt(48),
)
client.audience = [client.dn]
client.audience = [client]
client.save()
flash(
_("The client has been created."),
@ -142,7 +142,7 @@ def client_edit(client_id):
software_version=form["software_version"].data,
jwk=form["jwk"].data,
jwks_uri=form["jwks_uri"].data,
audience=form["audience"].data,
audience=[Client.get(dn=dn) for dn in form["audience"].data],
preconsent=form["preconsent"].data,
)
client.save()

View file

@ -20,19 +20,17 @@ bp = Blueprint("consents", __name__, url_prefix="/consent")
@bp.route("/")
@user_needed()
def consents(user):
consents = Consent.query(subject=user.dn)
client_dns = {t.client for t in consents}
clients = {dn: Client.get(dn) for dn in client_dns}
consents = Consent.query(subject=user)
clients = {t.client for t in consents}
preconsented = [
client
for client in Client.query()
if client.preconsent and client.dn not in clients
if client.preconsent and client not in clients
]
return render_template(
"oidc/user/consent_list.html",
consents=consents,
clients=clients,
menuitem="consents",
scope_details=SCOPE_DETAILS,
ignored_scopes=["openid"],
@ -45,7 +43,7 @@ def consents(user):
def revoke(user, consent_id):
consent = Consent.get(consent_id)
if not consent or consent.subject != user.dn:
if not consent or consent.subject != user:
flash(_("Could not revoke this access"), "error")
elif consent.revokation_date:
@ -63,7 +61,7 @@ def revoke(user, consent_id):
def restore(user, consent_id):
consent = Consent.get(consent_id)
if not consent or consent.subject != user.dn:
if not consent or consent.subject != user:
flash(_("Could not restore this access"), "error")
elif not consent.revokation_date:
@ -88,14 +86,14 @@ def revoke_preconsent(user, client_id):
flash(_("Could not revoke this access"), "error")
return redirect(url_for("oidc.consents.consents"))
consent = Consent.get(client=client.dn, subject=user.dn)
consent = Consent.get(client=client, subject=user)
if consent:
return redirect(url_for("oidc.consents.revoke", consent_id=consent.cn[0]))
consent = Consent(
cn=str(uuid.uuid4()),
client=client.dn,
subject=user.dn,
client=client,
subject=user,
scope=client.scope,
)
consent.revoke()

View file

@ -90,8 +90,8 @@ def authorize():
# CONSENT
consents = Consent.query(
client=client.dn,
subject=user.dn,
client=client,
subject=user,
)
consent = consents[0] if consents else None
@ -101,7 +101,7 @@ def authorize():
or (consent and all(scope in set(consent.scope) for scope in scopes))
and not consent.revoked
):
return authorization.create_authorization_response(grant_user=user.dn)
return authorization.create_authorization_response(grant_user=user)
elif request.args.get("prompt") == "none":
response = {"error": "consent_required"}
@ -135,7 +135,7 @@ def authorize():
grant_user = None
if request.form["answer"] == "accept":
grant_user = user.dn
grant_user = user
if consent:
if consent.revoked:
@ -146,8 +146,8 @@ def authorize():
else:
consent = Consent(
cn=str(uuid.uuid4()),
client=client.dn,
subject=user.dn,
client=client,
subject=user,
scope=scopes,
issue_date=datetime.datetime.now(),
)

View file

@ -96,13 +96,13 @@ class Client(LDAPObject, ClientMixin):
return metadata
def delete(self):
for consent in Consent.query(client=self.dn):
for consent in Consent.query(client=self):
consent.delete()
for code in AuthorizationCode.query(client=self.dn):
for code in AuthorizationCode.query(client=self):
code.delete()
for token in Token.query(client=self.dn):
for token in Token.query(client=self):
token.delete()
super().delete()
@ -206,7 +206,7 @@ class Token(LDAPObject, TokenMixin):
return bool(self.revokation_date)
def check_client(self, client):
return client.client_id == Client.get(self.client).client_id
return client.client_id == self.client.client_id
class Consent(LDAPObject):

View file

@ -43,7 +43,8 @@ AUTHORIZATION_CODE_LIFETIME = 84400
def exists_nonce(nonce, req):
exists = AuthorizationCode.query(client=req.client_id, nonce=nonce)
client = Client.get(dn=req.client_id)
exists = AuthorizationCode.query(client=client, nonce=nonce)
return bool(exists)
@ -98,7 +99,6 @@ def claims_from_scope(scope):
def generate_user_info(user, scope):
user = User.get(dn=user)
claims = claims_from_scope(scope)
data = generate_user_claims(user, claims)
return UserInfo(**data)
@ -131,7 +131,7 @@ def save_authorization_code(code, request):
authorization_code_id=gen_salt(48),
code=code,
subject=request.user,
client=request.client.dn,
client=request.client,
redirect_uri=request.redirect_uri or request.client.redirect_uris[0],
scope=scope,
nonce=nonce,
@ -151,7 +151,7 @@ class AuthorizationCodeGrant(_AuthorizationCodeGrant):
return save_authorization_code(code, request)
def query_authorization_code(self, code, client):
item = AuthorizationCode.query(code=code, client=client.dn)
item = AuthorizationCode.query(code=code, client=client)
if item and not item[0].is_expired():
return item[0]
@ -159,9 +159,7 @@ class AuthorizationCodeGrant(_AuthorizationCodeGrant):
authorization_code.delete()
def authenticate_user(self, authorization_code):
user = User.get(dn=authorization_code.subject)
if user:
return user.dn
return authorization_code.subject
class OpenIDCode(_OpenIDCode):
@ -176,16 +174,14 @@ class OpenIDCode(_OpenIDCode):
def get_audiences(self, request):
client = request.client
return [Client.get(aud).client_id for aud in client.audience]
return [aud.client_id for aud in client.audience]
class PasswordGrant(_ResourceOwnerPasswordCredentialsGrant):
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
def authenticate_user(self, username, password):
user = User.authenticate(username, password)
if user:
return user.dn
return User.authenticate(username, password)
class RefreshTokenGrant(_RefreshTokenGrant):
@ -195,9 +191,7 @@ class RefreshTokenGrant(_RefreshTokenGrant):
return token[0]
def authenticate_user(self, credential):
user = User.get(dn=credential.subject)
if user:
return user.dn
return credential.subject
def revoke_old_credential(self, credential):
credential.revokation_date = datetime.datetime.now()
@ -216,7 +210,7 @@ class OpenIDImplicitGrant(_OpenIDImplicitGrant):
def get_audiences(self, request):
client = request.client
return [Client.get(aud).client_id for aud in client.audience]
return [aud.client_id for aud in client.audience]
class OpenIDHybridGrant(_OpenIDHybridGrant):
@ -234,7 +228,7 @@ class OpenIDHybridGrant(_OpenIDHybridGrant):
def get_audiences(self, request):
client = request.client
return [Client.get(aud).client_id for aud in client.audience]
return [aud.client_id for aud in client.audience]
def query_client(client_id):
@ -250,7 +244,7 @@ def save_token(token, request):
issue_date=now,
lifetime=token["expires_in"],
scope=token["scope"],
client=request.client.dn,
client=request.client,
refresh_token=token.get("refresh_token"),
subject=request.user,
audience=request.client.audience,
@ -294,19 +288,17 @@ class IntrospectionEndpoint(_IntrospectionEndpoint):
return query_token(token, token_type_hint)
def check_permission(self, token, client, request):
return client.dn in token.audience
return client in token.audience
def introspect_token(self, token):
client_id = Client.get(token.client).client_id
user = User.get(dn=token.subject)
audience = [Client.get(aud).client_id for aud in token.audience]
audience = [aud.client_id for aud in token.audience]
return {
"active": True,
"client_id": client_id,
"client_id": token.client.client_id,
"token_type": token.type,
"username": user.name,
"username": token.subject.name,
"scope": token.get_scope(),
"sub": user.uid[0],
"sub": token.subject.uid[0],
"aud": audience,
"iss": get_issuer(),
"exp": token.get_expires_at(),
@ -402,7 +394,7 @@ require_oauth = ResourceProtector()
def generate_access_token(client, grant_type, user, scope):
audience = [Client.get(dn).client_id for dn in client.audience]
audience = [client.client_id for client in client.audience]
bearer_token_generator = authorization._token_generators["default"]
kwargs = {
"token": {},

View file

@ -20,11 +20,9 @@ bp = Blueprint("tokens", __name__, url_prefix="/admin/token")
@permissions_needed("manage_oidc")
def index(user):
tokens = Token.query()
items = (
(token, Client.get(token.client), User.get(dn=token.subject))
for token in tokens
return render_template(
"oidc/admin/token_list.html", tokens=tokens, menuitem="admin"
)
return render_template("oidc/admin/token_list.html", items=items, menuitem="admin")
@bp.route("/<token_id>", methods=["GET", "POST"])
@ -35,15 +33,9 @@ def view(user, token_id):
if not token:
abort(404)
token_client = Client.get(token.client)
token_user = User.get(dn=token.subject)
token_audience = [Client.get(aud) for aud in token.audience]
return render_template(
"oidc/admin/token_view.html",
token=token,
token_client=token_client,
token_user=token_user,
token_audience=token_audience,
menuitem="admin",
)

View file

@ -45,7 +45,7 @@ def fake_groups(nb=1, nb_users_max=1):
description=fake.sentence(),
)
nb_users = random.randrange(1, nb_users_max + 1)
group.member = list({random.choice(users).dn for _ in range(nb_users)})
group.member = list({random.choice(users) for _ in range(nb_users)})
group.save()
groups.append(group)
return groups

View file

@ -20,7 +20,7 @@
<th>{% trans %}Subject{% endtrans %}</th>
<th>{% trans %}Created{% endtrans %}</th>
</thead>
{% for token, client, user in items %}
{% for token in tokens %}
<tr>
<td>
<a href="{{ url_for('oidc.tokens.view', token_id=token.token_id) }}">
@ -28,13 +28,13 @@
</a>
</td>
<td>
<a href="{{ url_for('oidc.clients.edit', client_id=client.client_id) }}">
{{ client.client_name }}
<a href="{{ url_for('oidc.clients.edit', client_id=token.client.client_id) }}">
{{ token.client.client_name }}
</a>
</td>
<td>
<a href="{{ url_for("account.profile_edition", username=user.uid[0]) }}">
{{ user.uid[0] }}
<a href="{{ url_for("account.profile_edition", username=token.subject.uid[0]) }}">
{{ token.subject.uid[0] }}
</a>
</td>
<td data-order="{{ token.issue_date|timestamp }}">{{ token.issue_date }}</td>

View file

@ -33,16 +33,16 @@
<tr>
<td>{{ _("Client") }}</td>
<td>
<a href="{{ url_for("oidc.clients.edit", client_id=token_client.client_id) }}">
{{ token_client.client_name }}
<a href="{{ url_for("oidc.clients.edit", client_id=token.client.client_id) }}">
{{ token.client.client_name }}
</a>
</td>
</tr>
<tr>
<td>{{ _("Subject") }}</td>
<td>
<a href="{{ url_for("account.profile_edition", username=token_user.uid[0]) }}">
{{ token_user.name }} - {{ token_user.uid[0] }}
<a href="{{ url_for("account.profile_edition", username=token.subject.uid[0]) }}">
{{ token.subject.name }} - {{ token.subject.uid[0] }}
</a>
</td>
</tr>
@ -60,7 +60,7 @@
<td>{{ _("Audience") }}</td>
<td>
<ul class="ui list">
{% for client in token_audience %}
{% for client in token.audience %}
<li class="item">
<a href="{{ url_for("oidc.clients.edit", client_id=client.dn) }}">
{{ client.client_name }}

View file

@ -25,16 +25,15 @@
{% if consents %}
<div class="ui centered cards">
{% for consent in consents %}
{% set client = clients[consent.client] %}
<div class="ui card">
<div class="content">
{% if client.logo_uri %}
<img class="right floated mini ui image" src="{{ client.logo_uri }}">
{% if consent.client.logo_uri %}
<img class="right floated mini ui image" src="{{ consent.client.logo_uri }}">
{% endif %}
{% if client.client_uri %}
<a href="{{ client.client_uri }}" class="header">{{ client.client_name }}</a>
{% if consent.client.client_uri %}
<a href="{{ consent.client.client_uri }}" class="header">{{ consent.client.client_name }}</a>
{% else %}
<div class="header">{{ client.client_name }}</div>
<div class="header">{{ consent.client.client_name }}</div>
{% endif %}
{% if consent.issue_date %}
<div class="meta">{% trans %}From:{% endtrans %} {{ consent.issue_date.strftime("%d/%m/%Y %H:%M:%S") }}</div>
@ -66,18 +65,18 @@
</div>
</div>
</div>
{% if client.tos_uri or client.policy_uri %}
{% if consent.client.tos_uri or consent.client.policy_uri %}
<div class="extra content">
{% if client.policy_uri %}
{% if consent.client.policy_uri %}
<span class="right floated">
<i class="mask icon"></i>
<a href="{{ client.policy_uri }}">{% trans %}Policy{% endtrans %}</a>
<a href="{{ consent.client.policy_uri }}">{% trans %}Policy{% endtrans %}</a>
</span>
{% endif %}
{% if client.tos_uri %}
{% if consent.client.tos_uri %}
<span>
<i class="file signature icon"></i>
<a href="{{ client.tos_uri }}">{% trans %}Terms of service{% endtrans %}</a>
<a href="{{ consent.client.tos_uri }}">{% trans %}Terms of service{% endtrans %}</a>
</span>
{% endif %}
</div>

View file

@ -233,7 +233,7 @@ def logged_moderator(moderator, testclient):
def foo_group(app, user, slapd_connection):
Group.ldap_object_classes(slapd_connection)
group = Group(
member=[user.dn],
member=[user],
cn="foo",
)
group.save()
@ -247,7 +247,7 @@ def foo_group(app, user, slapd_connection):
def bar_group(app, admin, slapd_connection):
Group.ldap_object_classes(slapd_connection)
group = Group(
member=[admin.dn],
member=[admin],
cn="bar",
)
group.save()

View file

@ -117,3 +117,15 @@ def test_operational_attribute_conversion(slapd_connection):
"oauthClientName": [b"foobar_name"],
"invalidAttribute": [b"foobar"],
}
def test_guess_object_from_dn(slapd_connection, testclient, foo_group):
foo_group.member = [foo_group]
foo_group.save()
g = LDAPObject.get(dn=foo_group.dn)
assert isinstance(g, Group)
assert g == foo_group
assert g.cn == foo_group.cn
ou = LDAPObject.get(dn=f"{Group.base},{Group.root_dn}")
assert isinstance(g, LDAPObject)

View file

@ -109,7 +109,7 @@ def client(testclient, other_client, slapd_connection):
token_endpoint_auth_method="client_secret_basic",
post_logout_redirect_uris=["https://mydomain.tld/disconnected"],
)
c.audience = [c.dn, other_client.dn]
c.audience = [c, other_client]
c.save()
yield c
@ -145,7 +145,7 @@ def other_client(testclient, slapd_connection):
token_endpoint_auth_method="client_secret_basic",
post_logout_redirect_uris=["https://myotherdomain.tld/disconnected"],
)
c.audience = [c.dn]
c.audience = [c]
c.save()
yield c
@ -157,8 +157,8 @@ def authorization(testclient, user, client, slapd_connection):
a = AuthorizationCode(
authorization_code_id=gen_salt(48),
code="my-code",
client=client.dn,
subject=user.dn,
client=client,
subject=user,
redirect_uri="https://foo.bar/callback",
response_type="code",
scope="openid profile",
@ -179,9 +179,9 @@ def token(testclient, client, user, slapd_connection):
t = Token(
token_id=gen_salt(48),
access_token=gen_salt(48),
audience=[client.dn],
client=client.dn,
subject=user.dn,
audience=[client],
client=client,
subject=user,
token_type=None,
refresh_token=gen_salt(48),
scope="openid profile",
@ -197,7 +197,7 @@ def token(testclient, client, user, slapd_connection):
def id_token(testclient, client, user, slapd_connection):
return generate_id_token(
{},
generate_user_info(user.dn, client.scope),
generate_user_info(user, client.scope),
aud=client.client_id,
**get_jwt_config(None)
)
@ -207,8 +207,8 @@ def id_token(testclient, client, user, slapd_connection):
def consent(testclient, client, user, slapd_connection):
t = Consent(
cn=str(uuid.uuid4()),
client=client.dn,
subject=user.dn,
client=client,
subject=user,
scope=["openid", "profile"],
issue_date=datetime.datetime.now(),
)

View file

@ -47,7 +47,7 @@ def test_authorization_code_flow(
"phone",
}
consents = Consent.query(client=client.dn, subject=logged_user.dn)
consents = Consent.query(client=client, subject=logged_user)
assert set(consents[0].scope) == {
"openid",
"profile",
@ -71,8 +71,8 @@ def test_authorization_code_flow(
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
assert token.client == client.dn
assert token.subject == logged_user.dn
assert token.client == client
assert token.subject == logged_user
assert set(token.scope[0].split(" ")) == {
"openid",
"profile",
@ -140,7 +140,7 @@ def test_authorization_code_flow_with_redirect_uri(
code = params["code"][0]
authcode = AuthorizationCode.get(code=code)
assert authcode is not None
consents = Consent.query(client=client.dn, subject=logged_user.dn)
consents = Consent.query(client=client, subject=logged_user)
res = testclient.post(
"/oauth/token",
@ -156,8 +156,8 @@ def test_authorization_code_flow_with_redirect_uri(
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
assert token.client == client.dn
assert token.subject == logged_user.dn
assert token.client == client
assert token.subject == logged_user
for consent in consents:
consent.delete()
@ -188,7 +188,7 @@ def test_authorization_code_flow_preconsented(
authcode = AuthorizationCode.get(code=code)
assert authcode is not None
consents = Consent.query(client=client.dn, subject=logged_user.dn)
consents = Consent.query(client=client, subject=logged_user)
assert not consents
res = testclient.post(
@ -205,8 +205,8 @@ def test_authorization_code_flow_preconsented(
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
assert token.client == client.dn
assert token.subject == logged_user.dn
assert token.client == client
assert token.subject == logged_user
id_token = res.json["id_token"]
claims = jwt.decode(id_token, keypair[1])
@ -257,7 +257,7 @@ def test_logout_login(testclient, logged_user, client):
authcode = AuthorizationCode.get(code=code)
assert authcode is not None
consents = Consent.query(client=client.dn, subject=logged_user.dn)
consents = Consent.query(client=client, subject=logged_user)
assert "profile" in consents[0].scope
res = testclient.post(
@ -274,8 +274,8 @@ def test_logout_login(testclient, logged_user, client):
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
assert token.client == client.dn
assert token.subject == logged_user.dn
assert token.client == client
assert token.subject == logged_user
res = testclient.get(
"/oauth/userinfo",
@ -338,7 +338,7 @@ def test_refresh_token(testclient, user, client):
authcode = AuthorizationCode.get(code=code)
assert authcode is not None
consents = Consent.query(client=client.dn, subject=user.dn)
consents = Consent.query(client=client, subject=user)
assert "profile" in consents[0].scope
with freezegun.freeze_time("2020-01-01 00:01:00"):
@ -418,7 +418,7 @@ def test_code_challenge(testclient, logged_user, client):
authcode = AuthorizationCode.get(code=code)
assert authcode is not None
consents = Consent.query(client=client.dn, subject=logged_user.dn)
consents = Consent.query(client=client, subject=logged_user)
assert "profile" in consents[0].scope
res = testclient.post(
@ -436,8 +436,8 @@ def test_code_challenge(testclient, logged_user, client):
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
assert token.client == client.dn
assert token.subject == logged_user.dn
assert token.client == client
assert token.subject == logged_user
res = testclient.get(
"/oauth/userinfo",
@ -477,7 +477,7 @@ def test_authorization_code_flow_when_consent_already_given(
authcode = AuthorizationCode.get(code=code)
assert authcode is not None
consents = Consent.query(client=client.dn, subject=logged_user.dn)
consents = Consent.query(client=client, subject=logged_user)
assert "profile" in consents[0].scope
res = testclient.post(
@ -535,7 +535,7 @@ def test_authorization_code_flow_when_consent_already_given_but_for_a_smaller_sc
authcode = AuthorizationCode.get(code=code)
assert authcode is not None
consents = Consent.query(client=client.dn, subject=logged_user.dn)
consents = Consent.query(client=client, subject=logged_user)
assert "profile" in consents[0].scope
assert "groups" not in consents[0].scope
@ -571,7 +571,7 @@ def test_authorization_code_flow_when_consent_already_given_but_for_a_smaller_sc
authcode = AuthorizationCode.get(code=code)
assert authcode is not None
consents = Consent.query(client=client.dn, subject=logged_user.dn)
consents = Consent.query(client=client, subject=logged_user)
assert "profile" in consents[0].scope
assert "groups" in consents[0].scope
@ -606,8 +606,8 @@ def test_authorization_code_flow_but_user_cannot_use_oidc(
def test_prompt_none(testclient, logged_user, client):
consent = Consent(
cn=str(uuid.uuid4()),
client=client.dn,
subject=logged_user.dn,
client=client,
subject=logged_user,
scope=["openid", "profile"],
)
consent.save()
@ -633,8 +633,8 @@ def test_prompt_none(testclient, logged_user, client):
def test_prompt_not_logged(testclient, user, client):
consent = Consent(
cn=str(uuid.uuid4()),
client=client.dn,
subject=user.dn,
client=client,
subject=user,
scope=["openid", "profile"],
)
consent.save()
@ -732,7 +732,7 @@ def test_authorization_code_request_scope_too_large(
"profile",
}
consents = Consent.query(client=other_client.dn, subject=logged_user.dn)
consents = Consent.query(client=other_client, subject=logged_user)
assert set(consents[0].scope) == {
"openid",
"profile",
@ -752,8 +752,8 @@ def test_authorization_code_request_scope_too_large(
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
assert token.client == other_client.dn
assert token.subject == logged_user.dn
assert token.client == other_client
assert token.subject == logged_user
assert set(token.scope[0].split(" ")) == {
"openid",
"profile",
@ -965,7 +965,7 @@ def test_token_default_expiration_date(testclient, logged_user, client, keypair)
claims = jwt.decode(id_token, keypair[1])
assert claims["exp"] - claims["iat"] == 3600
consents = Consent.query(client=client.dn, subject=logged_user.dn)
consents = Consent.query(client=client, subject=logged_user)
for consent in consents:
consent.delete()
@ -1023,6 +1023,6 @@ def test_token_custom_expiration_date(testclient, logged_user, client, keypair):
claims = jwt.decode(id_token, keypair[1])
assert claims["exp"] - claims["iat"] == 6000
consents = Consent.query(client=client.dn, subject=logged_user.dn)
consents = Consent.query(client=client, subject=logged_user)
for consent in consents:
consent.delete()

View file

@ -11,8 +11,8 @@ def test_clean_command(testclient, slapd_connection, client, user):
valid_code = AuthorizationCode(
authorization_code_id=gen_salt(48),
code="my-valid-code",
client=client.dn,
subject=user.dn,
client=client,
subject=user,
redirect_uri="https://foo.bar/callback",
response_type="code",
scope="openid profile",
@ -27,8 +27,8 @@ def test_clean_command(testclient, slapd_connection, client, user):
expired_code = AuthorizationCode(
authorization_code_id=gen_salt(48),
code="my-expired-code",
client=client.dn,
subject=user.dn,
client=client,
subject=user,
redirect_uri="https://foo.bar/callback",
response_type="code",
scope="openid profile",
@ -47,8 +47,8 @@ def test_clean_command(testclient, slapd_connection, client, user):
valid_token = Token(
token_id=gen_salt(48),
access_token="my-valid-token",
client=client.dn,
subject=user.dn,
client=client,
subject=user,
type=None,
refresh_token=gen_salt(48),
scope="openid profile",
@ -59,8 +59,8 @@ def test_clean_command(testclient, slapd_connection, client, user):
expired_token = Token(
token_id=gen_salt(48),
access_token="my-expired-token",
client=client.dn,
subject=user.dn,
client=client,
subject=user,
type=None,
refresh_token=gen_salt(48),
scope="openid profile",

View file

@ -53,7 +53,7 @@ def test_client_add(testclient, logged_admin):
client_id = res.forms["readonly"]["client_id"].value
client = Client.get(client_id)
data["audience"] = [client.dn]
data["audience"] = [client]
for k, v in data.items():
client_value = getattr(client, k)
if k == "scope":
@ -106,6 +106,7 @@ def test_client_edit(testclient, client, logged_admin, other_client):
assert ("success", "The client has been edited.") in res.flashes
client = Client.get(client.dn)
data["audience"] = [client, other_client]
for k, v in data.items():
client_value = getattr(client, k)
if k == "scope":
@ -131,16 +132,12 @@ def test_client_delete(testclient, logged_admin):
client = Client(client_id="client_id")
client.save()
token = Token(
token_id="id", client=client.dn, issue_datetime=datetime.datetime.utcnow()
token_id="id", client=client, issue_datetime=datetime.datetime.utcnow()
)
token.save()
consent = Consent(
cn="cn", subject=logged_admin.dn, client=client.dn, scope="openid"
)
consent = Consent(cn="cn", subject=logged_admin, client=client, scope="openid")
consent.save()
code = AuthorizationCode(
authorization_code_id="id", client=client.dn, subject=client.dn
)
code = AuthorizationCode(authorization_code_id="id", client=client, subject=client)
res = testclient.get("/admin/client/edit/" + client.client_id)
res = res.forms["clientadd"].submit(name="action", value="delete").follow()

View file

@ -118,10 +118,9 @@ def test_oidc_authorization_after_revokation(
res = res.form.submit(name="answer", value="accept", status=302)
Consent.query()
consents = Consent.query(client=client.dn, subject=logged_user.dn)
assert consents[0].dn == consent.dn
consents = Consent.query(client=client, subject=logged_user)
consent.reload()
assert consents[0] == consent
assert not consent.revoked
params = parse_qs(urlsplit(res.location).query)
@ -140,8 +139,8 @@ def test_oidc_authorization_after_revokation(
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
assert token.client == client.dn
assert token.subject == logged_user.dn
assert token.client == client
assert token.subject == logged_user
def test_preconsented_client_appears_in_consent_list(testclient, client, logged_user):
@ -166,8 +165,8 @@ def test_revoke_preconsented_client(testclient, client, logged_user, token):
assert ("success", "The access has been revoked") in res.flashes
consent = Consent.get()
assert consent.client == client.dn
assert consent.subject == logged_user.dn
assert consent.client == client
assert consent.subject == logged_user
assert consent.scope == ["openid", "email", "profile", "groups", "address", "phone"]
assert not consent.issue_date
token.reload()

View file

@ -159,7 +159,7 @@ def test_end_session_invalid_client_id(
def test_client_hint_invalid(testclient, slapd_connection, logged_user, client):
id_token = generate_id_token(
{},
generate_user_info(logged_user.dn, client.scope),
generate_user_info(logged_user, client.scope),
aud="invalid-client-id",
**get_jwt_config(None),
)
@ -266,7 +266,7 @@ def test_jwt_not_issued_here(
def test_client_hint_mismatch(testclient, slapd_connection, logged_user, client):
id_token = generate_id_token(
{},
generate_user_info(logged_user.dn, client.scope),
generate_user_info(logged_user, client.scope),
aud="another_client_id",
**get_jwt_config(None),
)
@ -299,7 +299,7 @@ def test_bad_user_id_token_mismatch(
id_token = generate_id_token(
{},
generate_user_info(admin.dn, client.scope),
generate_user_info(admin, client.scope),
aud=client.client_id,
**get_jwt_config(None),
)

View file

@ -93,8 +93,8 @@ def test_full_flow(testclient, logged_user, client, user, other_client):
access_token = res.json["access_token"]
token = Token.get(access_token=access_token)
assert token.client == client.dn
assert token.subject == logged_user.dn
assert token.client == client
assert token.subject == logged_user
res = testclient.post(
"/oauth/introspect",

View file

@ -126,15 +126,13 @@ def test_get_members_filters_non_existent_user(
testclient, logged_moderator, foo_group, user
):
# an LDAP group can be inconsistent by containing members which doesn't exist
non_existent_user_id = user.dn.replace(user.name, "yolo")
foo_group.member = foo_group.member + [non_existent_user_id]
non_existent_user = User(cn="foo", sn="bar")
foo_group.member = foo_group.member + [non_existent_user]
foo_group.save()
foo_members = foo_group.get_members()
foo_group.get_members()
assert foo_group.member == [user.dn, non_existent_user_id]
assert len(foo_members) == 1
assert foo_members[0].dn == user.dn
assert foo_group.member == [user, non_existent_user]
testclient.get("/groups/foo", status=200)

View file

@ -36,8 +36,8 @@ def test_edition(
("cn=bar,ou=groups,dc=mydomain,dc=tld", False, "bar"),
}
assert logged_user.groups == [foo_group]
assert foo_group.member == [logged_user.dn]
assert bar_group.member == [admin.dn]
assert foo_group.member == [logged_user]
assert bar_group.member == [admin]
assert res.form["groups"].attrs["readonly"]
assert res.form["uid"].attrs["readonly"]
@ -76,8 +76,8 @@ def test_edition(
foo_group.reload()
bar_group.reload()
assert logged_user.groups == [foo_group]
assert foo_group.member == [logged_user.dn]
assert bar_group.member == [admin.dn]
assert foo_group.member == [logged_user]
assert bar_group.member == [admin]
assert logged_user.check_password("correct horse battery staple")
@ -285,8 +285,8 @@ def test_user_creation_edition_and_deletion(
foo_group.reload()
bar_group.reload()
assert george.dn in set(foo_group.member)
assert george.dn in set(bar_group.member)
assert george in set(foo_group.member)
assert george in set(bar_group.member)
assert set(george.groups) == {foo_group, bar_group}
assert "george" in testclient.get("/users", status=200).text
assert "george" in testclient.get("/users", status=200).text