From cbcc8039d1525c2807b09818081f034bcb38a2a9 Mon Sep 17 00:00:00 2001 From: Rob van der Linde Date: Thu, 18 Jan 2024 15:47:52 +1300 Subject: [PATCH] netcmd: models: fix build_expression did not work with EnumField Signed-off-by: Rob van der Linde Reviewed-by: Douglas Bagnall Reviewed-by: Andrew Bartlett --- python/samba/netcmd/domain/models/fields.py | 4 ++ .../samba/tests/samba_tool/domain_models.py | 39 ++++++++++++++++++- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/python/samba/netcmd/domain/models/fields.py b/python/samba/netcmd/domain/models/fields.py index e9f05296800..c02562e7c37 100644 --- a/python/samba/netcmd/domain/models/fields.py +++ b/python/samba/netcmd/domain/models/fields.py @@ -209,6 +209,10 @@ class EnumField(Field): else: return MessageElement(str(value.value), flags, self.name) + def expression(self, value): + """Returns the ldb search expression for this field.""" + return f"({self.name}={binary_encode(str(value.value))})" + class DateTimeField(Field): """A field for parsing ldb timestamps into Python datetime.""" diff --git a/python/samba/tests/samba_tool/domain_models.py b/python/samba/tests/samba_tool/domain_models.py index d58f47bfd9a..45d6095c775 100644 --- a/python/samba/tests/samba_tool/domain_models.py +++ b/python/samba/tests/samba_tool/domain_models.py @@ -27,8 +27,8 @@ from xml.etree import ElementTree from ldb import FLAG_MOD_ADD, MessageElement, SCOPE_ONELEVEL from samba.dcerpc import security from samba.dcerpc.misc import GUID -from samba.netcmd.domain.models import (Group, Site, User, StrongNTLMPolicy, - fields) +from samba.netcmd.domain.models import (AccountType, Group, Site, User, + StrongNTLMPolicy, fields) from samba.ndr import ndr_pack, ndr_unpack from .base import SambaToolCmdTest @@ -37,6 +37,41 @@ HOST = "ldap://{DC_SERVER}".format(**os.environ) CREDS = "-U{DC_USERNAME}%{DC_PASSWORD}".format(**os.environ) +class ModelTests(SambaToolCmdTest): + + @classmethod + def setUpClass(cls): + cls.samdb = cls.getSamDB("-H", HOST, CREDS) + super().setUpClass() + + def test_query_count(self): + """Test count property on Query object without converting to a list.""" + groups = Group.query(self.samdb) + self.assertEqual(groups.count, len(list(groups))) + + def test_query_filter_bool(self): + """Tests filtering by a BooleanField.""" + total = Group.query(self.samdb).count + system_groups = Group.query(self.samdb, + is_critical_system_object=True).count + user_groups = Group.query(self.samdb, + is_critical_system_object=False).count + self.assertNotEqual(system_groups, 0) + self.assertNotEqual(user_groups, 0) + self.assertEqual(system_groups + user_groups, total) + + def test_query_filter_enum(self): + """Tests filtering by an EnumField.""" + robots_vs_humans = User.query(self.samdb).count + robots = User.query(self.samdb, + account_type=AccountType.WORKSTATION_TRUST).count + humans = User.query(self.samdb, + account_type=AccountType.NORMAL_ACCOUNT).count + self.assertNotEqual(robots, 0) + self.assertNotEqual(humans, 0) + self.assertEqual(robots + humans, robots_vs_humans) + + class FieldTestMixin: """Tests a model field to ensure it behaves correctly in both directions.