1
0
mirror of https://github.com/samba-team/samba.git synced 2025-01-25 06:04:04 +03:00
Andreas Schneider a15b8611ce python:samba:emulate: Fix code spelling
Signed-off-by: Andreas Schneider <asn@samba.org>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
2023-06-23 13:44:31 +00:00

2416 lines
85 KiB
Python

# -*- encoding: utf-8 -*-
# Samba traffic replay and learning
#
# Copyright (C) Catalyst IT Ltd. 2017
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
import time
import os
import random
import json
import math
import sys
import signal
from errno import ECHILD, ESRCH
from collections import OrderedDict, Counter, defaultdict, namedtuple
from dns.resolver import query as dns_query
from samba.emulate import traffic_packets
from samba.samdb import SamDB
import ldb
from ldb import LdbError
from samba.dcerpc import ClientConnection
from samba.dcerpc import security, drsuapi, lsa
from samba.dcerpc import netlogon
from samba.dcerpc.netlogon import netr_Authenticator
from samba.dcerpc import srvsvc
from samba.dcerpc import samr
from samba.drs_utils import drs_DsBind
import traceback
from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
from samba.auth import system_session
from samba.dsdb import (
UF_NORMAL_ACCOUNT,
UF_SERVER_TRUST_ACCOUNT,
UF_TRUSTED_FOR_DELEGATION,
UF_WORKSTATION_TRUST_ACCOUNT
)
from samba.dcerpc.misc import SEC_CHAN_BDC
from samba import gensec
from samba import sd_utils
from samba.common import get_string
from samba.logger import get_samba_logger
import bisect
CURRENT_MODEL_VERSION = 2 # save as this
REQUIRED_MODEL_VERSION = 2 # load accepts this or greater
SLEEP_OVERHEAD = 3e-4
# we don't use None, because it complicates [de]serialisation
NON_PACKET = '-'
CLIENT_CLUES = {
('dns', '0'): 1.0, # query
('smb', '0x72'): 1.0, # Negotiate protocol
('ldap', '0'): 1.0, # bind
('ldap', '3'): 1.0, # searchRequest
('ldap', '2'): 1.0, # unbindRequest
('cldap', '3'): 1.0,
('dcerpc', '11'): 1.0, # bind
('dcerpc', '14'): 1.0, # Alter_context
('nbns', '0'): 1.0, # query
}
SERVER_CLUES = {
('dns', '1'): 1.0, # response
('ldap', '1'): 1.0, # bind response
('ldap', '4'): 1.0, # search result
('ldap', '5'): 1.0, # search done
('cldap', '5'): 1.0,
('dcerpc', '12'): 1.0, # bind_ack
('dcerpc', '13'): 1.0, # bind_nak
('dcerpc', '15'): 1.0, # Alter_context response
}
SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
WAIT_SCALE = 10.0
WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
NO_WAIT_LOG_TIME_RANGE = (-10, -3)
# DEBUG_LEVEL can be changed by scripts with -d
DEBUG_LEVEL = 0
LOGGER = get_samba_logger(name=__name__)
def debug(level, msg, *args):
"""Print a formatted debug message to standard error.
:param level: The debug level, message will be printed if it is <= the
currently set debug level. The debug level can be set with
the -d option.
:param msg: The message to be logged, can contain C-Style format
specifiers
:param args: The parameters required by the format specifiers
"""
if level <= DEBUG_LEVEL:
if not args:
print(msg, file=sys.stderr)
else:
print(msg % tuple(args), file=sys.stderr)
def debug_lineno(*args):
""" Print an unformatted log message to stderr, containing the line number
"""
tb = traceback.extract_stack(limit=2)
print((" %s:" "\033[01;33m"
"%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
file=sys.stderr)
for a in args:
print(a, file=sys.stderr)
print(file=sys.stderr)
sys.stderr.flush()
def random_colour_print(seeds):
"""Return a function that prints a coloured line to stderr. The colour
of the line depends on a sort of hash of the integer arguments."""
if seeds:
s = 214
for x in seeds:
s += 17
s *= x
s %= 214
prefix = "\033[38;5;%dm" % (18 + s)
def p(*args):
if DEBUG_LEVEL > 0:
for a in args:
print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
else:
def p(*args):
if DEBUG_LEVEL > 0:
for a in args:
print(a, file=sys.stderr)
return p
class FakePacketError(Exception):
pass
class Packet(object):
"""Details of a network packet"""
__slots__ = ('timestamp',
'ip_protocol',
'stream_number',
'src',
'dest',
'protocol',
'opcode',
'desc',
'extra',
'endpoints')
def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
protocol, opcode, desc, extra):
self.timestamp = timestamp
self.ip_protocol = ip_protocol
self.stream_number = stream_number
self.src = src
self.dest = dest
self.protocol = protocol
self.opcode = opcode
self.desc = desc
self.extra = extra
if self.src < self.dest:
self.endpoints = (self.src, self.dest)
else:
self.endpoints = (self.dest, self.src)
@classmethod
def from_line(cls, line):
fields = line.rstrip('\n').split('\t')
(timestamp,
ip_protocol,
stream_number,
src,
dest,
protocol,
opcode,
desc) = fields[:8]
extra = fields[8:]
timestamp = float(timestamp)
src = int(src)
dest = int(dest)
return cls(timestamp, ip_protocol, stream_number, src, dest,
protocol, opcode, desc, extra)
def as_summary(self, time_offset=0.0):
"""Format the packet as a traffic_summary line.
"""
extra = '\t'.join(self.extra)
t = self.timestamp + time_offset
return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
(t,
self.ip_protocol,
self.stream_number or '',
self.src,
self.dest,
self.protocol,
self.opcode,
self.desc,
extra))
def __str__(self):
return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
(self.timestamp, self.src, self.dest, self.ip_protocol or '-',
self.stream_number, self.protocol, self.opcode, self.desc,
('«' + ' '.join(self.extra) + '»' if self.extra else '')))
def __repr__(self):
return "<Packet @%s>" % self
def copy(self):
return self.__class__(self.timestamp,
self.ip_protocol,
self.stream_number,
self.src,
self.dest,
self.protocol,
self.opcode,
self.desc,
self.extra)
def as_packet_type(self):
t = '%s:%s' % (self.protocol, self.opcode)
return t
def client_score(self):
"""A positive number means we think it is a client; a negative number
means we think it is a server. Zero means no idea. range: -1 to 1.
"""
key = (self.protocol, self.opcode)
if key in CLIENT_CLUES:
return CLIENT_CLUES[key]
if key in SERVER_CLUES:
return -SERVER_CLUES[key]
return 0.0
def play(self, conversation, context):
"""Send the packet over the network, if required.
Some packets are ignored, i.e. for protocols not handled,
server response messages, or messages that are generated by the
protocol layer associated with other packets.
"""
fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
try:
fn = getattr(traffic_packets, fn_name)
except AttributeError as e:
print("Conversation(%s) Missing handler %s" %
(conversation.conversation_id, fn_name),
file=sys.stderr)
return
# Don't display a message for kerberos packets, they're not directly
# generated they're used to indicate kerberos should be used
if self.protocol != "kerberos":
debug(2, "Conversation(%s) Calling handler %s" %
(conversation.conversation_id, fn_name))
start = time.time()
try:
if fn(self, conversation, context):
# Only collect timing data for functions that generate
# network traffic, or fail
end = time.time()
duration = end - start
print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
(end, conversation.conversation_id, self.protocol,
self.opcode, duration))
except Exception as e:
end = time.time()
duration = end - start
print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
(end, conversation.conversation_id, self.protocol,
self.opcode, duration, e))
def __cmp__(self, other):
return self.timestamp - other.timestamp
def is_really_a_packet(self, missing_packet_stats=None):
return is_a_real_packet(self.protocol, self.opcode)
def is_a_real_packet(protocol, opcode):
"""Is the packet one that can be ignored?
If so removing it will have no effect on the replay
"""
if protocol in SKIPPED_PROTOCOLS:
# Ignore any packets for the protocols we're not interested in.
return False
if protocol == "ldap" and opcode == '':
# skip ldap continuation packets
return False
fn_name = 'packet_%s_%s' % (protocol, opcode)
fn = getattr(traffic_packets, fn_name, None)
if fn is None:
LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
return False
if fn is traffic_packets.null_packet:
return False
return True
def is_a_traffic_generating_packet(protocol, opcode):
"""Return true if a packet generates traffic in its own right. Some of
these will generate traffic in certain contexts (e.g. ldap unbind
after a bind) but not if the conversation consists only of these packets.
"""
if protocol == 'wait':
return False
if (protocol, opcode) in (
('kerberos', ''),
('ldap', '2'),
('dcerpc', '15'),
('dcerpc', '16')):
return False
return is_a_real_packet(protocol, opcode)
class ReplayContext(object):
"""State/Context for a conversation between an simulated client and a
server. Some of the context is shared amongst all conversations
and should be generated before the fork, while other context is
specific to a particular conversation and should be generated
*after* the fork, in generate_process_local_config().
"""
def __init__(self,
server=None,
lp=None,
creds=None,
total_conversations=None,
badpassword_frequency=None,
prefer_kerberos=None,
tempdir=None,
statsdir=None,
ou=None,
base_dn=None,
domain=os.environ.get("DOMAIN"),
domain_sid=None,
instance_id=None):
self.server = server
self.netlogon_connection = None
self.creds = creds
self.lp = lp
if prefer_kerberos:
self.kerberos_state = MUST_USE_KERBEROS
else:
self.kerberos_state = DONT_USE_KERBEROS
self.ou = ou
self.base_dn = base_dn
self.domain = domain
self.statsdir = statsdir
self.global_tempdir = tempdir
self.domain_sid = domain_sid
self.realm = lp.get('realm')
self.instance_id = instance_id
# Bad password attempt controls
self.badpassword_frequency = badpassword_frequency
self.last_lsarpc_bad = False
self.last_lsarpc_named_bad = False
self.last_simple_bind_bad = False
self.last_bind_bad = False
self.last_srvsvc_bad = False
self.last_drsuapi_bad = False
self.last_netlogon_bad = False
self.last_samlogon_bad = False
self.total_conversations = total_conversations
self.generate_ldap_search_tables()
def generate_ldap_search_tables(self):
session = system_session()
db = SamDB(url="ldap://%s" % self.server,
session_info=session,
credentials=self.creds,
lp=self.lp)
res = db.search(db.domain_dn(),
scope=ldb.SCOPE_SUBTREE,
controls=["paged_results:1:1000"],
attrs=['dn'])
# find a list of dns for each pattern
# e.g. CN,CN,CN,DC,DC
dn_map = {}
attribute_clue_map = {
'invocationId': []
}
for r in res:
dn = str(r.dn)
pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
dns = dn_map.setdefault(pattern, [])
dns.append(dn)
if dn.startswith('CN=NTDS Settings,'):
attribute_clue_map['invocationId'].append(dn)
# extend the map in case we are working with a different
# number of DC components.
# for k, v in self.dn_map.items():
# print >>sys.stderr, k, len(v)
for k in list(dn_map.keys()):
if k[-3:] != ',DC':
continue
p = k[:-3]
while p[-3:] == ',DC':
p = p[:-3]
for i in range(5):
p += ',DC'
if p != k and p in dn_map:
print('dn_map collision %s %s' % (k, p),
file=sys.stderr)
continue
dn_map[p] = dn_map[k]
self.dn_map = dn_map
self.attribute_clue_map = attribute_clue_map
# pre-populate DN-based search filters (it's simplest to generate them
# once, when the test starts). These are used by guess_search_filter()
# to avoid full-scans
self.search_filters = {}
# lookup all the GPO DNs
res = db.search(db.domain_dn(), scope=ldb.SCOPE_SUBTREE, attrs=['dn'],
expression='(objectclass=groupPolicyContainer)')
gpos_by_dn = "".join("(distinguishedName={0})".format(msg['dn']) for msg in res)
# a search for the 'gPCFileSysPath' attribute is probably a GPO search
# (as per the MS-GPOL spec) which searches for GPOs by DN
self.search_filters['gPCFileSysPath'] = "(|{0})".format(gpos_by_dn)
# likewise, a search for gpLink is probably the Domain SOM search part
# of the MS-GPOL, in which case it's looking up a few OUs by DN
ou_str = ""
for ou in ["Domain Controllers,", "traffic_replay,", ""]:
ou_str += "(distinguishedName={0}{1})".format(ou, db.domain_dn())
self.search_filters['gpLink'] = "(|{0})".format(ou_str)
# The CEP Web Service can query the AD DC to get pKICertificateTemplate
# objects (as per MS-WCCE)
self.search_filters['pKIExtendedKeyUsage'] = \
'(objectCategory=pKICertificateTemplate)'
# assume that anything querying the usnChanged is some kind of
# synchronization tool, e.g. AD Change Detection Connector
res = db.search('', scope=ldb.SCOPE_BASE, attrs=['highestCommittedUSN'])
self.search_filters['usnChanged'] = \
'(usnChanged>={0})'.format(res[0]['highestCommittedUSN'])
# The traffic_learner script doesn't preserve the LDAP search filter, and
# having no filter can result in a full DB scan. This is costly for a large
# DB, and not necessarily representative of real world traffic. As there
# several standard LDAP queries that get used by AD tools, we can apply
# some logic and guess what the search filter might have been originally.
def guess_search_filter(self, attrs, dn_sig, dn):
# there are some standard spec-based searches that query fairly unique
# attributes. Check if the search is likely one of these
for key in self.search_filters.keys():
if key in attrs:
return self.search_filters[key]
# if it's the top-level domain, assume we're looking up a single user,
# e.g. like powershell Get-ADUser or a similar tool
if dn_sig == 'DC,DC':
random_user_id = random.random() % self.total_conversations
account_name = user_name(self.instance_id, random_user_id)
return '(&(sAMAccountName=%s)(objectClass=user))' % account_name
# otherwise just return everything in the sub-tree
return '(objectClass=*)'
def generate_process_local_config(self, account, conversation):
self.ldap_connections = []
self.dcerpc_connections = []
self.lsarpc_connections = []
self.lsarpc_connections_named = []
self.drsuapi_connections = []
self.srvsvc_connections = []
self.samr_contexts = []
self.netbios_name = account.netbios_name
self.machinepass = account.machinepass
self.username = account.username
self.userpass = account.userpass
self.tempdir = mk_masked_dir(self.global_tempdir,
'conversation-%d' %
conversation.conversation_id)
self.lp.set("private dir", self.tempdir)
self.lp.set("lock dir", self.tempdir)
self.lp.set("state directory", self.tempdir)
self.lp.set("tls verify peer", "no_check")
self.remoteAddress = "/root/ncalrpc_as_system"
self.samlogon_dn = ("cn=%s,%s" %
(self.netbios_name, self.ou))
self.user_dn = ("cn=%s,%s" %
(self.username, self.ou))
self.generate_machine_creds()
self.generate_user_creds()
def with_random_bad_credentials(self, f, good, bad, failed_last_time):
"""Execute the supplied logon function, randomly choosing the
bad credentials.
Based on the frequency in badpassword_frequency randomly perform the
function with the supplied bad credentials.
If run with bad credentials, the function is re-run with the good
credentials.
failed_last_time is used to prevent consecutive bad credential
attempts. So the over all bad credential frequency will be lower
than that requested, but not significantly.
"""
if not failed_last_time:
if (self.badpassword_frequency and
random.random() < self.badpassword_frequency):
try:
f(bad)
except Exception:
# Ignore any exceptions as the operation may fail
# as it's being performed with bad credentials
pass
failed_last_time = True
else:
failed_last_time = False
result = f(good)
return (result, failed_last_time)
def generate_user_creds(self):
"""Generate the conversation specific user Credentials.
Each Conversation has an associated user account used to simulate
any non Administrative user traffic.
Generates user credentials with good and bad passwords and ldap
simple bind credentials with good and bad passwords.
"""
self.user_creds = Credentials()
self.user_creds.guess(self.lp)
self.user_creds.set_workstation(self.netbios_name)
self.user_creds.set_password(self.userpass)
self.user_creds.set_username(self.username)
self.user_creds.set_domain(self.domain)
self.user_creds.set_kerberos_state(self.kerberos_state)
self.user_creds_bad = Credentials()
self.user_creds_bad.guess(self.lp)
self.user_creds_bad.set_workstation(self.netbios_name)
self.user_creds_bad.set_password(self.userpass[:-4])
self.user_creds_bad.set_username(self.username)
self.user_creds_bad.set_kerberos_state(self.kerberos_state)
# Credentials for ldap simple bind.
self.simple_bind_creds = Credentials()
self.simple_bind_creds.guess(self.lp)
self.simple_bind_creds.set_workstation(self.netbios_name)
self.simple_bind_creds.set_password(self.userpass)
self.simple_bind_creds.set_username(self.username)
self.simple_bind_creds.set_gensec_features(
self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
self.simple_bind_creds.set_kerberos_state(self.kerberos_state)
self.simple_bind_creds.set_bind_dn(self.user_dn)
self.simple_bind_creds_bad = Credentials()
self.simple_bind_creds_bad.guess(self.lp)
self.simple_bind_creds_bad.set_workstation(self.netbios_name)
self.simple_bind_creds_bad.set_password(self.userpass[:-4])
self.simple_bind_creds_bad.set_username(self.username)
self.simple_bind_creds_bad.set_gensec_features(
self.simple_bind_creds_bad.get_gensec_features() |
gensec.FEATURE_SEAL)
self.simple_bind_creds_bad.set_kerberos_state(self.kerberos_state)
self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
def generate_machine_creds(self):
"""Generate the conversation specific machine Credentials.
Each Conversation has an associated machine account.
Generates machine credentials with good and bad passwords.
"""
self.machine_creds = Credentials()
self.machine_creds.guess(self.lp)
self.machine_creds.set_workstation(self.netbios_name)
self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
self.machine_creds.set_password(self.machinepass)
self.machine_creds.set_username(self.netbios_name + "$")
self.machine_creds.set_domain(self.domain)
self.machine_creds.set_kerberos_state(self.kerberos_state)
self.machine_creds_bad = Credentials()
self.machine_creds_bad.guess(self.lp)
self.machine_creds_bad.set_workstation(self.netbios_name)
self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
self.machine_creds_bad.set_password(self.machinepass[:-4])
self.machine_creds_bad.set_username(self.netbios_name + "$")
self.machine_creds_bad.set_kerberos_state(self.kerberos_state)
def get_matching_dn(self, pattern, attributes=None):
# If the pattern is an empty string, we assume ROOTDSE,
# Otherwise we try adding or removing DC suffixes, then
# shorter leading patterns until we hit one.
# e.g if there is no CN,CN,CN,CN,DC,DC
# we first try CN,CN,CN,CN,DC
# and CN,CN,CN,CN,DC,DC,DC
# then change to CN,CN,CN,DC,DC
# and as last resort we use the base_dn
attr_clue = self.attribute_clue_map.get(attributes)
if attr_clue:
return random.choice(attr_clue)
pattern = pattern.upper()
while pattern:
if pattern in self.dn_map:
return random.choice(self.dn_map[pattern])
# chop one off the front and try it all again.
pattern = pattern[3:]
return self.base_dn
def get_dcerpc_connection(self, new=False):
guid = '12345678-1234-abcd-ef00-01234567cffb' # RPC_NETLOGON UUID
if self.dcerpc_connections and not new:
return self.dcerpc_connections[-1]
c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
(guid, 1), self.lp)
self.dcerpc_connections.append(c)
return c
def get_srvsvc_connection(self, new=False):
if self.srvsvc_connections and not new:
return self.srvsvc_connections[-1]
def connect(creds):
return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
self.lp,
creds)
(c, self.last_srvsvc_bad) = \
self.with_random_bad_credentials(connect,
self.user_creds,
self.user_creds_bad,
self.last_srvsvc_bad)
self.srvsvc_connections.append(c)
return c
def get_lsarpc_connection(self, new=False):
if self.lsarpc_connections and not new:
return self.lsarpc_connections[-1]
def connect(creds):
binding_options = 'schannel,seal,sign'
return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
(self.server, binding_options),
self.lp,
creds)
(c, self.last_lsarpc_bad) = \
self.with_random_bad_credentials(connect,
self.machine_creds,
self.machine_creds_bad,
self.last_lsarpc_bad)
self.lsarpc_connections.append(c)
return c
def get_lsarpc_named_pipe_connection(self, new=False):
if self.lsarpc_connections_named and not new:
return self.lsarpc_connections_named[-1]
def connect(creds):
return lsa.lsarpc("ncacn_np:%s" % (self.server),
self.lp,
creds)
(c, self.last_lsarpc_named_bad) = \
self.with_random_bad_credentials(connect,
self.machine_creds,
self.machine_creds_bad,
self.last_lsarpc_named_bad)
self.lsarpc_connections_named.append(c)
return c
def get_drsuapi_connection_pair(self, new=False, unbind=False):
"""get a (drs, drs_handle) tuple"""
if self.drsuapi_connections and not new:
c = self.drsuapi_connections[-1]
return c
def connect(creds):
binding_options = 'seal'
binding_string = "ncacn_ip_tcp:%s[%s]" %\
(self.server, binding_options)
return drsuapi.drsuapi(binding_string, self.lp, creds)
(drs, self.last_drsuapi_bad) = \
self.with_random_bad_credentials(connect,
self.user_creds,
self.user_creds_bad,
self.last_drsuapi_bad)
(drs_handle, supported_extensions) = drs_DsBind(drs)
c = (drs, drs_handle)
self.drsuapi_connections.append(c)
return c
def get_ldap_connection(self, new=False, simple=False):
if self.ldap_connections and not new:
return self.ldap_connections[-1]
def simple_bind(creds):
"""
To run simple bind against Windows, we need to run
following commands in PowerShell:
Install-windowsfeature ADCS-Cert-Authority
Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
Restart-Computer
"""
return SamDB('ldaps://%s' % self.server,
credentials=creds,
lp=self.lp)
def sasl_bind(creds):
return SamDB('ldap://%s' % self.server,
credentials=creds,
lp=self.lp)
if simple:
(samdb, self.last_simple_bind_bad) = \
self.with_random_bad_credentials(simple_bind,
self.simple_bind_creds,
self.simple_bind_creds_bad,
self.last_simple_bind_bad)
else:
(samdb, self.last_bind_bad) = \
self.with_random_bad_credentials(sasl_bind,
self.user_creds,
self.user_creds_bad,
self.last_bind_bad)
self.ldap_connections.append(samdb)
return samdb
def get_samr_context(self, new=False):
if not self.samr_contexts or new:
self.samr_contexts.append(
SamrContext(self.server, lp=self.lp, creds=self.creds))
return self.samr_contexts[-1]
def get_netlogon_connection(self):
if self.netlogon_connection:
return self.netlogon_connection
def connect(creds):
return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
(self.server),
self.lp,
creds)
(c, self.last_netlogon_bad) = \
self.with_random_bad_credentials(connect,
self.machine_creds,
self.machine_creds_bad,
self.last_netlogon_bad)
self.netlogon_connection = c
return c
def guess_a_dns_lookup(self):
return (self.realm, 'A')
def get_authenticator(self):
auth = self.machine_creds.new_client_authenticator()
current = netr_Authenticator()
current.cred.data = [x if isinstance(x, int) else ord(x)
for x in auth["credential"]]
current.timestamp = auth["timestamp"]
subsequent = netr_Authenticator()
return (current, subsequent)
def write_stats(self, filename, **kwargs):
"""Write arbitrary key/value pairs to a file in our stats directory in
order for them to be picked up later by another process working out
statistics."""
filename = os.path.join(self.statsdir, filename)
f = open(filename, 'w')
for k, v in kwargs.items():
print("%s: %s" % (k, v), file=f)
f.close()
class SamrContext(object):
"""State/Context associated with a samr connection.
"""
def __init__(self, server, lp=None, creds=None):
self.connection = None
self.handle = None
self.domain_handle = None
self.domain_sid = None
self.group_handle = None
self.user_handle = None
self.rids = None
self.server = server
self.lp = lp
self.creds = creds
def get_connection(self):
if not self.connection:
self.connection = samr.samr(
"ncacn_ip_tcp:%s[seal]" % (self.server),
lp_ctx=self.lp,
credentials=self.creds)
return self.connection
def get_handle(self):
if not self.handle:
c = self.get_connection()
self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
return self.handle
class Conversation(object):
"""Details of a converation between a simulated client and a server."""
def __init__(self, start_time=None, endpoints=None, seq=(),
conversation_id=None):
self.start_time = start_time
self.endpoints = endpoints
self.packets = []
self.msg = random_colour_print(endpoints)
self.client_balance = 0.0
self.conversation_id = conversation_id
for p in seq:
self.add_short_packet(*p)
def __cmp__(self, other):
if self.start_time is None:
if other.start_time is None:
return 0
return -1
if other.start_time is None:
return 1
return self.start_time - other.start_time
def add_packet(self, packet):
"""Add a packet object to this conversation, making a local copy with
a conversation-relative timestamp."""
p = packet.copy()
if self.start_time is None:
self.start_time = p.timestamp
if self.endpoints is None:
self.endpoints = p.endpoints
if p.endpoints != self.endpoints:
raise FakePacketError("Conversation endpoints %s don't match"
"packet endpoints %s" %
(self.endpoints, p.endpoints))
p.timestamp -= self.start_time
if p.src == p.endpoints[0]:
self.client_balance -= p.client_score()
else:
self.client_balance += p.client_score()
if p.is_really_a_packet():
self.packets.append(p)
def add_short_packet(self, timestamp, protocol, opcode, extra,
client=True, skip_unused_packets=True):
"""Create a packet from a timestamp, and 'protocol:opcode' pair, and a
(possibly empty) list of extra data. If client is True, assume
this packet is from the client to the server.
"""
if skip_unused_packets and not is_a_real_packet(protocol, opcode):
return
src, dest = self.guess_client_server()
if not client:
src, dest = dest, src
key = (protocol, opcode)
desc = OP_DESCRIPTIONS.get(key, '')
ip_protocol = IP_PROTOCOLS.get(protocol, '06')
packet = Packet(timestamp - self.start_time, ip_protocol,
'', src, dest,
protocol, opcode, desc, extra)
# XXX we're assuming the timestamp is already adjusted for
# this conversation?
# XXX should we adjust client balance for guessed packets?
if packet.src == packet.endpoints[0]:
self.client_balance -= packet.client_score()
else:
self.client_balance += packet.client_score()
if packet.is_really_a_packet():
self.packets.append(packet)
def __str__(self):
return ("<Conversation %s %s starting %.3f %d packets>" %
(self.conversation_id, self.endpoints, self.start_time,
len(self.packets)))
__repr__ = __str__
def __iter__(self):
return iter(self.packets)
def __len__(self):
return len(self.packets)
def get_duration(self):
if len(self.packets) < 2:
return 0
return self.packets[-1].timestamp - self.packets[0].timestamp
def replay_as_summary_lines(self):
return [p.as_summary(self.start_time) for p in self.packets]
def replay_with_delay(self, start, context=None, account=None):
"""Replay the conversation at the right time.
(We're already in a fork)."""
# first we sleep until the first packet
t = self.start_time
now = time.time() - start
gap = t - now
sleep_time = gap - SLEEP_OVERHEAD
if sleep_time > 0:
time.sleep(sleep_time)
miss = (time.time() - start) - t
self.msg("starting %s [miss %.3f]" % (self, miss))
max_gap = 0.0
max_sleep_miss = 0.0
# packet times are relative to conversation start
p_start = time.time()
for p in self.packets:
now = time.time() - p_start
gap = now - p.timestamp
if gap > max_gap:
max_gap = gap
if gap < 0:
sleep_time = -gap - SLEEP_OVERHEAD
if sleep_time > 0:
time.sleep(sleep_time)
t = time.time() - p_start
if t - p.timestamp > max_sleep_miss:
max_sleep_miss = t - p.timestamp
p.play(self, context)
return max_gap, miss, max_sleep_miss
def guess_client_server(self, server_clue=None):
"""Have a go at deciding who is the server and who is the client.
returns (client, server)
"""
a, b = self.endpoints
if self.client_balance < 0:
return (a, b)
# in the absence of a clue, we will fall through to assuming
# the lowest number is the server (which is usually true).
if self.client_balance == 0 and server_clue == b:
return (a, b)
return (b, a)
def forget_packets_outside_window(self, s, e):
"""Prune any packets outside the time window we're interested in
:param s: start of the window
:param e: end of the window
"""
self.packets = [p for p in self.packets if s <= p.timestamp <= e]
self.start_time = self.packets[0].timestamp if self.packets else None
def renormalise_times(self, start_time):
"""Adjust the packet start times relative to the new start time."""
for p in self.packets:
p.timestamp -= start_time
if self.start_time is not None:
self.start_time -= start_time
class DnsHammer(Conversation):
"""A lightweight conversation that generates a lot of dns:0 packets on
the fly"""
def __init__(self, dns_rate, duration, query_file=None):
n = int(dns_rate * duration)
self.times = [random.uniform(0, duration) for i in range(n)]
self.times.sort()
self.rate = dns_rate
self.duration = duration
self.start_time = 0
self.query_choices = self._get_query_choices(query_file=query_file)
def __str__(self):
return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
(len(self.times), self.duration, self.rate))
def _get_query_choices(self, query_file=None):
"""
Read dns query choices from a file, or return default
rname may contain format string like `{realm}`
realm can be fetched from context.realm
"""
if query_file:
with open(query_file, 'r') as f:
text = f.read()
choices = []
for line in text.splitlines():
line = line.strip()
if line and not line.startswith('#'):
args = line.split(',')
assert len(args) == 4
choices.append(args)
return choices
else:
return [
(0, '{realm}', 'A', 'yes'),
(1, '{realm}', 'NS', 'yes'),
(2, '*.{realm}', 'A', 'no'),
(3, '*.{realm}', 'NS', 'no'),
(10, '_msdcs.{realm}', 'A', 'yes'),
(11, '_msdcs.{realm}', 'NS', 'yes'),
(20, 'nx.realm.com', 'A', 'no'),
(21, 'nx.realm.com', 'NS', 'no'),
(22, '*.nx.realm.com', 'A', 'no'),
(23, '*.nx.realm.com', 'NS', 'no'),
]
def replay(self, context=None):
assert context
assert context.realm
start = time.time()
for t in self.times:
now = time.time() - start
gap = t - now
sleep_time = gap - SLEEP_OVERHEAD
if sleep_time > 0:
time.sleep(sleep_time)
opcode, rname, rtype, exist = random.choice(self.query_choices)
rname = rname.format(realm=context.realm)
success = True
packet_start = time.time()
try:
answers = dns_query(rname, rtype)
if exist == 'yes' and not len(answers):
# expect answers but didn't get, fail
success = False
except Exception:
success = False
finally:
end = time.time()
duration = end - packet_start
print("%f\tDNS\tdns\t%s\t%f\t%s\t" % (end, opcode, duration, success))
def ingest_summaries(files, dns_mode='count'):
"""Load a summary traffic summary file and generated Converations from it.
"""
dns_counts = defaultdict(int)
packets = []
for f in files:
if isinstance(f, str):
f = open(f)
print("Ingesting %s" % (f.name,), file=sys.stderr)
for line in f:
p = Packet.from_line(line)
if p.protocol == 'dns' and dns_mode != 'include':
dns_counts[p.opcode] += 1
else:
packets.append(p)
f.close()
if not packets:
return [], 0
start_time = min(p.timestamp for p in packets)
last_packet = max(p.timestamp for p in packets)
print("gathering packets into conversations", file=sys.stderr)
conversations = OrderedDict()
for i, p in enumerate(packets):
p.timestamp -= start_time
c = conversations.get(p.endpoints)
if c is None:
c = Conversation(conversation_id=(i + 2))
conversations[p.endpoints] = c
c.add_packet(p)
# We only care about conversations with actual traffic, so we
# filter out conversations with nothing to say. We do that here,
# rather than earlier, because those empty packets contain useful
# hints as to which end of the conversation was the client.
conversation_list = []
for c in conversations.values():
if len(c) != 0:
conversation_list.append(c)
# This is obviously not correct, as many conversations will appear
# to start roughly simultaneously at the beginning of the snapshot.
# To which we say: oh well, so be it.
duration = float(last_packet - start_time)
mean_interval = len(conversations) / duration
return conversation_list, mean_interval, duration, dns_counts
def guess_server_address(conversations):
# we guess the most common address.
addresses = Counter()
for c in conversations:
addresses.update(c.endpoints)
if addresses:
return addresses.most_common(1)[0]
def stringify_keys(x):
y = {}
for k, v in x.items():
k2 = '\t'.join(k)
y[k2] = v
return y
def unstringify_keys(x):
y = {}
for k, v in x.items():
t = tuple(str(k).split('\t'))
y[t] = v
return y
class TrafficModel(object):
def __init__(self, n=3):
self.ngrams = {}
self.query_details = {}
self.n = n
self.dns_opcounts = defaultdict(int)
self.cumulative_duration = 0.0
self.packet_rate = [0, 1]
def learn(self, conversations, dns_opcounts=None):
if dns_opcounts is None:
dns_opcounts = {}
prev = 0.0
cum_duration = 0.0
key = (NON_PACKET,) * (self.n - 1)
server = guess_server_address(conversations)
for k, v in dns_opcounts.items():
self.dns_opcounts[k] += v
if len(conversations) > 1:
first = conversations[0].start_time
total = 0
last = first + 0.1
for c in conversations:
total += len(c)
last = max(last, c.packets[-1].timestamp)
self.packet_rate[0] = total
self.packet_rate[1] = last - first
for c in conversations:
client, server = c.guess_client_server(server)
cum_duration += c.get_duration()
key = (NON_PACKET,) * (self.n - 1)
for p in c:
if p.src != client:
continue
elapsed = p.timestamp - prev
prev = p.timestamp
if elapsed > WAIT_THRESHOLD:
# add the wait as an extra state
wait = 'wait:%d' % (math.log(max(1.0,
elapsed * WAIT_SCALE)))
self.ngrams.setdefault(key, []).append(wait)
key = key[1:] + (wait,)
short_p = p.as_packet_type()
self.query_details.setdefault(short_p,
[]).append(tuple(p.extra))
self.ngrams.setdefault(key, []).append(short_p)
key = key[1:] + (short_p,)
self.cumulative_duration += cum_duration
# add in the end
self.ngrams.setdefault(key, []).append(NON_PACKET)
def save(self, f):
ngrams = {}
for k, v in self.ngrams.items():
k = '\t'.join(k)
ngrams[k] = dict(Counter(v))
query_details = {}
for k, v in self.query_details.items():
query_details[k] = dict(Counter('\t'.join(x) if x else '-'
for x in v))
d = {
'ngrams': ngrams,
'query_details': query_details,
'cumulative_duration': self.cumulative_duration,
'packet_rate': self.packet_rate,
'version': CURRENT_MODEL_VERSION
}
d['dns'] = self.dns_opcounts
if isinstance(f, str):
f = open(f, 'w')
json.dump(d, f, indent=2)
def load(self, f):
if isinstance(f, str):
f = open(f)
d = json.load(f)
try:
version = d["version"]
if version < REQUIRED_MODEL_VERSION:
raise ValueError("the model file is version %d; "
"version %d is required" %
(version, REQUIRED_MODEL_VERSION))
except KeyError:
raise ValueError("the model file lacks a version number; "
"version %d is required" %
(REQUIRED_MODEL_VERSION))
for k, v in d['ngrams'].items():
k = tuple(str(k).split('\t'))
values = self.ngrams.setdefault(k, [])
for p, count in v.items():
values.extend([str(p)] * count)
values.sort()
for k, v in d['query_details'].items():
values = self.query_details.setdefault(str(k), [])
for p, count in v.items():
if p == '-':
values.extend([()] * count)
else:
values.extend([tuple(str(p).split('\t'))] * count)
values.sort()
if 'dns' in d:
for k, v in d['dns'].items():
self.dns_opcounts[k] += v
self.cumulative_duration = d['cumulative_duration']
self.packet_rate = d['packet_rate']
def construct_conversation_sequence(self, timestamp=0.0,
hard_stop=None,
replay_speed=1,
ignore_before=0,
persistence=0):
"""Construct an individual conversation packet sequence from the
model.
"""
c = []
key = (NON_PACKET,) * (self.n - 1)
if ignore_before is None:
ignore_before = timestamp - 1
while True:
p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
if p == NON_PACKET:
if timestamp < ignore_before:
break
if random.random() > persistence:
print("ending after %s (persistence %.1f)" % (key, persistence),
file=sys.stderr)
break
p = 'wait:%d' % random.randrange(5, 12)
print("trying %s instead of end" % p, file=sys.stderr)
if p in self.query_details:
extra = random.choice(self.query_details[p])
else:
extra = []
protocol, opcode = p.split(':', 1)
if protocol == 'wait':
log_wait_time = int(opcode) + random.random()
wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
timestamp += wait
else:
log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
wait = math.exp(log_wait) / replay_speed
timestamp += wait
if hard_stop is not None and timestamp > hard_stop:
break
if timestamp >= ignore_before:
c.append((timestamp, protocol, opcode, extra))
key = key[1:] + (p,)
if key[-2][:5] == 'wait:' and key[-1][:5] == 'wait:':
# two waits in a row can only be caused by "persistence"
# tricks, and will not result in any packets being found.
# Instead we pretend this is a fresh start.
key = (NON_PACKET,) * (self.n - 1)
return c
def scale_to_packet_rate(self, scale):
rate_n, rate_t = self.packet_rate
return scale * rate_n / rate_t
def packet_rate_to_scale(self, pps):
rate_n, rate_t = self.packet_rate
return pps * rate_t / rate_n
def generate_conversation_sequences(self, packet_rate, duration, replay_speed=1,
persistence=0):
"""Generate a list of conversation descriptions from the model."""
# We run the simulation for ten times as long as our desired
# duration, and take the section at the end.
lead_in = 9 * duration
target_packets = int(packet_rate * duration)
conversations = []
n_packets = 0
while n_packets < target_packets:
start = random.uniform(-lead_in, duration)
c = self.construct_conversation_sequence(start,
hard_stop=duration,
replay_speed=replay_speed,
ignore_before=0,
persistence=persistence)
# will these "packets" generate actual traffic?
# some (e.g. ldap unbind) will not generate anything
# if the previous packets are not there, and if the
# conversation only has those it wastes a process doing nothing.
for timestamp, protocol, opcode, extra in c:
if is_a_traffic_generating_packet(protocol, opcode):
break
else:
continue
conversations.append(c)
n_packets += len(c)
scale = self.packet_rate_to_scale(packet_rate)
print(("we have %d packets (target %d) in %d conversations at %.1f/s "
"(scale %f)" % (n_packets, target_packets, len(conversations),
packet_rate, scale)),
file=sys.stderr)
conversations.sort() # sorts by first element == start time
return conversations
def seq_to_conversations(seq, server=1, client=2):
conversations = []
for s in seq:
if s:
c = Conversation(s[0][0], (server, client), s)
client += 1
conversations.append(c)
return conversations
IP_PROTOCOLS = {
'dns': '11',
'rpc_netlogon': '06',
'kerberos': '06', # ratio 16248:258
'smb': '06',
'smb2': '06',
'ldap': '06',
'cldap': '11',
'lsarpc': '06',
'samr': '06',
'dcerpc': '06',
'epm': '06',
'drsuapi': '06',
'browser': '11',
'smb_netlogon': '11',
'srvsvc': '06',
'nbns': '11',
}
OP_DESCRIPTIONS = {
('browser', '0x01'): 'Host Announcement (0x01)',
('browser', '0x02'): 'Request Announcement (0x02)',
('browser', '0x08'): 'Browser Election Request (0x08)',
('browser', '0x09'): 'Get Backup List Request (0x09)',
('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
('browser', '0x0f'): 'Local Master Announcement (0x0f)',
('cldap', '3'): 'searchRequest',
('cldap', '5'): 'searchResDone',
('dcerpc', '0'): 'Request',
('dcerpc', '11'): 'Bind',
('dcerpc', '12'): 'Bind_ack',
('dcerpc', '13'): 'Bind_nak',
('dcerpc', '14'): 'Alter_context',
('dcerpc', '15'): 'Alter_context_resp',
('dcerpc', '16'): 'AUTH3',
('dcerpc', '2'): 'Response',
('dns', '0'): 'query',
('dns', '1'): 'response',
('drsuapi', '0'): 'DsBind',
('drsuapi', '12'): 'DsCrackNames',
('drsuapi', '13'): 'DsWriteAccountSpn',
('drsuapi', '1'): 'DsUnbind',
('drsuapi', '2'): 'DsReplicaSync',
('drsuapi', '3'): 'DsGetNCChanges',
('drsuapi', '4'): 'DsReplicaUpdateRefs',
('epm', '3'): 'Map',
('kerberos', ''): '',
('ldap', '0'): 'bindRequest',
('ldap', '1'): 'bindResponse',
('ldap', '2'): 'unbindRequest',
('ldap', '3'): 'searchRequest',
('ldap', '4'): 'searchResEntry',
('ldap', '5'): 'searchResDone',
('ldap', ''): '*** Unknown ***',
('lsarpc', '14'): 'lsa_LookupNames',
('lsarpc', '15'): 'lsa_LookupSids',
('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
('lsarpc', '6'): 'lsa_OpenPolicy',
('lsarpc', '76'): 'lsa_LookupSids3',
('lsarpc', '77'): 'lsa_LookupNames4',
('nbns', '0'): 'query',
('nbns', '1'): 'response',
('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
('rpc_netlogon', '4'): 'NetrServerReqChallenge',
('samr', '0',): 'Connect',
('samr', '16'): 'GetAliasMembership',
('samr', '17'): 'LookupNames',
('samr', '18'): 'LookupRids',
('samr', '19'): 'OpenGroup',
('samr', '1'): 'Close',
('samr', '25'): 'QueryGroupMember',
('samr', '34'): 'OpenUser',
('samr', '36'): 'QueryUserInfo',
('samr', '39'): 'GetGroupsForUser',
('samr', '3'): 'QuerySecurity',
('samr', '5'): 'LookupDomain',
('samr', '64'): 'Connect5',
('samr', '6'): 'EnumDomains',
('samr', '7'): 'OpenDomain',
('samr', '8'): 'QueryDomainInfo',
('smb', '0x04'): 'Close (0x04)',
('smb', '0x24'): 'Locking AndX (0x24)',
('smb', '0x2e'): 'Read AndX (0x2e)',
('smb', '0x32'): 'Trans2 (0x32)',
('smb', '0x71'): 'Tree Disconnect (0x71)',
('smb', '0x72'): 'Negotiate Protocol (0x72)',
('smb', '0x73'): 'Session Setup AndX (0x73)',
('smb', '0x74'): 'Logoff AndX (0x74)',
('smb', '0x75'): 'Tree Connect AndX (0x75)',
('smb', '0xa2'): 'NT Create AndX (0xa2)',
('smb2', '0'): 'NegotiateProtocol',
('smb2', '11'): 'Ioctl',
('smb2', '14'): 'Find',
('smb2', '16'): 'GetInfo',
('smb2', '18'): 'Break',
('smb2', '1'): 'SessionSetup',
('smb2', '2'): 'SessionLogoff',
('smb2', '3'): 'TreeConnect',
('smb2', '4'): 'TreeDisconnect',
('smb2', '5'): 'Create',
('smb2', '6'): 'Close',
('smb2', '8'): 'Read',
('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
'user unknown (0x17)'),
('srvsvc', '16'): 'NetShareGetInfo',
('srvsvc', '21'): 'NetSrvGetInfo',
}
def expand_short_packet(p, timestamp, src, dest, extra):
protocol, opcode = p.split(':', 1)
desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
ip_protocol = IP_PROTOCOLS.get(protocol, '06')
line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
line.extend(extra)
return '\t'.join(line)
def flushing_signal_handler(signal, frame):
"""Signal handler closes standard out and error.
Triggered by a sigterm, ensures that the log messages are flushed
to disk and not lost.
"""
sys.stderr.close()
sys.stdout.close()
os._exit(0)
def replay_seq_in_fork(cs, start, context, account, client_id, server_id=1):
"""Fork a new process and replay the conversation sequence."""
# We will need to reseed the random number generator or all the
# clients will end up using the same sequence of random
# numbers. random.randint() is mixed in so the initial seed will
# have an effect here.
seed = client_id * 1000 + random.randint(0, 999)
# flush our buffers so messages won't be written by both sides
sys.stdout.flush()
sys.stderr.flush()
pid = os.fork()
if pid != 0:
return pid
# we must never return, or we'll end up running parts of the
# parent's clean-up code. So we work in a try...finally, and
# try to print any exceptions.
try:
random.seed(seed)
endpoints = (server_id, client_id)
status = 0
t = cs[0][0]
c = Conversation(t, endpoints, seq=cs, conversation_id=client_id)
signal.signal(signal.SIGTERM, flushing_signal_handler)
context.generate_process_local_config(account, c)
sys.stdin.close()
os.close(0)
filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
c.conversation_id)
f = open(filename, 'w')
try:
sys.stdout.close()
os.close(1)
except IOError as e:
LOGGER.info("stdout closing failed with %s" % e)
sys.stdout = f
now = time.time() - start
gap = t - now
sleep_time = gap - SLEEP_OVERHEAD
if sleep_time > 0:
time.sleep(sleep_time)
max_lag, start_lag, max_sleep_miss = c.replay_with_delay(start=start,
context=context)
print("Maximum lag: %f" % max_lag)
print("Start lag: %f" % start_lag)
print("Max sleep miss: %f" % max_sleep_miss)
except Exception:
status = 1
print(("EXCEPTION in child PID %d, conversation %s" % (os.getpid(), c)),
file=sys.stderr)
traceback.print_exc(sys.stderr)
sys.stderr.flush()
finally:
sys.stderr.close()
sys.stdout.close()
os._exit(status)
def dnshammer_in_fork(dns_rate, duration, context, query_file=None):
sys.stdout.flush()
sys.stderr.flush()
pid = os.fork()
if pid != 0:
return pid
sys.stdin.close()
os.close(0)
try:
sys.stdout.close()
os.close(1)
except IOError as e:
LOGGER.warn("stdout closing failed with %s" % e)
filename = os.path.join(context.statsdir, 'stats-dns')
sys.stdout = open(filename, 'w')
try:
status = 0
signal.signal(signal.SIGTERM, flushing_signal_handler)
hammer = DnsHammer(dns_rate, duration, query_file=query_file)
hammer.replay(context=context)
except Exception:
status = 1
print(("EXCEPTION in child PID %d, the DNS hammer" % (os.getpid())),
file=sys.stderr)
traceback.print_exc(sys.stderr)
finally:
sys.stderr.close()
sys.stdout.close()
os._exit(status)
def replay(conversation_seq,
host=None,
creds=None,
lp=None,
accounts=None,
dns_rate=0,
dns_query_file=None,
duration=None,
latency_timeout=1.0,
stop_on_any_error=False,
**kwargs):
context = ReplayContext(server=host,
creds=creds,
lp=lp,
total_conversations=len(conversation_seq),
**kwargs)
if len(accounts) < len(conversation_seq):
raise ValueError(("we have %d accounts but %d conversations" %
(len(accounts), len(conversation_seq))))
# Set the process group so that the calling scripts are not killed
# when the forked child processes are killed.
os.setpgrp()
# we delay the start by a bit to allow all the forks to get up and
# running.
delay = len(conversation_seq) * 0.02
start = time.time() + delay
if duration is None:
# end slightly after the last packet of the last conversation
# to start. Conversations other than the last could still be
# going, but we don't care.
duration = conversation_seq[-1][-1][0] + latency_timeout
print("We will start in %.1f seconds" % delay,
file=sys.stderr)
print("We will stop after %.1f seconds" % (duration + delay),
file=sys.stderr)
print("runtime %.1f seconds" % duration,
file=sys.stderr)
# give one second grace for packets to finish before killing begins
end = start + duration + 1.0
LOGGER.info("Replaying traffic for %u conversations over %d seconds"
% (len(conversation_seq), duration))
context.write_stats('intentions',
Planned_conversations=len(conversation_seq),
Planned_packets=sum(len(x) for x in conversation_seq))
children = {}
try:
if dns_rate:
pid = dnshammer_in_fork(dns_rate, duration, context,
query_file=dns_query_file)
children[pid] = 1
for i, cs in enumerate(conversation_seq):
account = accounts[i]
client_id = i + 2
pid = replay_seq_in_fork(cs, start, context, account, client_id)
children[pid] = client_id
# HERE, we are past all the forks
t = time.time()
print("all forks done in %.1f seconds, waiting %.1f" %
(t - start + delay, t - start),
file=sys.stderr)
while time.time() < end and children:
time.sleep(0.003)
try:
pid, status = os.waitpid(-1, os.WNOHANG)
except OSError as e:
if e.errno != ECHILD: # no child processes
raise
break
if pid:
c = children.pop(pid, None)
if DEBUG_LEVEL > 0:
print(("process %d finished conversation %d;"
" %d to go" %
(pid, c, len(children))), file=sys.stderr)
if stop_on_any_error and status != 0:
break
except Exception:
print("EXCEPTION in parent", file=sys.stderr)
traceback.print_exc()
finally:
context.write_stats('unfinished',
Unfinished_conversations=len(children))
for s in (15, 15, 9):
print(("killing %d children with -%d" %
(len(children), s)), file=sys.stderr)
for pid in children:
try:
os.kill(pid, s)
except OSError as e:
if e.errno != ESRCH: # don't fail if it has already died
raise
time.sleep(0.5)
end = time.time() + 1
while children:
try:
pid, status = os.waitpid(-1, os.WNOHANG)
except OSError as e:
if e.errno != ECHILD:
raise
if pid != 0:
c = children.pop(pid, None)
if c is None:
print("children is %s, no pid found" % children)
sys.stderr.flush()
sys.stdout.flush()
os._exit(1)
print(("kill -%d %d KILLED conversation; "
"%d to go" %
(s, pid, len(children))),
file=sys.stderr)
if time.time() >= end:
break
if not children:
break
time.sleep(1)
if children:
print("%d children are missing" % len(children),
file=sys.stderr)
# there may be stragglers that were forked just as ^C was hit
# and don't appear in the list of children. We can get them
# with killpg, but that will also kill us, so this is^H^H would be
# goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
# so as not to have to fuss around writing signal handlers.
try:
os.killpg(0, 2)
except KeyboardInterrupt:
print("ignoring fake ^C", file=sys.stderr)
def openLdb(host, creds, lp):
session = system_session()
ldb = SamDB(url="ldap://%s" % host,
session_info=session,
options=['modules:paged_searches'],
credentials=creds,
lp=lp)
return ldb
def ou_name(ldb, instance_id):
"""Generate an ou name from the instance id"""
return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
ldb.domain_dn())
def create_ou(ldb, instance_id):
"""Create an ou, all created user and machine accounts will belong to it.
This allows all the created resources to be cleaned up easily.
"""
ou = ou_name(ldb, instance_id)
try:
ldb.add({"dn": ou.split(',', 1)[1],
"objectclass": "organizationalunit"})
except LdbError as e:
(status, _) = e.args
# ignore already exists
if status != 68:
raise
try:
ldb.add({"dn": ou,
"objectclass": "organizationalunit"})
except LdbError as e:
(status, _) = e.args
# ignore already exists
if status != 68:
raise
return ou
# ConversationAccounts holds details of the machine and user accounts
# associated with a conversation.
#
# We use a named tuple to reduce shared memory usage.
ConversationAccounts = namedtuple('ConversationAccounts',
('netbios_name',
'machinepass',
'username',
'userpass'))
def generate_replay_accounts(ldb, instance_id, number, password):
"""Generate a series of unique machine and user account names."""
accounts = []
for i in range(1, number + 1):
netbios_name = machine_name(instance_id, i)
username = user_name(instance_id, i)
account = ConversationAccounts(netbios_name, password, username,
password)
accounts.append(account)
return accounts
def create_machine_account(ldb, instance_id, netbios_name, machinepass,
traffic_account=True):
"""Create a machine account via ldap."""
ou = ou_name(ldb, instance_id)
dn = "cn=%s,%s" % (netbios_name, ou)
utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
if traffic_account:
# we set these bits for the machine account otherwise the replayed
# traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
account_controls = str(UF_TRUSTED_FOR_DELEGATION |
UF_SERVER_TRUST_ACCOUNT)
else:
account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
ldb.add({
"dn": dn,
"objectclass": "computer",
"sAMAccountName": "%s$" % netbios_name,
"userAccountControl": account_controls,
"unicodePwd": utf16pw})
def create_user_account(ldb, instance_id, username, userpass):
"""Create a user account via ldap."""
ou = ou_name(ldb, instance_id)
user_dn = "cn=%s,%s" % (username, ou)
utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
ldb.add({
"dn": user_dn,
"objectclass": "user",
"sAMAccountName": username,
"userAccountControl": str(UF_NORMAL_ACCOUNT),
"unicodePwd": utf16pw
})
# grant user write permission to do things like write account SPN
sdutils = sd_utils.SDUtils(ldb)
sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
def create_group(ldb, instance_id, name):
"""Create a group via ldap."""
ou = ou_name(ldb, instance_id)
dn = "cn=%s,%s" % (name, ou)
ldb.add({
"dn": dn,
"objectclass": "group",
"sAMAccountName": name,
})
def user_name(instance_id, i):
"""Generate a user name based in the instance id"""
return "STGU-%d-%d" % (instance_id, i)
def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
"""Search objectclass, return attr in a set"""
objs = ldb.search(
expression="(objectClass={})".format(objectclass),
attrs=[attr]
)
return {str(obj[attr]) for obj in objs}
def generate_users(ldb, instance_id, number, password):
"""Add users to the server"""
existing_objects = search_objectclass(ldb, objectclass='user')
users = 0
for i in range(number, 0, -1):
name = user_name(instance_id, i)
if name not in existing_objects:
create_user_account(ldb, instance_id, name, password)
users += 1
if users % 50 == 0:
LOGGER.info("Created %u/%u users" % (users, number))
return users
def machine_name(instance_id, i, traffic_account=True):
"""Generate a machine account name from instance id."""
if traffic_account:
# traffic accounts correspond to a given user, and use different
# userAccountControl flags to ensure packets get processed correctly
# by the DC
return "STGM-%d-%d" % (instance_id, i)
else:
# Otherwise we're just generating computer accounts to simulate a
# semi-realistic network. These use the default computer
# userAccountControl flags, so we use a different account name so that
# we don't try to use them when generating packets
return "PC-%d-%d" % (instance_id, i)
def generate_machine_accounts(ldb, instance_id, number, password,
traffic_account=True):
"""Add machine accounts to the server"""
existing_objects = search_objectclass(ldb, objectclass='computer')
added = 0
for i in range(number, 0, -1):
name = machine_name(instance_id, i, traffic_account)
if name + "$" not in existing_objects:
create_machine_account(ldb, instance_id, name, password,
traffic_account)
added += 1
if added % 50 == 0:
LOGGER.info("Created %u/%u machine accounts" % (added, number))
return added
def group_name(instance_id, i):
"""Generate a group name from instance id."""
return "STGG-%d-%d" % (instance_id, i)
def generate_groups(ldb, instance_id, number):
"""Create the required number of groups on the server."""
existing_objects = search_objectclass(ldb, objectclass='group')
groups = 0
for i in range(number, 0, -1):
name = group_name(instance_id, i)
if name not in existing_objects:
create_group(ldb, instance_id, name)
groups += 1
if groups % 1000 == 0:
LOGGER.info("Created %u/%u groups" % (groups, number))
return groups
def clean_up_accounts(ldb, instance_id):
"""Remove the created accounts and groups from the server."""
ou = ou_name(ldb, instance_id)
try:
ldb.delete(ou, ["tree_delete:1"])
except LdbError as e:
(status, _) = e.args
# ignore does not exist
if status != 32:
raise
def generate_users_and_groups(ldb, instance_id, password,
number_of_users, number_of_groups,
group_memberships, max_members,
machine_accounts, traffic_accounts=True):
"""Generate the required users and groups, allocating the users to
those groups."""
memberships_added = 0
groups_added = 0
computers_added = 0
create_ou(ldb, instance_id)
LOGGER.info("Generating dummy user accounts")
users_added = generate_users(ldb, instance_id, number_of_users, password)
LOGGER.info("Generating dummy machine accounts")
computers_added = generate_machine_accounts(ldb, instance_id,
machine_accounts, password,
traffic_accounts)
if number_of_groups > 0:
LOGGER.info("Generating dummy groups")
groups_added = generate_groups(ldb, instance_id, number_of_groups)
if group_memberships > 0:
LOGGER.info("Assigning users to groups")
assignments = GroupAssignments(number_of_groups,
groups_added,
number_of_users,
users_added,
group_memberships,
max_members)
LOGGER.info("Adding users to groups")
add_users_to_groups(ldb, instance_id, assignments)
memberships_added = assignments.total()
if (groups_added > 0 and users_added == 0 and
number_of_groups != groups_added):
LOGGER.warning("The added groups will contain no members")
LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
(users_added, computers_added, groups_added,
memberships_added))
class GroupAssignments(object):
def __init__(self, number_of_groups, groups_added, number_of_users,
users_added, group_memberships, max_members):
self.count = 0
self.generate_group_distribution(number_of_groups)
self.generate_user_distribution(number_of_users, group_memberships)
self.max_members = max_members
self.assignments = defaultdict(list)
self.assign_groups(number_of_groups, groups_added, number_of_users,
users_added, group_memberships)
def cumulative_distribution(self, weights):
# make sure the probabilities conform to a cumulative distribution
# spread between 0.0 and 1.0. Dividing by the weighted total gives each
# probability a proportional share of 1.0. Higher probabilities get a
# bigger share, so are more likely to be picked. We use the cumulative
# value, so we can use random.random() as a simple index into the list
dist = []
total = sum(weights)
if total == 0:
return None
cumulative = 0.0
for probability in weights:
cumulative += probability
dist.append(cumulative / total)
return dist
def generate_user_distribution(self, num_users, num_memberships):
"""Probability distribution of a user belonging to a group.
"""
# Assign a weighted probability to each user. Use the Pareto
# Distribution so that some users are in a lot of groups, and the
# bulk of users are in only a few groups. If we're assigning a large
# number of group memberships, use a higher shape. This means slightly
# fewer outlying users that are in large numbers of groups. The aim is
# to have no users belonging to more than ~500 groups.
if num_memberships > 5000000:
shape = 3.0
elif num_memberships > 2000000:
shape = 2.5
elif num_memberships > 300000:
shape = 2.25
else:
shape = 1.75
weights = []
for x in range(1, num_users + 1):
p = random.paretovariate(shape)
weights.append(p)
# convert the weights to a cumulative distribution between 0.0 and 1.0
self.user_dist = self.cumulative_distribution(weights)
def generate_group_distribution(self, n):
"""Probability distribution of a group containing a user."""
# Assign a weighted probability to each user. Probability decreases
# as the group-ID increases
weights = []
for x in range(1, n + 1):
p = 1 / (x**1.3)
weights.append(p)
# convert the weights to a cumulative distribution between 0.0 and 1.0
self.group_weights = weights
self.group_dist = self.cumulative_distribution(weights)
def generate_random_membership(self):
"""Returns a randomly generated user-group membership"""
# the list items are cumulative distribution values between 0.0 and
# 1.0, which makes random() a handy way to index the list to get a
# weighted random user/group. (Here the user/group returned are
# zero-based array indexes)
user = bisect.bisect(self.user_dist, random.random())
group = bisect.bisect(self.group_dist, random.random())
return user, group
def users_in_group(self, group):
return self.assignments[group]
def get_groups(self):
return self.assignments.keys()
def cap_group_membership(self, group, max_members):
"""Prevent the group's membership from exceeding the max specified"""
num_members = len(self.assignments[group])
if num_members >= max_members:
LOGGER.info("Group {0} has {1} members".format(group, num_members))
# remove this group and then recalculate the cumulative
# distribution, so this group is no longer selected
self.group_weights[group - 1] = 0
new_dist = self.cumulative_distribution(self.group_weights)
self.group_dist = new_dist
def add_assignment(self, user, group):
# the assignments are stored in a dictionary where key=group,
# value=list-of-users-in-group (indexing by group-ID allows us to
# optimize for DB membership writes)
if user not in self.assignments[group]:
self.assignments[group].append(user)
self.count += 1
# check if there'a cap on how big the groups can grow
if self.max_members:
self.cap_group_membership(group, self.max_members)
def assign_groups(self, number_of_groups, groups_added,
number_of_users, users_added, group_memberships):
"""Allocate users to groups.
The intention is to have a few users that belong to most groups, while
the majority of users belong to a few groups.
A few groups will contain most users, with the remaining only having a
few users.
"""
if group_memberships <= 0:
return
# Calculate the number of group menberships required
group_memberships = math.ceil(
float(group_memberships) *
(float(users_added) / float(number_of_users)))
if self.max_members:
group_memberships = min(group_memberships,
self.max_members * number_of_groups)
existing_users = number_of_users - users_added - 1
existing_groups = number_of_groups - groups_added - 1
while self.total() < group_memberships:
user, group = self.generate_random_membership()
if group > existing_groups or user > existing_users:
# the + 1 converts the array index to the corresponding
# group or user number
self.add_assignment(user + 1, group + 1)
def total(self):
return self.count
def add_users_to_groups(db, instance_id, assignments):
"""Takes the assignments of users to groups and applies them to the DB."""
total = assignments.total()
count = 0
added = 0
for group in assignments.get_groups():
users_in_group = assignments.users_in_group(group)
if len(users_in_group) == 0:
continue
# Split up the users into chunks, so we write no more than 1K at a
# time. (Minimizing the DB modifies is more efficient, but writing
# 10K+ users to a single group becomes inefficient memory-wise)
for chunk in range(0, len(users_in_group), 1000):
chunk_of_users = users_in_group[chunk:chunk + 1000]
add_group_members(db, instance_id, group, chunk_of_users)
added += len(chunk_of_users)
count += 1
if count % 50 == 0:
LOGGER.info("Added %u/%u memberships" % (added, total))
def add_group_members(db, instance_id, group, users_in_group):
"""Adds the given users to group specified."""
ou = ou_name(db, instance_id)
def build_dn(name):
return("cn=%s,%s" % (name, ou))
group_dn = build_dn(group_name(instance_id, group))
m = ldb.Message()
m.dn = ldb.Dn(db, group_dn)
for user in users_in_group:
user_dn = build_dn(user_name(instance_id, user))
idx = "member-" + str(user)
m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
db.modify(m)
def generate_stats(statsdir, timing_file):
"""Generate and print the summary stats for a run."""
first = sys.float_info.max
last = 0
successful = 0
failed = 0
latencies = {}
failures = Counter()
unique_conversations = set()
if timing_file is not None:
tw = timing_file.write
else:
def tw(x):
pass
tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
float_values = {
'Maximum lag': 0,
'Start lag': 0,
'Max sleep miss': 0,
}
int_values = {
'Planned_conversations': 0,
'Planned_packets': 0,
'Unfinished_conversations': 0,
}
for filename in os.listdir(statsdir):
path = os.path.join(statsdir, filename)
with open(path, 'r') as f:
for line in f:
try:
fields = line.rstrip('\n').split('\t')
conversation = fields[1]
protocol = fields[2]
packet_type = fields[3]
latency = float(fields[4])
t = float(fields[0])
first = min(t - latency, first)
last = max(t, last)
op = (protocol, packet_type)
latencies.setdefault(op, []).append(latency)
if fields[5] == 'True':
successful += 1
else:
failed += 1
failures[op] += 1
unique_conversations.add(conversation)
tw(line)
except (ValueError, IndexError):
if ':' in line:
k, v = line.split(':', 1)
if k in float_values:
float_values[k] = max(float(v),
float_values[k])
elif k in int_values:
int_values[k] = max(int(v),
int_values[k])
else:
print(line, file=sys.stderr)
else:
# not a valid line print and ignore
print(line, file=sys.stderr)
duration = last - first
if successful == 0:
success_rate = 0
else:
success_rate = successful / duration
if failed == 0:
failure_rate = 0
else:
failure_rate = failed / duration
conversations = len(unique_conversations)
print("Total conversations: %10d" % conversations)
print("Successful operations: %10d (%.3f per second)"
% (successful, success_rate))
print("Failed operations: %10d (%.3f per second)"
% (failed, failure_rate))
for k, v in sorted(float_values.items()):
print("%-28s %f" % (k.replace('_', ' ') + ':', v))
for k, v in sorted(int_values.items()):
print("%-28s %d" % (k.replace('_', ' ') + ':', v))
print("Protocol Op Code Description "
" Count Failed Mean Median "
"95% Range Max")
ops = {}
for proto, packet in latencies:
if proto not in ops:
ops[proto] = set()
ops[proto].add(packet)
protocols = sorted(ops.keys())
for protocol in protocols:
packet_types = sorted(ops[protocol], key=opcode_key)
for packet_type in packet_types:
op = (protocol, packet_type)
values = latencies[op]
values = sorted(values)
count = len(values)
failed = failures[op]
mean = sum(values) / count
median = calc_percentile(values, 0.50)
percentile = calc_percentile(values, 0.95)
rng = values[-1] - values[0]
maxv = values[-1]
desc = OP_DESCRIPTIONS.get(op, '')
print("%-12s %4s %-35s %12d %12d %12.6f "
"%12.6f %12.6f %12.6f %12.6f"
% (protocol,
packet_type,
desc,
count,
failed,
mean,
median,
percentile,
rng,
maxv))
def opcode_key(v):
"""Sort key for the operation code to ensure that it sorts numerically"""
try:
return "%03d" % int(v)
except ValueError:
return v
def calc_percentile(values, percentile):
"""Calculate the specified percentile from the list of values.
Assumes the list is sorted in ascending order.
"""
if not values:
return 0
k = (len(values) - 1) * percentile
f = math.floor(k)
c = math.ceil(k)
if f == c:
return values[int(k)]
d0 = values[int(f)] * (c - k)
d1 = values[int(c)] * (k - f)
return d0 + d1
def mk_masked_dir(*path):
"""In a testenv we end up with 0777 directories that look an alarming
green colour with ls. Use umask to avoid that."""
# py3 os.mkdir can do this
d = os.path.join(*path)
mask = os.umask(0o077)
os.mkdir(d)
os.umask(mask)
return d