1
0
mirror of https://github.com/samba-team/samba.git synced 2025-08-04 08:22:08 +03:00

python:tests/rpcd_witness_samba_only: add tests for 'net witness list'

Signed-off-by: Stefan Metzmacher <metze@samba.org>
Reviewed-by: Günther Deschner <gd@samba.org>
This commit is contained in:
Stefan Metzmacher
2024-01-12 17:30:41 +01:00
parent 46fdeca696
commit 3e70b31f01

View File

@ -111,7 +111,11 @@ class RpcdWitnessSambaTests(BlackboxTestCase):
node["ip"], common_binding_args64)
self.nodes.append(node)
self.all_registrations = None
def tearDown(self):
self.destroy_all_registrations()
if self.disabled_idx != -1:
self.enable_node(self.disabled_idx)
@ -165,6 +169,49 @@ class RpcdWitnessSambaTests(BlackboxTestCase):
if dump_status:
self.dump_ctdb_status_all()
def call_net_witness_subcmd(self, subcmd,
as_json=False,
registration=None,
net_name=None,
share_name=None,
ip_address=None,
client_computer=None):
COMMAND = "UID_WRAPPER_ROOT=1 bin/net witness"
argv = "%s %s" % (COMMAND, subcmd)
if as_json:
argv += " --json"
if registration is not None:
argv += " --witness-registration='%s'" % (
registration.uuid)
if net_name is not None:
argv += " --witness-net-name='%s'" % (net_name)
if share_name is not None:
argv += " --witness-share-name='%s'" % (share_name)
if ip_address is not None:
argv += " --witness-ip-address='%s'" % (ip_address)
if client_computer is not None:
argv += " --witness-client-computer-name='%s'" % (client_computer)
try:
if self.verbose:
print("Calling: %s" % argv)
out = self.check_output(argv)
except samba.tests.BlackboxProcessError as e:
self.fail("Error calling [%s]: %s" % (argv, e))
out_str = get_string(out)
if not as_json:
return out_str
json_out = json.loads(out_str)
return json_out
@classmethod
def _define_GetInterfaceList_test(cls, conn_idx, disable_idx, ndr64=False):
if disable_idx != -1:
@ -454,6 +501,436 @@ class RpcdWitnessSambaTests(BlackboxTestCase):
if num != werror.WERR_NOT_FOUND:
raise
def prepare_all_registrations(self):
self.assertIsNone(self.all_registrations)
regs = []
for node_idx in range(0, self.num_nodes):
node = self.nodes[node_idx]
for ndr64 in [False, True]:
if ndr64:
binding_string = node["binding_string64"]
ndr_name = "NDR64"
else:
binding_string = node["binding_string32"]
ndr_name = "NDR32"
conn = witness.witness(binding_string, self.lp, self.remote_creds)
conn_ip = node["ip"]
net_name = self.server_hostname
ip_address = node["ip"]
share_name = self.cluster_share
computer_name = "test-net-witness-list-%s-%s" % (
node_idx, ndr_name)
flags = witness.WITNESS_REGISTER_NONE
timeout = 15
reg_version = witness.WITNESS_V1
reg = {
'node_idx': node_idx,
'ndr64': ndr64,
'binding_string': binding_string,
'conn_ip': conn_ip,
'version': reg_version,
'net_name': net_name,
'share_name': None,
'ip_address': ip_address,
'computer_name': computer_name,
'flags': 0,
'timeout': 0,
'conn': conn,
'context': None,
}
regs.append(reg)
reg_version = witness.WITNESS_V2
reg = {
'node_idx': node_idx,
'ndr64': ndr64,
'binding_string': binding_string,
'conn_ip': conn_ip,
'version': reg_version,
'net_name': net_name,
'share_name': None,
'ip_address': ip_address,
'computer_name': computer_name,
'flags': flags,
'timeout': timeout,
'conn': conn,
'context': None,
}
regs.append(reg)
reg = {
'node_idx': node_idx,
'ndr64': ndr64,
'binding_string': binding_string,
'conn_ip': conn_ip,
'version': reg_version,
'net_name': net_name,
'share_name': share_name,
'ip_address': ip_address,
'computer_name': computer_name,
'flags': flags,
'timeout': timeout,
'conn': conn,
'context': None,
}
regs.append(reg)
self.all_registrations = regs
return regs
def close_all_registrations(self):
self.assertIsNotNone(self.all_registrations)
for reg in self.all_registrations:
conn = reg['conn']
reg_context = reg['context']
if reg_context is not None:
conn.UnRegister(reg_context)
reg_context = None
reg['context'] = reg_context
def open_all_registrations(self):
self.assertIsNotNone(self.all_registrations)
for reg in self.all_registrations:
conn = reg['conn']
reg_context = reg['context']
self.assertIsNone(reg_context)
reg_version = reg['version']
if reg_version == witness.WITNESS_V1:
reg_context = conn.Register(reg_version,
reg['net_name'],
reg['ip_address'],
reg['computer_name'])
elif reg_version == witness.WITNESS_V2:
reg_context = conn.RegisterEx(reg_version,
reg['net_name'],
reg['share_name'],
reg['ip_address'],
reg['computer_name'],
reg['flags'],
reg['timeout'])
self.assertIsNotNone(reg_context)
reg['context'] = reg_context
def destroy_all_registrations(self):
if self.all_registrations is None:
return
for reg in self.all_registrations:
conn = reg['conn']
reg_context = reg['context']
if reg_context is not None:
conn.UnRegister(reg_context)
reg_context = None
reg['context'] = reg_context
conn = None
reg['conn'] = conn
self.all_registrations = None
def assertJsonReg(self, json_reg, reg):
self.assertEqual(json_reg['version'], "0x%08x" % reg['version'])
self.assertEqual(json_reg['net_name'], reg['net_name'])
if reg['share_name']:
self.assertEqual(json_reg['share_name'], reg['share_name'])
else:
self.assertIsNone(json_reg['share_name'])
self.assertEqual(json_reg['client_computer_name'], reg['computer_name'])
self.assertIn('flags', json_reg)
json_flags = json_reg['flags']
if reg['flags'] & witness.WITNESS_REGISTER_IP_NOTIFICATION:
expected_ip_notifaction = True
else:
expected_ip_notifaction = False
self.assertEqual(json_flags['WITNESS_REGISTER_IP_NOTIFICATION'],
expected_ip_notifaction)
self.assertEqual(json_flags['int'], reg['flags'])
self.assertEqual(json_flags['hex'], "0x%08x" % reg['flags'])
self.assertEqual(len(json_flags.keys()), 3)
self.assertEqual(json_reg['timeout'], reg['timeout'])
self.assertIn('context_handle', json_reg)
json_context = json_reg['context_handle']
self.assertEqual(json_context['uuid'], str(reg['context'].uuid))
self.assertEqual(json_context['handle_type'], reg['context'].handle_type)
self.assertEqual(len(json_context.keys()), 2)
self.assertIn('server_id', json_reg)
json_server_id = json_reg['server_id']
self.assertIn('pid', json_server_id)
self.assertIn('task_id', json_server_id)
self.assertEqual(json_server_id['vnn'], reg['node_idx'])
self.assertIn('unique_id', json_server_id)
self.assertEqual(len(json_server_id.keys()), 4)
self.assertIn('auth', json_reg)
json_auth = json_reg['auth']
self.assertEqual(json_auth['account_name'], self.remote_user)
self.assertEqual(json_auth['domain_name'], self.remote_domain)
self.assertIn('account_sid', json_auth)
self.assertEqual(len(json_auth.keys()), 3)
self.assertIn('connection', json_reg)
json_conn = json_reg['connection']
self.assertIn('local_address', json_conn)
self.assertIn(reg['conn_ip'], json_conn['local_address'])
self.assertIn('remote_address', json_conn)
self.assertEqual(len(json_conn.keys()), 2)
self.assertIn('registration_time', json_reg)
self.assertEqual(len(json_reg.keys()), 12)
def max_common_prefix(self, strings):
if len(strings) == 0:
return ""
def string_match_len(s1, s2):
idx = 0
for i in range(0, min(len(s1), len(s2))):
c1 = s1[i:i+1]
c2 = s2[i:i+1]
if c1 != c2:
break
idx = i
return idx
prefix = None
for s in strings:
if prefix is None:
prefix = s
continue
l = string_match_len(prefix, s)
prefix = prefix[0:l+1]
return prefix
def check_net_witness_output(self,
cmd,
regs,
registration_idx=None,
net_name=None,
share_name=None,
ip_address=None,
client_computer=None):
self.open_all_registrations()
if registration_idx is not None:
registration = regs[registration_idx]['context']
self.assertIsNotNone(registration)
else:
registration = None
plain_res = self.call_net_witness_subcmd(cmd,
registration=registration,
net_name=net_name,
share_name=share_name,
ip_address=ip_address,
client_computer=client_computer)
if self.verbose:
print("%s" % plain_res)
plain_lines = plain_res.splitlines()
num_headlines = 2
self.assertEqual(len(plain_lines), num_headlines+len(regs))
plain_lines = plain_lines[num_headlines:]
self.assertEqual(len(plain_lines), len(regs))
for reg in regs:
reg_uuid = reg['context'].uuid
expected_line = "%-36s " % reg_uuid
expected_line += "%-20s " % reg['net_name']
if reg['share_name']:
expected_share = reg['share_name']
else:
expected_share = "''"
expected_line += "%-15s " % expected_share
expected_line += "%-20s " % reg['ip_address']
expected_line += "%s" % reg['computer_name']
line = None
for l in plain_lines:
if not l.startswith(str(reg_uuid)):
continue
self.assertIsNone(line)
line = l
self.assertEqual(line, expected_line)
self.assertIsNotNone(line)
self.close_all_registrations()
self.open_all_registrations()
if registration_idx is not None:
registration = regs[registration_idx]['context']
self.assertIsNotNone(registration)
else:
registration = None
json_res = self.call_net_witness_subcmd(cmd,
as_json=True,
registration=registration,
net_name=net_name,
share_name=share_name,
ip_address=ip_address,
client_computer=client_computer)
num_filters = 0
if registration:
num_filters += 1
if net_name:
num_filters += 1
if share_name:
num_filters += 1
if ip_address:
num_filters += 1
if client_computer:
num_filters += 1
num_toplevel = 2
self.assertIn('filters', json_res);
self.assertIn('registrations', json_res);
self.assertEqual(len(json_res.keys()), num_toplevel)
json_filters = json_res['filters']
self.assertEqual(len(json_filters.keys()), num_filters)
if registration:
self.assertEqual(json_filters['--witness-registration'],
str(registration.uuid))
if net_name:
self.assertEqual(json_filters['--witness-net-name'],
net_name)
if share_name:
self.assertEqual(json_filters['--witness-share-name'],
share_name)
if ip_address:
self.assertEqual(json_filters['--witness-ip-address'],
ip_address)
if client_computer:
self.assertEqual(json_filters['--witness-client-computer-name'],
client_computer)
json_regs = json_res['registrations']
self.assertEqual(len(json_regs.keys()), len(regs))
for reg in regs:
reg_uuid = reg['context'].uuid
self.assertIn(str(reg_uuid), json_regs)
json_reg = json_regs[str(reg_uuid)]
self.assertJsonReg(json_reg, reg)
self.close_all_registrations()
def check_combinations(self, check_func, only_shares=False):
all_regs = self.prepare_all_registrations()
share_name_regs = {}
all_share_name_regs = []
no_share_name_regs = []
for reg in all_regs:
if reg['share_name'] is not None:
if reg['share_name'] not in share_name_regs:
share_name_regs[reg['share_name']] = []
share_name_regs[reg['share_name']].append(reg)
all_share_name_regs.append(reg)
else:
no_share_name_regs.append(reg)
ip_address_regs = {}
computer_name_regs = {}
for reg in all_regs:
if reg['ip_address'] not in ip_address_regs:
ip_address_regs[reg['ip_address']] = []
ip_address_regs[reg['ip_address']].append(reg)
if reg['computer_name'] not in computer_name_regs:
computer_name_regs[reg['computer_name']] = []
computer_name_regs[reg['computer_name']].append(reg)
all_share_names = '|'.join(share_name_regs.keys())
common_share_name = self.max_common_prefix(share_name_regs.keys())
all_ip_addresses = '|'.join(ip_address_regs.keys())
common_ip_address = self.max_common_prefix(ip_address_regs.keys())
all_computer_names = '|'.join(computer_name_regs.keys())
common_computer_name = self.max_common_prefix(computer_name_regs.keys())
check_func(all_regs)
check_func(all_regs,
net_name=self.server_hostname)
check_func(all_regs,
ip_address=all_ip_addresses)
check_func(all_regs,
client_computer=all_computer_names)
check_func(all_regs,
net_name=self.server_hostname,
ip_address=all_ip_addresses,
client_computer=all_computer_names)
check_func(all_regs,
net_name='.*',
share_name='.*',
ip_address='.*',
client_computer='.*')
check_func(all_regs,
share_name='^$|%s.*' % common_share_name,
ip_address='%s.*' % common_ip_address,
client_computer='%s.*' % common_computer_name)
check_func(all_share_name_regs,
share_name=all_share_names)
check_func(all_share_name_regs,
share_name='%s.*' % common_share_name)
check_func(no_share_name_regs,
share_name='^$')
for share_name in share_name_regs.keys():
regs = share_name_regs[share_name]
check_func(regs, share_name=share_name)
for ip_address in ip_address_regs.keys():
regs = ip_address_regs[ip_address]
check_func(regs, ip_address=ip_address)
for computer_name in computer_name_regs.keys():
regs = computer_name_regs[computer_name]
check_func(regs, client_computer=computer_name)
for reg in all_regs:
regs = [reg]
check_func(regs,
registration_idx=0)
check_func(regs,
registration_idx=0,
net_name=reg['net_name'],
share_name=reg['share_name'],
ip_address=reg['ip_address'],
client_computer=reg['computer_name'])
def test_net_witness_list(self):
def check_list(regs,
registration_idx=None,
net_name=None,
share_name=None,
ip_address=None,
client_computer=None):
return self.check_net_witness_output('list',
regs,
registration_idx=registration_idx,
net_name=net_name,
share_name=share_name,
ip_address=ip_address,
client_computer=client_computer)
self.check_combinations(check_list)
if __name__ == "__main__":
import unittest
unittest.main()