account oidc init rs

Signed-off-by: fufesou <shuanglongchen@yeah.net>
This commit is contained in:
fufesou 2022-10-19 22:48:51 +08:00
parent d89c1d3093
commit 3454454bd5
5 changed files with 289 additions and 0 deletions

2
Cargo.lock generated
View File

@ -4419,6 +4419,7 @@ dependencies = [
"system_shutdown",
"tray-item",
"trayicon",
"url",
"uuid",
"virtual_display",
"whoami",
@ -5492,6 +5493,7 @@ dependencies = [
"idna",
"matches",
"percent-encoding",
"serde 1.0.144",
]
[[package]]

View File

@ -64,6 +64,7 @@ wol-rs = "0.9.1"
flutter_rust_bridge = { git = "https://github.com/SoLongAndThanksForAllThePizza/flutter_rust_bridge", optional = true }
errno = "0.2.8"
rdev = { git = "https://github.com/asur4s/rdev" }
url = { version = "2.1", features = ["serde"] }
[target.'cfg(not(target_os = "linux"))'.dependencies]
reqwest = { version = "0.11", features = ["json", "rustls-tls"], default-features=false }

42
src/hbbs_http.rs Normal file
View File

@ -0,0 +1,42 @@
use hbb_common::{
anyhow::{self, bail},
tokio, ResultType,
};
use reqwest::Response;
use serde_derive::Deserialize;
use serde_json::{Map, Value};
use serde::de::DeserializeOwned;
pub mod account;
pub enum HbbHttpResponse<T> {
ErrorFormat,
Error(String),
DataTypeFormat,
Data(T),
}
#[tokio::main(flavor = "current_thread")]
async fn resp_to_serde_map(resp: Response) -> reqwest::Result<Map<String, Value>> {
resp.json().await
}
impl<T: DeserializeOwned> TryFrom<Response> for HbbHttpResponse<T> {
type Error = reqwest::Error;
fn try_from(resp: Response) -> Result<Self, <Self as TryFrom<Response>>::Error> {
let map = resp_to_serde_map(resp)?;
if let Some(error) = map.get("error") {
if let Some(err) = error.as_str() {
Ok(Self::Error(err.to_owned()))
} else {
Ok(Self::ErrorFormat)
}
} else {
match serde_json::from_value(Value::Object(map)) {
Ok(v) => Ok(Self::Data(v)),
Err(_) => Ok(Self::DataTypeFormat),
}
}
}
}

242
src/hbbs_http/account.rs Normal file
View File

@ -0,0 +1,242 @@
use super::HbbHttpResponse;
use hbb_common::{config::Config, log, sleep, tokio, tokio::sync::RwLock, ResultType};
use serde_derive::Deserialize;
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use url::Url;
lazy_static::lazy_static! {
static ref API_SERVER: String = crate::get_api_server(
Config::get_option("api-server"), Config::get_option("custom-rendezvous-server"));
static ref OIDC_SESSION: Arc<RwLock<OidcSession>> = Arc::new(RwLock::new(OidcSession::new()));
}
const QUERY_INTERVAL_SECS: f32 = 1.0;
const QUERY_TIMEOUT_SECS: u64 = 60;
#[derive(Deserialize, Clone)]
pub struct OidcAuthUrl {
code: String,
url: Url,
}
#[derive(Debug, Deserialize, Default, Clone)]
pub struct UserPayload {
pub id: String,
pub name: String,
pub email: Option<String>,
pub note: Option<String>,
pub status: Option<i64>,
pub grp: Option<String>,
pub is_admin: Option<bool>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct AuthBody {
access_token: String,
token_type: String,
user: UserPayload,
}
#[derive(Copy, Clone)]
pub enum OidcState {
// initial request
OidcRequest = 1,
// initial request failed
OidcRequestFailed = 2,
// request succeeded, loop querying
OidcQuerying = 11,
// loop querying failed
OidcQueryFailed = 12,
// query sucess before
OidcNotExists = 13,
// query timeout
OidcQueryTimeout = 14,
// already login
OidcLogin = 21,
}
pub struct OidcSession {
client: reqwest::Client,
state: OidcState,
failed_msg: String,
code_url: Option<OidcAuthUrl>,
auth_body: Option<AuthBody>,
keep_querying: bool,
running: bool,
query_timeout: Duration,
}
impl OidcSession {
fn new() -> Self {
Self {
client: reqwest::Client::new(),
state: OidcState::OidcRequest,
failed_msg: "".to_owned(),
code_url: None,
auth_body: None,
keep_querying: false,
running: false,
query_timeout: Duration::from_secs(QUERY_TIMEOUT_SECS),
}
}
async fn auth(op: &str, id: &str, uuid: &str) -> ResultType<HbbHttpResponse<OidcAuthUrl>> {
Ok(OIDC_SESSION
.read()
.await
.client
.post(format!("{}/api/oidc/auth", *API_SERVER))
.json(&HashMap::from([("op", op), ("id", id), ("uuid", uuid)]))
.send()
.await?
.try_into()?)
}
async fn query(code: &str, id: &str, uuid: &str) -> ResultType<HbbHttpResponse<AuthBody>> {
let url = reqwest::Url::parse_with_params(
&format!("{}/api/oidc/auth-query", *API_SERVER),
&[("code", code), ("id", id), ("uuid", uuid)],
)?;
Ok(OIDC_SESSION
.read()
.await
.client
.get(url)
.send()
.await?
.try_into()?)
}
fn reset(&mut self) {
self.state = OidcState::OidcRequest;
self.failed_msg = "".to_owned();
self.keep_querying = true;
self.running = false;
self.code_url = None;
self.auth_body = None;
}
async fn before_task(&mut self) {
self.reset();
self.running = true;
}
async fn after_task(&mut self) {
self.running = false;
}
async fn auth_task(op: String, id: String, uuid: String) {
let code_url = match Self::auth(&op, &id, &uuid).await {
Ok(HbbHttpResponse::<_>::Data(code_url)) => code_url,
Ok(HbbHttpResponse::<_>::Error(err)) => {
OIDC_SESSION
.write()
.await
.set_state(OidcState::OidcRequestFailed, err);
return;
}
Ok(_) => {
OIDC_SESSION.write().await.set_state(
OidcState::OidcRequestFailed,
"Invalid auth response".to_owned(),
);
return;
}
Err(err) => {
OIDC_SESSION
.write()
.await
.set_state(OidcState::OidcRequestFailed, err.to_string());
return;
}
};
OIDC_SESSION
.write()
.await
.set_state(OidcState::OidcQuerying, "".to_owned());
OIDC_SESSION.write().await.code_url = Some(code_url.clone());
let begin = Instant::now();
let query_timeout = OIDC_SESSION.read().await.query_timeout;
while OIDC_SESSION.read().await.keep_querying && begin.elapsed() < query_timeout {
match Self::query(&code_url.code, &id, &uuid).await {
Ok(HbbHttpResponse::<_>::Data(auth_body)) => {
OIDC_SESSION
.write()
.await
.set_state(OidcState::OidcLogin, "".to_owned());
OIDC_SESSION.write().await.auth_body = Some(auth_body);
return;
// to-do, set access-token
}
Ok(HbbHttpResponse::<_>::Error(err)) => {
if err.contains("No authed oidc is found") {
// ignore, keep querying
} else {
OIDC_SESSION
.write()
.await
.set_state(OidcState::OidcQueryFailed, err);
return;
}
}
Ok(_) => {
// ignore
}
Err(err) => {
log::trace!("Failed query oidc {}", err);
// ignore
}
}
sleep(QUERY_INTERVAL_SECS).await;
}
if begin.elapsed() >= query_timeout {
OIDC_SESSION
.write()
.await
.set_state(OidcState::OidcQueryTimeout, "timeout".to_owned());
}
// no need to handle "keep_querying == false"
}
fn set_state(&mut self, state: OidcState, failed_msg: String) {
self.state = state;
self.failed_msg = failed_msg;
}
pub async fn account_auth(op: String, id: String, uuid: String) {
if OIDC_SESSION.read().await.running {
OIDC_SESSION.write().await.keep_querying = false;
}
let wait_secs = 0.3;
sleep(wait_secs).await;
while OIDC_SESSION.read().await.running {
sleep(wait_secs).await;
}
tokio::spawn(async move {
OIDC_SESSION.write().await.before_task().await;
Self::auth_task(op, id, uuid).await;
OIDC_SESSION.write().await.after_task().await;
});
}
fn get_result_(&self) -> (u8, String, Option<AuthBody>) {
(
self.state as u8,
self.failed_msg.clone(),
self.auth_body.clone(),
)
}
pub async fn get_result() -> (u8, String, Option<AuthBody>) {
OIDC_SESSION.read().await.get_result_()
}
}

View File

@ -48,6 +48,8 @@ mod ui_cm_interface;
mod ui_interface;
mod ui_session_interface;
mod hbbs_http;
#[cfg(windows)]
pub mod clipboard_file;