1
0
mirror of https://github.com/samba-team/samba.git synced 2025-08-04 08:22:08 +03:00

selftest: merge DNSTest boilerplate

This will help unifying dns.py and dns_tkey.py to use common subclasses

The code was originally copied, but has since divereged.  This handles
that divergence.

Signed-off-by: Andrew Bartlett <abartlet@samba.org>
Reviewed-by: Garming Sam <garming@catalyst.net.nz>
This commit is contained in:
Andrew Bartlett
2017-06-09 10:00:09 +12:00
parent 589a6621ee
commit 11ba6f8cde
2 changed files with 196 additions and 88 deletions

View File

@ -63,12 +63,8 @@ creds.set_krb_forwardable(credentials.NO_KRB_FORWARDABLE)
class DNSTest(TestCase): class DNSTest(TestCase):
def setUp(self): def setUp(self):
global server, server_ip, lp, creds
super(DNSTest, self).setUp() super(DNSTest, self).setUp()
self.server = server_name self.timeout = None
self.server_ip = server_ip
self.lp = lp
self.creds = creds
def errstr(self, errcode): def errstr(self, errcode):
"Return a readable error code" "Return a readable error code"
@ -84,30 +80,42 @@ class DNSTest(TestCase):
"NXRRSET", "NXRRSET",
"NOTAUTH", "NOTAUTH",
"NOTZONE", "NOTZONE",
"0x0B",
"0x0C",
"0x0D",
"0x0E",
"0x0F",
"BADSIG",
"BADKEY"
] ]
return string_codes[errcode] return string_codes[errcode]
def assert_rcode_equals(self, rcode, expected):
"Helper function to check return code"
self.assertEquals(rcode, expected, "Expected RCODE %s, got %s" %
(self.errstr(expected), self.errstr(rcode)))
def assert_dns_rcode_equals(self, packet, rcode): def assert_dns_rcode_equals(self, packet, rcode):
"Helper function to check return code" "Helper function to check return code"
p_errcode = packet.operation & 0x000F p_errcode = packet.operation & 0x000F
self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" % self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" %
(self.errstr(rcode), self.errstr(p_errcode))) (self.errstr(rcode), self.errstr(p_errcode)))
def assert_dns_opcode_equals(self, packet, opcode): def assert_dns_opcode_equals(self, packet, opcode):
"Helper function to check opcode" "Helper function to check opcode"
p_opcode = packet.operation & 0x7800 p_opcode = packet.operation & 0x7800
self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" % self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" %
(opcode, p_opcode)) (opcode, p_opcode))
def make_name_packet(self, opcode, qid=None): def make_name_packet(self, opcode, qid=None):
"Helper creating a dns.name_packet" "Helper creating a dns.name_packet"
p = dns.name_packet() p = dns.name_packet()
if qid is None: if qid is None:
p.id = random.randint(0x0, 0xffff) p.id = random.randint(0x0, 0xff00)
p.operation = opcode p.operation = opcode
p.questions = [] p.questions = []
p.additional = []
return p return p
def finish_name_packet(self, packet, questions): def finish_name_packet(self, packet, questions):
@ -135,10 +143,12 @@ class DNSTest(TestCase):
"Helper to get dns domain" "Helper to get dns domain"
return self.creds.get_realm().lower() return self.creds.get_realm().lower()
def dns_transaction_udp(self, packet, host=server_ip, def dns_transaction_udp(self, packet, host,
dump=False, timeout=timeout): dump=False, timeout=None):
"send a DNS query and read the reply" "send a DNS query and read the reply"
s = None s = None
if timeout is None:
timeout = self.timeout
try: try:
send_packet = ndr.ndr_pack(packet) send_packet = ndr.ndr_pack(packet)
if dump: if dump:
@ -146,19 +156,22 @@ class DNSTest(TestCase):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
s.settimeout(timeout) s.settimeout(timeout)
s.connect((host, 53)) s.connect((host, 53))
s.send(send_packet, 0) s.sendall(send_packet, 0)
recv_packet = s.recv(2048, 0) recv_packet = s.recv(2048, 0)
if dump: if dump:
print self.hexdump(recv_packet) print self.hexdump(recv_packet)
return ndr.ndr_unpack(dns.name_packet, recv_packet) response = ndr.ndr_unpack(dns.name_packet, recv_packet)
return (response, recv_packet)
finally: finally:
if s is not None: if s is not None:
s.close() s.close()
def dns_transaction_tcp(self, packet, host=server_ip, def dns_transaction_tcp(self, packet, host,
dump=False, timeout=timeout): dump=False, timeout=None):
"send a DNS query and read the reply" "send a DNS query and read the reply, also return the raw packet"
s = None s = None
if timeout is None:
timeout = self.timeout
try: try:
send_packet = ndr.ndr_pack(packet) send_packet = ndr.ndr_pack(packet)
if dump: if dump:
@ -168,14 +181,22 @@ class DNSTest(TestCase):
s.connect((host, 53)) s.connect((host, 53))
tcp_packet = struct.pack('!H', len(send_packet)) tcp_packet = struct.pack('!H', len(send_packet))
tcp_packet += send_packet tcp_packet += send_packet
s.send(tcp_packet, 0) s.sendall(tcp_packet)
recv_packet = s.recv(0xffff + 2, 0) recv_packet = s.recv(0xffff + 2, 0)
if dump: if dump:
print self.hexdump(recv_packet) print self.hexdump(recv_packet)
return ndr.ndr_unpack(dns.name_packet, recv_packet[2:]) response = ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
finally: finally:
if s is not None: if s is not None:
s.close() s.close()
# unpacking and packing again should produce same bytestream
my_packet = ndr.ndr_pack(response)
self.assertEquals(my_packet, recv_packet[2:])
return (response, recv_packet[2:])
def make_txt_update(self, prefix, txt_array): def make_txt_update(self, prefix, txt_array):
p = self.make_name_packet(dns.DNS_OPCODE_UPDATE) p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
@ -210,12 +231,21 @@ class DNSTest(TestCase):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.assertEquals(response.ancount, 1) self.assertEquals(response.ancount, 1)
self.assertEquals(response.answers[0].rdata.txt.str, txt_array) self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
class TestSimpleQueries(DNSTest): class TestSimpleQueries(DNSTest):
def setUp(self):
super(TestSimpleQueries, self).setUp()
global server, server_ip, lp, creds, timeout
self.server = server_name
self.server_ip = server_ip
self.lp = lp
self.creds = creds
self.timeout = timeout
def test_one_a_query(self): def test_one_a_query(self):
"create a query packet containing one query record" "create a query packet containing one query record"
@ -228,7 +258,7 @@ class TestSimpleQueries(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY) self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
self.assertEquals(response.ancount, 1) self.assertEquals(response.ancount, 1)
@ -246,7 +276,7 @@ class TestSimpleQueries(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY) self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
self.assertEquals(response.ancount, 1) self.assertEquals(response.ancount, 1)
@ -264,7 +294,7 @@ class TestSimpleQueries(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_tcp(p) (response, response_packet) = self.dns_transaction_tcp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY) self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
self.assertEquals(response.ancount, 1) self.assertEquals(response.ancount, 1)
@ -282,7 +312,7 @@ class TestSimpleQueries(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY) self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
self.assertEquals(response.ancount, 0) self.assertEquals(response.ancount, 0)
@ -296,7 +326,7 @@ class TestSimpleQueries(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY) self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
self.assertEquals(response.ancount, 0) self.assertEquals(response.ancount, 0)
@ -316,7 +346,7 @@ class TestSimpleQueries(DNSTest):
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
try: try:
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
except socket.timeout: except socket.timeout:
# Windows chooses not to respond to incorrectly formatted queries. # Windows chooses not to respond to incorrectly formatted queries.
@ -336,7 +366,7 @@ class TestSimpleQueries(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
num_answers = 1 num_answers = 1
dc_ipv6 = os.getenv('SERVER_IPV6') dc_ipv6 = os.getenv('SERVER_IPV6')
@ -362,7 +392,7 @@ class TestSimpleQueries(DNSTest):
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
try: try:
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP)
except socket.timeout: except socket.timeout:
# Windows chooses not to respond to incorrectly formatted queries. # Windows chooses not to respond to incorrectly formatted queries.
@ -381,7 +411,7 @@ class TestSimpleQueries(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY) self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
# We don't get SOA records for single hosts # We don't get SOA records for single hosts
@ -399,7 +429,7 @@ class TestSimpleQueries(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY) self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
self.assertEquals(response.ancount, 1) self.assertEquals(response.ancount, 1)
@ -407,6 +437,14 @@ class TestSimpleQueries(DNSTest):
class TestDNSUpdates(DNSTest): class TestDNSUpdates(DNSTest):
def setUp(self):
super(TestDNSUpdates, self).setUp()
global server, server_ip, lp, creds, timeout
self.server = server_name
self.server_ip = server_ip
self.lp = lp
self.creds = creds
self.timeout = timeout
def test_two_updates(self): def test_two_updates(self):
"create two update requests" "create two update requests"
@ -423,7 +461,7 @@ class TestDNSUpdates(DNSTest):
self.finish_name_packet(p, updates) self.finish_name_packet(p, updates)
try: try:
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
except socket.timeout: except socket.timeout:
# Windows chooses not to respond to incorrectly formatted queries. # Windows chooses not to respond to incorrectly formatted queries.
@ -442,7 +480,7 @@ class TestDNSUpdates(DNSTest):
updates.append(u) updates.append(u)
self.finish_name_packet(p, updates) self.finish_name_packet(p, updates)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP)
def test_update_prereq_with_non_null_ttl(self): def test_update_prereq_with_non_null_ttl(self):
@ -469,7 +507,7 @@ class TestDNSUpdates(DNSTest):
p.answers = prereqs p.answers = prereqs
try: try:
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
except socket.timeout: except socket.timeout:
# Windows chooses not to respond to incorrectly formatted queries. # Windows chooses not to respond to incorrectly formatted queries.
@ -501,7 +539,7 @@ class TestDNSUpdates(DNSTest):
p.ancount = len(prereqs) p.ancount = len(prereqs)
p.answers = prereqs p.answers = prereqs
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXRRSET) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXRRSET)
def test_update_prereq_nonexisting_name(self): def test_update_prereq_nonexisting_name(self):
@ -527,14 +565,14 @@ class TestDNSUpdates(DNSTest):
p.ancount = len(prereqs) p.ancount = len(prereqs)
p.answers = prereqs p.answers = prereqs
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXRRSET) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXRRSET)
def test_update_add_txt_record(self): def test_update_add_txt_record(self):
"test adding records works" "test adding records works"
prefix, txt = 'textrec', ['"This is a test"'] prefix, txt = 'textrec', ['"This is a test"']
p = self.make_txt_update(prefix, txt) p = self.make_txt_update(prefix, txt)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.check_query_txt(prefix, txt) self.check_query_txt(prefix, txt)
@ -566,7 +604,7 @@ class TestDNSUpdates(DNSTest):
p.nscount = len(updates) p.nscount = len(updates)
p.nsrecs = updates p.nsrecs = updates
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
# Now check the record is around # Now check the record is around
@ -576,7 +614,7 @@ class TestDNSUpdates(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
# Now delete the record # Now delete the record
@ -602,7 +640,7 @@ class TestDNSUpdates(DNSTest):
p.nscount = len(updates) p.nscount = len(updates)
p.nsrecs = updates p.nsrecs = updates
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
# And finally check it's gone # And finally check it's gone
@ -613,7 +651,7 @@ class TestDNSUpdates(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
def test_readd_record(self): def test_readd_record(self):
@ -644,7 +682,7 @@ class TestDNSUpdates(DNSTest):
p.nscount = len(updates) p.nscount = len(updates)
p.nsrecs = updates p.nsrecs = updates
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
# Now check the record is around # Now check the record is around
@ -654,7 +692,7 @@ class TestDNSUpdates(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
# Now delete the record # Now delete the record
@ -680,7 +718,7 @@ class TestDNSUpdates(DNSTest):
p.nscount = len(updates) p.nscount = len(updates)
p.nsrecs = updates p.nsrecs = updates
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
# check it's gone # check it's gone
@ -691,7 +729,7 @@ class TestDNSUpdates(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
# recreate the record # recreate the record
@ -717,7 +755,7 @@ class TestDNSUpdates(DNSTest):
p.nscount = len(updates) p.nscount = len(updates)
p.nsrecs = updates p.nsrecs = updates
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
# Now check the record is around # Now check the record is around
@ -727,7 +765,7 @@ class TestDNSUpdates(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
def test_update_add_mx_record(self): def test_update_add_mx_record(self):
@ -756,7 +794,7 @@ class TestDNSUpdates(DNSTest):
p.nscount = len(updates) p.nscount = len(updates)
p.nsrecs = updates p.nsrecs = updates
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
p = self.make_name_packet(dns.DNS_OPCODE_QUERY) p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
@ -767,7 +805,7 @@ class TestDNSUpdates(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.assertEqual(response.ancount, 1) self.assertEqual(response.ancount, 1)
ans = response.answers[0] ans = response.answers[0]
@ -795,22 +833,22 @@ class TestComplexQueries(DNSTest):
updates = [r] updates = [r]
p.nscount = 1 p.nscount = 1
p.nsrecs = updates p.nsrecs = updates
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
def setUp(self): def setUp(self):
super(TestComplexQueries, self).setUp() super(TestComplexQueries, self).setUp()
def tearDown(self): global server, server_ip, lp, creds, timeout
super(TestComplexQueries, self).tearDown() self.server = server_name
self.server_ip = server_ip
self.lp = lp
self.creds = creds
self.timeout = timeout
def test_one_a_query(self): def test_one_a_query(self):
"create a query packet containing one query record" "create a query packet containing one query record"
name = "cname_test.%s" % self.get_dns_domain()
rdata = "%s.%s" % (self.server, self.get_dns_domain())
self.make_dns_update(name, rdata, dns.DNS_QTYPE_CNAME)
try: try:
# Create the record # Create the record
@ -828,7 +866,7 @@ class TestComplexQueries(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY) self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
self.assertEquals(response.ancount, 2) self.assertEquals(response.ancount, 2)
@ -862,7 +900,7 @@ class TestComplexQueries(DNSTest):
p.nscount = len(updates) p.nscount = len(updates)
p.nsrecs = updates p.nsrecs = updates
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
def test_cname_two_chain(self): def test_cname_two_chain(self):
@ -880,7 +918,7 @@ class TestComplexQueries(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY) self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
self.assertEquals(response.ancount, 3) self.assertEquals(response.ancount, 3)
@ -921,7 +959,7 @@ class TestComplexQueries(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY) self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
@ -938,6 +976,14 @@ class TestComplexQueries(DNSTest):
self.assertEquals(response.answers[1].rdata, name0) self.assertEquals(response.answers[1].rdata, name0)
class TestInvalidQueries(DNSTest): class TestInvalidQueries(DNSTest):
def setUp(self):
super(TestInvalidQueries, self).setUp()
global server, server_ip, lp, creds, timeout
self.server = server_name
self.server_ip = server_ip
self.lp = lp
self.creds = creds
self.timeout = timeout
def test_one_a_query(self): def test_one_a_query(self):
"send 0 bytes follows by create a query packet containing one query record" "send 0 bytes follows by create a query packet containing one query record"
@ -960,7 +1006,7 @@ class TestInvalidQueries(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY) self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
self.assertEquals(response.ancount, 1) self.assertEquals(response.ancount, 1)
@ -1006,6 +1052,13 @@ class TestInvalidQueries(DNSTest):
class TestZones(DNSTest): class TestZones(DNSTest):
def setUp(self): def setUp(self):
super(TestZones, self).setUp() super(TestZones, self).setUp()
global server, server_ip, lp, creds, timeout
self.server = server_name
self.server_ip = server_ip
self.lp = lp
self.creds = creds
self.timeout = timeout
self.zone = "test.lan" self.zone = "test.lan"
self.rpc_conn = dnsserver.dnsserver("ncacn_ip_tcp:%s[sign]" % (self.server_ip), self.rpc_conn = dnsserver.dnsserver("ncacn_ip_tcp:%s[sign]" % (self.server_ip),
self.lp, self.creds) self.lp, self.creds)
@ -1056,21 +1109,21 @@ class TestZones(DNSTest):
questions.append(q) questions.append(q)
self.finish_name_packet(p, questions) self.finish_name_packet(p, questions)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
# Windows returns OK while BIND logically seems to return NXDOMAIN # Windows returns OK while BIND logically seems to return NXDOMAIN
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY) self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
self.assertEquals(response.ancount, 0) self.assertEquals(response.ancount, 0)
self.create_zone(zone) self.create_zone(zone)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY) self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
self.assertEquals(response.ancount, 1) self.assertEquals(response.ancount, 1)
self.assertEquals(response.answers[0].rr_type, dns.DNS_QTYPE_SOA) self.assertEquals(response.answers[0].rr_type, dns.DNS_QTYPE_SOA)
self.delete_zone(zone) self.delete_zone(zone)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY) self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
self.assertEquals(response.ancount, 0) self.assertEquals(response.ancount, 0)
@ -1078,6 +1131,11 @@ class TestZones(DNSTest):
class TestRPCRoundtrip(DNSTest): class TestRPCRoundtrip(DNSTest):
def setUp(self): def setUp(self):
super(TestRPCRoundtrip, self).setUp() super(TestRPCRoundtrip, self).setUp()
global server, server_ip, lp, creds
self.server = server_name
self.server_ip = server_ip
self.lp = lp
self.creds = creds
self.rpc_conn = dnsserver.dnsserver("ncacn_ip_tcp:%s[sign]" % (self.server_ip), self.rpc_conn = dnsserver.dnsserver("ncacn_ip_tcp:%s[sign]" % (self.server_ip),
self.lp, self.creds) self.lp, self.creds)
@ -1110,7 +1168,7 @@ class TestRPCRoundtrip(DNSTest):
"test adding records works" "test adding records works"
prefix, txt = 'pad1textrec', ['"This is a test"', '', ''] prefix, txt = 'pad1textrec', ['"This is a test"', '', '']
p = self.make_txt_update(prefix, txt) p = self.make_txt_update(prefix, txt)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.check_query_txt(prefix, txt) self.check_query_txt(prefix, txt)
self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip, self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@ -1120,7 +1178,7 @@ class TestRPCRoundtrip(DNSTest):
prefix, txt = 'pad2textrec', ['"This is a test"', '', '', 'more text'] prefix, txt = 'pad2textrec', ['"This is a test"', '', '', 'more text']
p = self.make_txt_update(prefix, txt) p = self.make_txt_update(prefix, txt)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.check_query_txt(prefix, txt) self.check_query_txt(prefix, txt)
self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip, self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@ -1130,7 +1188,7 @@ class TestRPCRoundtrip(DNSTest):
prefix, txt = 'pad3textrec', ['', '', '"This is a test"'] prefix, txt = 'pad3textrec', ['', '', '"This is a test"']
p = self.make_txt_update(prefix, txt) p = self.make_txt_update(prefix, txt)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.check_query_txt(prefix, txt) self.check_query_txt(prefix, txt)
self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip, self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@ -1209,7 +1267,7 @@ class TestRPCRoundtrip(DNSTest):
"test adding records works" "test adding records works"
prefix, txt = 'nulltextrec', ['NULL\x00BYTE'] prefix, txt = 'nulltextrec', ['NULL\x00BYTE']
p = self.make_txt_update(prefix, txt) p = self.make_txt_update(prefix, txt)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.check_query_txt(prefix, ['NULL']) self.check_query_txt(prefix, ['NULL'])
self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip, self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@ -1219,7 +1277,7 @@ class TestRPCRoundtrip(DNSTest):
prefix, txt = 'nulltextrec2', ['NULL\x00BYTE', 'NULL\x00BYTE'] prefix, txt = 'nulltextrec2', ['NULL\x00BYTE', 'NULL\x00BYTE']
p = self.make_txt_update(prefix, txt) p = self.make_txt_update(prefix, txt)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.check_query_txt(prefix, ['NULL', 'NULL']) self.check_query_txt(prefix, ['NULL', 'NULL'])
self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip, self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@ -1254,7 +1312,7 @@ class TestRPCRoundtrip(DNSTest):
"test adding records works" "test adding records works"
prefix, txt = 'hextextrec', ['HIGH\xFFBYTE'] prefix, txt = 'hextextrec', ['HIGH\xFFBYTE']
p = self.make_txt_update(prefix, txt) p = self.make_txt_update(prefix, txt)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.check_query_txt(prefix, txt) self.check_query_txt(prefix, txt)
self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip, self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@ -1289,7 +1347,7 @@ class TestRPCRoundtrip(DNSTest):
"test adding records works" "test adding records works"
prefix, txt = 'slashtextrec', ['Th\\=is=is a test'] prefix, txt = 'slashtextrec', ['Th\\=is=is a test']
p = self.make_txt_update(prefix, txt) p = self.make_txt_update(prefix, txt)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.check_query_txt(prefix, txt) self.check_query_txt(prefix, txt)
self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip, self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@ -1329,7 +1387,7 @@ class TestRPCRoundtrip(DNSTest):
prefix, txt = 'textrec2', ['"This is a test"', prefix, txt = 'textrec2', ['"This is a test"',
'"and this is a test, too"'] '"and this is a test, too"']
p = self.make_txt_update(prefix, txt) p = self.make_txt_update(prefix, txt)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.check_query_txt(prefix, txt) self.check_query_txt(prefix, txt)
self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip, self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@ -1368,7 +1426,7 @@ class TestRPCRoundtrip(DNSTest):
"test adding two txt records works" "test adding two txt records works"
prefix, txt = 'emptytextrec', [] prefix, txt = 'emptytextrec', []
p = self.make_txt_update(prefix, txt) p = self.make_txt_update(prefix, txt)
response = self.dns_transaction_udp(p) (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK) self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.check_query_txt(prefix, txt) self.check_query_txt(prefix, txt)
self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip, self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,

View File

@ -29,6 +29,7 @@ from samba import credentials
from samba.dcerpc import dns, dnsp from samba.dcerpc import dns, dnsp
from samba.tests.subunitrun import SubunitOptions, TestProgram from samba.tests.subunitrun import SubunitOptions, TestProgram
from samba import gensec, tests from samba import gensec, tests
from samba.tests import TestCase
parser = optparse.OptionParser("dns.py <server name> <server ip> [options]") parser = optparse.OptionParser("dns.py <server name> <server ip> [options]")
sambaopts = options.SambaOptions(parser) sambaopts = options.SambaOptions(parser)
@ -56,21 +57,11 @@ server_name = args[0]
server_ip = args[1] server_ip = args[1]
class DNSTest(tests.TestCase): class DNSTest(TestCase):
def setUp(self): def setUp(self):
super(DNSTest, self).setUp() super(DNSTest, self).setUp()
self.server = server_name self.timeout = None
self.server_ip = server_ip
self.settings = {}
self.settings["lp_ctx"] = self.lp_ctx = tests.env_loadparm()
self.settings["target_hostname"] = self.server
self.creds = credentials.Credentials()
self.creds.guess(self.lp_ctx)
self.creds.set_username(tests.env_get_var_value('USERNAME'))
self.creds.set_password(tests.env_get_var_value('PASSWORD'))
self.creds.set_kerberos_state(credentials.MUST_USE_KERBEROS)
self.newrecname = "tkeytsig.%s" % self.get_dns_domain()
def errstr(self, errcode): def errstr(self, errcode):
"Return a readable error code" "Return a readable error code"
@ -150,9 +141,11 @@ class DNSTest(tests.TestCase):
return self.creds.get_realm().lower() return self.creds.get_realm().lower()
def dns_transaction_udp(self, packet, host, def dns_transaction_udp(self, packet, host,
dump=False, timeout=timeout): dump=False, timeout=None):
"send a DNS query and read the reply" "send a DNS query and read the reply"
s = None s = None
if timeout is None:
timeout = self.timeout
try: try:
send_packet = ndr.ndr_pack(packet) send_packet = ndr.ndr_pack(packet)
if dump: if dump:
@ -171,9 +164,11 @@ class DNSTest(tests.TestCase):
s.close() s.close()
def dns_transaction_tcp(self, packet, host, def dns_transaction_tcp(self, packet, host,
dump=False, timeout=timeout): dump=False, timeout=None):
"send a DNS query and read the reply, also return the raw packet" "send a DNS query and read the reply, also return the raw packet"
s = None s = None
if timeout is None:
timeout = self.timeout
try: try:
send_packet = ndr.ndr_pack(packet) send_packet = ndr.ndr_pack(packet)
if dump: if dump:
@ -200,6 +195,61 @@ class DNSTest(tests.TestCase):
return (response, recv_packet[2:]) return (response, recv_packet[2:])
def make_txt_update(self, prefix, txt_array):
p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
updates = []
name = self.get_dns_domain()
u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
updates.append(u)
self.finish_name_packet(p, updates)
updates = []
r = dns.res_rec()
r.name = "%s.%s" % (prefix, self.get_dns_domain())
r.rr_type = dns.DNS_QTYPE_TXT
r.rr_class = dns.DNS_QCLASS_IN
r.ttl = 900
r.length = 0xffff
rdata = self.make_txt_record(txt_array)
r.rdata = rdata
updates.append(r)
p.nscount = len(updates)
p.nsrecs = updates
return p
def check_query_txt(self, prefix, txt_array):
name = "%s.%s" % (prefix, self.get_dns_domain())
p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
questions = []
q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
questions.append(q)
self.finish_name_packet(p, questions)
(response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
self.assertEquals(response.ancount, 1)
self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
class DNSTKeyTest(DNSTest):
def setUp(self):
super(DNSTKeyTest, self).setUp()
self.server = server_name
self.server_ip = server_ip
self.settings = {}
self.settings["lp_ctx"] = self.lp_ctx = tests.env_loadparm()
self.settings["target_hostname"] = self.server
self.creds = credentials.Credentials()
self.creds.guess(self.lp_ctx)
self.creds.set_username(tests.env_get_var_value('USERNAME'))
self.creds.set_password(tests.env_get_var_value('PASSWORD'))
self.creds.set_kerberos_state(credentials.MUST_USE_KERBEROS)
self.newrecname = "tkeytsig.%s" % self.get_dns_domain()
def tkey_trans(self): def tkey_trans(self):
"Do a TKEY transaction and establish a gensec context" "Do a TKEY transaction and establish a gensec context"
@ -410,7 +460,7 @@ class DNSTest(tests.TestCase):
return p return p
class TestDNSUpdates(DNSTest): class TestDNSUpdates(DNSTKeyTest):
def test_tkey(self): def test_tkey(self):
"test DNS TKEY handshake" "test DNS TKEY handshake"