forked from Github-Mirrors/canaille
refactor: ldap objects id attribute is based on entryUUID instead of dn
This commit is contained in:
parent
7b054bb571
commit
ec7a721336
11 changed files with 48 additions and 34 deletions
|
@ -174,9 +174,6 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
|||
|
||||
ldap_name = self.python_attribute_to_ldap(name)
|
||||
|
||||
if ldap_name == "dn":
|
||||
return self.dn_for(self.rdn_value)
|
||||
|
||||
python_single_value = typing.get_origin(self.attributes[name]) is not list
|
||||
ldap_value = self.get_ldap_attribute(ldap_name)
|
||||
return cardinalize_attribute(python_single_value, ldap_value)
|
||||
|
@ -306,18 +303,18 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
|||
return cls._attribute_type_by_name
|
||||
|
||||
@classmethod
|
||||
def get(cls, id=None, filter=None, **kwargs):
|
||||
def get(cls, dn=None, filter=None, **kwargs):
|
||||
try:
|
||||
return cls.query(id, filter, **kwargs)[0]
|
||||
return cls.query(dn, filter, **kwargs)[0]
|
||||
except (IndexError, ldap.NO_SUCH_OBJECT):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def query(cls, id=None, filter=None, **kwargs):
|
||||
def query(cls, dn=None, filter=None, **kwargs):
|
||||
conn = Backend.get().connection
|
||||
|
||||
base = id or kwargs.get("id")
|
||||
if base is None:
|
||||
base = dn
|
||||
if dn is None:
|
||||
base = f"{cls.base},{cls.root_dn}"
|
||||
elif "=" not in base:
|
||||
base = ldap.dn.escape_dn_chars(base)
|
||||
|
@ -330,15 +327,17 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
|||
)
|
||||
if class_filter:
|
||||
class_filter = f"(|{class_filter})"
|
||||
|
||||
arg_filter = ""
|
||||
kwargs = python_attrs_to_ldap(
|
||||
ldap_args = python_attrs_to_ldap(
|
||||
{
|
||||
cls.python_attribute_to_ldap(name): values
|
||||
for name, values in kwargs.items()
|
||||
if values is not None
|
||||
},
|
||||
encode=False,
|
||||
)
|
||||
for key, value in kwargs.items():
|
||||
for key, value in ldap_args.items():
|
||||
if len(value) == 1:
|
||||
escaped_value = ldap.filter.escape_filter_chars(value[0])
|
||||
arg_filter += f"({key}={escaped_value})"
|
||||
|
@ -420,7 +419,6 @@ class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
|
|||
attributes = ["objectClass"] + [
|
||||
self.python_attribute_to_ldap(name) for name in self.attributes
|
||||
]
|
||||
attributes.remove("dn")
|
||||
read_post_control = PostReadControl(criticality=True, attrList=attributes)
|
||||
|
||||
# Object already exists in the LDAP database
|
||||
|
|
|
@ -13,7 +13,7 @@ from .ldapobject import LDAPObject
|
|||
|
||||
class User(canaille.core.models.User, LDAPObject):
|
||||
attribute_map = {
|
||||
"id": "dn",
|
||||
"id": "entryUUID",
|
||||
"created": "createTimestamp",
|
||||
"last_modified": "modifyTimestamp",
|
||||
"user_name": "uid",
|
||||
|
@ -59,7 +59,10 @@ class User(canaille.core.models.User, LDAPObject):
|
|||
if isinstance(filter_, dict):
|
||||
# not super generic, but how can we improve this? ¯\_(ツ)_/¯
|
||||
if "groups" in filter_ and "=" not in filter_.get("groups"):
|
||||
filter_["groups"] = Group.dn_for(filter_["groups"])
|
||||
group_by_id = Group.get(id=filter_["groups"])
|
||||
filter_["groups"] = (
|
||||
group_by_id.dn if group_by_id else Group.dn_for(filter_["groups"])
|
||||
)
|
||||
|
||||
base = "".join(
|
||||
f"({cls.python_attribute_to_ldap(key)}={value})"
|
||||
|
@ -150,7 +153,7 @@ class User(canaille.core.models.User, LDAPObject):
|
|||
|
||||
old_groups = self.state.get(group_attr) or []
|
||||
new_groups = [
|
||||
v if isinstance(v, Group) else Group.get(id=v) for v in new_groups
|
||||
v if isinstance(v, Group) else Group.get(dn=v) for v in new_groups
|
||||
]
|
||||
to_add = set(new_groups) - set(old_groups)
|
||||
to_del = set(old_groups) - set(new_groups)
|
||||
|
@ -186,7 +189,7 @@ class User(canaille.core.models.User, LDAPObject):
|
|||
|
||||
class Group(canaille.core.models.Group, LDAPObject):
|
||||
attribute_map = {
|
||||
"id": "dn",
|
||||
"id": "entryUUID",
|
||||
"created": "createTimestamp",
|
||||
"last_modified": "modifyTimestamp",
|
||||
"display_name": "cn",
|
||||
|
@ -198,11 +201,6 @@ class Group(canaille.core.models.Group, LDAPObject):
|
|||
def identifier(self):
|
||||
return self.rdn_value
|
||||
|
||||
@property
|
||||
def display_name(self):
|
||||
attribute = current_app.config["CANAILLE_LDAP"]["GROUP_NAME_ATTRIBUTE"]
|
||||
return getattr(self, attribute)[0]
|
||||
|
||||
|
||||
class Client(canaille.oidc.models.Client, LDAPObject):
|
||||
ldap_object_class = ["oauthClient"]
|
||||
|
@ -235,7 +233,7 @@ class Client(canaille.oidc.models.Client, LDAPObject):
|
|||
}
|
||||
|
||||
attribute_map = {
|
||||
"id": "dn",
|
||||
"id": "entryUUID",
|
||||
"created": "createTimestamp",
|
||||
"last_modified": "modifyTimestamp",
|
||||
"preconsent": "oauthPreconsent",
|
||||
|
@ -256,7 +254,7 @@ class AuthorizationCode(canaille.oidc.models.AuthorizationCode, LDAPObject):
|
|||
base = "ou=authorizations,ou=oauth"
|
||||
rdn_attribute = "oauthAuthorizationCodeID"
|
||||
attribute_map = {
|
||||
"id": "dn",
|
||||
"id": "entryUUID",
|
||||
"created": "createTimestamp",
|
||||
"last_modified": "modifyTimestamp",
|
||||
"authorization_code_id": "oauthAuthorizationCodeID",
|
||||
|
@ -284,7 +282,7 @@ class Token(canaille.oidc.models.Token, LDAPObject):
|
|||
base = "ou=tokens,ou=oauth"
|
||||
rdn_attribute = "oauthTokenID"
|
||||
attribute_map = {
|
||||
"id": "dn",
|
||||
"id": "entryUUID",
|
||||
"created": "createTimestamp",
|
||||
"last_modified": "modifyTimestamp",
|
||||
"token_id": "oauthTokenID",
|
||||
|
@ -310,7 +308,7 @@ class Consent(canaille.oidc.models.Consent, LDAPObject):
|
|||
base = "ou=consents,ou=oauth"
|
||||
rdn_attribute = "cn"
|
||||
attribute_map = {
|
||||
"id": "dn",
|
||||
"id": "entryUUID",
|
||||
"created": "createTimestamp",
|
||||
"last_modified": "modifyTimestamp",
|
||||
"consent_id": "cn",
|
||||
|
|
|
@ -50,7 +50,7 @@ def ldap_to_python(value, syntax):
|
|||
return value.decode("utf-8").upper() == "TRUE"
|
||||
|
||||
if syntax == Syntax.DISTINGUISHED_NAME:
|
||||
return LDAPObject.get(id=value.decode("utf-8"))
|
||||
return LDAPObject.get(dn=value.decode("utf-8"))
|
||||
|
||||
return value.decode("utf-8")
|
||||
|
||||
|
@ -75,7 +75,7 @@ def python_to_ldap(value, syntax, encode=True):
|
|||
value = "TRUE" if value else "FALSE"
|
||||
|
||||
if syntax == Syntax.DISTINGUISHED_NAME:
|
||||
value = value.id if value else None
|
||||
value = value.dn if value else None
|
||||
|
||||
if not value:
|
||||
return None
|
||||
|
|
|
@ -19,7 +19,6 @@ class MemoryModel(BackendModel):
|
|||
"""Associates attribute values and ids."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("id", str(uuid.uuid4()))
|
||||
self._state = {}
|
||||
self._cache = {}
|
||||
for attribute, value in kwargs.items():
|
||||
|
@ -116,6 +115,9 @@ class MemoryModel(BackendModel):
|
|||
return value
|
||||
|
||||
def save(self):
|
||||
if not self.id:
|
||||
self.id = str(uuid.uuid4())
|
||||
|
||||
self.last_modified = datetime.datetime.now(datetime.timezone.utc).replace(
|
||||
microsecond=0
|
||||
)
|
||||
|
|
|
@ -73,6 +73,7 @@ def add(user):
|
|||
if form["token_endpoint_auth_method"].data == "none"
|
||||
else gen_salt(48),
|
||||
)
|
||||
client.save()
|
||||
client.audience = [client]
|
||||
client.save()
|
||||
flash(
|
||||
|
|
|
@ -456,6 +456,7 @@ class ClientRegistrationEndpoint(ClientManagementMixin, _ClientRegistrationEndpo
|
|||
post_logout_redirect_uris=request.data.get("post_logout_redirect_uris"),
|
||||
**self.client_convert_data(**client_info, **client_metadata),
|
||||
)
|
||||
client.save()
|
||||
client.audience = [client]
|
||||
client.save()
|
||||
return client
|
||||
|
|
|
@ -119,6 +119,7 @@ def populate(app):
|
|||
response_types=["code", "id_token"],
|
||||
token_endpoint_auth_method="client_secret_basic",
|
||||
)
|
||||
client1.save()
|
||||
client1.audience = [client1]
|
||||
client1.save()
|
||||
|
||||
|
@ -142,6 +143,7 @@ def populate(app):
|
|||
token_endpoint_auth_method="client_secret_basic",
|
||||
preconsent=True,
|
||||
)
|
||||
client2.save()
|
||||
client2.audience = [client2]
|
||||
client2.save()
|
||||
|
||||
|
|
9
tests/backends/ldap/test_permissions.py
Normal file
9
tests/backends/ldap/test_permissions.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
def test_group_permissions_by_dn(testclient, user, foo_group):
|
||||
assert not user.can_manage_users
|
||||
|
||||
testclient.app.config["CANAILLE"]["ACL"]["ADMIN"]["FILTER"] = {
|
||||
"groups": foo_group.dn
|
||||
}
|
||||
user.reload()
|
||||
|
||||
assert user.can_manage_users
|
|
@ -48,14 +48,14 @@ def test_dn_when_leading_space_in_id_attribute(testclient, backend):
|
|||
)
|
||||
user.save()
|
||||
|
||||
dn = user.id
|
||||
dn = user.dn
|
||||
assert dn == "uid=user,ou=users,dc=mydomain,dc=tld"
|
||||
assert ldap.dn.is_dn(dn)
|
||||
assert ldap.dn.dn2str(ldap.dn.str2dn(dn)) == dn
|
||||
|
||||
assert user == models.User.get(user.identifier)
|
||||
assert user == models.User.get(user_name=user.identifier)
|
||||
assert user == models.User.get(id=dn)
|
||||
assert user == models.User.get(dn=dn)
|
||||
|
||||
user.delete()
|
||||
|
||||
|
@ -69,14 +69,14 @@ def test_special_chars_in_rdn(testclient, backend):
|
|||
)
|
||||
user.save()
|
||||
|
||||
dn = user.id
|
||||
dn = user.dn
|
||||
assert ldap.dn.is_dn(dn)
|
||||
assert ldap.dn.dn2str(ldap.dn.str2dn(dn)) == dn
|
||||
assert dn == "uid=\\#user,ou=users,dc=mydomain,dc=tld"
|
||||
|
||||
assert user == models.User.get(user.identifier)
|
||||
assert user == models.User.get(user_name=user.identifier)
|
||||
assert user == models.User.get(id=dn)
|
||||
assert user == models.User.get(dn=dn)
|
||||
|
||||
user.delete()
|
||||
|
||||
|
@ -184,13 +184,13 @@ def test_operational_attribute_conversion(backend):
|
|||
def test_guess_object_from_dn(backend, testclient, foo_group):
|
||||
foo_group.members = [foo_group]
|
||||
foo_group.save()
|
||||
dn = foo_group.id
|
||||
g = LDAPObject.get(id=dn)
|
||||
dn = foo_group.dn
|
||||
g = LDAPObject.get(dn=dn)
|
||||
assert isinstance(g, models.Group)
|
||||
assert g == foo_group
|
||||
assert g.display_name == foo_group.display_name
|
||||
|
||||
ou = LDAPObject.get(id=f"{models.Group.base},{models.Group.root_dn}")
|
||||
ou = LDAPObject.get(dn=f"{models.Group.base},{models.Group.root_dn}")
|
||||
assert isinstance(ou, LDAPObject)
|
||||
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ def test_model_lifecycle(testclient, backend):
|
|||
formatted_name="formatted_name",
|
||||
)
|
||||
|
||||
assert not user.id
|
||||
assert not models.User.query()
|
||||
assert not models.User.query(id=user.id)
|
||||
assert not models.User.query(id="invalid")
|
||||
|
|
|
@ -67,6 +67,7 @@ def client(testclient, trusted_client, backend):
|
|||
token_endpoint_auth_method="client_secret_basic",
|
||||
post_logout_redirect_uris=["https://mydomain.tld/disconnected"],
|
||||
)
|
||||
c.save()
|
||||
c.audience = [c, trusted_client]
|
||||
c.save()
|
||||
|
||||
|
@ -104,6 +105,7 @@ def trusted_client(testclient, backend):
|
|||
post_logout_redirect_uris=["https://myotherdomain.tld/disconnected"],
|
||||
preconsent=True,
|
||||
)
|
||||
c.save()
|
||||
c.audience = [c]
|
||||
c.save()
|
||||
|
||||
|
|
Loading…
Reference in a new issue