From ff0cf5eda14e5a480fdcaf166e7839851e9da9e0 Mon Sep 17 00:00:00 2001 From: David Mulder Date: Thu, 22 Aug 2024 08:16:44 -0600 Subject: [PATCH] Add tests for rust himmelblaud build Signed-off-by: David Mulder Reviewed-by: Alexander Bokovoy --- rust/Cargo.lock | 1 + rust/himmelblaud/Cargo.toml | 3 + rust/himmelblaud/build.rs | 3 + rust/himmelblaud/src/cache.rs | 148 +++++++++++++++++- rust/himmelblaud/src/himmelblaud.rs | 51 ++++++ .../src/himmelblaud/himmelblaud_getgrent.rs | 107 ++++++++++++- .../src/himmelblaud/himmelblaud_getgrgid.rs | 104 ++++++++++++ .../src/himmelblaud/himmelblaud_getgrnam.rs | 111 ++++++++++++- .../src/himmelblaud/himmelblaud_getpwent.rs | 120 +++++++++++++- .../src/himmelblaud/himmelblaud_getpwnam.rs | 112 ++++++++++++- .../src/himmelblaud/himmelblaud_getpwuid.rs | 111 +++++++++++++ rust/himmelblaud/src/main.rs | 6 + rust/himmelblaud/src/utils.rs | 68 ++++++++ rust/tdb/src/lib.rs | 13 +- 14 files changed, 947 insertions(+), 11 deletions(-) create mode 100644 rust/himmelblaud/build.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 2f757783da6..6f6b3ea31ce 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -819,6 +819,7 @@ dependencies = [ "serde_json", "sock", "tdb", + "tempfile", "tokio", "tokio-util", "version", diff --git a/rust/himmelblaud/Cargo.toml b/rust/himmelblaud/Cargo.toml index f76170d190a..5806e51af55 100644 --- a/rust/himmelblaud/Cargo.toml +++ b/rust/himmelblaud/Cargo.toml @@ -26,3 +26,6 @@ libc = { workspace = true } [build-dependencies] version = { path = "../version" } + +[dev-dependencies] +tempfile = "3.12.0" diff --git a/rust/himmelblaud/build.rs b/rust/himmelblaud/build.rs new file mode 100644 index 00000000000..da232e3c90e --- /dev/null +++ b/rust/himmelblaud/build.rs @@ -0,0 +1,3 @@ +fn main() { + println!("cargo:rustc-env=LD_LIBRARY_PATH=../../bin/shared:../../bin/shared/private/"); +} diff --git a/rust/himmelblaud/src/cache.rs b/rust/himmelblaud/src/cache.rs index 12a2cd97640..5a344e5f0dd 100644 --- a/rust/himmelblaud/src/cache.rs +++ b/rust/himmelblaud/src/cache.rs @@ -167,7 +167,7 @@ impl BasicCache { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub(crate) struct UserEntry { pub(crate) upn: String, pub(crate) uuid: String, @@ -294,6 +294,16 @@ impl GroupEntry { } } +#[cfg(test)] +impl GroupEntry { + pub fn new(uuid: &str) -> Self { + GroupEntry { + uuid: uuid.to_string(), + members: HashSet::new(), + } + } +} + pub(crate) struct GroupCache { cache: BasicCache, } @@ -330,7 +340,7 @@ impl GroupCache { pub(crate) fn merge_groups( &mut self, member: &str, - entries: Vec, + mut entries: Vec, ) -> Result<(), Box> { // We need to ensure the member is removed from non-intersecting // groups, otherwise we don't honor group membership removals. @@ -369,6 +379,11 @@ impl GroupCache { } } + // Ensure the member is added to the listed groups + for group in &mut entries { + group.add_member(member); + } + // Now add the new entries, merging with existing memberships for group in entries { match self.cache.fetch::(&group.uuid) { @@ -513,3 +528,132 @@ impl PrivateCache { self.cache.store_bytes(&device_id_tag, device_id.as_bytes()) } } + +#[cfg(test)] +mod tests { + use super::*; + use kanidm_hsm_crypto::soft::SoftTpm; + use std::str::FromStr; + use tempfile::tempdir; + + #[test] + fn test_basic_cache_new() { + let dir = tempdir().unwrap(); + let cache_path = dir.path().join("test.tdb"); + let cache = BasicCache::new(cache_path.to_str().unwrap()); + assert!(cache.is_ok()); + } + + #[test] + fn test_basic_cache_store_fetch_str() { + let dir = tempdir().unwrap(); + let cache_path = dir.path().join("test.tdb"); + let mut cache = BasicCache::new(cache_path.to_str().unwrap()).unwrap(); + + let key = "test_key"; + let value = "test_value"; + cache.store_bytes(key, value.as_bytes()).unwrap(); + let fetched_value = cache.fetch_str(key).unwrap(); + assert_eq!(fetched_value, value); + } + + #[test] + fn test_basic_cache_store_fetch() { + let dir = tempdir().unwrap(); + let cache_path = dir.path().join("test.tdb"); + let mut cache = BasicCache::new(cache_path.to_str().unwrap()).unwrap(); + + let key = "test_key"; + let value = UserEntry { + upn: "user@test.com".to_string(), + uuid: "f63a43c7-b783-4da9-acb4-89f8ebfc49e9".to_string(), + name: "Test User".to_string(), + }; + + cache.store(key, &value).unwrap(); + let fetched_value: Option = cache.fetch(key); + assert!(fetched_value.is_some()); + let fetched_value = fetched_value.unwrap(); + assert_eq!(fetched_value.upn, value.upn); + assert_eq!(fetched_value.uuid, value.uuid); + assert_eq!(fetched_value.name, value.name); + } + + #[test] + fn test_user_cache_store_fetch() { + let dir = tempdir().unwrap(); + let cache_path = dir.path().join("test.tdb"); + let mut cache = UserCache::new(cache_path.to_str().unwrap()).unwrap(); + + let entry = UserEntry { + upn: "user@test.com".to_string(), + uuid: "f63a43c7-b783-4da9-acb4-89f8ebfc49e9".to_string(), + name: "Test User".to_string(), + }; + + cache.store(entry.clone()).unwrap(); + let fetched_entry = cache.fetch(&entry.upn); + assert!(fetched_entry.is_some()); + let fetched_entry = fetched_entry.unwrap(); + assert_eq!(fetched_entry.upn, entry.upn); + assert_eq!(fetched_entry.uuid, entry.uuid); + assert_eq!(fetched_entry.name, entry.name); + } + + #[test] + fn test_uid_cache_store_fetch() { + let dir = tempdir().unwrap(); + let cache_path = dir.path().join("test.tdb"); + let mut cache = UidCache::new(cache_path.to_str().unwrap()).unwrap(); + + let uid: uid_t = 1000; + let upn = "user@test.com"; + + cache.store(uid, upn).unwrap(); + let fetched_upn = cache.fetch(uid); + assert!(fetched_upn.is_some()); + assert_eq!(fetched_upn.unwrap(), upn); + } + + #[test] + fn test_group_cache_store_fetch() { + let dir = tempdir().unwrap(); + let cache_path = dir.path().join("test.tdb"); + let mut cache = GroupCache::new(cache_path.to_str().unwrap()).unwrap(); + + let mut group = GroupEntry { + uuid: "5f8be63a-a379-4324-9f42-9ea40bed9d7f".to_string(), + members: HashSet::new(), + }; + group.add_member("user@test.com"); + + cache.cache.store(&group.uuid, &group).unwrap(); + let fetched_group = cache.fetch(&group.uuid); + assert!(fetched_group.is_some()); + let fetched_group = fetched_group.unwrap(); + assert_eq!(fetched_group.uuid, group.uuid); + assert!(fetched_group.members.contains("user@test.com")); + } + + #[test] + fn test_private_cache_loadable_machine_key_fetch_or_create() { + let dir = tempdir().unwrap(); + let cache_path = dir.path().join("test.tdb"); + let mut cache = + PrivateCache::new(cache_path.to_str().unwrap()).unwrap(); + + let mut hsm = BoxedDynTpm::new(SoftTpm::new()); + let auth_str = AuthValue::generate().expect("Failed to create hex pin"); + let auth_value = AuthValue::from_str(&auth_str) + .expect("Unable to create auth value"); + + let result = + cache.loadable_machine_key_fetch_or_create(&mut hsm, &auth_value); + assert!(result.is_ok()); + + let fetched_key = cache + .cache + .fetch::("loadable_machine_key"); + assert!(fetched_key.is_some()); + } +} diff --git a/rust/himmelblaud/src/himmelblaud.rs b/rust/himmelblaud/src/himmelblaud.rs index 3dca776b998..36541d80c1b 100644 --- a/rust/himmelblaud/src/himmelblaud.rs +++ b/rust/himmelblaud/src/himmelblaud.rs @@ -19,6 +19,7 @@ along with this program. If not, see . */ use crate::cache::{GroupCache, PrivateCache, UidCache, UserCache}; +#[cfg(not(test))] use crate::himmelblaud::himmelblaud_pam_auth::AuthSession; use bytes::{BufMut, BytesMut}; use dbg::{DBG_DEBUG, DBG_ERR, DBG_WARNING}; @@ -37,6 +38,7 @@ use tokio::net::UnixStream; use tokio::sync::Mutex; use tokio_util::codec::{Decoder, Encoder, Framed}; +#[cfg(not(test))] pub(crate) struct Resolver { realm: String, tenant_id: String, @@ -52,6 +54,7 @@ pub(crate) struct Resolver { client: Arc>, } +#[cfg(not(test))] impl Resolver { pub(crate) fn new( realm: &str, @@ -84,6 +87,46 @@ impl Resolver { } } +// The test environment is unable to communicate with Entra ID, therefore +// we alter the resolver to only test the cache interactions. + +#[cfg(test)] +pub(crate) struct Resolver { + realm: String, + tenant_id: String, + lp: LoadParm, + idmap: Idmap, + pcache: PrivateCache, + user_cache: UserCache, + uid_cache: UidCache, + group_cache: GroupCache, +} + +#[cfg(test)] +impl Resolver { + pub(crate) fn new( + realm: &str, + tenant_id: &str, + lp: LoadParm, + idmap: Idmap, + pcache: PrivateCache, + user_cache: UserCache, + uid_cache: UidCache, + group_cache: GroupCache, + ) -> Self { + Resolver { + realm: realm.to_string(), + tenant_id: tenant_id.to_string(), + lp, + idmap, + pcache, + user_cache, + uid_cache, + group_cache, + } + } +} + struct ClientCodec; impl Decoder for ClientCodec { @@ -142,11 +185,13 @@ pub(crate) async fn handle_client( }; let mut reqs = Framed::new(stream, ClientCodec::new()); + #[cfg(not(test))] let mut pam_auth_session_state = None; while let Some(Ok(req)) = reqs.next().await { let mut resolver = resolver.lock().await; let resp = match req { + #[cfg(not(test))] Request::PamAuthenticateInit(account_id) => { DBG_DEBUG!("pam authenticate init"); @@ -168,6 +213,7 @@ pub(crate) async fn handle_client( } } } + #[cfg(not(test))] Request::PamAuthenticateStep(pam_next_req) => { DBG_DEBUG!("pam authenticate step"); match &mut pam_auth_session_state { @@ -220,10 +266,13 @@ pub(crate) async fn handle_client( resolver.getgrnam(&grp_id).await? } Request::NssGroupByGid(gid) => resolver.getgrgid(gid).await?, + #[cfg(not(test))] Request::PamAccountAllowed(account_id) => { resolver.pam_acct_mgmt(&account_id).await? } Request::PamAccountBeginSession(_account_id) => Response::Success, + #[cfg(test)] + _ => Response::Error, }; reqs.send(resp).await?; reqs.flush().await?; @@ -240,5 +289,7 @@ mod himmelblaud_getgrnam; mod himmelblaud_getpwent; mod himmelblaud_getpwnam; mod himmelblaud_getpwuid; +#[cfg(not(test))] mod himmelblaud_pam_acct_mgmt; +#[cfg(not(test))] mod himmelblaud_pam_auth; diff --git a/rust/himmelblaud/src/himmelblaud/himmelblaud_getgrent.rs b/rust/himmelblaud/src/himmelblaud/himmelblaud_getgrent.rs index 9f0c615e15d..67e02eed463 100644 --- a/rust/himmelblaud/src/himmelblaud/himmelblaud_getgrent.rs +++ b/rust/himmelblaud/src/himmelblaud/himmelblaud_getgrent.rs @@ -31,7 +31,7 @@ impl Resolver { let name = entry.uuid.clone(); let gid = self .idmap - .gen_to_unix(&self.tenant_id, &entry.uuid.to_uppercase()) + .gen_to_unix(&self.tenant_id, &entry.uuid) .map_err(|e| { DBG_ERR!("{:?}", e); Box::new(NT_STATUS_NO_SUCH_GROUP) @@ -47,3 +47,108 @@ impl Resolver { Ok(Response::NssGroups(res)) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::cache::GroupEntry; + use crate::{GroupCache, PrivateCache, UidCache, UserCache}; + use idmap::Idmap; + use param::LoadParm; + use std::collections::HashSet; + use tempfile::tempdir; + + #[tokio::test] + async fn test_getgrent() { + // Create a temporary directory for the cache + let dir = tempdir().unwrap(); + + // Initialize the caches + let private_cache_path = dir + .path() + .join("himmelblau.tdb") + .to_str() + .unwrap() + .to_string(); + let pcache = PrivateCache::new(&private_cache_path).unwrap(); + let user_cache_path = dir + .path() + .join("himmelblau_users.tdb") + .to_str() + .unwrap() + .to_string(); + let user_cache = UserCache::new(&user_cache_path).unwrap(); + let uid_cache_path = dir + .path() + .join("uid_cache.tdb") + .to_str() + .unwrap() + .to_string(); + let uid_cache = UidCache::new(&uid_cache_path).unwrap(); + let group_cache_path = dir + .path() + .join("himmelblau_groups.tdb") + .to_str() + .unwrap() + .to_string(); + let mut group_cache = GroupCache::new(&group_cache_path).unwrap(); + + // Insert dummy GroupEntries into the cache + let group_uuid1 = "c490c3ea-fd98-4d45-b6aa-2a3520f804fa"; + let group_uuid2 = "f7a51b58-84de-42a3-b5b1-967b17c04f89"; + let dummy_group1 = GroupEntry::new(group_uuid1); + let dummy_group2 = GroupEntry::new(group_uuid2); + group_cache + .merge_groups("user1@test.com", vec![dummy_group1.clone()]) + .unwrap(); + group_cache + .merge_groups("user2@test.com", vec![dummy_group2.clone()]) + .unwrap(); + + // Initialize the Idmap with dummy configuration + let realm = "test.com"; + let tenant_id = "89a61bb7-d1b9-4356-a1e0-75d88e06f14e"; + let mut idmap = Idmap::new().unwrap(); + idmap + .add_gen_domain(realm, tenant_id, (1000, 2000)) + .unwrap(); + + // Initialize dummy configuration + let lp = LoadParm::new(None).expect("Failed loading default config"); + + // Initialize the Resolver + let mut resolver = Resolver { + realm: realm.to_string(), + tenant_id: tenant_id.to_string(), + lp, + idmap, + pcache, + user_cache, + uid_cache, + group_cache, + }; + + // Test the getgrent function + let result = resolver.getgrent().await.unwrap(); + + match result { + Response::NssGroups(mut groups) => { + groups.sort_by(|a, b| a.name.cmp(&b.name)); + assert_eq!(groups.len(), 2); + + let group1 = &groups[0]; + assert_eq!(group1.name, dummy_group1.uuid); + assert_eq!(group1.gid, 1388); + assert_eq!(group1.members, vec!["user1@test.com".to_string()]); + + let group2 = &groups[1]; + assert_eq!(group2.name, dummy_group2.uuid); + assert_eq!(group2.gid, 1593); + assert_eq!(group2.members, vec!["user2@test.com".to_string()]); + } + other => { + panic!("Expected NssGroups with a list of groups: {:?}", other) + } + } + } +} diff --git a/rust/himmelblaud/src/himmelblaud/himmelblaud_getgrgid.rs b/rust/himmelblaud/src/himmelblaud/himmelblaud_getgrgid.rs index 9ddf4082931..f921446a653 100644 --- a/rust/himmelblaud/src/himmelblaud/himmelblaud_getgrgid.rs +++ b/rust/himmelblaud/src/himmelblaud/himmelblaud_getgrgid.rs @@ -41,3 +41,107 @@ impl Resolver { Ok(Response::NssGroup(None)) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::cache::GroupEntry; + use crate::{GroupCache, PrivateCache, UidCache, UserCache}; + use idmap::Idmap; + use param::LoadParm; + use std::collections::HashSet; + use tempfile::tempdir; + + #[tokio::test] + async fn test_getgrgid() { + // Create a temporary directory for the cache + let dir = tempdir().unwrap(); + + // Initialize the caches + let private_cache_path = dir + .path() + .join("himmelblau.tdb") + .to_str() + .unwrap() + .to_string(); + let pcache = PrivateCache::new(&private_cache_path).unwrap(); + let user_cache_path = dir + .path() + .join("himmelblau_users.tdb") + .to_str() + .unwrap() + .to_string(); + let user_cache = UserCache::new(&user_cache_path).unwrap(); + let uid_cache_path = dir + .path() + .join("uid_cache.tdb") + .to_str() + .unwrap() + .to_string(); + let mut uid_cache = UidCache::new(&uid_cache_path).unwrap(); + let group_cache_path = dir + .path() + .join("himmelblau_groups.tdb") + .to_str() + .unwrap() + .to_string(); + let mut group_cache = GroupCache::new(&group_cache_path).unwrap(); + + // Initialize the Idmap with dummy configuration + let realm = "test.com"; + let tenant_id = "89a61bb7-d1b9-4356-a1e0-75d88e06f14e"; + let mut idmap = Idmap::new().unwrap(); + idmap + .add_gen_domain(realm, tenant_id, (1000, 2000)) + .unwrap(); + + // Insert a dummy GroupEntry into the cache + let group_uuid = "c490c3ea-fd98-4d45-b6aa-2a3520f804fa".to_string(); + let dummy_gid = idmap + .gen_to_unix(tenant_id, &group_uuid) + .expect("Failed to map group gid"); + // Store the calculated gid -> uuid map in the cache + uid_cache + .store(dummy_gid, &group_uuid) + .expect("Failed to store group gid"); + let dummy_group = GroupEntry::new(&group_uuid); + group_cache + .merge_groups("user1@test.com", vec![dummy_group.clone()]) + .unwrap(); + + // Initialize dummy configuration + let lp = LoadParm::new(None).expect("Failed loading default config"); + + // Initialize the Resolver + let mut resolver = Resolver { + realm: realm.to_string(), + tenant_id: tenant_id.to_string(), + lp, + idmap, + pcache, + user_cache, + uid_cache, + group_cache, + }; + + // Test the getgrgid function with a gid that exists + let result = resolver.getgrgid(dummy_gid).await.unwrap(); + + match result { + Response::NssGroup(Some(group)) => { + assert_eq!(group.name, dummy_group.uuid); + assert_eq!(group.gid, dummy_gid); + assert_eq!(group.members, vec!["user1@test.com".to_string()]); + } + other => panic!("Expected NssGroup with Some(group): {:?}", other), + } + + // Test the getgrgid function with a gid that does not exist + let nonexistent_gid: gid_t = 1600; + let result = resolver.getgrgid(nonexistent_gid).await.unwrap(); + match result { + Response::NssGroup(None) => {} // This is the expected result + _ => panic!("Expected NssGroup with None"), + } + } +} diff --git a/rust/himmelblaud/src/himmelblaud/himmelblaud_getgrnam.rs b/rust/himmelblaud/src/himmelblaud/himmelblaud_getgrnam.rs index cafc6a4f685..16b1de58365 100644 --- a/rust/himmelblaud/src/himmelblaud/himmelblaud_getgrnam.rs +++ b/rust/himmelblaud/src/himmelblaud/himmelblaud_getgrnam.rs @@ -34,13 +34,13 @@ impl Resolver { }; let gid = self .idmap - .gen_to_unix(&self.tenant_id, &entry.uuid.to_uppercase()) + .gen_to_unix(&self.tenant_id, &entry.uuid) .map_err(|e| { DBG_ERR!("{:?}", e); Box::new(NT_STATUS_INVALID_TOKEN) })?; // Store the calculated gid -> uuid map in the cache - self.uid_cache.store(gid, &entry.uuid.to_uppercase())?; + self.uid_cache.store(gid, &entry.uuid)?; let group = Group { name: entry.uuid.clone(), passwd: "x".to_string(), @@ -50,3 +50,110 @@ impl Resolver { return Ok(Response::NssGroup(Some(group))); } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::cache::GroupEntry; + use crate::{GroupCache, PrivateCache, UidCache, UserCache}; + use idmap::Idmap; + use param::LoadParm; + use std::collections::HashSet; + use tempfile::tempdir; + + #[tokio::test] + async fn test_getgrnam() { + // Create a temporary directory for the cache + let dir = tempdir().unwrap(); + + // Initialize the caches + let private_cache_path = dir + .path() + .join("himmelblau.tdb") + .to_str() + .unwrap() + .to_string(); + let pcache = PrivateCache::new(&private_cache_path).unwrap(); + let user_cache_path = dir + .path() + .join("himmelblau_users.tdb") + .to_str() + .unwrap() + .to_string(); + let user_cache = UserCache::new(&user_cache_path).unwrap(); + let uid_cache_path = dir + .path() + .join("uid_cache.tdb") + .to_str() + .unwrap() + .to_string(); + let uid_cache = UidCache::new(&uid_cache_path).unwrap(); + let group_cache_path = dir + .path() + .join("himmelblau_groups.tdb") + .to_str() + .unwrap() + .to_string(); + let mut group_cache = GroupCache::new(&group_cache_path).unwrap(); + + // Insert a dummy GroupEntry into the cache + let group_uuid = "c490c3ea-fd98-4d45-b6aa-2a3520f804fa"; + let dummy_group = GroupEntry::new(group_uuid); + group_cache + .merge_groups("user1@test.com", vec![dummy_group.clone()]) + .unwrap(); + group_cache + .merge_groups("user2@test.com", vec![dummy_group.clone()]) + .unwrap(); + + // Initialize the Idmap with dummy configuration + let realm = "test.com"; + let tenant_id = "89a61bb7-d1b9-4356-a1e0-75d88e06f14e"; + let mut idmap = Idmap::new().unwrap(); + idmap + .add_gen_domain(realm, tenant_id, (1000, 2000)) + .unwrap(); + + // Initialize dummy configuration + let lp = LoadParm::new(None).expect("Failed loading default config"); + + // Initialize the Resolver + let mut resolver = Resolver { + realm: realm.to_string(), + tenant_id: tenant_id.to_string(), + lp, + idmap, + pcache, + user_cache, + uid_cache, + group_cache, + }; + + // Test the getgrnam function with a group that exists + let result = resolver.getgrnam(group_uuid).await.unwrap(); + + match result { + Response::NssGroup(Some(mut group)) => { + group.members.sort(); + assert_eq!(group.name, dummy_group.uuid); + assert_eq!(group.gid, 1388); + assert_eq!( + group.members, + vec![ + "user1@test.com".to_string(), + "user2@test.com".to_string() + ] + ); + } + other => panic!("Expected NssGroup with Some(group): {:?}", other), + } + + // Test the getgrnam function with a group that does not exist + let nonexistent_group_uuid = "2ea8f1d4-1b94-4003-865b-cb247a8a1f5d"; + let result = resolver.getgrnam(nonexistent_group_uuid).await.unwrap(); + match result { + Response::NssGroup(None) => {} // This is the expected result + _ => panic!("Expected NssGroup with None"), + } + } +} diff --git a/rust/himmelblaud/src/himmelblaud/himmelblaud_getpwent.rs b/rust/himmelblaud/src/himmelblaud/himmelblaud_getpwent.rs index b618c150643..e8f7d9406ea 100644 --- a/rust/himmelblaud/src/himmelblaud/himmelblaud_getpwent.rs +++ b/rust/himmelblaud/src/himmelblaud/himmelblaud_getpwent.rs @@ -53,7 +53,7 @@ impl Resolver { for entry in user_entries { let uid = self .idmap - .gen_to_unix(&self.tenant_id, &entry.upn.to_lowercase()) + .gen_to_unix(&self.tenant_id, &entry.upn) .map_err(|e| { DBG_ERR!("{:?}", e); Box::new(NT_STATUS_INVALID_TOKEN) @@ -90,3 +90,121 @@ impl Resolver { Ok(Response::NssAccounts(res)) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::cache::UserEntry; + use crate::{GroupCache, PrivateCache, UidCache, UserCache}; + use idmap::Idmap; + use param::LoadParm; + use std::collections::HashSet; + use tempfile::tempdir; + + #[tokio::test] + async fn test_getpwent() { + // Create a temporary directory for the cache + let dir = tempdir().unwrap(); + + // Initialize the caches + let private_cache_path = dir + .path() + .join("himmelblau.tdb") + .to_str() + .unwrap() + .to_string(); + let pcache = PrivateCache::new(&private_cache_path).unwrap(); + let user_cache_path = dir + .path() + .join("himmelblau_users.tdb") + .to_str() + .unwrap() + .to_string(); + let mut user_cache = UserCache::new(&user_cache_path).unwrap(); + let uid_cache_path = dir + .path() + .join("uid_cache.tdb") + .to_str() + .unwrap() + .to_string(); + let uid_cache = UidCache::new(&uid_cache_path).unwrap(); + let group_cache_path = dir + .path() + .join("himmelblau_groups.tdb") + .to_str() + .unwrap() + .to_string(); + let group_cache = GroupCache::new(&group_cache_path).unwrap(); + + // Insert dummy UserEntrys into the cache + let dummy_user = UserEntry { + upn: "user1@test.com".to_string(), + uuid: "731e9af3-668d-4033-afd1-9f09b9120cc7".to_string(), + name: "User One".to_string(), + }; + user_cache + .store(dummy_user.clone()) + .expect("Failed storing user in cache"); + + let dummy_user2 = UserEntry { + upn: "user2@test.com".to_string(), + uuid: "7be6c0c5-5763-4633-aecf-f8c460b338fd".to_string(), + name: "User Two".to_string(), + }; + user_cache + .store(dummy_user2.clone()) + .expect("Failed storing user in cache"); + + // Initialize the Idmap with dummy configuration + let realm = "test.com"; + let tenant_id = "89a61bb7-d1b9-4356-a1e0-75d88e06f14e"; + let mut idmap = Idmap::new().unwrap(); + idmap + .add_gen_domain(realm, tenant_id, (1000, 2000)) + .unwrap(); + + // Initialize dummy configuration for LoadParm + let lp = LoadParm::new(None).expect("Failed loading default config"); + + // Initialize the Resolver + let mut resolver = Resolver { + realm: realm.to_string(), + tenant_id: tenant_id.to_string(), + lp, + idmap, + pcache, + user_cache, + uid_cache, + group_cache, + }; + + // Test the getpwent function + let result = resolver.getpwent().await.unwrap(); + + match result { + Response::NssAccounts(accounts) => { + assert_eq!(accounts.len(), 2); + + let account1 = &accounts[0]; + assert_eq!(account1.name, dummy_user.upn); + assert_eq!(account1.uid, 1316); + assert_eq!(account1.gid, 1316); + assert_eq!(account1.gecos, dummy_user.name); + assert_eq!(account1.dir, "/home/test.com/user1"); + assert_eq!(account1.shell, "/bin/false"); + + let account2 = &accounts[1]; + assert_eq!(account2.name, dummy_user2.upn); + assert_eq!(account2.uid, 1671); + assert_eq!(account2.gid, 1671); + assert_eq!(account2.gecos, dummy_user2.name); + assert_eq!(account2.dir, "/home/test.com/user2"); + assert_eq!(account2.shell, "/bin/false"); + } + other => panic!( + "Expected NssAccounts with a list of accounts: {:?}", + other + ), + } + } +} diff --git a/rust/himmelblaud/src/himmelblaud/himmelblaud_getpwnam.rs b/rust/himmelblaud/src/himmelblaud/himmelblaud_getpwnam.rs index 2389418b002..576a62e78e6 100644 --- a/rust/himmelblaud/src/himmelblaud/himmelblaud_getpwnam.rs +++ b/rust/himmelblaud/src/himmelblaud/himmelblaud_getpwnam.rs @@ -52,10 +52,8 @@ impl Resolver { DBG_ERR!("Failed to discover template shell. Is it set?"); Box::new(NT_STATUS_NOT_A_DIRECTORY) })?; - let uid = self - .idmap - .gen_to_unix(&self.tenant_id, &upn.to_lowercase()) - .map_err(|e| { + let uid = + self.idmap.gen_to_unix(&self.tenant_id, &upn).map_err(|e| { DBG_ERR!("{:?}", e); Box::new(NT_STATUS_INVALID_TOKEN) })?; @@ -93,6 +91,7 @@ impl Resolver { // based on whether the upn exists in Entra ID. let entry = match self.user_cache.fetch(account_id) { Some(entry) => entry, + #[cfg(not(test))] None => { // Check if the user exists in Entra ID let exists = match self @@ -115,9 +114,114 @@ impl Resolver { } return Ok(Response::NssAccount(None)); } + #[cfg(test)] + None => return Ok(Response::NssAccount(None)), }; return Ok(Response::NssAccount(Some( self.create_passwd_from_upn(&entry.upn, &entry.name)?, ))); } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::cache::UserEntry; + use crate::{GroupCache, PrivateCache, UidCache, UserCache}; + use idmap::Idmap; + use param::LoadParm; + use tempfile::tempdir; + + #[tokio::test] + async fn test_getpwnam() { + // Create a temporary directory for the cache + let dir = tempdir().unwrap(); + + // Initialize the caches + let private_cache_path = dir + .path() + .join("himmelblau.tdb") + .to_str() + .unwrap() + .to_string(); + let pcache = PrivateCache::new(&private_cache_path).unwrap(); + let user_cache_path = dir + .path() + .join("himmelblau_users.tdb") + .to_str() + .unwrap() + .to_string(); + let mut user_cache = UserCache::new(&user_cache_path).unwrap(); + let uid_cache_path = dir + .path() + .join("uid_cache.tdb") + .to_str() + .unwrap() + .to_string(); + let uid_cache = UidCache::new(&uid_cache_path).unwrap(); + let group_cache_path = dir + .path() + .join("himmelblau_groups.tdb") + .to_str() + .unwrap() + .to_string(); + let group_cache = GroupCache::new(&group_cache_path).unwrap(); + + // Insert a dummy UserEntry into the cache + let dummy_user = UserEntry { + upn: "user1@test.com".to_string(), + uuid: "731e9af3-668d-4033-afd1-9f09b9120cc7".to_string(), + name: "User One".to_string(), + }; + let _ = user_cache.store(dummy_user.clone()); + + // Initialize the Idmap with dummy configuration + let realm = "test.com"; + let tenant_id = "89a61bb7-d1b9-4356-a1e0-75d88e06f14e"; + let mut idmap = Idmap::new().unwrap(); + idmap + .add_gen_domain(realm, tenant_id, (1000, 2000)) + .unwrap(); + + // Initialize dummy configuration for LoadParm + let lp = LoadParm::new(None).expect("Failed loading default config"); + + // Initialize the Resolver + let mut resolver = Resolver { + realm: realm.to_string(), + tenant_id: tenant_id.to_string(), + lp, + idmap, + pcache, + user_cache, + uid_cache, + group_cache, + }; + + // Test the getpwnam function with a user that exists in the cache + let result = resolver.getpwnam(&dummy_user.upn).await.unwrap(); + + match result { + Response::NssAccount(Some(account)) => { + assert_eq!(account.name, dummy_user.upn); + assert_eq!(account.uid, 1316); + assert_eq!(account.gid, 1316); + assert_eq!(account.gecos, dummy_user.name); + assert_eq!(account.dir, "/home/test.com/user1"); + assert_eq!(account.shell, "/bin/false"); + } + other => { + panic!("Expected NssAccount with Some(account): {:?}", other) + } + } + + // Test the getpwnam function with a user that does not exist in the cache + let nonexistent_user_upn = "nonexistent@test.com"; + let result = resolver.getpwnam(nonexistent_user_upn).await.unwrap(); + + match result { + Response::NssAccount(None) => {} // This is the expected result + other => panic!("Expected NssAccount with None: {:?}", other), + } + } +} diff --git a/rust/himmelblaud/src/himmelblaud/himmelblaud_getpwuid.rs b/rust/himmelblaud/src/himmelblaud/himmelblaud_getpwuid.rs index b43948def09..f93f55cff1e 100644 --- a/rust/himmelblaud/src/himmelblaud/himmelblaud_getpwuid.rs +++ b/rust/himmelblaud/src/himmelblaud/himmelblaud_getpwuid.rs @@ -43,3 +43,114 @@ impl Resolver { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::cache::UserEntry; + use crate::{GroupCache, PrivateCache, UidCache, UserCache}; + use idmap::Idmap; + use param::LoadParm; + use tempfile::tempdir; + + #[tokio::test] + async fn test_getpwuid() { + // Create a temporary directory for the cache + let dir = tempdir().unwrap(); + + // Initialize the caches + let private_cache_path = dir + .path() + .join("himmelblau.tdb") + .to_str() + .unwrap() + .to_string(); + let pcache = PrivateCache::new(&private_cache_path).unwrap(); + let user_cache_path = dir + .path() + .join("himmelblau_users.tdb") + .to_str() + .unwrap() + .to_string(); + let mut user_cache = UserCache::new(&user_cache_path).unwrap(); + let uid_cache_path = dir + .path() + .join("uid_cache.tdb") + .to_str() + .unwrap() + .to_string(); + let mut uid_cache = UidCache::new(&uid_cache_path).unwrap(); + let group_cache_path = dir + .path() + .join("himmelblau_groups.tdb") + .to_str() + .unwrap() + .to_string(); + let group_cache = GroupCache::new(&group_cache_path).unwrap(); + + // Insert a dummy UserEntry into the cache + let dummy_user = UserEntry { + upn: "user1@test.com".to_string(), + uuid: "731e9af3-668d-4033-afd1-9f09b9120cc7".to_string(), + name: "User One".to_string(), + }; + let _ = user_cache.store(dummy_user.clone()); + + // Initialize the Idmap with dummy configuration + let realm = "test.com"; + let tenant_id = "89a61bb7-d1b9-4356-a1e0-75d88e06f14e"; + let mut idmap = Idmap::new().unwrap(); + idmap + .add_gen_domain(realm, tenant_id, (1000, 2000)) + .unwrap(); + + let uid = idmap + .gen_to_unix(tenant_id, &dummy_user.upn) + .expect("Failed to generate uid for user"); + // Store the calculated uid -> upn map in the cache + uid_cache + .store(uid, &dummy_user.upn) + .expect("Failed storing generated uid in the cache"); + + // Initialize dummy configuration for LoadParm + let lp = LoadParm::new(None).expect("Failed loading default config"); + + // Initialize the Resolver + let mut resolver = Resolver { + realm: realm.to_string(), + tenant_id: tenant_id.to_string(), + lp, + idmap, + pcache, + user_cache, + uid_cache, + group_cache, + }; + + // Test the getpwuid function with a uid that exists in the cache + let result = resolver.getpwuid(uid).await.unwrap(); + + match result { + Response::NssAccount(Some(account)) => { + assert_eq!(account.name, dummy_user.upn); + assert_eq!(account.uid, uid); + assert_eq!(account.gid, uid); + assert_eq!(account.gecos, dummy_user.name); + assert_eq!(account.dir, "/home/test.com/user1"); + assert_eq!(account.shell, "/bin/false"); + } + other => { + panic!("Expected NssAccount with Some(account): {:?}", other) + } + } + + // Test the getpwuid function with a uid that does not exist in the cache + let nonexistent_uid = 9999; + let result = resolver.getpwuid(nonexistent_uid).await.unwrap(); + + match result { + Response::NssAccount(None) => {} // This is the expected result + other => panic!("Expected NssAccount with None: {:?}", other), + } + } +} diff --git a/rust/himmelblaud/src/main.rs b/rust/himmelblaud/src/main.rs index afbb4f50238..2ac065841f3 100644 --- a/rust/himmelblaud/src/main.rs +++ b/rust/himmelblaud/src/main.rs @@ -18,6 +18,11 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . */ + +// Ignore unused/dead code when running cargo test +#![cfg_attr(test, allow(unused_imports))] +#![cfg_attr(test, allow(dead_code))] + use clap::{Arg, ArgAction, Command}; use dbg::*; use himmelblau::graph::Graph; @@ -41,6 +46,7 @@ mod himmelblaud; use cache::{GroupCache, PrivateCache, UidCache, UserCache}; mod utils; +#[cfg(not(test))] #[tokio::main(flavor = "current_thread")] async fn main() -> ExitCode { let clap_args = Command::new("himmelblaud") diff --git a/rust/himmelblaud/src/utils.rs b/rust/himmelblaud/src/utils.rs index ff0fcb4e2db..b452b4844d0 100644 --- a/rust/himmelblaud/src/utils.rs +++ b/rust/himmelblaud/src/utils.rs @@ -79,3 +79,71 @@ pub(crate) async fn hsm_pin_fetch_or_create( Box::new(NT_STATUS_UNSUCCESSFUL) }) } + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + use tempfile::tempdir; + use tokio::fs; + + #[test] + fn test_split_username_success() { + let username = "user@domain.com"; + let result = split_username(username); + assert!(result.is_ok()); + let (user, domain) = result.unwrap(); + assert_eq!(user, "user"); + assert_eq!(domain, "domain.com"); + } + + #[test] + fn test_split_username_failure() { + let username = "invalid_username"; + let result = split_username(username); + assert!(result.is_err()); + assert_eq!(*result.unwrap_err(), NT_STATUS_INVALID_USER_PRINCIPAL_NAME); + } + + #[tokio::test] + async fn test_hsm_pin_fetch_or_create_generate() { + let dir = tempdir().unwrap(); + let path = dir.path().join("hsm_pin"); + + let result = hsm_pin_fetch_or_create(path.to_str().unwrap()).await; + assert!(result.is_ok()); + + // Verify that the file is created and contains a valid auth value + let saved_pin = fs::read(path).await.expect("Auth value missing"); + AuthValue::try_from(saved_pin.as_slice()) + .expect("Failed parsing auth value"); + } + + #[tokio::test] + async fn test_hsm_pin_fetch_or_create_invalid_path() { + let result = hsm_pin_fetch_or_create("invalid_path\0").await; + assert!(result.is_err()); + match result { + Err(e) => assert_eq!(*e, NT_STATUS_UNSUCCESSFUL), + Ok(_) => panic!("Expected error but got success"), + } + } + + #[tokio::test] + async fn test_hsm_pin_fetch_or_create_invalid_auth_value() { + let dir = tempdir().unwrap(); + let path = dir.path().join("hsm_pin"); + + // Write invalid content to the file + let mut file = std::fs::File::create(&path).unwrap(); + file.write_all(b"invalid_auth_value").unwrap(); + + // Test reading the invalid file + let result = hsm_pin_fetch_or_create(path.to_str().unwrap()).await; + assert!(result.is_err()); + match result { + Err(e) => assert_eq!(*e, NT_STATUS_UNSUCCESSFUL), + Ok(_) => panic!("Expected error but got success"), + } + } +} diff --git a/rust/tdb/src/lib.rs b/rust/tdb/src/lib.rs index 9514e282453..cb9af44d8ed 100644 --- a/rust/tdb/src/lib.rs +++ b/rust/tdb/src/lib.rs @@ -26,6 +26,7 @@ use ntstatus_gen::NT_STATUS_UNSUCCESSFUL; use std::error::Error; use std::ffi::c_void; use std::fmt; +use std::path::PathBuf; use std::sync::{Arc, Mutex}; mod ffi { @@ -111,12 +112,22 @@ impl Tdb { open_flags: Option, mode: Option, ) -> Result> { + let path = PathBuf::from(name); let tdb = unsafe { ffi::tdb_open( wrap_string(name), hash_size.unwrap_or(0), tdb_flags.unwrap_or(ffi::TDB_DEFAULT as i32), - open_flags.unwrap_or(libc::O_RDWR), + match open_flags { + Some(open_flags) => open_flags, + None => { + if path.exists() { + libc::O_RDWR + } else { + libc::O_RDWR | libc::O_CREAT + } + } + }, mode.unwrap_or(0o600), ) };