# Unix SMB/CIFS implementation. # Copyright (C) Kai Blin 2011 # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation; either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # import os import sys import random import socket import samba import time import errno import samba.ndr as ndr from samba import credentials from samba.tests import TestCase, source_tree_topdir from samba.dcerpc import dns from samba.tests.subunitrun import SubunitOptions, TestProgram import samba.getopt as options import optparse import subprocess DNS_PORT2 = 54 parser = optparse.OptionParser("dns_forwarder.py (dns forwarder)+ [options]") sambaopts = options.SambaOptions(parser) parser.add_option_group(sambaopts) # This timeout only has relevance when testing against Windows # Format errors tend to return patchy responses, so a timeout is needed. parser.add_option("--timeout", type="int", dest="timeout", help="Specify timeout for DNS requests") # use command line creds if available credopts = options.CredentialsOptions(parser) parser.add_option_group(credopts) subunitopts = SubunitOptions(parser) parser.add_option_group(subunitopts) opts, args = parser.parse_args() lp = sambaopts.get_loadparm() creds = credopts.get_credentials(lp) timeout = opts.timeout if len(args) < 3: parser.print_usage() sys.exit(1) server_name = args[0] server_ip = args[1] dns_servers = args[2:] creds.set_krb_forwardable(credentials.NO_KRB_FORWARDABLE) class DNSTest(TestCase): errcodes = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_RCODE_')) def assert_dns_rcode_equals(self, packet, rcode): "Helper function to check return code" p_errcode = packet.operation & dns.DNS_RCODE self.assertEqual(p_errcode, rcode, "Expected RCODE %s, got %s" % (self.errcodes[rcode], self.errcodes[p_errcode])) def assert_dns_opcode_equals(self, packet, opcode): "Helper function to check opcode" p_opcode = packet.operation & dns.DNS_OPCODE self.assertEqual(p_opcode, opcode, "Expected OPCODE %s, got %s" % (opcode, p_opcode)) def make_name_packet(self, opcode, qid=None): "Helper creating a dns.name_packet" p = dns.name_packet() if qid is None: p.id = random.randint(0x0, 0xffff) p.operation = opcode p.questions = [] return p def finish_name_packet(self, packet, questions): "Helper to finalize a dns.name_packet" packet.qdcount = len(questions) packet.questions = questions def make_name_question(self, name, qtype, qclass): "Helper creating a dns.name_question" q = dns.name_question() q.name = name q.question_type = qtype q.question_class = qclass return q def get_dns_domain(self): "Helper to get dns domain" return self.creds.get_realm().lower() def dns_transaction_udp(self, packet, host=server_ip, dump=False, timeout=timeout): "send a DNS query and read the reply" s = None try: send_packet = ndr.ndr_pack(packet) if dump: print(self.hexdump(send_packet)) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0) s.settimeout(timeout) s.connect((host, 53)) s.send(send_packet, 0) recv_packet = s.recv(2048, 0) if dump: print(self.hexdump(recv_packet)) return ndr.ndr_unpack(dns.name_packet, recv_packet) finally: if s is not None: s.close() def make_cname_update(self, key, value): p = self.make_name_packet(dns.DNS_OPCODE_UPDATE) name = self.get_dns_domain() u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN) self.finish_name_packet(p, [u]) r = dns.res_rec() r.name = key r.rr_type = dns.DNS_QTYPE_CNAME r.rr_class = dns.DNS_QCLASS_IN r.ttl = 900 r.length = 0xffff rdata = value r.rdata = rdata p.nscount = 1 p.nsrecs = [r] response = self.dns_transaction_udp(p) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) def contact_real_server(host, port): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0) s.connect((host, port)) return s class TestDnsForwarding(DNSTest): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.subprocesses = [] def setUp(self): super().setUp() self.server = server_name self.server_ip = server_ip self.lp = lp self.creds = creds def start_toy_server(self, host, port, id): python = sys.executable p = subprocess.Popen([python, os.path.join(source_tree_topdir(), 'python/samba/tests/' 'dns_forwarder_helpers/server.py'), host, str(port), id]) self.subprocesses.append(p) if (host.find(':') != -1): s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM, 0) else: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0) for i in range(300): time.sleep(0.05) s.connect((host, port)) try: s.send(b'timeout 0', 0) except socket.error as e: if e.errno in (errno.ECONNREFUSED, errno.EHOSTUNREACH): continue if p.returncode is not None: self.fail("Toy server has managed to die already!") return s def tearDown(self): super().tearDown() for p in self.subprocesses: p.kill() def test_comatose_forwarder(self): s = self.start_toy_server(dns_servers[0], 53, 'forwarder1') s.send(b"timeout 1000000", 0) # make DNS query name = "an-address-that-will-not-resolve" p = self.make_name_packet(dns.DNS_OPCODE_QUERY) questions = [] q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN) questions.append(q) self.finish_name_packet(p, questions) send_packet = ndr.ndr_pack(p) s.send(send_packet, 0) s.settimeout(1) try: s.recv(0xffff + 2, 0) self.fail("DNS forwarder should have been inactive") except socket.timeout: # Expected forwarder to be dead pass def test_no_active_forwarder(self): ad = contact_real_server(server_ip, 53) name = "dsfsfds.dsfsdfs" p = self.make_name_packet(dns.DNS_OPCODE_QUERY) questions = [] q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN) questions.append(q) self.finish_name_packet(p, questions) send_packet = ndr.ndr_pack(p) self.finish_name_packet(p, questions) p.operation |= dns.DNS_FLAG_RECURSION_DESIRED send_packet = ndr.ndr_pack(p) ad.send(send_packet, 0) ad.settimeout(timeout) try: data = ad.recv(0xffff + 2, 0) data = ndr.ndr_unpack(dns.name_packet, data) self.assert_dns_rcode_equals(data, dns.DNS_RCODE_SERVFAIL) self.assertEqual(data.ancount, 0) except socket.timeout: self.fail("DNS server is too slow (timeout %s)" % timeout) def test_no_flag_recursive_forwarder(self): ad = contact_real_server(server_ip, 53) name = "dsfsfds.dsfsdfs" p = self.make_name_packet(dns.DNS_OPCODE_QUERY) questions = [] q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN) questions.append(q) self.finish_name_packet(p, questions) send_packet = ndr.ndr_pack(p) self.finish_name_packet(p, questions) # Leave off the recursive flag send_packet = ndr.ndr_pack(p) ad.send(send_packet, 0) ad.settimeout(timeout) try: data = ad.recv(0xffff + 2, 0) data = ndr.ndr_unpack(dns.name_packet, data) self.assert_dns_rcode_equals(data, dns.DNS_RCODE_NXDOMAIN) self.assertEqual(data.ancount, 0) except socket.timeout: self.fail("DNS server is too slow (timeout %s)" % timeout) def test_single_forwarder(self): s = self.start_toy_server(dns_servers[0], 53, 'forwarder1') ad = contact_real_server(server_ip, 53) name = "dsfsfds.dsfsdfs" p = self.make_name_packet(dns.DNS_OPCODE_QUERY) questions = [] q = self.make_name_question(name, dns.DNS_QTYPE_CNAME, dns.DNS_QCLASS_IN) questions.append(q) self.finish_name_packet(p, questions) p.operation |= dns.DNS_FLAG_RECURSION_DESIRED send_packet = ndr.ndr_pack(p) ad.send(send_packet, 0) ad.settimeout(timeout) try: data = ad.recv(0xffff + 2, 0) data = ndr.ndr_unpack(dns.name_packet, data) self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK) self.assertEqual('forwarder1', data.answers[0].rdata) except socket.timeout: self.fail("DNS server is too slow (timeout %s)" % timeout) def test_single_forwarder_not_actually_there(self): ad = contact_real_server(server_ip, 53) name = "dsfsfds.dsfsdfs" p = self.make_name_packet(dns.DNS_OPCODE_QUERY) questions = [] q = self.make_name_question(name, dns.DNS_QTYPE_CNAME, dns.DNS_QCLASS_IN) questions.append(q) self.finish_name_packet(p, questions) p.operation |= dns.DNS_FLAG_RECURSION_DESIRED send_packet = ndr.ndr_pack(p) ad.send(send_packet, 0) ad.settimeout(timeout) try: data = ad.recv(0xffff + 2, 0) data = ndr.ndr_unpack(dns.name_packet, data) self.assert_dns_rcode_equals(data, dns.DNS_RCODE_SERVFAIL) except socket.timeout: self.fail("DNS server is too slow (timeout %s)" % timeout) def test_single_forwarder_waiting_forever(self): s = self.start_toy_server(dns_servers[0], 53, 'forwarder1') s.send(b'timeout 10000', 0) ad = contact_real_server(server_ip, 53) name = "dsfsfds.dsfsdfs" p = self.make_name_packet(dns.DNS_OPCODE_QUERY) questions = [] q = self.make_name_question(name, dns.DNS_QTYPE_CNAME, dns.DNS_QCLASS_IN) questions.append(q) self.finish_name_packet(p, questions) p.operation |= dns.DNS_FLAG_RECURSION_DESIRED send_packet = ndr.ndr_pack(p) ad.send(send_packet, 0) ad.settimeout(timeout) try: data = ad.recv(0xffff + 2, 0) data = ndr.ndr_unpack(dns.name_packet, data) self.assert_dns_rcode_equals(data, dns.DNS_RCODE_SERVFAIL) except socket.timeout: self.fail("DNS server is too slow (timeout %s)" % timeout) def test_double_forwarder_first_frozen(self): if len(dns_servers) < 2: print("Ignoring test_double_forwarder_first_frozen") return s1 = self.start_toy_server(dns_servers[0], 53, 'forwarder1') s2 = self.start_toy_server(dns_servers[1], DNS_PORT2, 'forwarder2') s1.send(b'timeout 1000', 0) ad = contact_real_server(server_ip, 53) name = "dsfsfds.dsfsdfs" p = self.make_name_packet(dns.DNS_OPCODE_QUERY) questions = [] q = self.make_name_question(name, dns.DNS_QTYPE_CNAME, dns.DNS_QCLASS_IN) questions.append(q) self.finish_name_packet(p, questions) p.operation |= dns.DNS_FLAG_RECURSION_DESIRED send_packet = ndr.ndr_pack(p) ad.send(send_packet, 0) ad.settimeout(timeout) try: data = ad.recv(0xffff + 2, 0) data = ndr.ndr_unpack(dns.name_packet, data) self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK) self.assertEqual('forwarder2', data.answers[0].rdata) except socket.timeout: self.fail("DNS server is too slow (timeout %s)" % timeout) def test_double_forwarder_first_down(self): if len(dns_servers) < 2: print("Ignoring test_double_forwarder_first_down") return s2 = self.start_toy_server(dns_servers[1], DNS_PORT2, 'forwarder2') ad = contact_real_server(server_ip, 53) name = "dsfsfds.dsfsdfs" p = self.make_name_packet(dns.DNS_OPCODE_QUERY) questions = [] q = self.make_name_question(name, dns.DNS_QTYPE_CNAME, dns.DNS_QCLASS_IN) questions.append(q) self.finish_name_packet(p, questions) p.operation |= dns.DNS_FLAG_RECURSION_DESIRED send_packet = ndr.ndr_pack(p) ad.send(send_packet, 0) ad.settimeout(timeout) try: data = ad.recv(0xffff + 2, 0) data = ndr.ndr_unpack(dns.name_packet, data) self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK) self.assertEqual('forwarder2', data.answers[0].rdata) except socket.timeout: self.fail("DNS server is too slow (timeout %s)" % timeout) def test_double_forwarder_both_slow(self): if len(dns_servers) < 2: print("Ignoring test_double_forwarder_both_slow") return s1 = self.start_toy_server(dns_servers[0], 53, 'forwarder1') s2 = self.start_toy_server(dns_servers[1], DNS_PORT2, 'forwarder2') s1.send(b'timeout 1.5', 0) s2.send(b'timeout 1.5', 0) ad = contact_real_server(server_ip, 53) name = "dsfsfds.dsfsdfs" p = self.make_name_packet(dns.DNS_OPCODE_QUERY) questions = [] q = self.make_name_question(name, dns.DNS_QTYPE_CNAME, dns.DNS_QCLASS_IN) questions.append(q) self.finish_name_packet(p, questions) p.operation |= dns.DNS_FLAG_RECURSION_DESIRED send_packet = ndr.ndr_pack(p) ad.send(send_packet, 0) ad.settimeout(timeout) try: data = ad.recv(0xffff + 2, 0) data = ndr.ndr_unpack(dns.name_packet, data) self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK) self.assertEqual('forwarder1', data.answers[0].rdata) except socket.timeout: self.fail("DNS server is too slow (timeout %s)" % timeout) def test_cname(self): s1 = self.start_toy_server(dns_servers[0], 53, 'forwarder1') ad = contact_real_server(server_ip, 53) name = "resolve.cname" p = self.make_name_packet(dns.DNS_OPCODE_QUERY) questions = [] q = self.make_name_question(name, dns.DNS_QTYPE_CNAME, dns.DNS_QCLASS_IN) questions.append(q) self.finish_name_packet(p, questions) p.operation |= dns.DNS_FLAG_RECURSION_DESIRED send_packet = ndr.ndr_pack(p) ad.send(send_packet, 0) ad.settimeout(timeout) try: data = ad.recv(0xffff + 2, 0) data = ndr.ndr_unpack(dns.name_packet, data) self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK) self.assertEqual(len(data.answers), 1) self.assertEqual('forwarder1', data.answers[0].rdata) except socket.timeout: self.fail("DNS server is too slow (timeout %s)" % timeout) def test_double_cname(self): s1 = self.start_toy_server(dns_servers[0], 53, 'forwarder1') name = 'resolve.cname.%s' % self.get_dns_domain() self.make_cname_update(name, "dsfsfds.dsfsdfs") ad = contact_real_server(server_ip, 53) p = self.make_name_packet(dns.DNS_OPCODE_QUERY) questions = [] q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN) questions.append(q) self.finish_name_packet(p, questions) p.operation |= dns.DNS_FLAG_RECURSION_DESIRED send_packet = ndr.ndr_pack(p) ad.send(send_packet, 0) ad.settimeout(timeout) try: data = ad.recv(0xffff + 2, 0) data = ndr.ndr_unpack(dns.name_packet, data) self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK) self.assertEqual('forwarder1', data.answers[1].rdata) except socket.timeout: self.fail("DNS server is too slow (timeout %s)" % timeout) def test_cname_forwarding_with_slow_server(self): if len(dns_servers) < 2: print("Ignoring test_cname_forwarding_with_slow_server") return s1 = self.start_toy_server(dns_servers[0], 53, 'forwarder1') s2 = self.start_toy_server(dns_servers[1], DNS_PORT2, 'forwarder2') s1.send(b'timeout 10000', 0) name = 'resolve.cname.%s' % self.get_dns_domain() self.make_cname_update(name, "dsfsfds.dsfsdfs") ad = contact_real_server(server_ip, 53) p = self.make_name_packet(dns.DNS_OPCODE_QUERY) questions = [] q = self.make_name_question(name, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN) questions.append(q) self.finish_name_packet(p, questions) p.operation |= dns.DNS_FLAG_RECURSION_DESIRED send_packet = ndr.ndr_pack(p) ad.send(send_packet, 0) ad.settimeout(timeout) try: data = ad.recv(0xffff + 2, 0) data = ndr.ndr_unpack(dns.name_packet, data) self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK) self.assertEqual('forwarder2', data.answers[-1].rdata) except socket.timeout: self.fail("DNS server is too slow (timeout %s)" % timeout) def test_cname_forwarding_with_server_down(self): if len(dns_servers) < 2: print("Ignoring test_cname_forwarding_with_server_down") return s2 = self.start_toy_server(dns_servers[1], DNS_PORT2, 'forwarder2') name1 = 'resolve1.cname.%s' % self.get_dns_domain() name2 = 'resolve2.cname.%s' % self.get_dns_domain() self.make_cname_update(name1, name2) self.make_cname_update(name2, "dsfsfds.dsfsdfs") ad = contact_real_server(server_ip, 53) p = self.make_name_packet(dns.DNS_OPCODE_QUERY) questions = [] q = self.make_name_question(name1, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN) questions.append(q) self.finish_name_packet(p, questions) p.operation |= dns.DNS_FLAG_RECURSION_DESIRED send_packet = ndr.ndr_pack(p) ad.send(send_packet, 0) ad.settimeout(timeout) try: data = ad.recv(0xffff + 2, 0) data = ndr.ndr_unpack(dns.name_packet, data) self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK) self.assertEqual('forwarder2', data.answers[-1].rdata) except socket.timeout: self.fail("DNS server is too slow (timeout %s)" % timeout) def test_cname_forwarding_with_lots_of_cnames(self): name3 = 'resolve3.cname.%s' % self.get_dns_domain() s1 = self.start_toy_server(dns_servers[0], 53, name3) name1 = 'resolve1.cname.%s' % self.get_dns_domain() name2 = 'resolve2.cname.%s' % self.get_dns_domain() self.make_cname_update(name1, name2) self.make_cname_update(name3, name1) self.make_cname_update(name2, "dsfsfds.dsfsdfs") ad = contact_real_server(server_ip, 53) p = self.make_name_packet(dns.DNS_OPCODE_QUERY) questions = [] q = self.make_name_question(name1, dns.DNS_QTYPE_A, dns.DNS_QCLASS_IN) questions.append(q) self.finish_name_packet(p, questions) p.operation |= dns.DNS_FLAG_RECURSION_DESIRED send_packet = ndr.ndr_pack(p) ad.send(send_packet, 0) ad.settimeout(timeout) try: data = ad.recv(0xffff + 2, 0) data = ndr.ndr_unpack(dns.name_packet, data) # This should cause a loop in Windows # (which is restricted by a 20 CNAME limit) # # The reason it doesn't here is because forwarded CNAME have no # additional processing in the internal DNS server. self.assert_dns_rcode_equals(data, dns.DNS_RCODE_OK) self.assertEqual(name3, data.answers[-1].rdata) except socket.timeout: self.fail("DNS server is too slow (timeout %s)" % timeout) TestProgram(module=__name__, opts=subunitopts)