2019-02-15 12:56:07 +03:00
#!/usr/bin/env python3
2018-01-12 17:53:03 +03:00
#
# Unix SMB/CIFS implementation.
# Copyright (C) Volker Lendecke 2017
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
2019-01-30 05:10:45 +03:00
# Used by selftest to proxy DNS queries to the correct testenv DC.
# See selftest/target/README for more details.
2018-01-12 17:53:03 +03:00
# Based on the EchoServer example from python docs
import threading
import sys
import select
import socket
2020-03-11 18:55:33 +03:00
import collections
2019-01-23 11:34:40 +03:00
import time
2018-01-12 17:53:03 +03:00
from samba . dcerpc import dns
import samba . ndr as ndr
if sys . version_info [ 0 ] < 3 :
import SocketServer
sserver = SocketServer
else :
import socketserver
sserver = socketserver
2019-02-03 23:28:07 +03:00
DNS_REQUEST_TIMEOUT = 10
2020-03-13 09:06:05 +03:00
# make sure the script dies immediately when hitting control-C,
# rather than raising KeyboardInterrupt. As we do all database
# operations using transactions, this is safe.
import signal
signal . signal ( signal . SIGINT , signal . SIG_DFL )
2019-01-30 03:24:45 +03:00
2018-01-12 17:53:03 +03:00
class DnsHandler ( sserver . BaseRequestHandler ) :
2019-01-23 11:34:40 +03:00
dns_qtype_strings = dict ( ( v , k ) for k , v in vars ( dns ) . items ( ) if k . startswith ( ' DNS_QTYPE_ ' ) )
def dns_qtype_string ( self , qtype ) :
" Return a readable qtype code "
return self . dns_qtype_strings [ qtype ]
dns_rcode_strings = dict ( ( v , k ) for k , v in vars ( dns ) . items ( ) if k . startswith ( ' DNS_RCODE_ ' ) )
def dns_rcode_string ( self , rcode ) :
" Return a readable error code "
return self . dns_rcode_strings [ rcode ]
2018-01-12 17:53:03 +03:00
def dns_transaction_udp ( self , packet , host ) :
" send a DNS query and read the reply "
s = None
2020-03-11 19:09:13 +03:00
flags = socket . AddressInfo . AI_NUMERICHOST
flags | = socket . AddressInfo . AI_NUMERICSERV
flags | = socket . AddressInfo . AI_PASSIVE
addr_info = socket . getaddrinfo ( host , int ( 53 ) ,
type = socket . SocketKind . SOCK_DGRAM ,
flags = flags )
assert len ( addr_info ) == 1
2018-01-12 17:53:03 +03:00
try :
send_packet = ndr . ndr_pack ( packet )
2020-03-11 19:09:13 +03:00
s = socket . socket ( addr_info [ 0 ] [ 0 ] , addr_info [ 0 ] [ 1 ] , 0 )
2019-02-03 23:28:07 +03:00
s . settimeout ( DNS_REQUEST_TIMEOUT )
2020-03-11 19:09:13 +03:00
s . connect ( addr_info [ 0 ] [ 4 ] )
2018-01-12 17:53:03 +03:00
s . sendall ( send_packet , 0 )
recv_packet = s . recv ( 2048 , 0 )
return ndr . ndr_unpack ( dns . name_packet , recv_packet )
except socket . error as err :
print ( " Error sending to host %s for name %s : %s \n " %
( host , packet . questions [ 0 ] . name , err . errno ) )
raise
finally :
if s is not None :
s . close ( )
2019-02-14 07:36:40 +03:00
def get_pdc_ipv4_addr ( self , lookup_name ) :
""" Maps a DNS realm to the IPv4 address of the PDC for that testenv """
2019-02-14 05:38:54 +03:00
2019-02-20 06:41:47 +03:00
realm_to_ip_mappings = self . server . realm_to_ip_mappings
2019-02-14 05:38:54 +03:00
# sort the realms so we find the longest-match first
2019-02-20 06:41:47 +03:00
testenv_realms = sorted ( realm_to_ip_mappings . keys ( ) , key = len )
2019-02-14 05:38:54 +03:00
testenv_realms . reverse ( )
for realm in testenv_realms :
2019-02-14 07:36:40 +03:00
if lookup_name . endswith ( realm ) :
2019-02-20 06:41:47 +03:00
# return the corresponding IP address for this realm's PDC
return realm_to_ip_mappings [ realm ]
2019-02-14 05:38:54 +03:00
2018-01-12 17:53:03 +03:00
return None
2019-02-14 07:36:40 +03:00
def forwarder ( self , name ) :
lname = name . lower ( )
# check for special cases used by tests (e.g. dns_forwarder.py)
if lname . endswith ( ' an-address-that-will-not-resolve ' ) :
return ' ignore '
if lname . endswith ( ' dsfsdfs ' ) :
return ' fail '
if lname . endswith ( " torture1 " , 0 , len ( lname ) - 2 ) :
# CATCH TORTURE100, TORTURE101, ...
return ' torture '
if lname . endswith ( ' _none_.example.com ' ) :
return ' torture '
if lname . endswith ( ' torturedom.samba.example.com ' ) :
return ' torture '
# return the testenv PDC matching the realm being requested
return self . get_pdc_ipv4_addr ( lname )
2018-01-12 17:53:03 +03:00
def handle ( self ) :
2019-01-23 11:34:40 +03:00
start = time . monotonic ( )
2019-01-17 07:36:50 +03:00
data , sock = self . request
2019-01-30 03:24:45 +03:00
query = ndr . ndr_unpack ( dns . name_packet , data )
2018-01-12 17:53:03 +03:00
name = query . questions [ 0 ] . name
forwarder = self . forwarder ( name )
response = None
2020-03-13 17:15:53 +03:00
if forwarder == ' ignore ' :
2018-01-12 17:53:03 +03:00
return
2020-03-13 17:15:53 +03:00
elif forwarder == ' fail ' :
2018-01-12 17:53:03 +03:00
pass
2019-01-23 11:34:40 +03:00
elif forwarder in [ ' torture ' , None ] :
2018-01-12 17:53:03 +03:00
response = query
response . operation | = dns . DNS_FLAG_REPLY
response . operation | = dns . DNS_FLAG_RECURSION_AVAIL
response . operation | = dns . DNS_RCODE_NXDOMAIN
2019-01-23 11:34:40 +03:00
else :
2020-08-25 09:28:00 +03:00
try :
response = self . dns_transaction_udp ( query , forwarder )
except OSError as err :
print ( " dns_hub: Error sending dns query to forwarder[ %s ] for name[ %s ]: %s " %
( forwarder , name , err ) )
2018-01-12 17:53:03 +03:00
if response is None :
response = query
response . operation | = dns . DNS_FLAG_REPLY
response . operation | = dns . DNS_FLAG_RECURSION_AVAIL
response . operation | = dns . DNS_RCODE_SERVFAIL
send_packet = ndr . ndr_pack ( response )
2019-01-23 11:34:40 +03:00
end = time . monotonic ( )
tdiff = end - start
errcode = response . operation & dns . DNS_RCODE
if tdiff > ( DNS_REQUEST_TIMEOUT / 5 ) :
debug = True
else :
debug = False
if debug :
print ( " dns_hub: forwarder[ %s ] client[ %s ] name[ %s ][ %s ] %s response.operation[0x %x ] tdiff[ %s ] \n " %
( forwarder , self . client_address , name ,
self . dns_qtype_string ( query . questions [ 0 ] . question_type ) ,
self . dns_rcode_string ( errcode ) , response . operation , tdiff ) )
2018-01-12 17:53:03 +03:00
try :
2019-01-17 07:36:50 +03:00
sock . sendto ( send_packet , self . client_address )
2018-01-12 17:53:03 +03:00
except socket . error as err :
2019-01-23 11:34:40 +03:00
print ( " dns_hub: Error sending response to client[ %s ] for name[ %s ] tdiff[ %s ]: %s \n " %
( self . client_address , name , tdiff , err ) )
2019-01-18 23:14:28 +03:00
2018-01-12 17:53:03 +03:00
class server_thread ( threading . Thread ) :
2020-03-11 18:55:33 +03:00
def __init__ ( self , server , name ) :
threading . Thread . __init__ ( self , name = name )
2018-01-12 17:53:03 +03:00
self . server = server
def run ( self ) :
2020-03-11 18:55:33 +03:00
print ( " dns_hub[ %s ]: before serve_forever() " % self . name )
2018-01-12 17:53:03 +03:00
self . server . serve_forever ( )
2020-03-11 18:55:33 +03:00
print ( " dns_hub[ %s ]: after serve_forever() " % self . name )
2018-01-12 17:53:03 +03:00
2020-03-11 18:55:33 +03:00
def stop ( self ) :
print ( " dns_hub[ %s ]: before shutdown() " % self . name )
self . server . shutdown ( )
print ( " dns_hub[ %s ]: after shutdown() " % self . name )
2023-03-07 06:40:37 +03:00
self . server . server_close ( )
2020-03-11 18:55:33 +03:00
class UDPV4Server ( sserver . UDPServer ) :
address_family = socket . AF_INET
class UDPV6Server ( sserver . UDPServer ) :
address_family = socket . AF_INET6
2019-01-30 03:24:45 +03:00
2018-01-12 17:53:03 +03:00
def main ( ) :
2019-06-27 07:57:20 +03:00
if len ( sys . argv ) < 4 :
2020-03-11 18:55:33 +03:00
print ( " Usage: dns_hub.py TIMEOUT LISTENADDRESS[,LISTENADDRESS,...] MAPPING[,MAPPING,...] " )
2019-06-27 07:57:20 +03:00
sys . exit ( 1 )
2019-01-30 03:24:45 +03:00
timeout = int ( sys . argv [ 1 ] ) * 1000
timeout = min ( timeout , 2 * * 31 - 1 ) # poll with 32-bit int can't take more
2020-03-11 18:55:33 +03:00
# we pass in the listen addresses as a comma-separated string.
listenaddresses = sys . argv [ 2 ] . split ( ' , ' )
2019-02-20 06:09:54 +03:00
# we pass in the realm-to-IP mappings as a comma-separated key=value
# string. Convert this back into a dictionary that the DnsHandler can use
2020-03-11 18:55:33 +03:00
realm_mappings = collections . OrderedDict ( kv . split ( ' = ' ) for kv in sys . argv [ 3 ] . split ( ' , ' ) )
def prepare_server_thread ( listenaddress , realm_mappings ) :
flags = socket . AddressInfo . AI_NUMERICHOST
flags | = socket . AddressInfo . AI_NUMERICSERV
flags | = socket . AddressInfo . AI_PASSIVE
addr_info = socket . getaddrinfo ( listenaddress , int ( 53 ) ,
type = socket . SocketKind . SOCK_DGRAM ,
flags = flags )
assert len ( addr_info ) == 1
if addr_info [ 0 ] [ 0 ] == socket . AddressFamily . AF_INET6 :
server = UDPV6Server ( addr_info [ 0 ] [ 4 ] , DnsHandler )
else :
server = UDPV4Server ( addr_info [ 0 ] [ 4 ] , DnsHandler )
# we pass in the realm-to-IP mappings as a comma-separated key=value
# string. Convert this back into a dictionary that the DnsHandler can use
server . realm_to_ip_mappings = realm_mappings
t = server_thread ( server , name = " UDP[ %s ] " % listenaddress )
return t
2019-02-20 06:09:54 +03:00
2019-02-20 06:51:14 +03:00
print ( " dns_hub will proxy DNS requests for the following realms: " )
2020-03-11 18:55:33 +03:00
for realm , ip in realm_mappings . items ( ) :
2019-02-20 06:51:14 +03:00
print ( " {0} ==> {1} " . format ( realm , ip ) )
2020-03-11 18:55:33 +03:00
print ( " dns_hub will listen on the following UDP addresses: " )
threads = [ ]
for listenaddress in listenaddresses :
print ( " %s " % listenaddress )
t = prepare_server_thread ( listenaddress , realm_mappings )
threads . append ( t )
for t in threads :
t . start ( )
2018-01-12 17:53:03 +03:00
p = select . poll ( )
stdin = sys . stdin . fileno ( )
p . register ( stdin , select . POLLIN )
p . poll ( timeout )
print ( " dns_hub: after poll() " )
2020-03-11 18:55:33 +03:00
for t in threads :
t . stop ( )
for t in threads :
t . join ( )
2018-01-12 17:53:03 +03:00
print ( " dns_hub: before exit() " )
sys . exit ( 0 )
main ( )