mirror of
https://github.com/samba-team/samba.git
synced 2025-01-10 01:18:15 +03:00
25e6d7c6a3
Signed-off-by: Douglas Bagnall <douglas.bagnall@catalyst.net.nz> Reviewed-by: Andrew Bartlett <abartlet@samba.org>
2415 lines
85 KiB
Python
2415 lines
85 KiB
Python
# -*- encoding: utf-8 -*-
|
|
# Samba traffic replay and learning
|
|
#
|
|
# Copyright (C) Catalyst IT Ltd. 2017
|
|
#
|
|
# This program is free software; you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation; either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
#
|
|
|
|
import time
|
|
import os
|
|
import random
|
|
import json
|
|
import math
|
|
import sys
|
|
import signal
|
|
from errno import ECHILD, ESRCH
|
|
|
|
from collections import OrderedDict, Counter, defaultdict, namedtuple
|
|
from dns.resolver import query as dns_query
|
|
|
|
from samba.emulate import traffic_packets
|
|
from samba.samdb import SamDB
|
|
import ldb
|
|
from ldb import LdbError
|
|
from samba.dcerpc import ClientConnection
|
|
from samba.dcerpc import security, drsuapi, lsa
|
|
from samba.dcerpc import netlogon
|
|
from samba.dcerpc.netlogon import netr_Authenticator
|
|
from samba.dcerpc import srvsvc
|
|
from samba.dcerpc import samr
|
|
from samba.drs_utils import drs_DsBind
|
|
import traceback
|
|
from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
|
|
from samba.auth import system_session
|
|
from samba.dsdb import (
|
|
UF_NORMAL_ACCOUNT,
|
|
UF_SERVER_TRUST_ACCOUNT,
|
|
UF_TRUSTED_FOR_DELEGATION,
|
|
UF_WORKSTATION_TRUST_ACCOUNT
|
|
)
|
|
from samba.dcerpc.misc import SEC_CHAN_BDC
|
|
from samba import gensec
|
|
from samba import sd_utils
|
|
from samba.common import get_string
|
|
from samba.logger import get_samba_logger
|
|
import bisect
|
|
|
|
CURRENT_MODEL_VERSION = 2 # save as this
|
|
REQUIRED_MODEL_VERSION = 2 # load accepts this or greater
|
|
SLEEP_OVERHEAD = 3e-4
|
|
|
|
# we don't use None, because it complicates [de]serialisation
|
|
NON_PACKET = '-'
|
|
|
|
CLIENT_CLUES = {
|
|
('dns', '0'): 1.0, # query
|
|
('smb', '0x72'): 1.0, # Negotiate protocol
|
|
('ldap', '0'): 1.0, # bind
|
|
('ldap', '3'): 1.0, # searchRequest
|
|
('ldap', '2'): 1.0, # unbindRequest
|
|
('cldap', '3'): 1.0,
|
|
('dcerpc', '11'): 1.0, # bind
|
|
('dcerpc', '14'): 1.0, # Alter_context
|
|
('nbns', '0'): 1.0, # query
|
|
}
|
|
|
|
SERVER_CLUES = {
|
|
('dns', '1'): 1.0, # response
|
|
('ldap', '1'): 1.0, # bind response
|
|
('ldap', '4'): 1.0, # search result
|
|
('ldap', '5'): 1.0, # search done
|
|
('cldap', '5'): 1.0,
|
|
('dcerpc', '12'): 1.0, # bind_ack
|
|
('dcerpc', '13'): 1.0, # bind_nak
|
|
('dcerpc', '15'): 1.0, # Alter_context response
|
|
}
|
|
|
|
SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
|
|
|
|
WAIT_SCALE = 10.0
|
|
WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
|
|
NO_WAIT_LOG_TIME_RANGE = (-10, -3)
|
|
|
|
# DEBUG_LEVEL can be changed by scripts with -d
|
|
DEBUG_LEVEL = 0
|
|
|
|
LOGGER = get_samba_logger(name=__name__)
|
|
|
|
|
|
def debug(level, msg, *args):
|
|
"""Print a formatted debug message to standard error.
|
|
|
|
|
|
:param level: The debug level, message will be printed if it is <= the
|
|
currently set debug level. The debug level can be set with
|
|
the -d option.
|
|
:param msg: The message to be logged, can contain C-Style format
|
|
specifiers
|
|
:param args: The parameters required by the format specifiers
|
|
"""
|
|
if level <= DEBUG_LEVEL:
|
|
if not args:
|
|
print(msg, file=sys.stderr)
|
|
else:
|
|
print(msg % tuple(args), file=sys.stderr)
|
|
|
|
|
|
def debug_lineno(*args):
|
|
""" Print an unformatted log message to stderr, containing the line number
|
|
"""
|
|
tb = traceback.extract_stack(limit=2)
|
|
print((" %s:" "\033[01;33m"
|
|
"%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
|
|
file=sys.stderr)
|
|
for a in args:
|
|
print(a, file=sys.stderr)
|
|
print(file=sys.stderr)
|
|
sys.stderr.flush()
|
|
|
|
|
|
def random_colour_print(seeds):
|
|
"""Return a function that prints a coloured line to stderr. The colour
|
|
of the line depends on a sort of hash of the integer arguments."""
|
|
if seeds:
|
|
s = 214
|
|
for x in seeds:
|
|
s += 17
|
|
s *= x
|
|
s %= 214
|
|
prefix = "\033[38;5;%dm" % (18 + s)
|
|
|
|
def p(*args):
|
|
if DEBUG_LEVEL > 0:
|
|
for a in args:
|
|
print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
|
|
else:
|
|
def p(*args):
|
|
if DEBUG_LEVEL > 0:
|
|
for a in args:
|
|
print(a, file=sys.stderr)
|
|
|
|
return p
|
|
|
|
|
|
class FakePacketError(Exception):
|
|
pass
|
|
|
|
|
|
class Packet(object):
|
|
"""Details of a network packet"""
|
|
__slots__ = ('timestamp',
|
|
'ip_protocol',
|
|
'stream_number',
|
|
'src',
|
|
'dest',
|
|
'protocol',
|
|
'opcode',
|
|
'desc',
|
|
'extra',
|
|
'endpoints')
|
|
def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
|
|
protocol, opcode, desc, extra):
|
|
self.timestamp = timestamp
|
|
self.ip_protocol = ip_protocol
|
|
self.stream_number = stream_number
|
|
self.src = src
|
|
self.dest = dest
|
|
self.protocol = protocol
|
|
self.opcode = opcode
|
|
self.desc = desc
|
|
self.extra = extra
|
|
if self.src < self.dest:
|
|
self.endpoints = (self.src, self.dest)
|
|
else:
|
|
self.endpoints = (self.dest, self.src)
|
|
|
|
@classmethod
|
|
def from_line(cls, line):
|
|
fields = line.rstrip('\n').split('\t')
|
|
(timestamp,
|
|
ip_protocol,
|
|
stream_number,
|
|
src,
|
|
dest,
|
|
protocol,
|
|
opcode,
|
|
desc) = fields[:8]
|
|
extra = fields[8:]
|
|
|
|
timestamp = float(timestamp)
|
|
src = int(src)
|
|
dest = int(dest)
|
|
|
|
return cls(timestamp, ip_protocol, stream_number, src, dest,
|
|
protocol, opcode, desc, extra)
|
|
|
|
def as_summary(self, time_offset=0.0):
|
|
"""Format the packet as a traffic_summary line.
|
|
"""
|
|
extra = '\t'.join(self.extra)
|
|
t = self.timestamp + time_offset
|
|
return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
|
|
(t,
|
|
self.ip_protocol,
|
|
self.stream_number or '',
|
|
self.src,
|
|
self.dest,
|
|
self.protocol,
|
|
self.opcode,
|
|
self.desc,
|
|
extra))
|
|
|
|
def __str__(self):
|
|
return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
|
|
(self.timestamp, self.src, self.dest, self.ip_protocol or '-',
|
|
self.stream_number, self.protocol, self.opcode, self.desc,
|
|
('«' + ' '.join(self.extra) + '»' if self.extra else '')))
|
|
|
|
def __repr__(self):
|
|
return "<Packet @%s>" % self
|
|
|
|
def copy(self):
|
|
return self.__class__(self.timestamp,
|
|
self.ip_protocol,
|
|
self.stream_number,
|
|
self.src,
|
|
self.dest,
|
|
self.protocol,
|
|
self.opcode,
|
|
self.desc,
|
|
self.extra)
|
|
|
|
def as_packet_type(self):
|
|
t = '%s:%s' % (self.protocol, self.opcode)
|
|
return t
|
|
|
|
def client_score(self):
|
|
"""A positive number means we think it is a client; a negative number
|
|
means we think it is a server. Zero means no idea. range: -1 to 1.
|
|
"""
|
|
key = (self.protocol, self.opcode)
|
|
if key in CLIENT_CLUES:
|
|
return CLIENT_CLUES[key]
|
|
if key in SERVER_CLUES:
|
|
return -SERVER_CLUES[key]
|
|
return 0.0
|
|
|
|
def play(self, conversation, context):
|
|
"""Send the packet over the network, if required.
|
|
|
|
Some packets are ignored, i.e. for protocols not handled,
|
|
server response messages, or messages that are generated by the
|
|
protocol layer associated with other packets.
|
|
"""
|
|
fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
|
|
try:
|
|
fn = getattr(traffic_packets, fn_name)
|
|
|
|
except AttributeError as e:
|
|
print("Conversation(%s) Missing handler %s" %
|
|
(conversation.conversation_id, fn_name),
|
|
file=sys.stderr)
|
|
return
|
|
|
|
# Don't display a message for kerberos packets, they're not directly
|
|
# generated they're used to indicate kerberos should be used
|
|
if self.protocol != "kerberos":
|
|
debug(2, "Conversation(%s) Calling handler %s" %
|
|
(conversation.conversation_id, fn_name))
|
|
|
|
start = time.time()
|
|
try:
|
|
if fn(self, conversation, context):
|
|
# Only collect timing data for functions that generate
|
|
# network traffic, or fail
|
|
end = time.time()
|
|
duration = end - start
|
|
print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
|
|
(end, conversation.conversation_id, self.protocol,
|
|
self.opcode, duration))
|
|
except Exception as e:
|
|
end = time.time()
|
|
duration = end - start
|
|
print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
|
|
(end, conversation.conversation_id, self.protocol,
|
|
self.opcode, duration, e))
|
|
|
|
def __cmp__(self, other):
|
|
return self.timestamp - other.timestamp
|
|
|
|
def is_really_a_packet(self, missing_packet_stats=None):
|
|
return is_a_real_packet(self.protocol, self.opcode)
|
|
|
|
|
|
def is_a_real_packet(protocol, opcode):
|
|
"""Is the packet one that can be ignored?
|
|
|
|
If so removing it will have no effect on the replay
|
|
"""
|
|
if protocol in SKIPPED_PROTOCOLS:
|
|
# Ignore any packets for the protocols we're not interested in.
|
|
return False
|
|
if protocol == "ldap" and opcode == '':
|
|
# skip ldap continuation packets
|
|
return False
|
|
|
|
fn_name = 'packet_%s_%s' % (protocol, opcode)
|
|
fn = getattr(traffic_packets, fn_name, None)
|
|
if fn is None:
|
|
LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
|
|
return False
|
|
if fn is traffic_packets.null_packet:
|
|
return False
|
|
return True
|
|
|
|
|
|
def is_a_traffic_generating_packet(protocol, opcode):
|
|
"""Return true if a packet generates traffic in its own right. Some of
|
|
these will generate traffic in certain contexts (e.g. ldap unbind
|
|
after a bind) but not if the conversation consists only of these packets.
|
|
"""
|
|
if protocol == 'wait':
|
|
return False
|
|
|
|
if (protocol, opcode) in (
|
|
('kerberos', ''),
|
|
('ldap', '2'),
|
|
('dcerpc', '15'),
|
|
('dcerpc', '16')):
|
|
return False
|
|
|
|
return is_a_real_packet(protocol, opcode)
|
|
|
|
|
|
class ReplayContext(object):
|
|
"""State/Context for a conversation between an simulated client and a
|
|
server. Some of the context is shared amongst all conversations
|
|
and should be generated before the fork, while other context is
|
|
specific to a particular conversation and should be generated
|
|
*after* the fork, in generate_process_local_config().
|
|
"""
|
|
def __init__(self,
|
|
server=None,
|
|
lp=None,
|
|
creds=None,
|
|
total_conversations=None,
|
|
badpassword_frequency=None,
|
|
prefer_kerberos=None,
|
|
tempdir=None,
|
|
statsdir=None,
|
|
ou=None,
|
|
base_dn=None,
|
|
domain=os.environ.get("DOMAIN"),
|
|
domain_sid=None,
|
|
instance_id=None):
|
|
self.server = server
|
|
self.netlogon_connection = None
|
|
self.creds = creds
|
|
self.lp = lp
|
|
if prefer_kerberos:
|
|
self.kerberos_state = MUST_USE_KERBEROS
|
|
else:
|
|
self.kerberos_state = DONT_USE_KERBEROS
|
|
self.ou = ou
|
|
self.base_dn = base_dn
|
|
self.domain = domain
|
|
self.statsdir = statsdir
|
|
self.global_tempdir = tempdir
|
|
self.domain_sid = domain_sid
|
|
self.realm = lp.get('realm')
|
|
self.instance_id = instance_id
|
|
|
|
# Bad password attempt controls
|
|
self.badpassword_frequency = badpassword_frequency
|
|
self.last_lsarpc_bad = False
|
|
self.last_lsarpc_named_bad = False
|
|
self.last_simple_bind_bad = False
|
|
self.last_bind_bad = False
|
|
self.last_srvsvc_bad = False
|
|
self.last_drsuapi_bad = False
|
|
self.last_netlogon_bad = False
|
|
self.last_samlogon_bad = False
|
|
self.total_conversations = total_conversations
|
|
self.generate_ldap_search_tables()
|
|
|
|
def generate_ldap_search_tables(self):
|
|
session = system_session()
|
|
|
|
db = SamDB(url="ldap://%s" % self.server,
|
|
session_info=session,
|
|
credentials=self.creds,
|
|
lp=self.lp)
|
|
|
|
res = db.search(db.domain_dn(),
|
|
scope=ldb.SCOPE_SUBTREE,
|
|
controls=["paged_results:1:1000"],
|
|
attrs=['dn'])
|
|
|
|
# find a list of dns for each pattern
|
|
# e.g. CN,CN,CN,DC,DC
|
|
dn_map = {}
|
|
attribute_clue_map = {
|
|
'invocationId': []
|
|
}
|
|
|
|
for r in res:
|
|
dn = str(r.dn)
|
|
pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
|
|
dns = dn_map.setdefault(pattern, [])
|
|
dns.append(dn)
|
|
if dn.startswith('CN=NTDS Settings,'):
|
|
attribute_clue_map['invocationId'].append(dn)
|
|
|
|
# extend the map in case we are working with a different
|
|
# number of DC components.
|
|
# for k, v in self.dn_map.items():
|
|
# print >>sys.stderr, k, len(v)
|
|
|
|
for k in list(dn_map.keys()):
|
|
if k[-3:] != ',DC':
|
|
continue
|
|
p = k[:-3]
|
|
while p[-3:] == ',DC':
|
|
p = p[:-3]
|
|
for i in range(5):
|
|
p += ',DC'
|
|
if p != k and p in dn_map:
|
|
print('dn_map collision %s %s' % (k, p),
|
|
file=sys.stderr)
|
|
continue
|
|
dn_map[p] = dn_map[k]
|
|
|
|
self.dn_map = dn_map
|
|
self.attribute_clue_map = attribute_clue_map
|
|
|
|
# pre-populate DN-based search filters (it's simplest to generate them
|
|
# once, when the test starts). These are used by guess_search_filter()
|
|
# to avoid full-scans
|
|
self.search_filters = {}
|
|
|
|
# lookup all the GPO DNs
|
|
res = db.search(db.domain_dn(), scope=ldb.SCOPE_SUBTREE, attrs=['dn'],
|
|
expression='(objectclass=groupPolicyContainer)')
|
|
gpos_by_dn = "".join("(distinguishedName={0})".format(msg['dn']) for msg in res)
|
|
|
|
# a search for the 'gPCFileSysPath' attribute is probably a GPO search
|
|
# (as per the MS-GPOL spec) which searches for GPOs by DN
|
|
self.search_filters['gPCFileSysPath'] = "(|{0})".format(gpos_by_dn)
|
|
|
|
# likewise, a search for gpLink is probably the Domain SOM search part
|
|
# of the MS-GPOL, in which case it's looking up a few OUs by DN
|
|
ou_str = ""
|
|
for ou in ["Domain Controllers,", "traffic_replay,", ""]:
|
|
ou_str += "(distinguishedName={0}{1})".format(ou, db.domain_dn())
|
|
self.search_filters['gpLink'] = "(|{0})".format(ou_str)
|
|
|
|
# The CEP Web Service can query the AD DC to get pKICertificateTemplate
|
|
# objects (as per MS-WCCE)
|
|
self.search_filters['pKIExtendedKeyUsage'] = \
|
|
'(objectCategory=pKICertificateTemplate)'
|
|
|
|
# assume that anything querying the usnChanged is some kind of
|
|
# synchronization tool, e.g. AD Change Detection Connector
|
|
res = db.search('', scope=ldb.SCOPE_BASE, attrs=['highestCommittedUSN'])
|
|
self.search_filters['usnChanged'] = \
|
|
'(usnChanged>={0})'.format(res[0]['highestCommittedUSN'])
|
|
|
|
# The traffic_learner script doesn't preserve the LDAP search filter, and
|
|
# having no filter can result in a full DB scan. This is costly for a large
|
|
# DB, and not necessarily representative of real world traffic. As there
|
|
# several standard LDAP queries that get used by AD tools, we can apply
|
|
# some logic and guess what the search filter might have been originally.
|
|
def guess_search_filter(self, attrs, dn_sig, dn):
|
|
|
|
# there are some standard spec-based searches that query fairly unique
|
|
# attributes. Check if the search is likely one of these
|
|
for key in self.search_filters.keys():
|
|
if key in attrs:
|
|
return self.search_filters[key]
|
|
|
|
# if it's the top-level domain, assume we're looking up a single user,
|
|
# e.g. like powershell Get-ADUser or a similar tool
|
|
if dn_sig == 'DC,DC':
|
|
random_user_id = random.random() % self.total_conversations
|
|
account_name = user_name(self.instance_id, random_user_id)
|
|
return '(&(sAMAccountName=%s)(objectClass=user))' % account_name
|
|
|
|
# otherwise just return everything in the sub-tree
|
|
return '(objectClass=*)'
|
|
|
|
def generate_process_local_config(self, account, conversation):
|
|
self.ldap_connections = []
|
|
self.dcerpc_connections = []
|
|
self.lsarpc_connections = []
|
|
self.lsarpc_connections_named = []
|
|
self.drsuapi_connections = []
|
|
self.srvsvc_connections = []
|
|
self.samr_contexts = []
|
|
self.netbios_name = account.netbios_name
|
|
self.machinepass = account.machinepass
|
|
self.username = account.username
|
|
self.userpass = account.userpass
|
|
|
|
self.tempdir = mk_masked_dir(self.global_tempdir,
|
|
'conversation-%d' %
|
|
conversation.conversation_id)
|
|
|
|
self.lp.set("private dir", self.tempdir)
|
|
self.lp.set("lock dir", self.tempdir)
|
|
self.lp.set("state directory", self.tempdir)
|
|
self.lp.set("tls verify peer", "no_check")
|
|
|
|
self.remoteAddress = "/root/ncalrpc_as_system"
|
|
self.samlogon_dn = ("cn=%s,%s" %
|
|
(self.netbios_name, self.ou))
|
|
self.user_dn = ("cn=%s,%s" %
|
|
(self.username, self.ou))
|
|
|
|
self.generate_machine_creds()
|
|
self.generate_user_creds()
|
|
|
|
def with_random_bad_credentials(self, f, good, bad, failed_last_time):
|
|
"""Execute the supplied logon function, randomly choosing the
|
|
bad credentials.
|
|
|
|
Based on the frequency in badpassword_frequency randomly perform the
|
|
function with the supplied bad credentials.
|
|
If run with bad credentials, the function is re-run with the good
|
|
credentials.
|
|
failed_last_time is used to prevent consecutive bad credential
|
|
attempts. So the over all bad credential frequency will be lower
|
|
than that requested, but not significantly.
|
|
"""
|
|
if not failed_last_time:
|
|
if (self.badpassword_frequency and
|
|
random.random() < self.badpassword_frequency):
|
|
try:
|
|
f(bad)
|
|
except Exception:
|
|
# Ignore any exceptions as the operation may fail
|
|
# as it's being performed with bad credentials
|
|
pass
|
|
failed_last_time = True
|
|
else:
|
|
failed_last_time = False
|
|
|
|
result = f(good)
|
|
return (result, failed_last_time)
|
|
|
|
def generate_user_creds(self):
|
|
"""Generate the conversation specific user Credentials.
|
|
|
|
Each Conversation has an associated user account used to simulate
|
|
any non Administrative user traffic.
|
|
|
|
Generates user credentials with good and bad passwords and ldap
|
|
simple bind credentials with good and bad passwords.
|
|
"""
|
|
self.user_creds = Credentials()
|
|
self.user_creds.guess(self.lp)
|
|
self.user_creds.set_workstation(self.netbios_name)
|
|
self.user_creds.set_password(self.userpass)
|
|
self.user_creds.set_username(self.username)
|
|
self.user_creds.set_domain(self.domain)
|
|
self.user_creds.set_kerberos_state(self.kerberos_state)
|
|
|
|
self.user_creds_bad = Credentials()
|
|
self.user_creds_bad.guess(self.lp)
|
|
self.user_creds_bad.set_workstation(self.netbios_name)
|
|
self.user_creds_bad.set_password(self.userpass[:-4])
|
|
self.user_creds_bad.set_username(self.username)
|
|
self.user_creds_bad.set_kerberos_state(self.kerberos_state)
|
|
|
|
# Credentials for ldap simple bind.
|
|
self.simple_bind_creds = Credentials()
|
|
self.simple_bind_creds.guess(self.lp)
|
|
self.simple_bind_creds.set_workstation(self.netbios_name)
|
|
self.simple_bind_creds.set_password(self.userpass)
|
|
self.simple_bind_creds.set_username(self.username)
|
|
self.simple_bind_creds.set_gensec_features(
|
|
self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
|
|
self.simple_bind_creds.set_kerberos_state(self.kerberos_state)
|
|
self.simple_bind_creds.set_bind_dn(self.user_dn)
|
|
|
|
self.simple_bind_creds_bad = Credentials()
|
|
self.simple_bind_creds_bad.guess(self.lp)
|
|
self.simple_bind_creds_bad.set_workstation(self.netbios_name)
|
|
self.simple_bind_creds_bad.set_password(self.userpass[:-4])
|
|
self.simple_bind_creds_bad.set_username(self.username)
|
|
self.simple_bind_creds_bad.set_gensec_features(
|
|
self.simple_bind_creds_bad.get_gensec_features() |
|
|
gensec.FEATURE_SEAL)
|
|
self.simple_bind_creds_bad.set_kerberos_state(self.kerberos_state)
|
|
self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
|
|
|
|
def generate_machine_creds(self):
|
|
"""Generate the conversation specific machine Credentials.
|
|
|
|
Each Conversation has an associated machine account.
|
|
|
|
Generates machine credentials with good and bad passwords.
|
|
"""
|
|
|
|
self.machine_creds = Credentials()
|
|
self.machine_creds.guess(self.lp)
|
|
self.machine_creds.set_workstation(self.netbios_name)
|
|
self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
|
|
self.machine_creds.set_password(self.machinepass)
|
|
self.machine_creds.set_username(self.netbios_name + "$")
|
|
self.machine_creds.set_domain(self.domain)
|
|
self.machine_creds.set_kerberos_state(self.kerberos_state)
|
|
|
|
self.machine_creds_bad = Credentials()
|
|
self.machine_creds_bad.guess(self.lp)
|
|
self.machine_creds_bad.set_workstation(self.netbios_name)
|
|
self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
|
|
self.machine_creds_bad.set_password(self.machinepass[:-4])
|
|
self.machine_creds_bad.set_username(self.netbios_name + "$")
|
|
self.machine_creds_bad.set_kerberos_state(self.kerberos_state)
|
|
|
|
def get_matching_dn(self, pattern, attributes=None):
|
|
# If the pattern is an empty string, we assume ROOTDSE,
|
|
# Otherwise we try adding or removing DC suffixes, then
|
|
# shorter leading patterns until we hit one.
|
|
# e.g if there is no CN,CN,CN,CN,DC,DC
|
|
# we first try CN,CN,CN,CN,DC
|
|
# and CN,CN,CN,CN,DC,DC,DC
|
|
# then change to CN,CN,CN,DC,DC
|
|
# and as last resort we use the base_dn
|
|
attr_clue = self.attribute_clue_map.get(attributes)
|
|
if attr_clue:
|
|
return random.choice(attr_clue)
|
|
|
|
pattern = pattern.upper()
|
|
while pattern:
|
|
if pattern in self.dn_map:
|
|
return random.choice(self.dn_map[pattern])
|
|
# chop one off the front and try it all again.
|
|
pattern = pattern[3:]
|
|
|
|
return self.base_dn
|
|
|
|
def get_dcerpc_connection(self, new=False):
|
|
guid = '12345678-1234-abcd-ef00-01234567cffb' # RPC_NETLOGON UUID
|
|
if self.dcerpc_connections and not new:
|
|
return self.dcerpc_connections[-1]
|
|
c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
|
|
(guid, 1), self.lp)
|
|
self.dcerpc_connections.append(c)
|
|
return c
|
|
|
|
def get_srvsvc_connection(self, new=False):
|
|
if self.srvsvc_connections and not new:
|
|
return self.srvsvc_connections[-1]
|
|
|
|
def connect(creds):
|
|
return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
|
|
self.lp,
|
|
creds)
|
|
|
|
(c, self.last_srvsvc_bad) = \
|
|
self.with_random_bad_credentials(connect,
|
|
self.user_creds,
|
|
self.user_creds_bad,
|
|
self.last_srvsvc_bad)
|
|
|
|
self.srvsvc_connections.append(c)
|
|
return c
|
|
|
|
def get_lsarpc_connection(self, new=False):
|
|
if self.lsarpc_connections and not new:
|
|
return self.lsarpc_connections[-1]
|
|
|
|
def connect(creds):
|
|
binding_options = 'schannel,seal,sign'
|
|
return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
|
|
(self.server, binding_options),
|
|
self.lp,
|
|
creds)
|
|
|
|
(c, self.last_lsarpc_bad) = \
|
|
self.with_random_bad_credentials(connect,
|
|
self.machine_creds,
|
|
self.machine_creds_bad,
|
|
self.last_lsarpc_bad)
|
|
|
|
self.lsarpc_connections.append(c)
|
|
return c
|
|
|
|
def get_lsarpc_named_pipe_connection(self, new=False):
|
|
if self.lsarpc_connections_named and not new:
|
|
return self.lsarpc_connections_named[-1]
|
|
|
|
def connect(creds):
|
|
return lsa.lsarpc("ncacn_np:%s" % (self.server),
|
|
self.lp,
|
|
creds)
|
|
|
|
(c, self.last_lsarpc_named_bad) = \
|
|
self.with_random_bad_credentials(connect,
|
|
self.machine_creds,
|
|
self.machine_creds_bad,
|
|
self.last_lsarpc_named_bad)
|
|
|
|
self.lsarpc_connections_named.append(c)
|
|
return c
|
|
|
|
def get_drsuapi_connection_pair(self, new=False, unbind=False):
|
|
"""get a (drs, drs_handle) tuple"""
|
|
if self.drsuapi_connections and not new:
|
|
c = self.drsuapi_connections[-1]
|
|
return c
|
|
|
|
def connect(creds):
|
|
binding_options = 'seal'
|
|
binding_string = "ncacn_ip_tcp:%s[%s]" %\
|
|
(self.server, binding_options)
|
|
return drsuapi.drsuapi(binding_string, self.lp, creds)
|
|
|
|
(drs, self.last_drsuapi_bad) = \
|
|
self.with_random_bad_credentials(connect,
|
|
self.user_creds,
|
|
self.user_creds_bad,
|
|
self.last_drsuapi_bad)
|
|
|
|
(drs_handle, supported_extensions) = drs_DsBind(drs)
|
|
c = (drs, drs_handle)
|
|
self.drsuapi_connections.append(c)
|
|
return c
|
|
|
|
def get_ldap_connection(self, new=False, simple=False):
|
|
if self.ldap_connections and not new:
|
|
return self.ldap_connections[-1]
|
|
|
|
def simple_bind(creds):
|
|
"""
|
|
To run simple bind against Windows, we need to run
|
|
following commands in PowerShell:
|
|
|
|
Install-windowsfeature ADCS-Cert-Authority
|
|
Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
|
|
Restart-Computer
|
|
|
|
"""
|
|
return SamDB('ldaps://%s' % self.server,
|
|
credentials=creds,
|
|
lp=self.lp)
|
|
|
|
def sasl_bind(creds):
|
|
return SamDB('ldap://%s' % self.server,
|
|
credentials=creds,
|
|
lp=self.lp)
|
|
if simple:
|
|
(samdb, self.last_simple_bind_bad) = \
|
|
self.with_random_bad_credentials(simple_bind,
|
|
self.simple_bind_creds,
|
|
self.simple_bind_creds_bad,
|
|
self.last_simple_bind_bad)
|
|
else:
|
|
(samdb, self.last_bind_bad) = \
|
|
self.with_random_bad_credentials(sasl_bind,
|
|
self.user_creds,
|
|
self.user_creds_bad,
|
|
self.last_bind_bad)
|
|
|
|
self.ldap_connections.append(samdb)
|
|
return samdb
|
|
|
|
def get_samr_context(self, new=False):
|
|
if not self.samr_contexts or new:
|
|
self.samr_contexts.append(
|
|
SamrContext(self.server, lp=self.lp, creds=self.creds))
|
|
return self.samr_contexts[-1]
|
|
|
|
def get_netlogon_connection(self):
|
|
|
|
if self.netlogon_connection:
|
|
return self.netlogon_connection
|
|
|
|
def connect(creds):
|
|
return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
|
|
(self.server),
|
|
self.lp,
|
|
creds)
|
|
(c, self.last_netlogon_bad) = \
|
|
self.with_random_bad_credentials(connect,
|
|
self.machine_creds,
|
|
self.machine_creds_bad,
|
|
self.last_netlogon_bad)
|
|
self.netlogon_connection = c
|
|
return c
|
|
|
|
def guess_a_dns_lookup(self):
|
|
return (self.realm, 'A')
|
|
|
|
def get_authenticator(self):
|
|
auth = self.machine_creds.new_client_authenticator()
|
|
current = netr_Authenticator()
|
|
current.cred.data = list(auth["credential"])
|
|
current.timestamp = auth["timestamp"]
|
|
|
|
subsequent = netr_Authenticator()
|
|
return (current, subsequent)
|
|
|
|
def write_stats(self, filename, **kwargs):
|
|
"""Write arbitrary key/value pairs to a file in our stats directory in
|
|
order for them to be picked up later by another process working out
|
|
statistics."""
|
|
filename = os.path.join(self.statsdir, filename)
|
|
f = open(filename, 'w')
|
|
for k, v in kwargs.items():
|
|
print("%s: %s" % (k, v), file=f)
|
|
f.close()
|
|
|
|
|
|
class SamrContext(object):
|
|
"""State/Context associated with a samr connection.
|
|
"""
|
|
def __init__(self, server, lp=None, creds=None):
|
|
self.connection = None
|
|
self.handle = None
|
|
self.domain_handle = None
|
|
self.domain_sid = None
|
|
self.group_handle = None
|
|
self.user_handle = None
|
|
self.rids = None
|
|
self.server = server
|
|
self.lp = lp
|
|
self.creds = creds
|
|
|
|
def get_connection(self):
|
|
if not self.connection:
|
|
self.connection = samr.samr(
|
|
"ncacn_ip_tcp:%s[seal]" % (self.server),
|
|
lp_ctx=self.lp,
|
|
credentials=self.creds)
|
|
|
|
return self.connection
|
|
|
|
def get_handle(self):
|
|
if not self.handle:
|
|
c = self.get_connection()
|
|
self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
|
|
return self.handle
|
|
|
|
|
|
class Conversation(object):
|
|
"""Details of a converation between a simulated client and a server."""
|
|
def __init__(self, start_time=None, endpoints=None, seq=(),
|
|
conversation_id=None):
|
|
self.start_time = start_time
|
|
self.endpoints = endpoints
|
|
self.packets = []
|
|
self.msg = random_colour_print(endpoints)
|
|
self.client_balance = 0.0
|
|
self.conversation_id = conversation_id
|
|
for p in seq:
|
|
self.add_short_packet(*p)
|
|
|
|
def __cmp__(self, other):
|
|
if self.start_time is None:
|
|
if other.start_time is None:
|
|
return 0
|
|
return -1
|
|
if other.start_time is None:
|
|
return 1
|
|
return self.start_time - other.start_time
|
|
|
|
def add_packet(self, packet):
|
|
"""Add a packet object to this conversation, making a local copy with
|
|
a conversation-relative timestamp."""
|
|
p = packet.copy()
|
|
|
|
if self.start_time is None:
|
|
self.start_time = p.timestamp
|
|
|
|
if self.endpoints is None:
|
|
self.endpoints = p.endpoints
|
|
|
|
if p.endpoints != self.endpoints:
|
|
raise FakePacketError("Conversation endpoints %s don't match"
|
|
"packet endpoints %s" %
|
|
(self.endpoints, p.endpoints))
|
|
|
|
p.timestamp -= self.start_time
|
|
|
|
if p.src == p.endpoints[0]:
|
|
self.client_balance -= p.client_score()
|
|
else:
|
|
self.client_balance += p.client_score()
|
|
|
|
if p.is_really_a_packet():
|
|
self.packets.append(p)
|
|
|
|
def add_short_packet(self, timestamp, protocol, opcode, extra,
|
|
client=True, skip_unused_packets=True):
|
|
"""Create a packet from a timestamp, and 'protocol:opcode' pair, and a
|
|
(possibly empty) list of extra data. If client is True, assume
|
|
this packet is from the client to the server.
|
|
"""
|
|
if skip_unused_packets and not is_a_real_packet(protocol, opcode):
|
|
return
|
|
|
|
src, dest = self.guess_client_server()
|
|
if not client:
|
|
src, dest = dest, src
|
|
key = (protocol, opcode)
|
|
desc = OP_DESCRIPTIONS.get(key, '')
|
|
ip_protocol = IP_PROTOCOLS.get(protocol, '06')
|
|
packet = Packet(timestamp - self.start_time, ip_protocol,
|
|
'', src, dest,
|
|
protocol, opcode, desc, extra)
|
|
# XXX we're assuming the timestamp is already adjusted for
|
|
# this conversation?
|
|
# XXX should we adjust client balance for guessed packets?
|
|
if packet.src == packet.endpoints[0]:
|
|
self.client_balance -= packet.client_score()
|
|
else:
|
|
self.client_balance += packet.client_score()
|
|
if packet.is_really_a_packet():
|
|
self.packets.append(packet)
|
|
|
|
def __str__(self):
|
|
return ("<Conversation %s %s starting %.3f %d packets>" %
|
|
(self.conversation_id, self.endpoints, self.start_time,
|
|
len(self.packets)))
|
|
|
|
__repr__ = __str__
|
|
|
|
def __iter__(self):
|
|
return iter(self.packets)
|
|
|
|
def __len__(self):
|
|
return len(self.packets)
|
|
|
|
def get_duration(self):
|
|
if len(self.packets) < 2:
|
|
return 0
|
|
return self.packets[-1].timestamp - self.packets[0].timestamp
|
|
|
|
def replay_as_summary_lines(self):
|
|
return [p.as_summary(self.start_time) for p in self.packets]
|
|
|
|
def replay_with_delay(self, start, context=None, account=None):
|
|
"""Replay the conversation at the right time.
|
|
(We're already in a fork)."""
|
|
# first we sleep until the first packet
|
|
t = self.start_time
|
|
now = time.time() - start
|
|
gap = t - now
|
|
sleep_time = gap - SLEEP_OVERHEAD
|
|
if sleep_time > 0:
|
|
time.sleep(sleep_time)
|
|
|
|
miss = (time.time() - start) - t
|
|
self.msg("starting %s [miss %.3f]" % (self, miss))
|
|
|
|
max_gap = 0.0
|
|
max_sleep_miss = 0.0
|
|
# packet times are relative to conversation start
|
|
p_start = time.time()
|
|
for p in self.packets:
|
|
now = time.time() - p_start
|
|
gap = now - p.timestamp
|
|
if gap > max_gap:
|
|
max_gap = gap
|
|
if gap < 0:
|
|
sleep_time = -gap - SLEEP_OVERHEAD
|
|
if sleep_time > 0:
|
|
time.sleep(sleep_time)
|
|
t = time.time() - p_start
|
|
if t - p.timestamp > max_sleep_miss:
|
|
max_sleep_miss = t - p.timestamp
|
|
|
|
p.play(self, context)
|
|
|
|
return max_gap, miss, max_sleep_miss
|
|
|
|
def guess_client_server(self, server_clue=None):
|
|
"""Have a go at deciding who is the server and who is the client.
|
|
returns (client, server)
|
|
"""
|
|
a, b = self.endpoints
|
|
|
|
if self.client_balance < 0:
|
|
return (a, b)
|
|
|
|
# in the absence of a clue, we will fall through to assuming
|
|
# the lowest number is the server (which is usually true).
|
|
|
|
if self.client_balance == 0 and server_clue == b:
|
|
return (a, b)
|
|
|
|
return (b, a)
|
|
|
|
def forget_packets_outside_window(self, s, e):
|
|
"""Prune any packets outside the time window we're interested in
|
|
|
|
:param s: start of the window
|
|
:param e: end of the window
|
|
"""
|
|
self.packets = [p for p in self.packets if s <= p.timestamp <= e]
|
|
self.start_time = self.packets[0].timestamp if self.packets else None
|
|
|
|
def renormalise_times(self, start_time):
|
|
"""Adjust the packet start times relative to the new start time."""
|
|
for p in self.packets:
|
|
p.timestamp -= start_time
|
|
|
|
if self.start_time is not None:
|
|
self.start_time -= start_time
|
|
|
|
|
|
class DnsHammer(Conversation):
|
|
"""A lightweight conversation that generates a lot of dns:0 packets on
|
|
the fly"""
|
|
|
|
def __init__(self, dns_rate, duration, query_file=None):
|
|
n = int(dns_rate * duration)
|
|
self.times = [random.uniform(0, duration) for i in range(n)]
|
|
self.times.sort()
|
|
self.rate = dns_rate
|
|
self.duration = duration
|
|
self.start_time = 0
|
|
self.query_choices = self._get_query_choices(query_file=query_file)
|
|
|
|
def __str__(self):
|
|
return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
|
|
(len(self.times), self.duration, self.rate))
|
|
|
|
def _get_query_choices(self, query_file=None):
|
|
"""
|
|
Read dns query choices from a file, or return default
|
|
|
|
rname may contain format string like `{realm}`
|
|
realm can be fetched from context.realm
|
|
"""
|
|
|
|
if query_file:
|
|
with open(query_file, 'r') as f:
|
|
text = f.read()
|
|
choices = []
|
|
for line in text.splitlines():
|
|
line = line.strip()
|
|
if line and not line.startswith('#'):
|
|
args = line.split(',')
|
|
assert len(args) == 4
|
|
choices.append(args)
|
|
return choices
|
|
else:
|
|
return [
|
|
(0, '{realm}', 'A', 'yes'),
|
|
(1, '{realm}', 'NS', 'yes'),
|
|
(2, '*.{realm}', 'A', 'no'),
|
|
(3, '*.{realm}', 'NS', 'no'),
|
|
(10, '_msdcs.{realm}', 'A', 'yes'),
|
|
(11, '_msdcs.{realm}', 'NS', 'yes'),
|
|
(20, 'nx.realm.com', 'A', 'no'),
|
|
(21, 'nx.realm.com', 'NS', 'no'),
|
|
(22, '*.nx.realm.com', 'A', 'no'),
|
|
(23, '*.nx.realm.com', 'NS', 'no'),
|
|
]
|
|
|
|
def replay(self, context=None):
|
|
assert context
|
|
assert context.realm
|
|
start = time.time()
|
|
for t in self.times:
|
|
now = time.time() - start
|
|
gap = t - now
|
|
sleep_time = gap - SLEEP_OVERHEAD
|
|
if sleep_time > 0:
|
|
time.sleep(sleep_time)
|
|
|
|
opcode, rname, rtype, exist = random.choice(self.query_choices)
|
|
rname = rname.format(realm=context.realm)
|
|
success = True
|
|
packet_start = time.time()
|
|
try:
|
|
answers = dns_query(rname, rtype)
|
|
if exist == 'yes' and not len(answers):
|
|
# expect answers but didn't get, fail
|
|
success = False
|
|
except Exception:
|
|
success = False
|
|
finally:
|
|
end = time.time()
|
|
duration = end - packet_start
|
|
print("%f\tDNS\tdns\t%s\t%f\t%s\t" % (end, opcode, duration, success))
|
|
|
|
|
|
def ingest_summaries(files, dns_mode='count'):
|
|
"""Load a summary traffic summary file and generated Converations from it.
|
|
"""
|
|
|
|
dns_counts = defaultdict(int)
|
|
packets = []
|
|
for f in files:
|
|
if isinstance(f, str):
|
|
f = open(f)
|
|
print("Ingesting %s" % (f.name,), file=sys.stderr)
|
|
for line in f:
|
|
p = Packet.from_line(line)
|
|
if p.protocol == 'dns' and dns_mode != 'include':
|
|
dns_counts[p.opcode] += 1
|
|
else:
|
|
packets.append(p)
|
|
|
|
f.close()
|
|
|
|
if not packets:
|
|
return [], 0
|
|
|
|
start_time = min(p.timestamp for p in packets)
|
|
last_packet = max(p.timestamp for p in packets)
|
|
|
|
print("gathering packets into conversations", file=sys.stderr)
|
|
conversations = OrderedDict()
|
|
for i, p in enumerate(packets):
|
|
p.timestamp -= start_time
|
|
c = conversations.get(p.endpoints)
|
|
if c is None:
|
|
c = Conversation(conversation_id=(i + 2))
|
|
conversations[p.endpoints] = c
|
|
c.add_packet(p)
|
|
|
|
# We only care about conversations with actual traffic, so we
|
|
# filter out conversations with nothing to say. We do that here,
|
|
# rather than earlier, because those empty packets contain useful
|
|
# hints as to which end of the conversation was the client.
|
|
conversation_list = []
|
|
for c in conversations.values():
|
|
if len(c) != 0:
|
|
conversation_list.append(c)
|
|
|
|
# This is obviously not correct, as many conversations will appear
|
|
# to start roughly simultaneously at the beginning of the snapshot.
|
|
# To which we say: oh well, so be it.
|
|
duration = float(last_packet - start_time)
|
|
mean_interval = len(conversations) / duration
|
|
|
|
return conversation_list, mean_interval, duration, dns_counts
|
|
|
|
|
|
def guess_server_address(conversations):
|
|
# we guess the most common address.
|
|
addresses = Counter()
|
|
for c in conversations:
|
|
addresses.update(c.endpoints)
|
|
if addresses:
|
|
return addresses.most_common(1)[0]
|
|
|
|
|
|
def stringify_keys(x):
|
|
y = {}
|
|
for k, v in x.items():
|
|
k2 = '\t'.join(k)
|
|
y[k2] = v
|
|
return y
|
|
|
|
|
|
def unstringify_keys(x):
|
|
y = {}
|
|
for k, v in x.items():
|
|
t = tuple(str(k).split('\t'))
|
|
y[t] = v
|
|
return y
|
|
|
|
|
|
class TrafficModel(object):
|
|
def __init__(self, n=3):
|
|
self.ngrams = {}
|
|
self.query_details = {}
|
|
self.n = n
|
|
self.dns_opcounts = defaultdict(int)
|
|
self.cumulative_duration = 0.0
|
|
self.packet_rate = [0, 1]
|
|
|
|
def learn(self, conversations, dns_opcounts=None):
|
|
if dns_opcounts is None:
|
|
dns_opcounts = {}
|
|
prev = 0.0
|
|
cum_duration = 0.0
|
|
key = (NON_PACKET,) * (self.n - 1)
|
|
|
|
server = guess_server_address(conversations)
|
|
|
|
for k, v in dns_opcounts.items():
|
|
self.dns_opcounts[k] += v
|
|
|
|
if len(conversations) > 1:
|
|
first = conversations[0].start_time
|
|
total = 0
|
|
last = first + 0.1
|
|
for c in conversations:
|
|
total += len(c)
|
|
last = max(last, c.packets[-1].timestamp)
|
|
|
|
self.packet_rate[0] = total
|
|
self.packet_rate[1] = last - first
|
|
|
|
for c in conversations:
|
|
client, server = c.guess_client_server(server)
|
|
cum_duration += c.get_duration()
|
|
key = (NON_PACKET,) * (self.n - 1)
|
|
for p in c:
|
|
if p.src != client:
|
|
continue
|
|
|
|
elapsed = p.timestamp - prev
|
|
prev = p.timestamp
|
|
if elapsed > WAIT_THRESHOLD:
|
|
# add the wait as an extra state
|
|
wait = 'wait:%d' % (math.log(max(1.0,
|
|
elapsed * WAIT_SCALE)))
|
|
self.ngrams.setdefault(key, []).append(wait)
|
|
key = key[1:] + (wait,)
|
|
|
|
short_p = p.as_packet_type()
|
|
self.query_details.setdefault(short_p,
|
|
[]).append(tuple(p.extra))
|
|
self.ngrams.setdefault(key, []).append(short_p)
|
|
key = key[1:] + (short_p,)
|
|
|
|
self.cumulative_duration += cum_duration
|
|
# add in the end
|
|
self.ngrams.setdefault(key, []).append(NON_PACKET)
|
|
|
|
def save(self, f):
|
|
ngrams = {}
|
|
for k, v in self.ngrams.items():
|
|
k = '\t'.join(k)
|
|
ngrams[k] = dict(Counter(v))
|
|
|
|
query_details = {}
|
|
for k, v in self.query_details.items():
|
|
query_details[k] = dict(Counter('\t'.join(x) if x else '-'
|
|
for x in v))
|
|
|
|
d = {
|
|
'ngrams': ngrams,
|
|
'query_details': query_details,
|
|
'cumulative_duration': self.cumulative_duration,
|
|
'packet_rate': self.packet_rate,
|
|
'version': CURRENT_MODEL_VERSION
|
|
}
|
|
d['dns'] = self.dns_opcounts
|
|
|
|
if isinstance(f, str):
|
|
f = open(f, 'w')
|
|
|
|
json.dump(d, f, indent=2)
|
|
|
|
def load(self, f):
|
|
if isinstance(f, str):
|
|
f = open(f)
|
|
|
|
d = json.load(f)
|
|
|
|
try:
|
|
version = d["version"]
|
|
if version < REQUIRED_MODEL_VERSION:
|
|
raise ValueError("the model file is version %d; "
|
|
"version %d is required" %
|
|
(version, REQUIRED_MODEL_VERSION))
|
|
except KeyError:
|
|
raise ValueError("the model file lacks a version number; "
|
|
"version %d is required" %
|
|
(REQUIRED_MODEL_VERSION))
|
|
|
|
for k, v in d['ngrams'].items():
|
|
k = tuple(str(k).split('\t'))
|
|
values = self.ngrams.setdefault(k, [])
|
|
for p, count in v.items():
|
|
values.extend([str(p)] * count)
|
|
values.sort()
|
|
|
|
for k, v in d['query_details'].items():
|
|
values = self.query_details.setdefault(str(k), [])
|
|
for p, count in v.items():
|
|
if p == '-':
|
|
values.extend([()] * count)
|
|
else:
|
|
values.extend([tuple(str(p).split('\t'))] * count)
|
|
values.sort()
|
|
|
|
if 'dns' in d:
|
|
for k, v in d['dns'].items():
|
|
self.dns_opcounts[k] += v
|
|
|
|
self.cumulative_duration = d['cumulative_duration']
|
|
self.packet_rate = d['packet_rate']
|
|
|
|
def construct_conversation_sequence(self, timestamp=0.0,
|
|
hard_stop=None,
|
|
replay_speed=1,
|
|
ignore_before=0,
|
|
persistence=0):
|
|
"""Construct an individual conversation packet sequence from the
|
|
model.
|
|
"""
|
|
c = []
|
|
key = (NON_PACKET,) * (self.n - 1)
|
|
if ignore_before is None:
|
|
ignore_before = timestamp - 1
|
|
|
|
while True:
|
|
p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
|
|
if p == NON_PACKET:
|
|
if timestamp < ignore_before:
|
|
break
|
|
if random.random() > persistence:
|
|
print("ending after %s (persistence %.1f)" % (key, persistence),
|
|
file=sys.stderr)
|
|
break
|
|
|
|
p = 'wait:%d' % random.randrange(5, 12)
|
|
print("trying %s instead of end" % p, file=sys.stderr)
|
|
|
|
if p in self.query_details:
|
|
extra = random.choice(self.query_details[p])
|
|
else:
|
|
extra = []
|
|
|
|
protocol, opcode = p.split(':', 1)
|
|
if protocol == 'wait':
|
|
log_wait_time = int(opcode) + random.random()
|
|
wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
|
|
timestamp += wait
|
|
else:
|
|
log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
|
|
wait = math.exp(log_wait) / replay_speed
|
|
timestamp += wait
|
|
if hard_stop is not None and timestamp > hard_stop:
|
|
break
|
|
if timestamp >= ignore_before:
|
|
c.append((timestamp, protocol, opcode, extra))
|
|
|
|
key = key[1:] + (p,)
|
|
if key[-2][:5] == 'wait:' and key[-1][:5] == 'wait:':
|
|
# two waits in a row can only be caused by "persistence"
|
|
# tricks, and will not result in any packets being found.
|
|
# Instead we pretend this is a fresh start.
|
|
key = (NON_PACKET,) * (self.n - 1)
|
|
|
|
return c
|
|
|
|
def scale_to_packet_rate(self, scale):
|
|
rate_n, rate_t = self.packet_rate
|
|
return scale * rate_n / rate_t
|
|
|
|
def packet_rate_to_scale(self, pps):
|
|
rate_n, rate_t = self.packet_rate
|
|
return pps * rate_t / rate_n
|
|
|
|
def generate_conversation_sequences(self, packet_rate, duration, replay_speed=1,
|
|
persistence=0):
|
|
"""Generate a list of conversation descriptions from the model."""
|
|
|
|
# We run the simulation for ten times as long as our desired
|
|
# duration, and take the section at the end.
|
|
lead_in = 9 * duration
|
|
target_packets = int(packet_rate * duration)
|
|
conversations = []
|
|
n_packets = 0
|
|
|
|
while n_packets < target_packets:
|
|
start = random.uniform(-lead_in, duration)
|
|
c = self.construct_conversation_sequence(start,
|
|
hard_stop=duration,
|
|
replay_speed=replay_speed,
|
|
ignore_before=0,
|
|
persistence=persistence)
|
|
# will these "packets" generate actual traffic?
|
|
# some (e.g. ldap unbind) will not generate anything
|
|
# if the previous packets are not there, and if the
|
|
# conversation only has those it wastes a process doing nothing.
|
|
for timestamp, protocol, opcode, extra in c:
|
|
if is_a_traffic_generating_packet(protocol, opcode):
|
|
break
|
|
else:
|
|
continue
|
|
|
|
conversations.append(c)
|
|
n_packets += len(c)
|
|
|
|
scale = self.packet_rate_to_scale(packet_rate)
|
|
print(("we have %d packets (target %d) in %d conversations at %.1f/s "
|
|
"(scale %f)" % (n_packets, target_packets, len(conversations),
|
|
packet_rate, scale)),
|
|
file=sys.stderr)
|
|
conversations.sort() # sorts by first element == start time
|
|
return conversations
|
|
|
|
|
|
def seq_to_conversations(seq, server=1, client=2):
|
|
conversations = []
|
|
for s in seq:
|
|
if s:
|
|
c = Conversation(s[0][0], (server, client), s)
|
|
client += 1
|
|
conversations.append(c)
|
|
return conversations
|
|
|
|
|
|
IP_PROTOCOLS = {
|
|
'dns': '11',
|
|
'rpc_netlogon': '06',
|
|
'kerberos': '06', # ratio 16248:258
|
|
'smb': '06',
|
|
'smb2': '06',
|
|
'ldap': '06',
|
|
'cldap': '11',
|
|
'lsarpc': '06',
|
|
'samr': '06',
|
|
'dcerpc': '06',
|
|
'epm': '06',
|
|
'drsuapi': '06',
|
|
'browser': '11',
|
|
'smb_netlogon': '11',
|
|
'srvsvc': '06',
|
|
'nbns': '11',
|
|
}
|
|
|
|
OP_DESCRIPTIONS = {
|
|
('browser', '0x01'): 'Host Announcement (0x01)',
|
|
('browser', '0x02'): 'Request Announcement (0x02)',
|
|
('browser', '0x08'): 'Browser Election Request (0x08)',
|
|
('browser', '0x09'): 'Get Backup List Request (0x09)',
|
|
('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
|
|
('browser', '0x0f'): 'Local Master Announcement (0x0f)',
|
|
('cldap', '3'): 'searchRequest',
|
|
('cldap', '5'): 'searchResDone',
|
|
('dcerpc', '0'): 'Request',
|
|
('dcerpc', '11'): 'Bind',
|
|
('dcerpc', '12'): 'Bind_ack',
|
|
('dcerpc', '13'): 'Bind_nak',
|
|
('dcerpc', '14'): 'Alter_context',
|
|
('dcerpc', '15'): 'Alter_context_resp',
|
|
('dcerpc', '16'): 'AUTH3',
|
|
('dcerpc', '2'): 'Response',
|
|
('dns', '0'): 'query',
|
|
('dns', '1'): 'response',
|
|
('drsuapi', '0'): 'DsBind',
|
|
('drsuapi', '12'): 'DsCrackNames',
|
|
('drsuapi', '13'): 'DsWriteAccountSpn',
|
|
('drsuapi', '1'): 'DsUnbind',
|
|
('drsuapi', '2'): 'DsReplicaSync',
|
|
('drsuapi', '3'): 'DsGetNCChanges',
|
|
('drsuapi', '4'): 'DsReplicaUpdateRefs',
|
|
('epm', '3'): 'Map',
|
|
('kerberos', ''): '',
|
|
('ldap', '0'): 'bindRequest',
|
|
('ldap', '1'): 'bindResponse',
|
|
('ldap', '2'): 'unbindRequest',
|
|
('ldap', '3'): 'searchRequest',
|
|
('ldap', '4'): 'searchResEntry',
|
|
('ldap', '5'): 'searchResDone',
|
|
('ldap', ''): '*** Unknown ***',
|
|
('lsarpc', '14'): 'lsa_LookupNames',
|
|
('lsarpc', '15'): 'lsa_LookupSids',
|
|
('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
|
|
('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
|
|
('lsarpc', '6'): 'lsa_OpenPolicy',
|
|
('lsarpc', '76'): 'lsa_LookupSids3',
|
|
('lsarpc', '77'): 'lsa_LookupNames4',
|
|
('nbns', '0'): 'query',
|
|
('nbns', '1'): 'response',
|
|
('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
|
|
('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
|
|
('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
|
|
('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
|
|
('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
|
|
('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
|
|
('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
|
|
('rpc_netlogon', '4'): 'NetrServerReqChallenge',
|
|
('samr', '0',): 'Connect',
|
|
('samr', '16'): 'GetAliasMembership',
|
|
('samr', '17'): 'LookupNames',
|
|
('samr', '18'): 'LookupRids',
|
|
('samr', '19'): 'OpenGroup',
|
|
('samr', '1'): 'Close',
|
|
('samr', '25'): 'QueryGroupMember',
|
|
('samr', '34'): 'OpenUser',
|
|
('samr', '36'): 'QueryUserInfo',
|
|
('samr', '39'): 'GetGroupsForUser',
|
|
('samr', '3'): 'QuerySecurity',
|
|
('samr', '5'): 'LookupDomain',
|
|
('samr', '64'): 'Connect5',
|
|
('samr', '6'): 'EnumDomains',
|
|
('samr', '7'): 'OpenDomain',
|
|
('samr', '8'): 'QueryDomainInfo',
|
|
('smb', '0x04'): 'Close (0x04)',
|
|
('smb', '0x24'): 'Locking AndX (0x24)',
|
|
('smb', '0x2e'): 'Read AndX (0x2e)',
|
|
('smb', '0x32'): 'Trans2 (0x32)',
|
|
('smb', '0x71'): 'Tree Disconnect (0x71)',
|
|
('smb', '0x72'): 'Negotiate Protocol (0x72)',
|
|
('smb', '0x73'): 'Session Setup AndX (0x73)',
|
|
('smb', '0x74'): 'Logoff AndX (0x74)',
|
|
('smb', '0x75'): 'Tree Connect AndX (0x75)',
|
|
('smb', '0xa2'): 'NT Create AndX (0xa2)',
|
|
('smb2', '0'): 'NegotiateProtocol',
|
|
('smb2', '11'): 'Ioctl',
|
|
('smb2', '14'): 'Find',
|
|
('smb2', '16'): 'GetInfo',
|
|
('smb2', '18'): 'Break',
|
|
('smb2', '1'): 'SessionSetup',
|
|
('smb2', '2'): 'SessionLogoff',
|
|
('smb2', '3'): 'TreeConnect',
|
|
('smb2', '4'): 'TreeDisconnect',
|
|
('smb2', '5'): 'Create',
|
|
('smb2', '6'): 'Close',
|
|
('smb2', '8'): 'Read',
|
|
('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
|
|
('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
|
|
'user unknown (0x17)'),
|
|
('srvsvc', '16'): 'NetShareGetInfo',
|
|
('srvsvc', '21'): 'NetSrvGetInfo',
|
|
}
|
|
|
|
|
|
def expand_short_packet(p, timestamp, src, dest, extra):
|
|
protocol, opcode = p.split(':', 1)
|
|
desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
|
|
ip_protocol = IP_PROTOCOLS.get(protocol, '06')
|
|
|
|
line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
|
|
line.extend(extra)
|
|
return '\t'.join(line)
|
|
|
|
|
|
def flushing_signal_handler(signal, frame):
|
|
"""Signal handler closes standard out and error.
|
|
|
|
Triggered by a sigterm, ensures that the log messages are flushed
|
|
to disk and not lost.
|
|
"""
|
|
sys.stderr.close()
|
|
sys.stdout.close()
|
|
os._exit(0)
|
|
|
|
|
|
def replay_seq_in_fork(cs, start, context, account, client_id, server_id=1):
|
|
"""Fork a new process and replay the conversation sequence."""
|
|
# We will need to reseed the random number generator or all the
|
|
# clients will end up using the same sequence of random
|
|
# numbers. random.randint() is mixed in so the initial seed will
|
|
# have an effect here.
|
|
seed = client_id * 1000 + random.randint(0, 999)
|
|
|
|
# flush our buffers so messages won't be written by both sides
|
|
sys.stdout.flush()
|
|
sys.stderr.flush()
|
|
pid = os.fork()
|
|
if pid != 0:
|
|
return pid
|
|
|
|
# we must never return, or we'll end up running parts of the
|
|
# parent's clean-up code. So we work in a try...finally, and
|
|
# try to print any exceptions.
|
|
try:
|
|
random.seed(seed)
|
|
endpoints = (server_id, client_id)
|
|
status = 0
|
|
t = cs[0][0]
|
|
c = Conversation(t, endpoints, seq=cs, conversation_id=client_id)
|
|
signal.signal(signal.SIGTERM, flushing_signal_handler)
|
|
|
|
context.generate_process_local_config(account, c)
|
|
sys.stdin.close()
|
|
os.close(0)
|
|
filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
|
|
c.conversation_id)
|
|
f = open(filename, 'w')
|
|
try:
|
|
sys.stdout.close()
|
|
os.close(1)
|
|
except IOError as e:
|
|
LOGGER.info("stdout closing failed with %s" % e)
|
|
|
|
sys.stdout = f
|
|
now = time.time() - start
|
|
gap = t - now
|
|
sleep_time = gap - SLEEP_OVERHEAD
|
|
if sleep_time > 0:
|
|
time.sleep(sleep_time)
|
|
|
|
max_lag, start_lag, max_sleep_miss = c.replay_with_delay(start=start,
|
|
context=context)
|
|
print("Maximum lag: %f" % max_lag)
|
|
print("Start lag: %f" % start_lag)
|
|
print("Max sleep miss: %f" % max_sleep_miss)
|
|
|
|
except Exception:
|
|
status = 1
|
|
print(("EXCEPTION in child PID %d, conversation %s" % (os.getpid(), c)),
|
|
file=sys.stderr)
|
|
traceback.print_exc(sys.stderr)
|
|
sys.stderr.flush()
|
|
finally:
|
|
sys.stderr.close()
|
|
sys.stdout.close()
|
|
os._exit(status)
|
|
|
|
|
|
def dnshammer_in_fork(dns_rate, duration, context, query_file=None):
|
|
sys.stdout.flush()
|
|
sys.stderr.flush()
|
|
pid = os.fork()
|
|
if pid != 0:
|
|
return pid
|
|
|
|
sys.stdin.close()
|
|
os.close(0)
|
|
|
|
try:
|
|
sys.stdout.close()
|
|
os.close(1)
|
|
except IOError as e:
|
|
LOGGER.warn("stdout closing failed with %s" % e)
|
|
filename = os.path.join(context.statsdir, 'stats-dns')
|
|
sys.stdout = open(filename, 'w')
|
|
|
|
try:
|
|
status = 0
|
|
signal.signal(signal.SIGTERM, flushing_signal_handler)
|
|
hammer = DnsHammer(dns_rate, duration, query_file=query_file)
|
|
hammer.replay(context=context)
|
|
except Exception:
|
|
status = 1
|
|
print(("EXCEPTION in child PID %d, the DNS hammer" % (os.getpid())),
|
|
file=sys.stderr)
|
|
traceback.print_exc(sys.stderr)
|
|
finally:
|
|
sys.stderr.close()
|
|
sys.stdout.close()
|
|
os._exit(status)
|
|
|
|
|
|
def replay(conversation_seq,
|
|
host=None,
|
|
creds=None,
|
|
lp=None,
|
|
accounts=None,
|
|
dns_rate=0,
|
|
dns_query_file=None,
|
|
duration=None,
|
|
latency_timeout=1.0,
|
|
stop_on_any_error=False,
|
|
**kwargs):
|
|
|
|
context = ReplayContext(server=host,
|
|
creds=creds,
|
|
lp=lp,
|
|
total_conversations=len(conversation_seq),
|
|
**kwargs)
|
|
|
|
if len(accounts) < len(conversation_seq):
|
|
raise ValueError(("we have %d accounts but %d conversations" %
|
|
(len(accounts), len(conversation_seq))))
|
|
|
|
# Set the process group so that the calling scripts are not killed
|
|
# when the forked child processes are killed.
|
|
os.setpgrp()
|
|
|
|
# we delay the start by a bit to allow all the forks to get up and
|
|
# running.
|
|
delay = len(conversation_seq) * 0.02
|
|
start = time.time() + delay
|
|
|
|
if duration is None:
|
|
# end slightly after the last packet of the last conversation
|
|
# to start. Conversations other than the last could still be
|
|
# going, but we don't care.
|
|
duration = conversation_seq[-1][-1][0] + latency_timeout
|
|
|
|
print("We will start in %.1f seconds" % delay,
|
|
file=sys.stderr)
|
|
print("We will stop after %.1f seconds" % (duration + delay),
|
|
file=sys.stderr)
|
|
print("runtime %.1f seconds" % duration,
|
|
file=sys.stderr)
|
|
|
|
# give one second grace for packets to finish before killing begins
|
|
end = start + duration + 1.0
|
|
|
|
LOGGER.info("Replaying traffic for %u conversations over %d seconds"
|
|
% (len(conversation_seq), duration))
|
|
|
|
context.write_stats('intentions',
|
|
Planned_conversations=len(conversation_seq),
|
|
Planned_packets=sum(len(x) for x in conversation_seq))
|
|
|
|
children = {}
|
|
try:
|
|
if dns_rate:
|
|
pid = dnshammer_in_fork(dns_rate, duration, context,
|
|
query_file=dns_query_file)
|
|
children[pid] = 1
|
|
|
|
for i, cs in enumerate(conversation_seq):
|
|
account = accounts[i]
|
|
client_id = i + 2
|
|
pid = replay_seq_in_fork(cs, start, context, account, client_id)
|
|
children[pid] = client_id
|
|
|
|
# HERE, we are past all the forks
|
|
t = time.time()
|
|
print("all forks done in %.1f seconds, waiting %.1f" %
|
|
(t - start + delay, t - start),
|
|
file=sys.stderr)
|
|
|
|
while time.time() < end and children:
|
|
time.sleep(0.003)
|
|
try:
|
|
pid, status = os.waitpid(-1, os.WNOHANG)
|
|
except OSError as e:
|
|
if e.errno != ECHILD: # no child processes
|
|
raise
|
|
break
|
|
if pid:
|
|
c = children.pop(pid, None)
|
|
if DEBUG_LEVEL > 0:
|
|
print(("process %d finished conversation %d;"
|
|
" %d to go" %
|
|
(pid, c, len(children))), file=sys.stderr)
|
|
if stop_on_any_error and status != 0:
|
|
break
|
|
|
|
except Exception:
|
|
print("EXCEPTION in parent", file=sys.stderr)
|
|
traceback.print_exc()
|
|
finally:
|
|
context.write_stats('unfinished',
|
|
Unfinished_conversations=len(children))
|
|
|
|
for s in (15, 15, 9):
|
|
print(("killing %d children with -%d" %
|
|
(len(children), s)), file=sys.stderr)
|
|
for pid in children:
|
|
try:
|
|
os.kill(pid, s)
|
|
except OSError as e:
|
|
if e.errno != ESRCH: # don't fail if it has already died
|
|
raise
|
|
time.sleep(0.5)
|
|
end = time.time() + 1
|
|
while children:
|
|
try:
|
|
pid, status = os.waitpid(-1, os.WNOHANG)
|
|
except OSError as e:
|
|
if e.errno != ECHILD:
|
|
raise
|
|
if pid != 0:
|
|
c = children.pop(pid, None)
|
|
if c is None:
|
|
print("children is %s, no pid found" % children)
|
|
sys.stderr.flush()
|
|
sys.stdout.flush()
|
|
os._exit(1)
|
|
print(("kill -%d %d KILLED conversation; "
|
|
"%d to go" %
|
|
(s, pid, len(children))),
|
|
file=sys.stderr)
|
|
if time.time() >= end:
|
|
break
|
|
|
|
if not children:
|
|
break
|
|
time.sleep(1)
|
|
|
|
if children:
|
|
print("%d children are missing" % len(children),
|
|
file=sys.stderr)
|
|
|
|
# there may be stragglers that were forked just as ^C was hit
|
|
# and don't appear in the list of children. We can get them
|
|
# with killpg, but that will also kill us, so this is^H^H would be
|
|
# goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
|
|
# so as not to have to fuss around writing signal handlers.
|
|
try:
|
|
os.killpg(0, 2)
|
|
except KeyboardInterrupt:
|
|
print("ignoring fake ^C", file=sys.stderr)
|
|
|
|
|
|
def openLdb(host, creds, lp):
|
|
session = system_session()
|
|
ldb = SamDB(url="ldap://%s" % host,
|
|
session_info=session,
|
|
options=['modules:paged_searches'],
|
|
credentials=creds,
|
|
lp=lp)
|
|
return ldb
|
|
|
|
|
|
def ou_name(ldb, instance_id):
|
|
"""Generate an ou name from the instance id"""
|
|
return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
|
|
ldb.domain_dn())
|
|
|
|
|
|
def create_ou(ldb, instance_id):
|
|
"""Create an ou, all created user and machine accounts will belong to it.
|
|
|
|
This allows all the created resources to be cleaned up easily.
|
|
"""
|
|
ou = ou_name(ldb, instance_id)
|
|
try:
|
|
ldb.add({"dn": ou.split(',', 1)[1],
|
|
"objectclass": "organizationalunit"})
|
|
except LdbError as e:
|
|
(status, _) = e.args
|
|
# ignore already exists
|
|
if status != 68:
|
|
raise
|
|
try:
|
|
ldb.add({"dn": ou,
|
|
"objectclass": "organizationalunit"})
|
|
except LdbError as e:
|
|
(status, _) = e.args
|
|
# ignore already exists
|
|
if status != 68:
|
|
raise
|
|
return ou
|
|
|
|
|
|
# ConversationAccounts holds details of the machine and user accounts
|
|
# associated with a conversation.
|
|
#
|
|
# We use a named tuple to reduce shared memory usage.
|
|
ConversationAccounts = namedtuple('ConversationAccounts',
|
|
('netbios_name',
|
|
'machinepass',
|
|
'username',
|
|
'userpass'))
|
|
|
|
|
|
def generate_replay_accounts(ldb, instance_id, number, password):
|
|
"""Generate a series of unique machine and user account names."""
|
|
|
|
accounts = []
|
|
for i in range(1, number + 1):
|
|
netbios_name = machine_name(instance_id, i)
|
|
username = user_name(instance_id, i)
|
|
|
|
account = ConversationAccounts(netbios_name, password, username,
|
|
password)
|
|
accounts.append(account)
|
|
return accounts
|
|
|
|
|
|
def create_machine_account(ldb, instance_id, netbios_name, machinepass,
|
|
traffic_account=True):
|
|
"""Create a machine account via ldap."""
|
|
|
|
ou = ou_name(ldb, instance_id)
|
|
dn = "cn=%s,%s" % (netbios_name, ou)
|
|
utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
|
|
|
|
if traffic_account:
|
|
# we set these bits for the machine account otherwise the replayed
|
|
# traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
|
|
account_controls = str(UF_TRUSTED_FOR_DELEGATION |
|
|
UF_SERVER_TRUST_ACCOUNT)
|
|
|
|
else:
|
|
account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
|
|
|
|
ldb.add({
|
|
"dn": dn,
|
|
"objectclass": "computer",
|
|
"sAMAccountName": "%s$" % netbios_name,
|
|
"userAccountControl": account_controls,
|
|
"unicodePwd": utf16pw})
|
|
|
|
|
|
def create_user_account(ldb, instance_id, username, userpass):
|
|
"""Create a user account via ldap."""
|
|
ou = ou_name(ldb, instance_id)
|
|
user_dn = "cn=%s,%s" % (username, ou)
|
|
utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
|
|
ldb.add({
|
|
"dn": user_dn,
|
|
"objectclass": "user",
|
|
"sAMAccountName": username,
|
|
"userAccountControl": str(UF_NORMAL_ACCOUNT),
|
|
"unicodePwd": utf16pw
|
|
})
|
|
|
|
# grant user write permission to do things like write account SPN
|
|
sdutils = sd_utils.SDUtils(ldb)
|
|
sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
|
|
|
|
|
|
def create_group(ldb, instance_id, name):
|
|
"""Create a group via ldap."""
|
|
|
|
ou = ou_name(ldb, instance_id)
|
|
dn = "cn=%s,%s" % (name, ou)
|
|
ldb.add({
|
|
"dn": dn,
|
|
"objectclass": "group",
|
|
"sAMAccountName": name,
|
|
})
|
|
|
|
|
|
def user_name(instance_id, i):
|
|
"""Generate a user name based in the instance id"""
|
|
return "STGU-%d-%d" % (instance_id, i)
|
|
|
|
|
|
def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
|
|
"""Search objectclass, return attr in a set"""
|
|
objs = ldb.search(
|
|
expression="(objectClass={})".format(objectclass),
|
|
attrs=[attr]
|
|
)
|
|
return {str(obj[attr]) for obj in objs}
|
|
|
|
|
|
def generate_users(ldb, instance_id, number, password):
|
|
"""Add users to the server"""
|
|
existing_objects = search_objectclass(ldb, objectclass='user')
|
|
users = 0
|
|
for i in range(number, 0, -1):
|
|
name = user_name(instance_id, i)
|
|
if name not in existing_objects:
|
|
create_user_account(ldb, instance_id, name, password)
|
|
users += 1
|
|
if users % 50 == 0:
|
|
LOGGER.info("Created %u/%u users" % (users, number))
|
|
|
|
return users
|
|
|
|
|
|
def machine_name(instance_id, i, traffic_account=True):
|
|
"""Generate a machine account name from instance id."""
|
|
if traffic_account:
|
|
# traffic accounts correspond to a given user, and use different
|
|
# userAccountControl flags to ensure packets get processed correctly
|
|
# by the DC
|
|
return "STGM-%d-%d" % (instance_id, i)
|
|
else:
|
|
# Otherwise we're just generating computer accounts to simulate a
|
|
# semi-realistic network. These use the default computer
|
|
# userAccountControl flags, so we use a different account name so that
|
|
# we don't try to use them when generating packets
|
|
return "PC-%d-%d" % (instance_id, i)
|
|
|
|
|
|
def generate_machine_accounts(ldb, instance_id, number, password,
|
|
traffic_account=True):
|
|
"""Add machine accounts to the server"""
|
|
existing_objects = search_objectclass(ldb, objectclass='computer')
|
|
added = 0
|
|
for i in range(number, 0, -1):
|
|
name = machine_name(instance_id, i, traffic_account)
|
|
if name + "$" not in existing_objects:
|
|
create_machine_account(ldb, instance_id, name, password,
|
|
traffic_account)
|
|
added += 1
|
|
if added % 50 == 0:
|
|
LOGGER.info("Created %u/%u machine accounts" % (added, number))
|
|
|
|
return added
|
|
|
|
|
|
def group_name(instance_id, i):
|
|
"""Generate a group name from instance id."""
|
|
return "STGG-%d-%d" % (instance_id, i)
|
|
|
|
|
|
def generate_groups(ldb, instance_id, number):
|
|
"""Create the required number of groups on the server."""
|
|
existing_objects = search_objectclass(ldb, objectclass='group')
|
|
groups = 0
|
|
for i in range(number, 0, -1):
|
|
name = group_name(instance_id, i)
|
|
if name not in existing_objects:
|
|
create_group(ldb, instance_id, name)
|
|
groups += 1
|
|
if groups % 1000 == 0:
|
|
LOGGER.info("Created %u/%u groups" % (groups, number))
|
|
|
|
return groups
|
|
|
|
|
|
def clean_up_accounts(ldb, instance_id):
|
|
"""Remove the created accounts and groups from the server."""
|
|
ou = ou_name(ldb, instance_id)
|
|
try:
|
|
ldb.delete(ou, ["tree_delete:1"])
|
|
except LdbError as e:
|
|
(status, _) = e.args
|
|
# ignore does not exist
|
|
if status != 32:
|
|
raise
|
|
|
|
|
|
def generate_users_and_groups(ldb, instance_id, password,
|
|
number_of_users, number_of_groups,
|
|
group_memberships, max_members,
|
|
machine_accounts, traffic_accounts=True):
|
|
"""Generate the required users and groups, allocating the users to
|
|
those groups."""
|
|
memberships_added = 0
|
|
groups_added = 0
|
|
computers_added = 0
|
|
|
|
create_ou(ldb, instance_id)
|
|
|
|
LOGGER.info("Generating dummy user accounts")
|
|
users_added = generate_users(ldb, instance_id, number_of_users, password)
|
|
|
|
LOGGER.info("Generating dummy machine accounts")
|
|
computers_added = generate_machine_accounts(ldb, instance_id,
|
|
machine_accounts, password,
|
|
traffic_accounts)
|
|
|
|
if number_of_groups > 0:
|
|
LOGGER.info("Generating dummy groups")
|
|
groups_added = generate_groups(ldb, instance_id, number_of_groups)
|
|
|
|
if group_memberships > 0:
|
|
LOGGER.info("Assigning users to groups")
|
|
assignments = GroupAssignments(number_of_groups,
|
|
groups_added,
|
|
number_of_users,
|
|
users_added,
|
|
group_memberships,
|
|
max_members)
|
|
LOGGER.info("Adding users to groups")
|
|
add_users_to_groups(ldb, instance_id, assignments)
|
|
memberships_added = assignments.total()
|
|
|
|
if (groups_added > 0 and users_added == 0 and
|
|
number_of_groups != groups_added):
|
|
LOGGER.warning("The added groups will contain no members")
|
|
|
|
LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
|
|
(users_added, computers_added, groups_added,
|
|
memberships_added))
|
|
|
|
|
|
class GroupAssignments(object):
|
|
def __init__(self, number_of_groups, groups_added, number_of_users,
|
|
users_added, group_memberships, max_members):
|
|
|
|
self.count = 0
|
|
self.generate_group_distribution(number_of_groups)
|
|
self.generate_user_distribution(number_of_users, group_memberships)
|
|
self.max_members = max_members
|
|
self.assignments = defaultdict(list)
|
|
self.assign_groups(number_of_groups, groups_added, number_of_users,
|
|
users_added, group_memberships)
|
|
|
|
def cumulative_distribution(self, weights):
|
|
# make sure the probabilities conform to a cumulative distribution
|
|
# spread between 0.0 and 1.0. Dividing by the weighted total gives each
|
|
# probability a proportional share of 1.0. Higher probabilities get a
|
|
# bigger share, so are more likely to be picked. We use the cumulative
|
|
# value, so we can use random.random() as a simple index into the list
|
|
dist = []
|
|
total = sum(weights)
|
|
if total == 0:
|
|
return None
|
|
|
|
cumulative = 0.0
|
|
for probability in weights:
|
|
cumulative += probability
|
|
dist.append(cumulative / total)
|
|
return dist
|
|
|
|
def generate_user_distribution(self, num_users, num_memberships):
|
|
"""Probability distribution of a user belonging to a group.
|
|
"""
|
|
# Assign a weighted probability to each user. Use the Pareto
|
|
# Distribution so that some users are in a lot of groups, and the
|
|
# bulk of users are in only a few groups. If we're assigning a large
|
|
# number of group memberships, use a higher shape. This means slightly
|
|
# fewer outlying users that are in large numbers of groups. The aim is
|
|
# to have no users belonging to more than ~500 groups.
|
|
if num_memberships > 5000000:
|
|
shape = 3.0
|
|
elif num_memberships > 2000000:
|
|
shape = 2.5
|
|
elif num_memberships > 300000:
|
|
shape = 2.25
|
|
else:
|
|
shape = 1.75
|
|
|
|
weights = []
|
|
for x in range(1, num_users + 1):
|
|
p = random.paretovariate(shape)
|
|
weights.append(p)
|
|
|
|
# convert the weights to a cumulative distribution between 0.0 and 1.0
|
|
self.user_dist = self.cumulative_distribution(weights)
|
|
|
|
def generate_group_distribution(self, n):
|
|
"""Probability distribution of a group containing a user."""
|
|
|
|
# Assign a weighted probability to each user. Probability decreases
|
|
# as the group-ID increases
|
|
weights = []
|
|
for x in range(1, n + 1):
|
|
p = 1 / (x**1.3)
|
|
weights.append(p)
|
|
|
|
# convert the weights to a cumulative distribution between 0.0 and 1.0
|
|
self.group_weights = weights
|
|
self.group_dist = self.cumulative_distribution(weights)
|
|
|
|
def generate_random_membership(self):
|
|
"""Returns a randomly generated user-group membership"""
|
|
|
|
# the list items are cumulative distribution values between 0.0 and
|
|
# 1.0, which makes random() a handy way to index the list to get a
|
|
# weighted random user/group. (Here the user/group returned are
|
|
# zero-based array indexes)
|
|
user = bisect.bisect(self.user_dist, random.random())
|
|
group = bisect.bisect(self.group_dist, random.random())
|
|
|
|
return user, group
|
|
|
|
def users_in_group(self, group):
|
|
return self.assignments[group]
|
|
|
|
def get_groups(self):
|
|
return self.assignments.keys()
|
|
|
|
def cap_group_membership(self, group, max_members):
|
|
"""Prevent the group's membership from exceeding the max specified"""
|
|
num_members = len(self.assignments[group])
|
|
if num_members >= max_members:
|
|
LOGGER.info("Group {0} has {1} members".format(group, num_members))
|
|
|
|
# remove this group and then recalculate the cumulative
|
|
# distribution, so this group is no longer selected
|
|
self.group_weights[group - 1] = 0
|
|
new_dist = self.cumulative_distribution(self.group_weights)
|
|
self.group_dist = new_dist
|
|
|
|
def add_assignment(self, user, group):
|
|
# the assignments are stored in a dictionary where key=group,
|
|
# value=list-of-users-in-group (indexing by group-ID allows us to
|
|
# optimize for DB membership writes)
|
|
if user not in self.assignments[group]:
|
|
self.assignments[group].append(user)
|
|
self.count += 1
|
|
|
|
# check if there'a cap on how big the groups can grow
|
|
if self.max_members:
|
|
self.cap_group_membership(group, self.max_members)
|
|
|
|
def assign_groups(self, number_of_groups, groups_added,
|
|
number_of_users, users_added, group_memberships):
|
|
"""Allocate users to groups.
|
|
|
|
The intention is to have a few users that belong to most groups, while
|
|
the majority of users belong to a few groups.
|
|
|
|
A few groups will contain most users, with the remaining only having a
|
|
few users.
|
|
"""
|
|
|
|
if group_memberships <= 0:
|
|
return
|
|
|
|
# Calculate the number of group menberships required
|
|
group_memberships = math.ceil(
|
|
float(group_memberships) *
|
|
(float(users_added) / float(number_of_users)))
|
|
|
|
if self.max_members:
|
|
group_memberships = min(group_memberships,
|
|
self.max_members * number_of_groups)
|
|
|
|
existing_users = number_of_users - users_added - 1
|
|
existing_groups = number_of_groups - groups_added - 1
|
|
while self.total() < group_memberships:
|
|
user, group = self.generate_random_membership()
|
|
|
|
if group > existing_groups or user > existing_users:
|
|
# the + 1 converts the array index to the corresponding
|
|
# group or user number
|
|
self.add_assignment(user + 1, group + 1)
|
|
|
|
def total(self):
|
|
return self.count
|
|
|
|
|
|
def add_users_to_groups(db, instance_id, assignments):
|
|
"""Takes the assignments of users to groups and applies them to the DB."""
|
|
|
|
total = assignments.total()
|
|
count = 0
|
|
added = 0
|
|
|
|
for group in assignments.get_groups():
|
|
users_in_group = assignments.users_in_group(group)
|
|
if len(users_in_group) == 0:
|
|
continue
|
|
|
|
# Split up the users into chunks, so we write no more than 1K at a
|
|
# time. (Minimizing the DB modifies is more efficient, but writing
|
|
# 10K+ users to a single group becomes inefficient memory-wise)
|
|
for chunk in range(0, len(users_in_group), 1000):
|
|
chunk_of_users = users_in_group[chunk:chunk + 1000]
|
|
add_group_members(db, instance_id, group, chunk_of_users)
|
|
|
|
added += len(chunk_of_users)
|
|
count += 1
|
|
if count % 50 == 0:
|
|
LOGGER.info("Added %u/%u memberships" % (added, total))
|
|
|
|
def add_group_members(db, instance_id, group, users_in_group):
|
|
"""Adds the given users to group specified."""
|
|
|
|
ou = ou_name(db, instance_id)
|
|
|
|
def build_dn(name):
|
|
return("cn=%s,%s" % (name, ou))
|
|
|
|
group_dn = build_dn(group_name(instance_id, group))
|
|
m = ldb.Message()
|
|
m.dn = ldb.Dn(db, group_dn)
|
|
|
|
for user in users_in_group:
|
|
user_dn = build_dn(user_name(instance_id, user))
|
|
idx = "member-" + str(user)
|
|
m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
|
|
|
|
db.modify(m)
|
|
|
|
|
|
def generate_stats(statsdir, timing_file):
|
|
"""Generate and print the summary stats for a run."""
|
|
first = sys.float_info.max
|
|
last = 0
|
|
successful = 0
|
|
failed = 0
|
|
latencies = {}
|
|
failures = Counter()
|
|
unique_conversations = set()
|
|
if timing_file is not None:
|
|
tw = timing_file.write
|
|
else:
|
|
def tw(x):
|
|
pass
|
|
|
|
tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
|
|
|
|
float_values = {
|
|
'Maximum lag': 0,
|
|
'Start lag': 0,
|
|
'Max sleep miss': 0,
|
|
}
|
|
int_values = {
|
|
'Planned_conversations': 0,
|
|
'Planned_packets': 0,
|
|
'Unfinished_conversations': 0,
|
|
}
|
|
|
|
for filename in os.listdir(statsdir):
|
|
path = os.path.join(statsdir, filename)
|
|
with open(path, 'r') as f:
|
|
for line in f:
|
|
try:
|
|
fields = line.rstrip('\n').split('\t')
|
|
conversation = fields[1]
|
|
protocol = fields[2]
|
|
packet_type = fields[3]
|
|
latency = float(fields[4])
|
|
t = float(fields[0])
|
|
first = min(t - latency, first)
|
|
last = max(t, last)
|
|
|
|
op = (protocol, packet_type)
|
|
latencies.setdefault(op, []).append(latency)
|
|
if fields[5] == 'True':
|
|
successful += 1
|
|
else:
|
|
failed += 1
|
|
failures[op] += 1
|
|
|
|
unique_conversations.add(conversation)
|
|
|
|
tw(line)
|
|
except (ValueError, IndexError):
|
|
if ':' in line:
|
|
k, v = line.split(':', 1)
|
|
if k in float_values:
|
|
float_values[k] = max(float(v),
|
|
float_values[k])
|
|
elif k in int_values:
|
|
int_values[k] = max(int(v),
|
|
int_values[k])
|
|
else:
|
|
print(line, file=sys.stderr)
|
|
else:
|
|
# not a valid line print and ignore
|
|
print(line, file=sys.stderr)
|
|
|
|
duration = last - first
|
|
if successful == 0:
|
|
success_rate = 0
|
|
else:
|
|
success_rate = successful / duration
|
|
if failed == 0:
|
|
failure_rate = 0
|
|
else:
|
|
failure_rate = failed / duration
|
|
|
|
conversations = len(unique_conversations)
|
|
|
|
print("Total conversations: %10d" % conversations)
|
|
print("Successful operations: %10d (%.3f per second)"
|
|
% (successful, success_rate))
|
|
print("Failed operations: %10d (%.3f per second)"
|
|
% (failed, failure_rate))
|
|
|
|
for k, v in sorted(float_values.items()):
|
|
print("%-28s %f" % (k.replace('_', ' ') + ':', v))
|
|
for k, v in sorted(int_values.items()):
|
|
print("%-28s %d" % (k.replace('_', ' ') + ':', v))
|
|
|
|
print("Protocol Op Code Description "
|
|
" Count Failed Mean Median "
|
|
"95% Range Max")
|
|
|
|
ops = {}
|
|
for proto, packet in latencies:
|
|
if proto not in ops:
|
|
ops[proto] = set()
|
|
ops[proto].add(packet)
|
|
protocols = sorted(ops.keys())
|
|
|
|
for protocol in protocols:
|
|
packet_types = sorted(ops[protocol], key=opcode_key)
|
|
for packet_type in packet_types:
|
|
op = (protocol, packet_type)
|
|
values = latencies[op]
|
|
values = sorted(values)
|
|
count = len(values)
|
|
failed = failures[op]
|
|
mean = sum(values) / count
|
|
median = calc_percentile(values, 0.50)
|
|
percentile = calc_percentile(values, 0.95)
|
|
rng = values[-1] - values[0]
|
|
maxv = values[-1]
|
|
desc = OP_DESCRIPTIONS.get(op, '')
|
|
print("%-12s %4s %-35s %12d %12d %12.6f "
|
|
"%12.6f %12.6f %12.6f %12.6f"
|
|
% (protocol,
|
|
packet_type,
|
|
desc,
|
|
count,
|
|
failed,
|
|
mean,
|
|
median,
|
|
percentile,
|
|
rng,
|
|
maxv))
|
|
|
|
|
|
def opcode_key(v):
|
|
"""Sort key for the operation code to ensure that it sorts numerically"""
|
|
try:
|
|
return "%03d" % int(v)
|
|
except ValueError:
|
|
return v
|
|
|
|
|
|
def calc_percentile(values, percentile):
|
|
"""Calculate the specified percentile from the list of values.
|
|
|
|
Assumes the list is sorted in ascending order.
|
|
"""
|
|
|
|
if not values:
|
|
return 0
|
|
k = (len(values) - 1) * percentile
|
|
f = math.floor(k)
|
|
c = math.ceil(k)
|
|
if f == c:
|
|
return values[int(k)]
|
|
d0 = values[int(f)] * (c - k)
|
|
d1 = values[int(c)] * (k - f)
|
|
return d0 + d1
|
|
|
|
|
|
def mk_masked_dir(*path):
|
|
"""In a testenv we end up with 0777 directories that look an alarming
|
|
green colour with ls. Use umask to avoid that."""
|
|
# py3 os.mkdir can do this
|
|
d = os.path.join(*path)
|
|
mask = os.umask(0o077)
|
|
os.mkdir(d)
|
|
os.umask(mask)
|
|
return d
|