2019-02-15 12:56:07 +03:00
#!/usr/bin/env python3
2018-01-12 17:53:03 +03:00
#
# Unix SMB/CIFS implementation.
# Copyright (C) Volker Lendecke 2017
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
2019-01-30 05:10:45 +03:00
# Used by selftest to proxy DNS queries to the correct testenv DC.
# See selftest/target/README for more details.
2018-01-12 17:53:03 +03:00
# Based on the EchoServer example from python docs
import threading
import sys
import select
import socket
2019-01-23 11:34:40 +03:00
import time
2018-01-12 17:53:03 +03:00
from samba . dcerpc import dns
import samba . ndr as ndr
if sys . version_info [ 0 ] < 3 :
import SocketServer
sserver = SocketServer
else :
import socketserver
sserver = socketserver
2019-02-03 23:28:07 +03:00
DNS_REQUEST_TIMEOUT = 10
2019-01-30 03:24:45 +03:00
2018-01-12 17:53:03 +03:00
class DnsHandler ( sserver . BaseRequestHandler ) :
2019-01-23 11:34:40 +03:00
dns_qtype_strings = dict ( ( v , k ) for k , v in vars ( dns ) . items ( ) if k . startswith ( ' DNS_QTYPE_ ' ) )
def dns_qtype_string ( self , qtype ) :
" Return a readable qtype code "
return self . dns_qtype_strings [ qtype ]
dns_rcode_strings = dict ( ( v , k ) for k , v in vars ( dns ) . items ( ) if k . startswith ( ' DNS_RCODE_ ' ) )
def dns_rcode_string ( self , rcode ) :
" Return a readable error code "
return self . dns_rcode_strings [ rcode ]
2018-01-12 17:53:03 +03:00
def dns_transaction_udp ( self , packet , host ) :
" send a DNS query and read the reply "
s = None
try :
send_packet = ndr . ndr_pack ( packet )
s = socket . socket ( socket . AF_INET , socket . SOCK_DGRAM , 0 )
2019-02-03 23:28:07 +03:00
s . settimeout ( DNS_REQUEST_TIMEOUT )
2018-01-12 17:53:03 +03:00
s . connect ( ( host , 53 ) )
s . sendall ( send_packet , 0 )
recv_packet = s . recv ( 2048 , 0 )
return ndr . ndr_unpack ( dns . name_packet , recv_packet )
except socket . error as err :
print ( " Error sending to host %s for name %s : %s \n " %
( host , packet . questions [ 0 ] . name , err . errno ) )
raise
finally :
if s is not None :
s . close ( )
return None
2019-02-14 07:36:40 +03:00
def get_pdc_ipv4_addr ( self , lookup_name ) :
""" Maps a DNS realm to the IPv4 address of the PDC for that testenv """
2019-02-14 05:38:54 +03:00
2019-02-20 06:41:47 +03:00
realm_to_ip_mappings = self . server . realm_to_ip_mappings
2019-02-14 05:38:54 +03:00
# sort the realms so we find the longest-match first
2019-02-20 06:41:47 +03:00
testenv_realms = sorted ( realm_to_ip_mappings . keys ( ) , key = len )
2019-02-14 05:38:54 +03:00
testenv_realms . reverse ( )
for realm in testenv_realms :
2019-02-14 07:36:40 +03:00
if lookup_name . endswith ( realm ) :
2019-02-20 06:41:47 +03:00
# return the corresponding IP address for this realm's PDC
return realm_to_ip_mappings [ realm ]
2019-02-14 05:38:54 +03:00
2018-01-12 17:53:03 +03:00
return None
2019-02-14 07:36:40 +03:00
def forwarder ( self , name ) :
lname = name . lower ( )
# check for special cases used by tests (e.g. dns_forwarder.py)
if lname . endswith ( ' an-address-that-will-not-resolve ' ) :
return ' ignore '
if lname . endswith ( ' dsfsdfs ' ) :
return ' fail '
if lname . endswith ( " torture1 " , 0 , len ( lname ) - 2 ) :
# CATCH TORTURE100, TORTURE101, ...
return ' torture '
if lname . endswith ( ' _none_.example.com ' ) :
return ' torture '
if lname . endswith ( ' torturedom.samba.example.com ' ) :
return ' torture '
# return the testenv PDC matching the realm being requested
return self . get_pdc_ipv4_addr ( lname )
2018-01-12 17:53:03 +03:00
def handle ( self ) :
2019-01-23 11:34:40 +03:00
start = time . monotonic ( )
2019-01-17 07:36:50 +03:00
data , sock = self . request
2019-01-30 03:24:45 +03:00
query = ndr . ndr_unpack ( dns . name_packet , data )
2018-01-12 17:53:03 +03:00
name = query . questions [ 0 ] . name
forwarder = self . forwarder ( name )
response = None
if forwarder is ' ignore ' :
return
elif forwarder is ' fail ' :
pass
2019-01-23 11:34:40 +03:00
elif forwarder in [ ' torture ' , None ] :
2018-01-12 17:53:03 +03:00
response = query
response . operation | = dns . DNS_FLAG_REPLY
response . operation | = dns . DNS_FLAG_RECURSION_AVAIL
response . operation | = dns . DNS_RCODE_NXDOMAIN
2019-01-23 11:34:40 +03:00
else :
response = self . dns_transaction_udp ( query , forwarder )
2018-01-12 17:53:03 +03:00
if response is None :
response = query
response . operation | = dns . DNS_FLAG_REPLY
response . operation | = dns . DNS_FLAG_RECURSION_AVAIL
response . operation | = dns . DNS_RCODE_SERVFAIL
send_packet = ndr . ndr_pack ( response )
2019-01-23 11:34:40 +03:00
end = time . monotonic ( )
tdiff = end - start
errcode = response . operation & dns . DNS_RCODE
if tdiff > ( DNS_REQUEST_TIMEOUT / 5 ) :
debug = True
else :
debug = False
if debug :
print ( " dns_hub: forwarder[ %s ] client[ %s ] name[ %s ][ %s ] %s response.operation[0x %x ] tdiff[ %s ] \n " %
( forwarder , self . client_address , name ,
self . dns_qtype_string ( query . questions [ 0 ] . question_type ) ,
self . dns_rcode_string ( errcode ) , response . operation , tdiff ) )
2018-01-12 17:53:03 +03:00
try :
2019-01-17 07:36:50 +03:00
sock . sendto ( send_packet , self . client_address )
2018-01-12 17:53:03 +03:00
except socket . error as err :
2019-01-23 11:34:40 +03:00
print ( " dns_hub: Error sending response to client[ %s ] for name[ %s ] tdiff[ %s ]: %s \n " %
( self . client_address , name , tdiff , err ) )
2019-01-18 23:14:28 +03:00
2018-01-12 17:53:03 +03:00
class server_thread ( threading . Thread ) :
def __init__ ( self , server ) :
threading . Thread . __init__ ( self )
self . server = server
def run ( self ) :
self . server . serve_forever ( )
print ( " dns_hub: after serve_forever() " )
2019-01-30 03:24:45 +03:00
2018-01-12 17:53:03 +03:00
def main ( ) :
2019-01-30 03:24:45 +03:00
timeout = int ( sys . argv [ 1 ] ) * 1000
timeout = min ( timeout , 2 * * 31 - 1 ) # poll with 32-bit int can't take more
2018-01-12 17:53:03 +03:00
host = sys . argv [ 2 ]
2019-02-20 06:09:54 +03:00
2018-01-12 17:53:03 +03:00
server = sserver . UDPServer ( ( host , int ( 53 ) ) , DnsHandler )
2019-02-20 06:09:54 +03:00
# we pass in the realm-to-IP mappings as a comma-separated key=value
# string. Convert this back into a dictionary that the DnsHandler can use
realm_mapping = dict ( kv . split ( ' = ' ) for kv in sys . argv [ 3 ] . split ( ' , ' ) )
server . realm_to_ip_mappings = realm_mapping
2019-02-20 06:51:14 +03:00
print ( " dns_hub will proxy DNS requests for the following realms: " )
for realm , ip in server . realm_to_ip_mappings . items ( ) :
print ( " {0} ==> {1} " . format ( realm , ip ) )
2018-01-12 17:53:03 +03:00
t = server_thread ( server )
t . start ( )
p = select . poll ( )
stdin = sys . stdin . fileno ( )
p . register ( stdin , select . POLLIN )
p . poll ( timeout )
print ( " dns_hub: after poll() " )
server . shutdown ( )
t . join ( )
print ( " dns_hub: before exit() " )
sys . exit ( 0 )
main ( )