canaille-globuzma/canaille/backends/ldap/ldapobject.py
2024-04-28 16:18:36 +02:00

483 lines
15 KiB
Python

import itertools
import typing
from collections.abc import Iterable
import ldap.dn
import ldap.filter
from ldap.controls.readentry import PostReadControl
from canaille.backends.models import BackendModel
from .backend import Backend
from .utils import cardinalize_attribute
from .utils import ldap_to_python
from .utils import listify
from .utils import python_to_ldap
def python_attrs_to_ldap(attrs, encode=True, null_allowed=True):
formatted_attrs = {
name: [
python_to_ldap(value, attribute_ldap_syntax(name), encode=encode)
for value in listify(values)
]
for name, values in attrs.items()
}
if not null_allowed:
formatted_attrs = {
key: [value for value in values if value]
for key, values in formatted_attrs.items()
if values
}
return formatted_attrs
def attribute_ldap_syntax(attribute_name):
ldap_attrs = LDAPObject.ldap_object_attributes()
if attribute_name not in ldap_attrs:
return None
if ldap_attrs[attribute_name].syntax:
return ldap_attrs[attribute_name].syntax
return attribute_ldap_syntax(ldap_attrs[attribute_name].sup[0])
class LDAPObjectMetaclass(type):
ldap_to_python_class = {}
def __new__(cls, name, bases, attrs):
klass = super().__new__(cls, name, bases, attrs)
if attrs.get("ldap_object_class"):
for oc in attrs["ldap_object_class"]:
cls.ldap_to_python_class[oc] = klass
return klass
def __setattr__(cls, name, value):
super().__setattr__(name, value)
if name == "ldap_object_class":
for oc in value:
cls.ldap_to_python_class[oc] = cls
class LDAPObjectQuery:
def __init__(self, klass, items):
self.klass = klass
self.items = items
def __len__(self):
return len(self.items)
def __getitem__(self, item):
if isinstance(item, slice):
return (self.decorate(obj[1]) for obj in self.items[item])
return self.decorate(self.items[item][1])
def __iter__(self):
return (self.decorate(obj[1]) for obj in self.items)
def __eq__(self, other):
if isinstance(other, Iterable):
return all(
a == b
for a, b in itertools.zip_longest(
iter(self), iter(other), fillvalue=object()
)
)
return super().__eq__(other)
def __bool__(self):
return bool(self.items)
def decorate(self, args):
klass = self.guess_class(self.klass, args["objectClass"])
obj = klass()
obj.state = args
obj.exists = True
return obj
def guess_class(self, klass, object_classes):
if klass == LDAPObject:
models = [
LDAPObjectMetaclass.ldap_to_python_class[oc.decode()]
for oc in object_classes
if oc.decode() in LDAPObjectMetaclass.ldap_to_python_class
]
return models[0]
return klass
class LDAPObject(BackendModel, metaclass=LDAPObjectMetaclass):
_object_class_by_name = None
_attribute_type_by_name = None
_may = None
_must = None
base = None
root_dn = None
rdn_attribute = None
attribute_map = None
ldap_object_class = None
def __init__(self, dn=None, **kwargs):
self.state = {}
self.changes = {}
self.exists = False
for name, value in kwargs.items():
setattr(self, name, value)
def __repr__(self):
attribute_name = self.ldap_attribute_to_python(self.rdn_attribute)
return (
f"<{self.__class__.__name__} {attribute_name}={self.rdn_value}>"
if self.rdn_attribute
else "<LDAPOBject>"
)
def __eq__(self, other):
ldap_attributes = self.may() + self.must()
if not (
isinstance(other, self.__class__)
and self.may() == other.may()
and self.must() == other.must()
and all(
self.has_ldap_attribute(attr) == other.has_ldap_attribute(attr)
for attr in ldap_attributes
)
):
return False
self_attributes = python_attrs_to_ldap(
{
attr: self.get_ldap_attribute(attr)
for attr in ldap_attributes
if self.has_ldap_attribute(attr)
}
)
other_attributes = python_attrs_to_ldap(
{
attr: other.get_ldap_attribute(attr)
for attr in ldap_attributes
if other.has_ldap_attribute(attr)
}
)
return self_attributes == other_attributes
def __hash__(self):
return hash(self.id)
def __getattribute__(self, name):
if name == "attributes" or name not in self.attributes:
return super().__getattribute__(name)
ldap_name = self.python_attribute_to_ldap(name)
python_single_value = typing.get_origin(self.attributes[name]) is not list
ldap_value = self.get_ldap_attribute(ldap_name)
return cardinalize_attribute(python_single_value, ldap_value)
def __setattr__(self, name, value):
if name not in self.attributes:
super().__setattr__(name, value)
ldap_name = self.python_attribute_to_ldap(name)
self.set_ldap_attribute(ldap_name, value)
def has_ldap_attribute(self, name):
return name in self.ldap_object_attributes() and (
name in self.changes or name in self.state
)
def get_ldap_attribute(self, name):
if name in self.changes:
return self.changes[name]
if not self.state.get(name):
return None
# Lazy conversion from ldap format to python format
if any(isinstance(value, bytes) for value in self.state[name]):
syntax = attribute_ldap_syntax(name)
self.state[name] = [
ldap_to_python(value, syntax) for value in self.state[name]
]
return self.state.get(name)
def set_ldap_attribute(self, name, value):
if name not in self.ldap_object_attributes():
return
value = listify(value)
self.changes[name] = value
@property
def rdn_value(self):
value = self.get_ldap_attribute(self.rdn_attribute)
return (value[0] if isinstance(value, list) else value).strip()
@property
def dn(self):
return self.dn_for(self.rdn_value)
@classmethod
def dn_for(cls, rdn):
return f"{cls.rdn_attribute}={ldap.dn.escape_dn_chars(rdn)},{cls.base},{cls.root_dn}"
@classmethod
def may(cls):
if not cls._may:
cls.update_ldap_attributes()
return cls._may
@classmethod
def must(cls):
return cls._must
@classmethod
def install(cls):
conn = Backend.get().connection
cls.ldap_object_classes(conn)
cls.ldap_object_attributes(conn)
acc = ""
for organizationalUnit in cls.base.split(",")[::-1]:
v = organizationalUnit.split("=")[1]
dn = f"{organizationalUnit}{acc},{cls.root_dn}"
acc = f",{organizationalUnit}"
try:
conn.add_s(
dn,
[
("objectClass", [b"organizationalUnit"]),
("ou", [v.encode("utf-8")]),
],
)
except ldap.ALREADY_EXISTS:
pass
@classmethod
def ldap_object_classes(cls, force=False):
if cls._object_class_by_name and not force:
return cls._object_class_by_name
conn = Backend.get().connection
res = conn.search_s(
"cn=subschema", ldap.SCOPE_BASE, "(objectclass=*)", ["*", "+"]
)
subschema_entry = res[0]
subschema_subentry = ldap.cidict.cidict(subschema_entry[1])
subschema = ldap.schema.SubSchema(subschema_subentry)
object_class_oids = subschema.listall(ldap.schema.models.ObjectClass)
cls._object_class_by_name = {}
for oid in object_class_oids:
object_class = subschema.get_obj(ldap.schema.models.ObjectClass, oid)
for name in object_class.names:
cls._object_class_by_name[name] = object_class
return cls._object_class_by_name
@classmethod
def ldap_object_attributes(cls, force=False):
if cls._attribute_type_by_name and not force:
return cls._attribute_type_by_name
conn = Backend.get().connection
res = conn.search_s(
"cn=subschema", ldap.SCOPE_BASE, "(objectclass=*)", ["*", "+"]
)
subschema_entry = res[0]
subschema_subentry = ldap.cidict.cidict(subschema_entry[1])
subschema = ldap.schema.SubSchema(subschema_subentry)
attribute_type_oids = subschema.listall(ldap.schema.models.AttributeType)
cls._attribute_type_by_name = {}
for oid in attribute_type_oids:
object_class = subschema.get_obj(ldap.schema.models.AttributeType, oid)
for name in object_class.names:
cls._attribute_type_by_name[name] = object_class
return cls._attribute_type_by_name
@classmethod
def get(cls, identifier=None, /, **kwargs):
try:
return cls.query(identifier, **kwargs)[0]
except (IndexError, ldap.NO_SUCH_OBJECT):
if identifier and cls.base:
return (
cls.get(**{cls.identifier_attribute: identifier})
or cls.get(id=identifier)
or None
)
return None
@classmethod
def query(cls, dn=None, filter=None, **kwargs):
conn = Backend.get().connection
base = dn
if dn is None:
base = f"{cls.base},{cls.root_dn}"
elif "=" not in base:
base = ldap.dn.escape_dn_chars(base)
base = f"{cls.rdn_attribute}={base},{cls.base},{cls.root_dn}"
class_filter = (
"".join([f"(objectClass={oc})" for oc in cls.ldap_object_class])
if getattr(cls, "ldap_object_class")
else ""
)
if class_filter:
class_filter = f"(|{class_filter})"
arg_filter = ""
ldap_args = python_attrs_to_ldap(
{
cls.python_attribute_to_ldap(name): values
for name, values in kwargs.items()
if values is not None
},
encode=False,
)
for key, value in ldap_args.items():
if len(value) == 1:
escaped_value = ldap.filter.escape_filter_chars(value[0])
arg_filter += f"({key}={escaped_value})"
else:
values = [ldap.filter.escape_filter_chars(v) for v in value]
arg_filter += (
"(|" + "".join([f"({key}={value})" for value in values]) + ")"
)
if not filter:
filter = ""
ldapfilter = f"(&{class_filter}{arg_filter}{filter})"
base = base or f"{cls.base},{cls.root_dn}"
try:
result = conn.search_s(
base, ldap.SCOPE_SUBTREE, ldapfilter or None, ["+", "*"]
)
except ldap.NO_SUCH_OBJECT:
result = []
return LDAPObjectQuery(cls, result)
@classmethod
def fuzzy(cls, query, attributes=None, **kwargs):
query = ldap.filter.escape_filter_chars(query)
attributes = attributes or cls.may() + cls.must()
attributes = [cls.python_attribute_to_ldap(name) for name in attributes]
filter = (
"(|" + "".join(f"({attribute}=*{query}*)" for attribute in attributes) + ")"
)
return cls.query(filter=filter, **kwargs)
@classmethod
def update_ldap_attributes(cls):
all_object_classes = cls.ldap_object_classes()
this_object_classes = {
all_object_classes[name] for name in cls.ldap_object_class
}
done = set()
cls._may = []
cls._must = []
while len(this_object_classes) > 0:
object_class = this_object_classes.pop()
done.add(object_class)
this_object_classes |= {
all_object_classes[ocsup]
for ocsup in object_class.sup
if ocsup not in done
}
cls._may.extend(object_class.may)
cls._must.extend(object_class.must)
cls._may = list(set(cls._may))
cls._must = list(set(cls._must))
@classmethod
def ldap_attribute_to_python(cls, name):
reverse_attribute_map = {v: k for k, v in (cls.attribute_map or {}).items()}
return reverse_attribute_map.get(name, name)
@classmethod
def python_attribute_to_ldap(cls, name):
return cls.attribute_map.get(name, name) if cls.attribute_map else None
def reload(self):
conn = Backend.get().connection
result = conn.search_s(self.dn, ldap.SCOPE_SUBTREE, None, ["+", "*"])
self.changes = {}
self.state = result[0][1]
def save(self):
conn = Backend.get().connection
current_object_classes = self.get_ldap_attribute("objectClass") or []
self.set_ldap_attribute(
"objectClass",
list(set(self.ldap_object_class) | set(current_object_classes)),
)
# PostReadControl allows to read the updated object attributes on creation/edition
attributes = ["objectClass"] + [
self.python_attribute_to_ldap(name) for name in self.attributes
]
read_post_control = PostReadControl(criticality=True, attrList=attributes)
# Object already exists in the LDAP database
if self.exists:
deletions = [
name
for name, value in self.changes.items()
if (
value is None
or value == []
or (isinstance(value, list) and len(value) == 1 and not value[0])
)
and name in self.state
]
changes = {
name: value
for name, value in self.changes.items()
if name not in deletions and self.state.get(name) != value
}
formatted_changes = python_attrs_to_ldap(changes, null_allowed=False)
modlist = [(ldap.MOD_DELETE, name, None) for name in deletions] + [
(ldap.MOD_REPLACE, name, values)
for name, values in formatted_changes.items()
]
_, _, _, [result] = conn.modify_ext_s(
self.dn, modlist, serverctrls=[read_post_control]
)
# Object does not exist yet in the LDAP database
else:
changes = {
name: value
for name, value in {**self.state, **self.changes}.items()
if value and value[0]
}
formatted_changes = python_attrs_to_ldap(changes, null_allowed=False)
modlist = [(name, values) for name, values in formatted_changes.items()]
_, _, _, [result] = conn.add_ext_s(
self.dn, modlist, serverctrls=[read_post_control]
)
self.exists = True
self.state = {**result.entry, **self.changes}
self.changes = {}
def delete(self):
conn = Backend.get().connection
try:
conn.delete_s(self.dn)
except ldap.NO_SUCH_OBJECT:
pass