2019-02-15 22:56:07 +13:00
#!/usr/bin/env python3
2018-01-12 15:53:03 +01:00
#
# Unix SMB/CIFS implementation.
# Copyright (C) Volker Lendecke 2017
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
2019-01-30 15:10:45 +13:00
# Used by selftest to proxy DNS queries to the correct testenv DC.
# See selftest/target/README for more details.
2018-01-12 15:53:03 +01:00
# Based on the EchoServer example from python docs
import threading
import sys
import select
import socket
2020-03-11 16:55:33 +01:00
import collections
2019-01-23 09:34:40 +01:00
import time
2018-01-12 15:53:03 +01:00
from samba . dcerpc import dns
import samba . ndr as ndr
2024-05-28 19:39:33 +12:00
import socketserver
sserver = socketserver
2018-01-12 15:53:03 +01:00
2019-02-04 09:28:07 +13:00
DNS_REQUEST_TIMEOUT = 10
2020-03-13 07:06:05 +01:00
# make sure the script dies immediately when hitting control-C,
# rather than raising KeyboardInterrupt. As we do all database
# operations using transactions, this is safe.
import signal
signal . signal ( signal . SIGINT , signal . SIG_DFL )
2019-01-30 13:24:45 +13:00
2018-01-12 15:53:03 +01:00
class DnsHandler ( sserver . BaseRequestHandler ) :
2019-01-23 09:34:40 +01:00
dns_qtype_strings = dict ( ( v , k ) for k , v in vars ( dns ) . items ( ) if k . startswith ( ' DNS_QTYPE_ ' ) )
def dns_qtype_string ( self , qtype ) :
" Return a readable qtype code "
return self . dns_qtype_strings [ qtype ]
dns_rcode_strings = dict ( ( v , k ) for k , v in vars ( dns ) . items ( ) if k . startswith ( ' DNS_RCODE_ ' ) )
def dns_rcode_string ( self , rcode ) :
" Return a readable error code "
return self . dns_rcode_strings [ rcode ]
2018-01-12 15:53:03 +01:00
def dns_transaction_udp ( self , packet , host ) :
" send a DNS query and read the reply "
s = None
2020-03-11 17:09:13 +01:00
flags = socket . AddressInfo . AI_NUMERICHOST
flags | = socket . AddressInfo . AI_NUMERICSERV
flags | = socket . AddressInfo . AI_PASSIVE
addr_info = socket . getaddrinfo ( host , int ( 53 ) ,
type = socket . SocketKind . SOCK_DGRAM ,
flags = flags )
assert len ( addr_info ) == 1
2018-01-12 15:53:03 +01:00
try :
send_packet = ndr . ndr_pack ( packet )
2020-03-11 17:09:13 +01:00
s = socket . socket ( addr_info [ 0 ] [ 0 ] , addr_info [ 0 ] [ 1 ] , 0 )
2019-02-04 09:28:07 +13:00
s . settimeout ( DNS_REQUEST_TIMEOUT )
2020-03-11 17:09:13 +01:00
s . connect ( addr_info [ 0 ] [ 4 ] )
2018-01-12 15:53:03 +01:00
s . sendall ( send_packet , 0 )
recv_packet = s . recv ( 2048 , 0 )
return ndr . ndr_unpack ( dns . name_packet , recv_packet )
except socket . error as err :
print ( " Error sending to host %s for name %s : %s \n " %
( host , packet . questions [ 0 ] . name , err . errno ) )
raise
finally :
if s is not None :
s . close ( )
2019-02-14 17:36:40 +13:00
def get_pdc_ipv4_addr ( self , lookup_name ) :
""" Maps a DNS realm to the IPv4 address of the PDC for that testenv """
2019-02-14 15:38:54 +13:00
2019-02-20 16:41:47 +13:00
realm_to_ip_mappings = self . server . realm_to_ip_mappings
2019-02-14 15:38:54 +13:00
# sort the realms so we find the longest-match first
2019-02-20 16:41:47 +13:00
testenv_realms = sorted ( realm_to_ip_mappings . keys ( ) , key = len )
2019-02-14 15:38:54 +13:00
testenv_realms . reverse ( )
for realm in testenv_realms :
2019-02-14 17:36:40 +13:00
if lookup_name . endswith ( realm ) :
2019-02-20 16:41:47 +13:00
# return the corresponding IP address for this realm's PDC
return realm_to_ip_mappings [ realm ]
2019-02-14 15:38:54 +13:00
2018-01-12 15:53:03 +01:00
return None
2019-02-14 17:36:40 +13:00
def forwarder ( self , name ) :
lname = name . lower ( )
# check for special cases used by tests (e.g. dns_forwarder.py)
if lname . endswith ( ' an-address-that-will-not-resolve ' ) :
return ' ignore '
if lname . endswith ( ' dsfsdfs ' ) :
return ' fail '
if lname . endswith ( " torture1 " , 0 , len ( lname ) - 2 ) :
# CATCH TORTURE100, TORTURE101, ...
return ' torture '
if lname . endswith ( ' _none_.example.com ' ) :
return ' torture '
if lname . endswith ( ' torturedom.samba.example.com ' ) :
return ' torture '
# return the testenv PDC matching the realm being requested
return self . get_pdc_ipv4_addr ( lname )
2018-01-12 15:53:03 +01:00
def handle ( self ) :
2019-01-23 09:34:40 +01:00
start = time . monotonic ( )
2019-01-17 17:36:50 +13:00
data , sock = self . request
2019-01-30 13:24:45 +13:00
query = ndr . ndr_unpack ( dns . name_packet , data )
2018-01-12 15:53:03 +01:00
name = query . questions [ 0 ] . name
forwarder = self . forwarder ( name )
response = None
2020-03-13 15:15:53 +01:00
if forwarder == ' ignore ' :
2018-01-12 15:53:03 +01:00
return
2020-03-13 15:15:53 +01:00
elif forwarder == ' fail ' :
2018-01-12 15:53:03 +01:00
pass
2019-01-23 09:34:40 +01:00
elif forwarder in [ ' torture ' , None ] :
2018-01-12 15:53:03 +01:00
response = query
response . operation | = dns . DNS_FLAG_REPLY
response . operation | = dns . DNS_FLAG_RECURSION_AVAIL
response . operation | = dns . DNS_RCODE_NXDOMAIN
2019-01-23 09:34:40 +01:00
else :
2020-08-25 08:28:00 +02:00
try :
response = self . dns_transaction_udp ( query , forwarder )
except OSError as err :
print ( " dns_hub: Error sending dns query to forwarder[ %s ] for name[ %s ]: %s " %
( forwarder , name , err ) )
2018-01-12 15:53:03 +01:00
if response is None :
response = query
response . operation | = dns . DNS_FLAG_REPLY
response . operation | = dns . DNS_FLAG_RECURSION_AVAIL
response . operation | = dns . DNS_RCODE_SERVFAIL
send_packet = ndr . ndr_pack ( response )
2019-01-23 09:34:40 +01:00
end = time . monotonic ( )
tdiff = end - start
errcode = response . operation & dns . DNS_RCODE
if tdiff > ( DNS_REQUEST_TIMEOUT / 5 ) :
debug = True
else :
debug = False
if debug :
print ( " dns_hub: forwarder[ %s ] client[ %s ] name[ %s ][ %s ] %s response.operation[0x %x ] tdiff[ %s ] \n " %
( forwarder , self . client_address , name ,
self . dns_qtype_string ( query . questions [ 0 ] . question_type ) ,
self . dns_rcode_string ( errcode ) , response . operation , tdiff ) )
2018-01-12 15:53:03 +01:00
try :
2019-01-17 17:36:50 +13:00
sock . sendto ( send_packet , self . client_address )
2018-01-12 15:53:03 +01:00
except socket . error as err :
2019-01-23 09:34:40 +01:00
print ( " dns_hub: Error sending response to client[ %s ] for name[ %s ] tdiff[ %s ]: %s \n " %
( self . client_address , name , tdiff , err ) )
2019-01-19 09:14:28 +13:00
2018-01-12 15:53:03 +01:00
class server_thread ( threading . Thread ) :
2020-03-11 16:55:33 +01:00
def __init__ ( self , server , name ) :
threading . Thread . __init__ ( self , name = name )
2018-01-12 15:53:03 +01:00
self . server = server
def run ( self ) :
2020-03-11 16:55:33 +01:00
print ( " dns_hub[ %s ]: before serve_forever() " % self . name )
2018-01-12 15:53:03 +01:00
self . server . serve_forever ( )
2020-03-11 16:55:33 +01:00
print ( " dns_hub[ %s ]: after serve_forever() " % self . name )
2018-01-12 15:53:03 +01:00
2020-03-11 16:55:33 +01:00
def stop ( self ) :
print ( " dns_hub[ %s ]: before shutdown() " % self . name )
self . server . shutdown ( )
print ( " dns_hub[ %s ]: after shutdown() " % self . name )
2023-03-07 16:40:37 +13:00
self . server . server_close ( )
2020-03-11 16:55:33 +01:00
class UDPV4Server ( sserver . UDPServer ) :
address_family = socket . AF_INET
class UDPV6Server ( sserver . UDPServer ) :
address_family = socket . AF_INET6
2019-01-30 13:24:45 +13:00
2018-01-12 15:53:03 +01:00
def main ( ) :
2019-06-27 16:57:20 +12:00
if len ( sys . argv ) < 4 :
2020-03-11 16:55:33 +01:00
print ( " Usage: dns_hub.py TIMEOUT LISTENADDRESS[,LISTENADDRESS,...] MAPPING[,MAPPING,...] " )
2019-06-27 16:57:20 +12:00
sys . exit ( 1 )
2019-01-30 13:24:45 +13:00
timeout = int ( sys . argv [ 1 ] ) * 1000
timeout = min ( timeout , 2 * * 31 - 1 ) # poll with 32-bit int can't take more
2020-03-11 16:55:33 +01:00
# we pass in the listen addresses as a comma-separated string.
listenaddresses = sys . argv [ 2 ] . split ( ' , ' )
2019-02-20 16:09:54 +13:00
# we pass in the realm-to-IP mappings as a comma-separated key=value
# string. Convert this back into a dictionary that the DnsHandler can use
2020-03-11 16:55:33 +01:00
realm_mappings = collections . OrderedDict ( kv . split ( ' = ' ) for kv in sys . argv [ 3 ] . split ( ' , ' ) )
def prepare_server_thread ( listenaddress , realm_mappings ) :
flags = socket . AddressInfo . AI_NUMERICHOST
flags | = socket . AddressInfo . AI_NUMERICSERV
flags | = socket . AddressInfo . AI_PASSIVE
addr_info = socket . getaddrinfo ( listenaddress , int ( 53 ) ,
type = socket . SocketKind . SOCK_DGRAM ,
flags = flags )
assert len ( addr_info ) == 1
if addr_info [ 0 ] [ 0 ] == socket . AddressFamily . AF_INET6 :
server = UDPV6Server ( addr_info [ 0 ] [ 4 ] , DnsHandler )
else :
server = UDPV4Server ( addr_info [ 0 ] [ 4 ] , DnsHandler )
# we pass in the realm-to-IP mappings as a comma-separated key=value
# string. Convert this back into a dictionary that the DnsHandler can use
server . realm_to_ip_mappings = realm_mappings
t = server_thread ( server , name = " UDP[ %s ] " % listenaddress )
return t
2019-02-20 16:09:54 +13:00
2019-02-20 16:51:14 +13:00
print ( " dns_hub will proxy DNS requests for the following realms: " )
2020-03-11 16:55:33 +01:00
for realm , ip in realm_mappings . items ( ) :
2019-02-20 16:51:14 +13:00
print ( " {0} ==> {1} " . format ( realm , ip ) )
2020-03-11 16:55:33 +01:00
print ( " dns_hub will listen on the following UDP addresses: " )
threads = [ ]
for listenaddress in listenaddresses :
print ( " %s " % listenaddress )
t = prepare_server_thread ( listenaddress , realm_mappings )
threads . append ( t )
for t in threads :
t . start ( )
2018-01-12 15:53:03 +01:00
p = select . poll ( )
stdin = sys . stdin . fileno ( )
p . register ( stdin , select . POLLIN )
p . poll ( timeout )
print ( " dns_hub: after poll() " )
2020-03-11 16:55:33 +01:00
for t in threads :
t . stop ( )
for t in threads :
t . join ( )
2018-01-12 15:53:03 +01:00
print ( " dns_hub: before exit() " )
sys . exit ( 0 )
main ( )