From cad1b6c27483dde146369bf03ed8c11ee939c01e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89loi=20Rivard?= Date: Mon, 6 Dec 2021 15:40:30 +0100 Subject: [PATCH] Escape filters --- canaille/conf/config.sample.toml | 2 +- canaille/ldaputils.py | 26 ++++++++++++++++++-------- canaille/models.py | 10 ++++++++-- demo/conf/canaille.toml | 2 +- tests/conftest.py | 4 ++-- tests/test_account.py | 6 +++--- tests/test_authorization_code_flow.py | 10 +++++----- tests/test_hybrid_flow.py | 4 ++-- tests/test_implicit_flow.py | 6 +++--- tests/test_password_flow.py | 8 ++++---- 10 files changed, 47 insertions(+), 31 deletions(-) diff --git a/canaille/conf/config.sample.toml b/canaille/conf/config.sample.toml index 425688b3..ca6bc949 100644 --- a/canaille/conf/config.sample.toml +++ b/canaille/conf/config.sample.toml @@ -77,7 +77,7 @@ GROUP_CLASS = "groupOfNames" GROUP_NAME_ATTRIBUTE = "cn" # A filter to check if a user belongs to a group -GROUP_USER_FILTER = "(member={user.dn})" +GROUP_USER_FILTER = "member={user.dn}" # You can define access controls that define what users can do on canaille # An access control consists in a FILTER to match users, a list of PERMISSIONS diff --git a/canaille/ldaputils.py b/canaille/ldaputils.py index cda48789..d383b872 100644 --- a/canaille/ldaputils.py +++ b/canaille/ldaputils.py @@ -1,4 +1,5 @@ import ldap +import ldap.filter from flask import g @@ -221,15 +222,24 @@ class LDAPObject: else "" ) arg_filter = "" - for k, v in kwargs.items(): - if not isinstance(v, list): - arg_filter += f"({k}={v})" - elif len(v) == 1: - arg_filter += f"({k}={v[0]})" - else: - arg_filter += "(|" + "".join([f"({k}={_v})" for _v in v]) + ")" + for key, value in kwargs.items(): + if not isinstance(value, list): + escaped_value = ldap.filter.escape_filter_chars(value) + arg_filter += f"({key}={escaped_value})" + + elif len(value) == 1: + escaped_value = ldap.filter.escape_filter_chars(value[0]) + arg_filter += f"({key}={escaped_value})" + + else: + values = [ldap.filter.escape_filter_chars(v) for v in value] + arg_filter += "(|" + "".join([f"({key}={v})" for v in values]) + ")" + + if not filter: + filter = "" + elif not filter.startswith("(") and not filter.endswith(")"): + filter = f"({filter})" - filter = filter or "" ldapfilter = f"(&{class_filter}{arg_filter}{filter})" base = base or f"{cls.base},{cls.root_dn}" result = conn.search_s(base, ldap.SCOPE_SUBTREE, ldapfilter or None) diff --git a/canaille/models.py b/canaille/models.py index 1e41b6a8..a53661af 100644 --- a/canaille/models.py +++ b/canaille/models.py @@ -1,5 +1,6 @@ import datetime import ldap +import ldap.filter import uuid from authlib.oauth2.rfc6749 import ( ClientMixin, @@ -26,7 +27,11 @@ class User(LDAPObject): conn = conn or cls.ldap() if login: - filter = current_app.config["LDAP"].get("USER_FILTER").format(login=login) + filter = ( + current_app.config["LDAP"] + .get("USER_FILTER") + .format(login=ldap.filter.escape_filter_chars(login)) + ) user = super().get(dn, filter, conn) if user: @@ -39,7 +44,8 @@ class User(LDAPObject): group_filter = current_app.config["LDAP"]["GROUP_USER_FILTER"].format( user=self ) - self._groups = Group.filter(filter=group_filter, conn=conn) + escaped_group_filter = ldap.filter.escape_filter_chars(group_filter) + self._groups = Group.filter(filter=escaped_group_filter, conn=conn) except KeyError: pass diff --git a/demo/conf/canaille.toml b/demo/conf/canaille.toml index 7c1c1d78..75de7d8c 100644 --- a/demo/conf/canaille.toml +++ b/demo/conf/canaille.toml @@ -79,7 +79,7 @@ GROUP_CLASS = "groupOfNames" GROUP_NAME_ATTRIBUTE = "cn" # A filter to check if a user belongs to a group -GROUP_USER_FILTER = "(member={user.dn})" +GROUP_USER_FILTER = "member={user.dn}" # You can define access controls that define what users can do on canaille # An access control consists in a FILTER to match users, a list of PERMISSIONS diff --git a/tests/conftest.py b/tests/conftest.py index 6d9b209b..73551b36 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -141,7 +141,7 @@ def configuration(slapd_server, smtpd, keypair_path): "GROUP_BASE": "ou=groups", "GROUP_CLASS": "groupOfNames", "GROUP_NAME_ATTRIBUTE": "cn", - "GROUP_USER_FILTER": "(member={user.dn})", + "GROUP_USER_FILTER": "member={user.dn}", "TIMEOUT": 0.1, }, "ACL": { @@ -329,7 +329,7 @@ def user(app, slapd_connection): User.ocs_by_name(slapd_connection) u = User( objectClass=["inetOrgPerson"], - cn="John Doe", + cn="John (johnny) Doe", sn="Doe", uid="user", mail="john@doe.com", diff --git a/tests/test_account.py b/tests/test_account.py index fe7c7f2f..f0394104 100644 --- a/tests/test_account.py +++ b/tests/test_account.py @@ -7,12 +7,12 @@ def test_signin_and_out(testclient, slapd_connection, user): res = testclient.get("/login", status=200) - res.form["login"] = "John Doe" + res.form["login"] = "John (johnny) Doe" res = res.form.submit(status=302) res = res.follow(status=200) with testclient.session_transaction() as session: - assert "John Doe" == session.get("attempt_login") + assert "John (johnny) Doe" == session.get("attempt_login") res.form["password"] = "correct horse battery staple" res = res.form.submit() @@ -37,7 +37,7 @@ def test_signin_wrong_password(testclient, slapd_connection, user): res = testclient.get("/login", status=200) - res.form["login"] = "John Doe" + res.form["login"] = "John (johnny) Doe" res = res.form.submit(status=302) res = res.follow(status=200) res.form["password"] = "incorrect horse" diff --git a/tests/test_authorization_code_flow.py b/tests/test_authorization_code_flow.py index 98017b41..1c4a24bd 100644 --- a/tests/test_authorization_code_flow.py +++ b/tests/test_authorization_code_flow.py @@ -57,7 +57,7 @@ def test_authorization_code_flow( status=200, ) assert { - "name": "John Doe", + "name": "John (johnny) Doe", "family_name": "Doe", "sub": "user", "groups": [], @@ -116,7 +116,7 @@ def test_authorization_code_flow_preconsented( status=200, ) assert { - "name": "John Doe", + "name": "John (johnny) Doe", "family_name": "Doe", "sub": "user", "groups": [], @@ -179,7 +179,7 @@ def test_logout_login(testclient, slapd_connection, logged_user, client): status=200, ) assert { - "name": "John Doe", + "name": "John (johnny) Doe", "family_name": "Doe", "sub": "user", "groups": [], @@ -243,7 +243,7 @@ def test_refresh_token(testclient, slapd_connection, logged_user, client): status=200, ) assert { - "name": "John Doe", + "name": "John (johnny) Doe", "family_name": "Doe", "sub": "user", "groups": [], @@ -302,7 +302,7 @@ def test_code_challenge(testclient, slapd_connection, logged_user, client): status=200, ) assert { - "name": "John Doe", + "name": "John (johnny) Doe", "family_name": "Doe", "sub": "user", "groups": [], diff --git a/tests/test_hybrid_flow.py b/tests/test_hybrid_flow.py index ad54ef41..5351a7fc 100644 --- a/tests/test_hybrid_flow.py +++ b/tests/test_hybrid_flow.py @@ -43,7 +43,7 @@ def test_oauth_hybrid(testclient, slapd_connection, user, client): status=200, ) assert { - "name": "John Doe", + "name": "John (johnny) Doe", "family_name": "Doe", "sub": "user", "groups": [], @@ -89,7 +89,7 @@ def test_oidc_hybrid( status=200, ) assert { - "name": "John Doe", + "name": "John (johnny) Doe", "family_name": "Doe", "sub": "user", "groups": [], diff --git a/tests/test_implicit_flow.py b/tests/test_implicit_flow.py index 04110709..a737c33e 100644 --- a/tests/test_implicit_flow.py +++ b/tests/test_implicit_flow.py @@ -41,7 +41,7 @@ def test_oauth_implicit(testclient, slapd_connection, user, client): ) assert "application/json" == res.content_type assert { - "name": "John Doe", + "name": "John (johnny) Doe", "sub": "user", "family_name": "Doe", "groups": [], @@ -100,7 +100,7 @@ def test_oidc_implicit( ) assert "application/json" == res.content_type assert { - "name": "John Doe", + "name": "John (johnny) Doe", "sub": "user", "family_name": "Doe", "groups": [], @@ -160,7 +160,7 @@ def test_oidc_implicit_with_group( ) assert "application/json" == res.content_type assert { - "name": "John Doe", + "name": "John (johnny) Doe", "sub": "user", "family_name": "Doe", "groups": ["foo"], diff --git a/tests/test_password_flow.py b/tests/test_password_flow.py index d57cb9e2..31fe373c 100644 --- a/tests/test_password_flow.py +++ b/tests/test_password_flow.py @@ -7,7 +7,7 @@ def test_password_flow_basic(testclient, slapd_connection, user, client): "/oauth/token", params=dict( grant_type="password", - username="John Doe", + username="John (johnny) Doe", password="correct horse battery staple", scope="profile", ), @@ -28,7 +28,7 @@ def test_password_flow_basic(testclient, slapd_connection, user, client): status=200, ) assert { - "name": "John Doe", + "name": "John (johnny) Doe", "sub": "user", "family_name": "Doe", "groups": [], @@ -43,7 +43,7 @@ def test_password_flow_post(testclient, slapd_connection, user, client): "/oauth/token", params=dict( grant_type="password", - username="John Doe", + username="John (johnny) Doe", password="correct horse battery staple", scope="profile", client_id=client.oauthClientID, @@ -65,7 +65,7 @@ def test_password_flow_post(testclient, slapd_connection, user, client): status=200, ) assert { - "name": "John Doe", + "name": "John (johnny) Doe", "sub": "user", "family_name": "Doe", "groups": [],