1
0
mirror of https://github.com/samba-team/samba.git synced 2024-12-22 13:34:15 +03:00

source4/scripting/python/samba/samba3: handle ntdb files.

Upgrading old Samba 3 instances seems like a place where we don't have
to read ntdb files, but Andrew Bartlett points out that you can run a
Samba 4.0 and even a 4.1 'classic' domain and desire to migrate that
to the AD DC.

So make this upgrade code generic: if it finds an ntdb file, read
that, otherwise read the tdb file.

Cc: Jelmer Vernooij <jelmer@samba.org>
Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
Reviewed-by: Jeremy Allison <jra@samba.org>
This commit is contained in:
Rusty Russell 2013-04-11 17:12:09 +09:30 committed by Jeremy Allison
parent 5b15d10795
commit 1cf46d2e35
2 changed files with 53 additions and 49 deletions

View File

@ -25,37 +25,41 @@ REGISTRY_DB_VERSION = 1
import os import os
import struct import struct
import tdb import tdb
import ntdb
import passdb import passdb
import param as s3param import param as s3param
def fetch_uint32(tdb, key): def fetch_uint32(db, key):
try: try:
data = tdb[key] data = db[key]
except KeyError: except KeyError:
return None return None
assert len(data) == 4 assert len(data) == 4
return struct.unpack("<L", data)[0] return struct.unpack("<L", data)[0]
def fetch_int32(tdb, key): def fetch_int32(db, key):
try: try:
data = tdb[key] data = db[key]
except KeyError: except KeyError:
return None return None
assert len(data) == 4 assert len(data) == 4
return struct.unpack("<l", data)[0] return struct.unpack("<l", data)[0]
class TdbDatabase(object): class DbDatabase(object):
"""Simple Samba 3 TDB database reader.""" """Simple Samba 3 TDB/NTDB database reader."""
def __init__(self, file): def __init__(self, file):
"""Open a file. """Open a file.
:param file: Path of the file to open. :param file: Path of the file to open, appending .tdb or .ntdb.
""" """
self.tdb = tdb.Tdb(file, flags=os.O_RDONLY) if os.path.exists(file + ".ntdb"):
self.db = ntdb.Ntdb(file + ".ntdb", flags=os.O_RDONLY)
else:
self.db = tdb.Tdb(file + ".tdb", flags=os.O_RDONLY)
self._check_version() self._check_version()
def _check_version(self): def _check_version(self):
@ -63,10 +67,10 @@ class TdbDatabase(object):
def close(self): def close(self):
"""Close resources associated with this object.""" """Close resources associated with this object."""
self.tdb.close() self.db.close()
class Registry(TdbDatabase): class Registry(DbDatabase):
"""Simple read-only support for reading the Samba3 registry. """Simple read-only support for reading the Samba3 registry.
:note: This object uses the same syntax for registry key paths as :note: This object uses the same syntax for registry key paths as
@ -80,7 +84,7 @@ class Registry(TdbDatabase):
def keys(self): def keys(self):
"""Return list with all the keys.""" """Return list with all the keys."""
return [k.rstrip("\x00") for k in self.tdb.iterkeys() if not k.startswith(REGISTRY_VALUE_PREFIX)] return [k.rstrip("\x00") for k in self.db.iterkeys() if not k.startswith(REGISTRY_VALUE_PREFIX)]
def subkeys(self, key): def subkeys(self, key):
"""Retrieve the subkeys for the specified key. """Retrieve the subkeys for the specified key.
@ -88,7 +92,7 @@ class Registry(TdbDatabase):
:param key: Key path. :param key: Key path.
:return: list with key names :return: list with key names
""" """
data = self.tdb.get("%s\x00" % key) data = self.db.get("%s\x00" % key)
if data is None: if data is None:
return [] return []
(num, ) = struct.unpack("<L", data[0:4]) (num, ) = struct.unpack("<L", data[0:4])
@ -104,7 +108,7 @@ class Registry(TdbDatabase):
:param key: Key to retrieve values for. :param key: Key to retrieve values for.
:return: Dictionary with value names as key, tuple with type and :return: Dictionary with value names as key, tuple with type and
data as value.""" data as value."""
data = self.tdb.get("%s/%s\x00" % (REGISTRY_VALUE_PREFIX, key)) data = self.db.get("%s/%s\x00" % (REGISTRY_VALUE_PREFIX, key))
if data is None: if data is None:
return {} return {}
ret = {} ret = {}
@ -135,15 +139,15 @@ IDMAP_USER_PREFIX = "UID "
# idmap version determines auto-conversion # idmap version determines auto-conversion
IDMAP_VERSION_V2 = 2 IDMAP_VERSION_V2 = 2
class IdmapDatabase(TdbDatabase): class IdmapDatabase(DbDatabase):
"""Samba 3 ID map database reader.""" """Samba 3 ID map database reader."""
def _check_version(self): def _check_version(self):
assert fetch_int32(self.tdb, "IDMAP_VERSION\0") == IDMAP_VERSION_V2 assert fetch_int32(self.db, "IDMAP_VERSION\0") == IDMAP_VERSION_V2
def ids(self): def ids(self):
"""Retrieve a list of all ids in this database.""" """Retrieve a list of all ids in this database."""
for k in self.tdb.iterkeys(): for k in self.db.iterkeys():
if k.startswith(IDMAP_USER_PREFIX): if k.startswith(IDMAP_USER_PREFIX):
yield k.rstrip("\0").split(" ") yield k.rstrip("\0").split(" ")
if k.startswith(IDMAP_GROUP_PREFIX): if k.startswith(IDMAP_GROUP_PREFIX):
@ -151,13 +155,13 @@ class IdmapDatabase(TdbDatabase):
def uids(self): def uids(self):
"""Retrieve a list of all uids in this database.""" """Retrieve a list of all uids in this database."""
for k in self.tdb.iterkeys(): for k in self.db.iterkeys():
if k.startswith(IDMAP_USER_PREFIX): if k.startswith(IDMAP_USER_PREFIX):
yield int(k[len(IDMAP_USER_PREFIX):].rstrip("\0")) yield int(k[len(IDMAP_USER_PREFIX):].rstrip("\0"))
def gids(self): def gids(self):
"""Retrieve a list of all gids in this database.""" """Retrieve a list of all gids in this database."""
for k in self.tdb.iterkeys(): for k in self.db.iterkeys():
if k.startswith(IDMAP_GROUP_PREFIX): if k.startswith(IDMAP_GROUP_PREFIX):
yield int(k[len(IDMAP_GROUP_PREFIX):].rstrip("\0")) yield int(k[len(IDMAP_GROUP_PREFIX):].rstrip("\0"))
@ -167,7 +171,7 @@ class IdmapDatabase(TdbDatabase):
:param xid: UID or GID to retrive SID for. :param xid: UID or GID to retrive SID for.
:param id_type: Type of id specified - 'UID' or 'GID' :param id_type: Type of id specified - 'UID' or 'GID'
""" """
data = self.tdb.get("%s %s\0" % (id_type, str(xid))) data = self.db.get("%s %s\0" % (id_type, str(xid)))
if data is None: if data is None:
return data return data
return data.rstrip("\0") return data.rstrip("\0")
@ -178,43 +182,43 @@ class IdmapDatabase(TdbDatabase):
:param uid: UID to retrieve SID for. :param uid: UID to retrieve SID for.
:return: A SID or None if no mapping was found. :return: A SID or None if no mapping was found.
""" """
data = self.tdb.get("%s%d\0" % (IDMAP_USER_PREFIX, uid)) data = self.db.get("%s%d\0" % (IDMAP_USER_PREFIX, uid))
if data is None: if data is None:
return data return data
return data.rstrip("\0") return data.rstrip("\0")
def get_group_sid(self, gid): def get_group_sid(self, gid):
data = self.tdb.get("%s%d\0" % (IDMAP_GROUP_PREFIX, gid)) data = self.db.get("%s%d\0" % (IDMAP_GROUP_PREFIX, gid))
if data is None: if data is None:
return data return data
return data.rstrip("\0") return data.rstrip("\0")
def get_user_hwm(self): def get_user_hwm(self):
"""Obtain the user high-water mark.""" """Obtain the user high-water mark."""
return fetch_uint32(self.tdb, IDMAP_HWM_USER) return fetch_uint32(self.db, IDMAP_HWM_USER)
def get_group_hwm(self): def get_group_hwm(self):
"""Obtain the group high-water mark.""" """Obtain the group high-water mark."""
return fetch_uint32(self.tdb, IDMAP_HWM_GROUP) return fetch_uint32(self.db, IDMAP_HWM_GROUP)
class SecretsDatabase(TdbDatabase): class SecretsDatabase(DbDatabase):
"""Samba 3 Secrets database reader.""" """Samba 3 Secrets database reader."""
def get_auth_password(self): def get_auth_password(self):
return self.tdb.get("SECRETS/AUTH_PASSWORD") return self.db.get("SECRETS/AUTH_PASSWORD")
def get_auth_domain(self): def get_auth_domain(self):
return self.tdb.get("SECRETS/AUTH_DOMAIN") return self.db.get("SECRETS/AUTH_DOMAIN")
def get_auth_user(self): def get_auth_user(self):
return self.tdb.get("SECRETS/AUTH_USER") return self.db.get("SECRETS/AUTH_USER")
def get_domain_guid(self, host): def get_domain_guid(self, host):
return self.tdb.get("SECRETS/DOMGUID/%s" % host) return self.db.get("SECRETS/DOMGUID/%s" % host)
def ldap_dns(self): def ldap_dns(self):
for k in self.tdb.iterkeys(): for k in self.db.iterkeys():
if k.startswith("SECRETS/LDAP_BIND_PW/"): if k.startswith("SECRETS/LDAP_BIND_PW/"):
yield k[len("SECRETS/LDAP_BIND_PW/"):].rstrip("\0") yield k[len("SECRETS/LDAP_BIND_PW/"):].rstrip("\0")
@ -223,59 +227,59 @@ class SecretsDatabase(TdbDatabase):
:return: Iterator over the names of domains in this database. :return: Iterator over the names of domains in this database.
""" """
for k in self.tdb.iterkeys(): for k in self.db.iterkeys():
if k.startswith("SECRETS/SID/"): if k.startswith("SECRETS/SID/"):
yield k[len("SECRETS/SID/"):].rstrip("\0") yield k[len("SECRETS/SID/"):].rstrip("\0")
def get_ldap_bind_pw(self, host): def get_ldap_bind_pw(self, host):
return self.tdb.get("SECRETS/LDAP_BIND_PW/%s" % host) return self.db.get("SECRETS/LDAP_BIND_PW/%s" % host)
def get_afs_keyfile(self, host): def get_afs_keyfile(self, host):
return self.tdb.get("SECRETS/AFS_KEYFILE/%s" % host) return self.db.get("SECRETS/AFS_KEYFILE/%s" % host)
def get_machine_sec_channel_type(self, host): def get_machine_sec_channel_type(self, host):
return fetch_uint32(self.tdb, "SECRETS/MACHINE_SEC_CHANNEL_TYPE/%s" % host) return fetch_uint32(self.db, "SECRETS/MACHINE_SEC_CHANNEL_TYPE/%s" % host)
def get_machine_last_change_time(self, host): def get_machine_last_change_time(self, host):
return fetch_uint32(self.tdb, "SECRETS/MACHINE_LAST_CHANGE_TIME/%s" % host) return fetch_uint32(self.db, "SECRETS/MACHINE_LAST_CHANGE_TIME/%s" % host)
def get_machine_password(self, host): def get_machine_password(self, host):
return self.tdb.get("SECRETS/MACHINE_PASSWORD/%s" % host) return self.db.get("SECRETS/MACHINE_PASSWORD/%s" % host)
def get_machine_acc(self, host): def get_machine_acc(self, host):
return self.tdb.get("SECRETS/$MACHINE.ACC/%s" % host) return self.db.get("SECRETS/$MACHINE.ACC/%s" % host)
def get_domtrust_acc(self, host): def get_domtrust_acc(self, host):
return self.tdb.get("SECRETS/$DOMTRUST.ACC/%s" % host) return self.db.get("SECRETS/$DOMTRUST.ACC/%s" % host)
def trusted_domains(self): def trusted_domains(self):
for k in self.tdb.iterkeys(): for k in self.db.iterkeys():
if k.startswith("SECRETS/$DOMTRUST.ACC/"): if k.startswith("SECRETS/$DOMTRUST.ACC/"):
yield k[len("SECRETS/$DOMTRUST.ACC/"):].rstrip("\0") yield k[len("SECRETS/$DOMTRUST.ACC/"):].rstrip("\0")
def get_random_seed(self): def get_random_seed(self):
return self.tdb.get("INFO/random_seed") return self.db.get("INFO/random_seed")
def get_sid(self, host): def get_sid(self, host):
return self.tdb.get("SECRETS/SID/%s" % host.upper()) return self.db.get("SECRETS/SID/%s" % host.upper())
SHARE_DATABASE_VERSION_V1 = 1 SHARE_DATABASE_VERSION_V1 = 1
SHARE_DATABASE_VERSION_V2 = 2 SHARE_DATABASE_VERSION_V2 = 2
class ShareInfoDatabase(TdbDatabase): class ShareInfoDatabase(DbDatabase):
"""Samba 3 Share Info database reader.""" """Samba 3 Share Info database reader."""
def _check_version(self): def _check_version(self):
assert fetch_int32(self.tdb, "INFO/version\0") in (SHARE_DATABASE_VERSION_V1, SHARE_DATABASE_VERSION_V2) assert fetch_int32(self.db, "INFO/version\0") in (SHARE_DATABASE_VERSION_V1, SHARE_DATABASE_VERSION_V2)
def get_secdesc(self, name): def get_secdesc(self, name):
"""Obtain the security descriptor on a particular share. """Obtain the security descriptor on a particular share.
:param name: Name of the share :param name: Name of the share
""" """
secdesc = self.tdb.get("SECDESC/%s" % name) secdesc = self.db.get("SECDESC/%s" % name)
# FIXME: Run ndr_pull_security_descriptor # FIXME: Run ndr_pull_security_descriptor
return secdesc return secdesc
@ -390,16 +394,16 @@ class Samba3(object):
return passdb.PDB(self.lp.get('passdb backend')) return passdb.PDB(self.lp.get('passdb backend'))
def get_registry(self): def get_registry(self):
return Registry(self.statedir_path("registry.tdb")) return Registry(self.statedir_path("registry"))
def get_secrets_db(self): def get_secrets_db(self):
return SecretsDatabase(self.privatedir_path("secrets.tdb")) return SecretsDatabase(self.privatedir_path("secrets"))
def get_shareinfo_db(self): def get_shareinfo_db(self):
return ShareInfoDatabase(self.statedir_path("share_info.tdb")) return ShareInfoDatabase(self.statedir_path("share_info"))
def get_idmap_db(self): def get_idmap_db(self):
return IdmapDatabase(self.statedir_path("winbindd_idmap.tdb")) return IdmapDatabase(self.statedir_path("winbindd_idmap"))
def get_wins_db(self): def get_wins_db(self):
return WinsDatabase(self.statedir_path("wins.dat")) return WinsDatabase(self.statedir_path("wins.dat"))

View File

@ -39,7 +39,7 @@ class RegistryTestCase(TestCase):
def setUp(self): def setUp(self):
super(RegistryTestCase, self).setUp() super(RegistryTestCase, self).setUp()
self.registry = Registry(os.path.join(DATADIR, "registry.tdb")) self.registry = Registry(os.path.join(DATADIR, "registry"))
def tearDown(self): def tearDown(self):
self.registry.close() self.registry.close()
@ -194,7 +194,7 @@ class IdmapDbTestCase(TestCase):
def setUp(self): def setUp(self):
super(IdmapDbTestCase, self).setUp() super(IdmapDbTestCase, self).setUp()
self.idmapdb = IdmapDatabase(os.path.join(DATADIR, self.idmapdb = IdmapDatabase(os.path.join(DATADIR,
"winbindd_idmap.tdb")) "winbindd_idmap"))
def test_user_hwm(self): def test_user_hwm(self):
self.assertEquals(10000, self.idmapdb.get_user_hwm()) self.assertEquals(10000, self.idmapdb.get_user_hwm())