# Test joining as a DC and check the join was done right # # Copyright (C) Andrew Bartlett 2017 # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation; either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # import samba import sys import shutil import os from samba.tests.dns_base import DNSTKeyTest from samba.join import DCJoinContext from samba.dcerpc import drsuapi, misc, dns from samba.credentials import Credentials from samba.provision import interface_ips_v4 def get_logger(name="subunit"): """Get a logger object.""" import logging logger = logging.getLogger(name) logger.addHandler(logging.StreamHandler(sys.stderr)) return logger class JoinTestCase(DNSTKeyTest): def setUp(self): self.server = samba.tests.env_get_var_value("SERVER") self.server_ip = samba.tests.env_get_var_value("SERVER_IP") super().setUp() self.lp = samba.tests.env_loadparm() self.creds = self.get_credentials() self.netbios_name = "jointest1" logger = get_logger() self.join_ctx = DCJoinContext(server=self.server, creds=self.creds, lp=self.get_loadparm(), netbios_name=self.netbios_name, targetdir=self.tempdir, domain=None, logger=logger, dns_backend="SAMBA_INTERNAL") self.join_ctx.userAccountControl = (samba.dsdb.UF_SERVER_TRUST_ACCOUNT | samba.dsdb.UF_TRUSTED_FOR_DELEGATION) self.join_ctx.replica_flags |= (drsuapi.DRSUAPI_DRS_WRIT_REP | drsuapi.DRSUAPI_DRS_FULL_SYNC_IN_PROGRESS) self.join_ctx.domain_replica_flags = self.join_ctx.replica_flags self.join_ctx.secure_channel_type = misc.SEC_CHAN_BDC self.join_ctx.cleanup_old_join() self.join_ctx.force_all_ips = True self.join_ctx.do_join() def tearDown(self): try: paths = self.join_ctx.paths except AttributeError: paths = None if paths is not None: shutil.rmtree(paths.private_dir) shutil.rmtree(paths.state_dir) self.rm_dirs("etc", "msg.lock", "bind-dns") self.rm_files("names.tdb") self.join_ctx.cleanup_old_join(force=True) super().tearDown() def test_join_makes_records(self): "create a query packet containing one query record via TCP" p = self.make_name_packet(dns.DNS_OPCODE_QUERY) questions = [] name = self.join_ctx.dnshostname q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN) questions.append(q) # Get expected IPs IPs = interface_ips_v4(self.lp, all_interfaces=True) self.finish_name_packet(p, questions) (response, response_packet) = self.dns_transaction_tcp(p, host=self.server_ip) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY) self.assertEqual(response.ancount, len(IPs)) questions = [] name = "%s._msdcs.%s" % (self.join_ctx.ntds_guid, self.join_ctx.dnsforest) q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN) questions.append(q) self.finish_name_packet(p, questions) (response, response_packet) = self.dns_transaction_tcp(p, host=self.server_ip) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY) self.assertEqual(response.ancount, 1 + len(IPs)) self.assertEqual(response.answers[0].rr_type, dns.DNS_QTYPE_CNAME) self.assertEqual(response.answers[0].rdata, self.join_ctx.dnshostname) self.assertEqual(response.answers[1].rr_type, dns.DNS_QTYPE_A) def test_join_records_can_update(self): dc_creds = Credentials() dc_creds.guess(self.join_ctx.lp) dc_creds.set_machine_account(self.join_ctx.lp) self.tkey_trans(creds=dc_creds) p = self.make_name_packet(dns.DNS_OPCODE_UPDATE) q = self.make_name_question(self.join_ctx.dnsdomain, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN) questions = [] questions.append(q) self.finish_name_packet(p, questions) updates = [] # Delete the old expected IPs IPs = interface_ips_v4(self.lp, all_interfaces=True) for IP in IPs[1:]: if ":" in IP: r = dns.res_rec() r.name = self.join_ctx.dnshostname r.rr_type = dns.DNS_QTYPE_AAAA r.rr_class = dns.DNS_QCLASS_NONE r.ttl = 0 r.length = 0xffff rdata = IP else: r = dns.res_rec() r.name = self.join_ctx.dnshostname r.rr_type = dns.DNS_QTYPE_A r.rr_class = dns.DNS_QCLASS_NONE r.ttl = 0 r.length = 0xffff rdata = IP r.rdata = rdata updates.append(r) p.nscount = len(updates) p.nsrecs = updates mac = self.sign_packet(p, self.tkey['name']) (response, response_p) = self.dns_transaction_udp(p, self.server_ip) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.verify_packet(response, response_p, mac) p = self.make_name_packet(dns.DNS_OPCODE_QUERY) questions = [] name = self.join_ctx.dnshostname q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN) questions.append(q) self.finish_name_packet(p, questions) (response, response_packet) = self.dns_transaction_tcp(p, host=self.server_ip) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY) self.assertEqual(response.ancount, 1)