1
0
mirror of https://github.com/altlinux/gpupdate.git synced 2025-01-09 21:17:52 +03:00

Improved search for the required domain

This commit is contained in:
Valery Sinelnikov 2024-02-20 12:59:49 +04:00
parent 446fa532db
commit aeab315c3d

View File

@ -52,9 +52,8 @@ class smbcreds (smbopts):
self.credopts = options.CredentialsOptions(self.parser)
self.creds = self.credopts.get_credentials(self.lp, fallback_machine=True)
self.set_dc(dc_fqdn)
self.dc_site = SiteDomainScanner(self.creds, self.selected_dc).select_servers()
if self.dc_site is not None:
self.selected_dc = self.dc_site
self.dc_site_servers = SiteDomainScanner(self.creds, self.lp, self.selected_dc).select_servers()
def get_dc(self):
return self.selected_dc
@ -129,7 +128,11 @@ class smbcreds (smbopts):
gpos = self.get_gpos(username)
list_selected_dc = set()
list_selected_dc.add(self.selected_dc)
if len(self.dc_site_servers) > 1:
list_selected_dc.add(self.dc_site_servers.pop(0))
else:
list_selected_dc.add(self.selected_dc)
while list_selected_dc:
logdata = dict()
@ -143,16 +146,26 @@ class smbcreds (smbopts):
except NTSTATUSError as smb_exc:
logdata['smb_exc'] = str(smb_exc)
if not check_scroll_enabled():
log('F1', logdata)
raise smb_exc
self.selected_dc = get_dc_hostname(self.creds, self.lp)
if self.selected_dc not in list_selected_dc:
logdata['action'] = 'Search another dc'
log('W11', logdata)
list_selected_dc.add(self.selected_dc)
if self.dc_site_servers:
list_selected_dc.add(self.dc_site_servers.pop())
else:
log('F1', logdata)
raise smb_exc
else:
log('F1', logdata)
raise smb_exc
if len(self.dc_site_servers) > 1:
list_selected_dc.add(self.dc_site_servers.pop(0))
else:
self.selected_dc = get_dc_hostname(self.creds, self.lp)
if self.selected_dc not in list_selected_dc:
logdata['action'] = 'Search another dc'
log('W11', logdata)
list_selected_dc.add(self.selected_dc)
else:
if self.dc_site_servers:
list_selected_dc.add(self.dc_site_servers.pop())
else:
log('F1', logdata)
raise smb_exc
except Exception as exc:
logdata['exc'] = str(exc)
log('F1', logdata)
@ -161,13 +174,34 @@ class smbcreds (smbopts):
class SiteDomainScanner:
def __init__(self, smbcreds, dc):
self.smbcreds = smbcreds
parser = optparse.OptionParser(None)
sambaopts = options.SambaOptions(parser)
lp = sambaopts.get_loadparm()
def __init__(self, smbcreds, lp, dc):
self.samdb = SamDB(url='ldap://{}'.format(dc), session_info=system_session(), credentials=smbcreds, lp=lp)
@staticmethod
def _get_ldb_single_message_attr(ldb_message, attr_name, encoding='utf8'):
if attr_name in ldb_message:
return ldb_message[attr_name][0].decode(encoding)
else:
return None
@staticmethod
def _get_ldb_single_result_attr(ldb_result, attr_name, encoding='utf8'):
if len(ldb_result) == 1 and attr_name in ldb_result[0]:
return ldb_result[0][attr_name][0].decode(encoding)
else:
return None
def _get_server_hostname(self, ds_service_name):
ds_service_name_dn = ldb.Dn(self.samdb, ds_service_name)
server_dn = ds_service_name_dn.parent()
res = self.samdb.search(server_dn, scope=ldb.SCOPE_BASE)
return self._get_ldb_single_result_attr(res, 'dNSHostName')
def _search_pdc_emulator(self):
res = self.samdb.search(self.samdb.domain_dn(), scope=ldb.SCOPE_BASE)
pdc_settings_object = self._get_ldb_single_result_attr(res, 'fSMORoleOwner')
return self._get_server_hostname(pdc_settings_object)
def get_ip_addresses(self):
interface_list = netifaces.interfaces()
addresses = []
@ -184,7 +218,7 @@ class SiteDomainScanner:
config_dn = self.samdb.get_config_basedn()
subnet_dn.add_base(config_dn)
res = self.samdb.search(subnet_dn, ldb.SCOPE_ONELEVEL, expression='objectClass=subnet', attrs=['cn', 'siteObject'])
subnets = {ipaddress.ip_network(msg['cn'][0].decode('utf8')): msg['siteObject'][0].decode('utf8') for msg in res}
subnets = {ipaddress.ip_network(self._get_ldb_single_message_attr(msg, 'cn')): self._get_ldb_single_message_attr(msg, 'siteObject') for msg in res}
return subnets
def get_ad_site_servers(self, site):
@ -192,8 +226,8 @@ class SiteDomainScanner:
site_dn = ldb.Dn(self.samdb, site)
servers_dn.add_base(site_dn)
res = self.samdb.search(servers_dn, ldb.SCOPE_ONELEVEL, expression='objectClass=server', attrs=['dNSHostName'])
servers = [msg['dNSHostName'][0].decode('utf8') for msg in res]
return servers[0] if servers else None
servers = [self._get_ldb_single_message_attr(msg, 'dNSHostName') for msg in res]
return servers
def check_ip_in_subnets(self, ip_addresses, subnets_sites):
return next((subnets_sites[subnet] for subnet in subnets_sites.keys()
@ -205,12 +239,18 @@ class SiteDomainScanner:
subnets_sites = self.get_ad_subnets_sites()
our_site = self.check_ip_in_subnets(ip_addresses, subnets_sites)
pdc_emulator = self._search_pdc_emulator()
if our_site:
servers = self.get_ad_site_servers(our_site)
if pdc_emulator in servers:
pdc_index = servers.index(pdc_emulator)
servers.insert(0, servers.pop(pdc_index))
else:
servers.append(pdc_emulator)
return servers
else:
return None
return [pdc_emulator]
except Exception as e:
return None