1
0
mirror of https://github.com/samba-team/samba.git synced 2025-02-02 09:47:23 +03:00

netcmd: models: Model.query adds optional polymorphic flag for returning specific class types

This defaults to False, query the User class returns only User instances.

    User.query(samdb)

When set to True, query the User class can return User, Computer, ManagedServiceAccount instances.

    User.query(samdb, polymorphic=True)

If polymorphic is False the same records are still returned but records will always be interpreted as the model that is being queried only, rather than a more specific model that matches that object class.

Signed-off-by: Rob van der Linde <rob@catalyst.net.nz>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
Reviewed-by: Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
This commit is contained in:
Rob van der Linde 2024-02-20 16:45:45 +13:00 committed by Andrew Bartlett
parent ccce7e7c03
commit ca973caa28
2 changed files with 34 additions and 8 deletions

View File

@ -226,10 +226,18 @@ class Model(metaclass=ModelMeta):
return expression
@classmethod
def query(cls, ldb, **kwargs):
def query(cls, ldb, polymorphic=False, **kwargs):
"""Returns a search query for this model.
NOTE: If polymorphic is enabled then querying will return instances
of that specific model, for example querying User can return Computer
and ManagedServiceAccount instances.
By default, polymorphic querying is disabled, and querying User
will only return User instances.
:param ldb: Ldb connection
:param polymorphic: If true enables polymorphic querying (see note)
:param kwargs: Search criteria as keyword args
"""
base_dn = cls.get_search_dn(ldb)
@ -244,7 +252,7 @@ class Model(metaclass=ModelMeta):
raise NotFound(f"Container does not exist: {base_dn}")
raise
return Query(cls, ldb, result)
return Query(cls, ldb, result, polymorphic)
@classmethod
def get(cls, ldb, **kwargs):

View File

@ -22,6 +22,7 @@
import re
from .constants import MODELS
from .exceptions import NotFound, MultipleObjectsReturned
RE_SPLIT_CAMELCASE = re.compile(r"[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))")
@ -30,27 +31,44 @@ RE_SPLIT_CAMELCASE = re.compile(r"[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))")
class Query:
"""Simple Query class used by the `Model.query` method."""
def __init__(self, model, ldb, result):
def __init__(self, model, ldb, result, polymorphic):
self.model = model
self.ldb = ldb
self.result = result
self.count = result.count
self.name = " ".join(RE_SPLIT_CAMELCASE.findall(model.__name__)).lower()
self.polymorphic = polymorphic
def __iter__(self):
"""Loop over Query class yields Model instances."""
for message in self.result:
yield self.model.from_message(self.ldb, message)
yield self._model_from_message(message)
def _model_from_message(self, message):
"""Returns the model class to use to construct instances.
If polymorphic query is enabled it will use the last item from
the objectClass list.
Otherwise, it will use the model from the queryset.
"""
if self.polymorphic:
object_class = str(message["objectClass"][-1])
model = MODELS.get(object_class, self.model)
else:
model = self.model
return model.from_message(self.ldb, message)
def first(self):
"""Returns the first item in the Query or None for no results."""
if self.count:
return self.model.from_message(self.ldb, self.result[0])
return self._model_from_message(self.result[0])
def last(self):
"""Returns the last item in the Query or None for no results."""
if self.count:
return self.model.from_message(self.ldb, self.result[-1])
return self._model_from_message(self.result[-1])
def get(self):
"""Returns one item or None if no results were found.
@ -62,7 +80,7 @@ class Query:
raise MultipleObjectsReturned(
f"More than one {self.name} objects returned (got {self.count}).")
elif self.count:
return self.model.from_message(self.ldb, self.result[0])
return self._model_from_message(self.result[0])
def one(self):
"""Must return EXACTLY one item or raise an exception.
@ -78,4 +96,4 @@ class Query:
raise MultipleObjectsReturned(
f"More than one {self.name} objects returned (got {self.count}).")
else:
return self.model.from_message(self.ldb, self.result[0])
return self._model_from_message(self.result[0])