ldaputils serialization refactoring

This commit is contained in:
Éloi Rivard 2021-12-08 14:58:12 +01:00
parent 1e9712a08e
commit ce6ccc0d3d

View file

@ -150,13 +150,28 @@ class LDAPObject:
return cls._attribute_type_by_name return cls._attribute_type_by_name
@staticmethod
def ldap_attrs_to_python(attrs):
return {
name: [value.decode("utf-8") for value in values]
for name, values in attrs.items()
}
@staticmethod
def python_attrs_to_ldap(attrs):
return {
name: [
value.encode("utf-8") if isinstance(value, str) else value
for value in values
]
for name, values in attrs.items()
}
def reload(self, conn=None): def reload(self, conn=None):
conn = conn or self.ldap() conn = conn or self.ldap()
result = conn.search_s(self.dn, ldap.SCOPE_SUBTREE) result = conn.search_s(self.dn, ldap.SCOPE_SUBTREE)
self.changes = {} self.changes = {}
self.attrs = { self.attrs = self.ldap_attrs_to_python(result[0][1])
k: [elt.decode("utf-8") for elt in v] for k, v in result[0][1].items()
}
def save(self, conn=None): def save(self, conn=None):
conn = conn or self.ldap() conn = conn or self.ldap()
@ -165,39 +180,31 @@ class LDAPObject:
except ldap.NO_SUCH_OBJECT: except ldap.NO_SUCH_OBJECT:
match = False match = False
# Object already exists in the LDAP database
if match: if match:
mods = { changes = {
k: v name: value
for k, v in self.changes.items() for name, value in self.changes.items()
if v and v[0] and self.attrs.get(k) != v if value and value[0] and self.attrs.get(name) != value
} }
attributes = [ changes = self.python_attrs_to_ldap(changes)
( modlist = [
ldap.MOD_REPLACE, (ldap.MOD_REPLACE, name, values) for name, values in changes.items()
k,
[elt.encode("utf-8") if isinstance(elt, str) else elt for elt in v],
)
for k, v in mods.items()
] ]
conn.modify_s(self.dn, attributes) conn.modify_s(self.dn, modlist)
# Object does not exist yet in the LDAP database
else: else:
mods = {} changes = {
for k, v in self.attrs.items(): name: value
if v and v[0]: for name, value in {**self.attrs, **self.changes}.items()
mods[k] = v if value and value[0]
for k, v in self.changes.items(): }
if v and v[0]: changes = self.python_attrs_to_ldap(changes)
mods[k] = v attributes = [(name, values) for name, values in changes.items()]
attributes = [
(k, [elt.encode("utf-8") if isinstance(elt, str) else elt for elt in v])
for k, v in mods.items()
]
conn.add_s(self.dn, attributes) conn.add_s(self.dn, attributes)
for k, v in self.changes.items(): self.attrs = {**self.attrs, **self.changes}
self.attrs[k] = v
self.changes = {} self.changes = {}
@classmethod @classmethod
@ -233,7 +240,9 @@ class LDAPObject:
else: else:
values = [ldap.filter.escape_filter_chars(v) for v in value] values = [ldap.filter.escape_filter_chars(v) for v in value]
arg_filter += "(|" + "".join([f"({key}={v})" for v in values]) + ")" arg_filter += (
"(|" + "".join([f"({key}={value})" for value in values]) + ")"
)
if not filter: if not filter:
filter = "" filter = ""
@ -244,12 +253,7 @@ class LDAPObject:
base = base or f"{cls.base},{cls.root_dn}" base = base or f"{cls.base},{cls.root_dn}"
result = conn.search_s(base, ldap.SCOPE_SUBTREE, ldapfilter or None) result = conn.search_s(base, ldap.SCOPE_SUBTREE, ldapfilter or None)
return [ return [cls(**cls.ldap_attrs_to_python(args)) for _, args in result]
cls(
**{k: [elt.decode("utf-8") for elt in v] for k, v in args.items()},
)
for _, args in result
]
def __getattr__(self, name): def __getattr__(self, name):
if (not self.may or name not in self.may) and ( if (not self.may or name not in self.may) and (