diff --git a/python/samba/samba3/__init__.py b/python/samba/samba3/__init__.py index acccff4e296..0165909f45d 100644 --- a/python/samba/samba3/__init__.py +++ b/python/samba/samba3/__init__.py @@ -25,37 +25,41 @@ REGISTRY_DB_VERSION = 1 import os import struct import tdb +import ntdb import passdb import param as s3param -def fetch_uint32(tdb, key): +def fetch_uint32(db, key): try: - data = tdb[key] + data = db[key] except KeyError: return None assert len(data) == 4 return struct.unpack("<L", data)[0] -def fetch_int32(tdb, key): +def fetch_int32(db, key): try: - data = tdb[key] + data = db[key] except KeyError: return None assert len(data) == 4 return struct.unpack("<l", data)[0] -class TdbDatabase(object): - """Simple Samba 3 TDB database reader.""" +class DbDatabase(object): + """Simple Samba 3 TDB/NTDB database reader.""" def __init__(self, file): """Open a file. - :param file: Path of the file to open. + :param file: Path of the file to open, appending .tdb or .ntdb. """ - self.tdb = tdb.Tdb(file, flags=os.O_RDONLY) + if os.path.exists(file + ".ntdb"): + self.db = ntdb.Ntdb(file + ".ntdb", flags=os.O_RDONLY) + else: + self.db = tdb.Tdb(file + ".tdb", flags=os.O_RDONLY) self._check_version() def _check_version(self): @@ -63,10 +67,10 @@ class TdbDatabase(object): def close(self): """Close resources associated with this object.""" - self.tdb.close() + self.db.close() -class Registry(TdbDatabase): +class Registry(DbDatabase): """Simple read-only support for reading the Samba3 registry. :note: This object uses the same syntax for registry key paths as @@ -80,7 +84,7 @@ class Registry(TdbDatabase): def keys(self): """Return list with all the keys.""" - return [k.rstrip("\x00") for k in self.tdb.iterkeys() if not k.startswith(REGISTRY_VALUE_PREFIX)] + return [k.rstrip("\x00") for k in self.db.iterkeys() if not k.startswith(REGISTRY_VALUE_PREFIX)] def subkeys(self, key): """Retrieve the subkeys for the specified key. @@ -88,7 +92,7 @@ class Registry(TdbDatabase): :param key: Key path. :return: list with key names """ - data = self.tdb.get("%s\x00" % key) + data = self.db.get("%s\x00" % key) if data is None: return [] (num, ) = struct.unpack("<L", data[0:4]) @@ -104,7 +108,7 @@ class Registry(TdbDatabase): :param key: Key to retrieve values for. :return: Dictionary with value names as key, tuple with type and data as value.""" - data = self.tdb.get("%s/%s\x00" % (REGISTRY_VALUE_PREFIX, key)) + data = self.db.get("%s/%s\x00" % (REGISTRY_VALUE_PREFIX, key)) if data is None: return {} ret = {} @@ -135,15 +139,15 @@ IDMAP_USER_PREFIX = "UID " # idmap version determines auto-conversion IDMAP_VERSION_V2 = 2 -class IdmapDatabase(TdbDatabase): +class IdmapDatabase(DbDatabase): """Samba 3 ID map database reader.""" def _check_version(self): - assert fetch_int32(self.tdb, "IDMAP_VERSION\0") == IDMAP_VERSION_V2 + assert fetch_int32(self.db, "IDMAP_VERSION\0") == IDMAP_VERSION_V2 def ids(self): """Retrieve a list of all ids in this database.""" - for k in self.tdb.iterkeys(): + for k in self.db.iterkeys(): if k.startswith(IDMAP_USER_PREFIX): yield k.rstrip("\0").split(" ") if k.startswith(IDMAP_GROUP_PREFIX): @@ -151,13 +155,13 @@ class IdmapDatabase(TdbDatabase): def uids(self): """Retrieve a list of all uids in this database.""" - for k in self.tdb.iterkeys(): + for k in self.db.iterkeys(): if k.startswith(IDMAP_USER_PREFIX): yield int(k[len(IDMAP_USER_PREFIX):].rstrip("\0")) def gids(self): """Retrieve a list of all gids in this database.""" - for k in self.tdb.iterkeys(): + for k in self.db.iterkeys(): if k.startswith(IDMAP_GROUP_PREFIX): yield int(k[len(IDMAP_GROUP_PREFIX):].rstrip("\0")) @@ -167,7 +171,7 @@ class IdmapDatabase(TdbDatabase): :param xid: UID or GID to retrive SID for. :param id_type: Type of id specified - 'UID' or 'GID' """ - data = self.tdb.get("%s %s\0" % (id_type, str(xid))) + data = self.db.get("%s %s\0" % (id_type, str(xid))) if data is None: return data return data.rstrip("\0") @@ -178,43 +182,43 @@ class IdmapDatabase(TdbDatabase): :param uid: UID to retrieve SID for. :return: A SID or None if no mapping was found. """ - data = self.tdb.get("%s%d\0" % (IDMAP_USER_PREFIX, uid)) + data = self.db.get("%s%d\0" % (IDMAP_USER_PREFIX, uid)) if data is None: return data return data.rstrip("\0") def get_group_sid(self, gid): - data = self.tdb.get("%s%d\0" % (IDMAP_GROUP_PREFIX, gid)) + data = self.db.get("%s%d\0" % (IDMAP_GROUP_PREFIX, gid)) if data is None: return data return data.rstrip("\0") def get_user_hwm(self): """Obtain the user high-water mark.""" - return fetch_uint32(self.tdb, IDMAP_HWM_USER) + return fetch_uint32(self.db, IDMAP_HWM_USER) def get_group_hwm(self): """Obtain the group high-water mark.""" - return fetch_uint32(self.tdb, IDMAP_HWM_GROUP) + return fetch_uint32(self.db, IDMAP_HWM_GROUP) -class SecretsDatabase(TdbDatabase): +class SecretsDatabase(DbDatabase): """Samba 3 Secrets database reader.""" def get_auth_password(self): - return self.tdb.get("SECRETS/AUTH_PASSWORD") + return self.db.get("SECRETS/AUTH_PASSWORD") def get_auth_domain(self): - return self.tdb.get("SECRETS/AUTH_DOMAIN") + return self.db.get("SECRETS/AUTH_DOMAIN") def get_auth_user(self): - return self.tdb.get("SECRETS/AUTH_USER") + return self.db.get("SECRETS/AUTH_USER") def get_domain_guid(self, host): - return self.tdb.get("SECRETS/DOMGUID/%s" % host) + return self.db.get("SECRETS/DOMGUID/%s" % host) def ldap_dns(self): - for k in self.tdb.iterkeys(): + for k in self.db.iterkeys(): if k.startswith("SECRETS/LDAP_BIND_PW/"): yield k[len("SECRETS/LDAP_BIND_PW/"):].rstrip("\0") @@ -223,59 +227,59 @@ class SecretsDatabase(TdbDatabase): :return: Iterator over the names of domains in this database. """ - for k in self.tdb.iterkeys(): + for k in self.db.iterkeys(): if k.startswith("SECRETS/SID/"): yield k[len("SECRETS/SID/"):].rstrip("\0") def get_ldap_bind_pw(self, host): - return self.tdb.get("SECRETS/LDAP_BIND_PW/%s" % host) + return self.db.get("SECRETS/LDAP_BIND_PW/%s" % host) def get_afs_keyfile(self, host): - return self.tdb.get("SECRETS/AFS_KEYFILE/%s" % host) + return self.db.get("SECRETS/AFS_KEYFILE/%s" % host) def get_machine_sec_channel_type(self, host): - return fetch_uint32(self.tdb, "SECRETS/MACHINE_SEC_CHANNEL_TYPE/%s" % host) + return fetch_uint32(self.db, "SECRETS/MACHINE_SEC_CHANNEL_TYPE/%s" % host) def get_machine_last_change_time(self, host): - return fetch_uint32(self.tdb, "SECRETS/MACHINE_LAST_CHANGE_TIME/%s" % host) + return fetch_uint32(self.db, "SECRETS/MACHINE_LAST_CHANGE_TIME/%s" % host) def get_machine_password(self, host): - return self.tdb.get("SECRETS/MACHINE_PASSWORD/%s" % host) + return self.db.get("SECRETS/MACHINE_PASSWORD/%s" % host) def get_machine_acc(self, host): - return self.tdb.get("SECRETS/$MACHINE.ACC/%s" % host) + return self.db.get("SECRETS/$MACHINE.ACC/%s" % host) def get_domtrust_acc(self, host): - return self.tdb.get("SECRETS/$DOMTRUST.ACC/%s" % host) + return self.db.get("SECRETS/$DOMTRUST.ACC/%s" % host) def trusted_domains(self): - for k in self.tdb.iterkeys(): + for k in self.db.iterkeys(): if k.startswith("SECRETS/$DOMTRUST.ACC/"): yield k[len("SECRETS/$DOMTRUST.ACC/"):].rstrip("\0") def get_random_seed(self): - return self.tdb.get("INFO/random_seed") + return self.db.get("INFO/random_seed") def get_sid(self, host): - return self.tdb.get("SECRETS/SID/%s" % host.upper()) + return self.db.get("SECRETS/SID/%s" % host.upper()) SHARE_DATABASE_VERSION_V1 = 1 SHARE_DATABASE_VERSION_V2 = 2 -class ShareInfoDatabase(TdbDatabase): +class ShareInfoDatabase(DbDatabase): """Samba 3 Share Info database reader.""" def _check_version(self): - assert fetch_int32(self.tdb, "INFO/version\0") in (SHARE_DATABASE_VERSION_V1, SHARE_DATABASE_VERSION_V2) + assert fetch_int32(self.db, "INFO/version\0") in (SHARE_DATABASE_VERSION_V1, SHARE_DATABASE_VERSION_V2) def get_secdesc(self, name): """Obtain the security descriptor on a particular share. :param name: Name of the share """ - secdesc = self.tdb.get("SECDESC/%s" % name) + secdesc = self.db.get("SECDESC/%s" % name) # FIXME: Run ndr_pull_security_descriptor return secdesc @@ -390,16 +394,16 @@ class Samba3(object): return passdb.PDB(self.lp.get('passdb backend')) def get_registry(self): - return Registry(self.statedir_path("registry.tdb")) + return Registry(self.statedir_path("registry")) def get_secrets_db(self): - return SecretsDatabase(self.privatedir_path("secrets.tdb")) + return SecretsDatabase(self.privatedir_path("secrets")) def get_shareinfo_db(self): - return ShareInfoDatabase(self.statedir_path("share_info.tdb")) + return ShareInfoDatabase(self.statedir_path("share_info")) def get_idmap_db(self): - return IdmapDatabase(self.statedir_path("winbindd_idmap.tdb")) + return IdmapDatabase(self.statedir_path("winbindd_idmap")) def get_wins_db(self): return WinsDatabase(self.statedir_path("wins.dat")) diff --git a/python/samba/tests/samba3.py b/python/samba/tests/samba3.py index 0a7f13c66fa..51d76dd94c0 100644 --- a/python/samba/tests/samba3.py +++ b/python/samba/tests/samba3.py @@ -39,7 +39,7 @@ class RegistryTestCase(TestCase): def setUp(self): super(RegistryTestCase, self).setUp() - self.registry = Registry(os.path.join(DATADIR, "registry.tdb")) + self.registry = Registry(os.path.join(DATADIR, "registry")) def tearDown(self): self.registry.close() @@ -194,7 +194,7 @@ class IdmapDbTestCase(TestCase): def setUp(self): super(IdmapDbTestCase, self).setUp() self.idmapdb = IdmapDatabase(os.path.join(DATADIR, - "winbindd_idmap.tdb")) + "winbindd_idmap")) def test_user_hwm(self): self.assertEquals(10000, self.idmapdb.get_user_hwm())