mirror of
https://github.com/samba-team/samba.git
synced 2024-12-23 17:34:34 +03:00
09fba1f3e3
Signed-off-by: Andreas Schneider <asn@samba.org> Reviewed-by: Noel Power <npower@samba.org>
252 lines
9.0 KiB
Python
Executable File
252 lines
9.0 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
#
|
|
# Unix SMB/CIFS implementation.
|
|
# Copyright (C) Volker Lendecke 2017
|
|
#
|
|
# This program is free software; you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation; either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
#
|
|
# Used by selftest to proxy DNS queries to the correct testenv DC.
|
|
# See selftest/target/README for more details.
|
|
# Based on the EchoServer example from python docs
|
|
|
|
import threading
|
|
import sys
|
|
import select
|
|
import socket
|
|
import collections
|
|
import time
|
|
from samba.dcerpc import dns
|
|
import samba.ndr as ndr
|
|
|
|
if sys.version_info[0] < 3:
|
|
import SocketServer
|
|
sserver = SocketServer
|
|
else:
|
|
import socketserver
|
|
sserver = socketserver
|
|
|
|
DNS_REQUEST_TIMEOUT = 10
|
|
|
|
# make sure the script dies immediately when hitting control-C,
|
|
# rather than raising KeyboardInterrupt. As we do all database
|
|
# operations using transactions, this is safe.
|
|
import signal
|
|
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
|
|
|
class DnsHandler(sserver.BaseRequestHandler):
|
|
dns_qtype_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_QTYPE_'))
|
|
def dns_qtype_string(self, qtype):
|
|
"Return a readable qtype code"
|
|
return self.dns_qtype_strings[qtype]
|
|
|
|
dns_rcode_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_RCODE_'))
|
|
def dns_rcode_string(self, rcode):
|
|
"Return a readable error code"
|
|
return self.dns_rcode_strings[rcode]
|
|
|
|
def dns_transaction_udp(self, packet, host):
|
|
"send a DNS query and read the reply"
|
|
s = None
|
|
flags = socket.AddressInfo.AI_NUMERICHOST
|
|
flags |= socket.AddressInfo.AI_NUMERICSERV
|
|
flags |= socket.AddressInfo.AI_PASSIVE
|
|
addr_info = socket.getaddrinfo(host, int(53),
|
|
type=socket.SocketKind.SOCK_DGRAM,
|
|
flags=flags)
|
|
assert len(addr_info) == 1
|
|
try:
|
|
send_packet = ndr.ndr_pack(packet)
|
|
s = socket.socket(addr_info[0][0], addr_info[0][1], 0)
|
|
s.settimeout(DNS_REQUEST_TIMEOUT)
|
|
s.connect(addr_info[0][4])
|
|
s.sendall(send_packet, 0)
|
|
recv_packet = s.recv(2048, 0)
|
|
return ndr.ndr_unpack(dns.name_packet, recv_packet)
|
|
except socket.error as err:
|
|
print("Error sending to host %s for name %s: %s\n" %
|
|
(host, packet.questions[0].name, err.errno))
|
|
raise
|
|
finally:
|
|
if s is not None:
|
|
s.close()
|
|
return None
|
|
|
|
def get_pdc_ipv4_addr(self, lookup_name):
|
|
"""Maps a DNS realm to the IPv4 address of the PDC for that testenv"""
|
|
|
|
realm_to_ip_mappings = self.server.realm_to_ip_mappings
|
|
|
|
# sort the realms so we find the longest-match first
|
|
testenv_realms = sorted(realm_to_ip_mappings.keys(), key=len)
|
|
testenv_realms.reverse()
|
|
|
|
for realm in testenv_realms:
|
|
if lookup_name.endswith(realm):
|
|
# return the corresponding IP address for this realm's PDC
|
|
return realm_to_ip_mappings[realm]
|
|
|
|
return None
|
|
|
|
def forwarder(self, name):
|
|
lname = name.lower()
|
|
|
|
# check for special cases used by tests (e.g. dns_forwarder.py)
|
|
if lname.endswith('an-address-that-will-not-resolve'):
|
|
return 'ignore'
|
|
if lname.endswith('dsfsdfs'):
|
|
return 'fail'
|
|
if lname.endswith("torture1", 0, len(lname)-2):
|
|
# CATCH TORTURE100, TORTURE101, ...
|
|
return 'torture'
|
|
if lname.endswith('_none_.example.com'):
|
|
return 'torture'
|
|
if lname.endswith('torturedom.samba.example.com'):
|
|
return 'torture'
|
|
|
|
# return the testenv PDC matching the realm being requested
|
|
return self.get_pdc_ipv4_addr(lname)
|
|
|
|
def handle(self):
|
|
start = time.monotonic()
|
|
data, sock = self.request
|
|
query = ndr.ndr_unpack(dns.name_packet, data)
|
|
name = query.questions[0].name
|
|
forwarder = self.forwarder(name)
|
|
response = None
|
|
|
|
if forwarder == 'ignore':
|
|
return
|
|
elif forwarder == 'fail':
|
|
pass
|
|
elif forwarder in ['torture', None]:
|
|
response = query
|
|
response.operation |= dns.DNS_FLAG_REPLY
|
|
response.operation |= dns.DNS_FLAG_RECURSION_AVAIL
|
|
response.operation |= dns.DNS_RCODE_NXDOMAIN
|
|
else:
|
|
try:
|
|
response = self.dns_transaction_udp(query, forwarder)
|
|
except OSError as err:
|
|
print("dns_hub: Error sending dns query to forwarder[%s] for name[%s]: %s" %
|
|
(forwarder, name, err))
|
|
|
|
if response is None:
|
|
response = query
|
|
response.operation |= dns.DNS_FLAG_REPLY
|
|
response.operation |= dns.DNS_FLAG_RECURSION_AVAIL
|
|
response.operation |= dns.DNS_RCODE_SERVFAIL
|
|
|
|
send_packet = ndr.ndr_pack(response)
|
|
|
|
end = time.monotonic()
|
|
tdiff = end - start
|
|
errcode = response.operation & dns.DNS_RCODE
|
|
if tdiff > (DNS_REQUEST_TIMEOUT/5):
|
|
debug = True
|
|
else:
|
|
debug = False
|
|
if debug:
|
|
print("dns_hub: forwarder[%s] client[%s] name[%s][%s] %s response.operation[0x%x] tdiff[%s]\n" %
|
|
(forwarder, self.client_address, name,
|
|
self.dns_qtype_string(query.questions[0].question_type),
|
|
self.dns_rcode_string(errcode), response.operation, tdiff))
|
|
|
|
try:
|
|
sock.sendto(send_packet, self.client_address)
|
|
except socket.error as err:
|
|
print("dns_hub: Error sending response to client[%s] for name[%s] tdiff[%s]: %s\n" %
|
|
(self.client_address, name, tdiff, err))
|
|
|
|
|
|
class server_thread(threading.Thread):
|
|
def __init__(self, server, name):
|
|
threading.Thread.__init__(self, name=name)
|
|
self.server = server
|
|
|
|
def run(self):
|
|
print("dns_hub[%s]: before serve_forever()" % self.name)
|
|
self.server.serve_forever()
|
|
print("dns_hub[%s]: after serve_forever()" % self.name)
|
|
|
|
def stop(self):
|
|
print("dns_hub[%s]: before shutdown()" % self.name)
|
|
self.server.shutdown()
|
|
print("dns_hub[%s]: after shutdown()" % self.name)
|
|
|
|
class UDPV4Server(sserver.UDPServer):
|
|
address_family = socket.AF_INET
|
|
|
|
class UDPV6Server(sserver.UDPServer):
|
|
address_family = socket.AF_INET6
|
|
|
|
def main():
|
|
if len(sys.argv) < 4:
|
|
print("Usage: dns_hub.py TIMEOUT LISTENADDRESS[,LISTENADDRESS,...] MAPPING[,MAPPING,...]")
|
|
sys.exit(1)
|
|
|
|
timeout = int(sys.argv[1]) * 1000
|
|
timeout = min(timeout, 2**31 - 1) # poll with 32-bit int can't take more
|
|
# we pass in the listen addresses as a comma-separated string.
|
|
listenaddresses = sys.argv[2].split(',')
|
|
# we pass in the realm-to-IP mappings as a comma-separated key=value
|
|
# string. Convert this back into a dictionary that the DnsHandler can use
|
|
realm_mappings = collections.OrderedDict(kv.split('=') for kv in sys.argv[3].split(','))
|
|
|
|
def prepare_server_thread(listenaddress, realm_mappings):
|
|
|
|
flags = socket.AddressInfo.AI_NUMERICHOST
|
|
flags |= socket.AddressInfo.AI_NUMERICSERV
|
|
flags |= socket.AddressInfo.AI_PASSIVE
|
|
addr_info = socket.getaddrinfo(listenaddress, int(53),
|
|
type=socket.SocketKind.SOCK_DGRAM,
|
|
flags=flags)
|
|
assert len(addr_info) == 1
|
|
if addr_info[0][0] == socket.AddressFamily.AF_INET6:
|
|
server = UDPV6Server(addr_info[0][4], DnsHandler)
|
|
else:
|
|
server = UDPV4Server(addr_info[0][4], DnsHandler)
|
|
|
|
# we pass in the realm-to-IP mappings as a comma-separated key=value
|
|
# string. Convert this back into a dictionary that the DnsHandler can use
|
|
server.realm_to_ip_mappings = realm_mappings
|
|
t = server_thread(server, name="UDP[%s]" % listenaddress)
|
|
return t
|
|
|
|
print("dns_hub will proxy DNS requests for the following realms:")
|
|
for realm, ip in realm_mappings.items():
|
|
print(" {0} ==> {1}".format(realm, ip))
|
|
|
|
print("dns_hub will listen on the following UDP addresses:")
|
|
threads = []
|
|
for listenaddress in listenaddresses:
|
|
print(" %s" % listenaddress)
|
|
t = prepare_server_thread(listenaddress, realm_mappings)
|
|
threads.append(t)
|
|
|
|
for t in threads:
|
|
t.start()
|
|
p = select.poll()
|
|
stdin = sys.stdin.fileno()
|
|
p.register(stdin, select.POLLIN)
|
|
p.poll(timeout)
|
|
print("dns_hub: after poll()")
|
|
for t in threads:
|
|
t.stop()
|
|
for t in threads:
|
|
t.join()
|
|
print("dns_hub: before exit()")
|
|
sys.exit(0)
|
|
|
|
main()
|