1
0
mirror of https://github.com/samba-team/samba.git synced 2024-12-23 17:34:34 +03:00

selftest: allow dns_hub.py to listen on more than one address

This makes it possible to serve ipv4 and ipv6 at the same time.

Signed-off-by: Stefan Metzmacher <metze@samba.org>
Reviewed-by: Andreas Schneider <asn@samba.org>
This commit is contained in:
Stefan Metzmacher 2020-03-11 16:55:33 +01:00 committed by Andreas Schneider
parent 2d1d67ed72
commit 25ba290d18

View File

@ -24,6 +24,7 @@ import threading
import sys import sys
import select import select
import socket import socket
import collections
import time import time
from samba.dcerpc import dns from samba.dcerpc import dns
import samba.ndr as ndr import samba.ndr as ndr
@ -158,44 +159,81 @@ class DnsHandler(sserver.BaseRequestHandler):
class server_thread(threading.Thread): class server_thread(threading.Thread):
def __init__(self, server): def __init__(self, server, name):
threading.Thread.__init__(self) threading.Thread.__init__(self, name=name)
self.server = server self.server = server
def run(self): def run(self):
print("dns_hub[%s]: before serve_forever()" % self.name)
self.server.serve_forever() self.server.serve_forever()
print("dns_hub: after serve_forever()") print("dns_hub[%s]: after serve_forever()" % self.name)
def stop(self):
print("dns_hub[%s]: before shutdown()" % self.name)
self.server.shutdown()
print("dns_hub[%s]: after shutdown()" % self.name)
class UDPV4Server(sserver.UDPServer):
address_family = socket.AF_INET
class UDPV6Server(sserver.UDPServer):
address_family = socket.AF_INET6
def main(): def main():
if len(sys.argv) < 4: if len(sys.argv) < 4:
print("Usage: dns_hub.py TIMEOUT HOST MAPPING") print("Usage: dns_hub.py TIMEOUT LISTENADDRESS[,LISTENADDRESS,...] MAPPING[,MAPPING,...]")
sys.exit(1) sys.exit(1)
timeout = int(sys.argv[1]) * 1000 timeout = int(sys.argv[1]) * 1000
timeout = min(timeout, 2**31 - 1) # poll with 32-bit int can't take more timeout = min(timeout, 2**31 - 1) # poll with 32-bit int can't take more
host = sys.argv[2] # we pass in the listen addresses as a comma-separated string.
listenaddresses = sys.argv[2].split(',')
server = sserver.UDPServer((host, int(53)), DnsHandler)
# we pass in the realm-to-IP mappings as a comma-separated key=value # we pass in the realm-to-IP mappings as a comma-separated key=value
# string. Convert this back into a dictionary that the DnsHandler can use # string. Convert this back into a dictionary that the DnsHandler can use
realm_mapping = dict(kv.split('=') for kv in sys.argv[3].split(',')) realm_mappings = collections.OrderedDict(kv.split('=') for kv in sys.argv[3].split(','))
server.realm_to_ip_mappings = realm_mapping
def prepare_server_thread(listenaddress, realm_mappings):
flags = socket.AddressInfo.AI_NUMERICHOST
flags |= socket.AddressInfo.AI_NUMERICSERV
flags |= socket.AddressInfo.AI_PASSIVE
addr_info = socket.getaddrinfo(listenaddress, int(53),
type=socket.SocketKind.SOCK_DGRAM,
flags=flags)
assert len(addr_info) == 1
if addr_info[0][0] == socket.AddressFamily.AF_INET6:
server = UDPV6Server(addr_info[0][4], DnsHandler)
else:
server = UDPV4Server(addr_info[0][4], DnsHandler)
# we pass in the realm-to-IP mappings as a comma-separated key=value
# string. Convert this back into a dictionary that the DnsHandler can use
server.realm_to_ip_mappings = realm_mappings
t = server_thread(server, name="UDP[%s]" % listenaddress)
return t
print("dns_hub will proxy DNS requests for the following realms:") print("dns_hub will proxy DNS requests for the following realms:")
for realm, ip in server.realm_to_ip_mappings.items(): for realm, ip in realm_mappings.items():
print(" {0} ==> {1}".format(realm, ip)) print(" {0} ==> {1}".format(realm, ip))
t = server_thread(server) print("dns_hub will listen on the following UDP addresses:")
t.start() threads = []
for listenaddress in listenaddresses:
print(" %s" % listenaddress)
t = prepare_server_thread(listenaddress, realm_mappings)
threads.append(t)
for t in threads:
t.start()
p = select.poll() p = select.poll()
stdin = sys.stdin.fileno() stdin = sys.stdin.fileno()
p.register(stdin, select.POLLIN) p.register(stdin, select.POLLIN)
p.poll(timeout) p.poll(timeout)
print("dns_hub: after poll()") print("dns_hub: after poll()")
server.shutdown() for t in threads:
t.join() t.stop()
for t in threads:
t.join()
print("dns_hub: before exit()") print("dns_hub: before exit()")
sys.exit(0) sys.exit(0)