diff --git a/proxmox-rest-server/src/api_config.rs b/proxmox-rest-server/src/api_config.rs index ad9a8111..80589446 100644 --- a/proxmox-rest-server/src/api_config.rs +++ b/proxmox-rest-server/src/api_config.rs @@ -1,13 +1,16 @@ use std::collections::HashMap; use std::future::Future; +use std::io; use std::path::PathBuf; use std::pin::Pin; use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; use anyhow::{format_err, Error}; -use http::{HeaderMap, Method}; +use http::{HeaderMap, Method, Uri}; use hyper::http::request::Parts; use hyper::{Body, Response}; +use tower_service::Service; use proxmox_router::{Router, RpcEnvironmentType, UserInformation}; use proxmox_sys::fs::{create_path, CreateOptions}; @@ -25,6 +28,7 @@ pub struct ApiConfig { handlers: Vec, auth_handler: Option, index_handler: Option, + pub(crate) privileged_addr: Option, #[cfg(feature = "templates")] templates: templates::Templates, @@ -53,6 +57,7 @@ impl ApiConfig { handlers: Vec::new(), auth_handler: None, index_handler: None, + privileged_addr: None, #[cfg(feature = "templates")] templates: Default::default(), @@ -73,6 +78,12 @@ impl ApiConfig { self.auth_handler(AuthHandler::from_fn(func)) } + /// This is used for `protected` API calls to proxy to a more privileged service. + pub fn privileged_addr(mut self, addr: impl Into) -> Self { + self.privileged_addr = Some(addr.into()); + self + } + /// Set the index handler. pub fn index_handler(mut self, index_handler: IndexHandler) -> Self { self.index_handler = Some(index_handler); @@ -452,3 +463,156 @@ impl From for AuthError { AuthError::Generic(err) } } + +#[derive(Clone, Debug)] +/// For `protected` requests we support TCP or Unix connections. +pub enum PrivilegedAddr { + Tcp(std::net::SocketAddr), + Unix(std::os::unix::net::SocketAddr), +} + +impl From for PrivilegedAddr { + fn from(addr: std::net::SocketAddr) -> Self { + Self::Tcp(addr) + } +} + +impl From for PrivilegedAddr { + fn from(addr: std::os::unix::net::SocketAddr) -> Self { + Self::Unix(addr) + } +} + +impl Service for PrivilegedAddr { + type Response = PrivilegedSocket; + type Error = io::Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: Uri) -> Self::Future { + match self { + PrivilegedAddr::Tcp(addr) => { + let addr = addr.clone(); + Box::pin(async move { + tokio::net::TcpStream::connect(addr) + .await + .map(PrivilegedSocket::Tcp) + }) + } + PrivilegedAddr::Unix(addr) => { + let addr = addr.clone(); + Box::pin(async move { + tokio::net::UnixStream::connect(addr.as_pathname().ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, "empty path for unix socket") + })?) + .await + .map(PrivilegedSocket::Unix) + }) + } + } + } +} + +/// A socket which is either a TCP stream or a UNIX stream. +pub enum PrivilegedSocket { + Tcp(tokio::net::TcpStream), + Unix(tokio::net::UnixStream), +} + +impl tokio::io::AsyncRead for PrivilegedSocket { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + Self::Tcp(s) => Pin::new(s).poll_read(cx, buf), + Self::Unix(s) => Pin::new(s).poll_read(cx, buf), + } + } +} + +impl tokio::io::AsyncWrite for PrivilegedSocket { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + Self::Tcp(s) => Pin::new(s).poll_write(cx, buf), + Self::Unix(s) => Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Tcp(s) => Pin::new(s).poll_flush(cx), + Self::Unix(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Tcp(s) => Pin::new(s).poll_shutdown(cx), + Self::Unix(s) => Pin::new(s).poll_shutdown(cx), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + match self.get_mut() { + Self::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs), + Self::Unix(s) => Pin::new(s).poll_write_vectored(cx, bufs), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + Self::Tcp(s) => s.is_write_vectored(), + Self::Unix(s) => s.is_write_vectored(), + } + } +} + +impl hyper::client::connect::Connection for PrivilegedSocket { + fn connected(&self) -> hyper::client::connect::Connected { + match self { + Self::Tcp(s) => s.connected(), + Self::Unix(_) => hyper::client::connect::Connected::new(), + } + } +} + +/// Implements hyper's `Accept` for `UnixListener`s. +pub struct UnixAcceptor { + listener: tokio::net::UnixListener, +} + +impl From for UnixAcceptor { + fn from(listener: tokio::net::UnixListener) -> Self { + Self { listener } + } +} + +impl hyper::server::accept::Accept for UnixAcceptor { + type Conn = tokio::net::UnixStream; + type Error = io::Error; + + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + Pin::new(&mut self.get_mut().listener) + .poll_accept(cx) + .map(|res| match res { + Ok((stream, _addr)) => Some(Ok(stream)), + Err(err) => Some(Err(err)), + }) + } +} diff --git a/proxmox-rest-server/src/lib.rs b/proxmox-rest-server/src/lib.rs index 1c64ffb4..ce9e4f15 100644 --- a/proxmox-rest-server/src/lib.rs +++ b/proxmox-rest-server/src/lib.rs @@ -45,7 +45,7 @@ mod file_logger; pub use file_logger::{FileLogOptions, FileLogger}; mod api_config; -pub use api_config::{ApiConfig, AuthError, AuthHandler, IndexHandler}; +pub use api_config::{ApiConfig, AuthError, AuthHandler, IndexHandler, UnixAcceptor}; mod rest; pub use rest::{Redirector, RestServer}; diff --git a/proxmox-rest-server/src/rest.rs b/proxmox-rest-server/src/rest.rs index 39f98e55..4900592d 100644 --- a/proxmox-rest-server/src/rest.rs +++ b/proxmox-rest-server/src/rest.rs @@ -8,7 +8,7 @@ use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; use anyhow::{bail, format_err, Error}; -use futures::future::{FutureExt, TryFutureExt}; +use futures::future::FutureExt; use futures::stream::TryStreamExt; use hyper::body::HttpBody; use hyper::header::{self, HeaderMap}; @@ -443,7 +443,8 @@ async fn get_request_parameters( struct NoLogExtension(); async fn proxy_protected_request( - info: &'static ApiMethod, + config: &ApiConfig, + info: &ApiMethod, mut parts: Parts, req_body: Body, peer: &std::net::SocketAddr, @@ -464,14 +465,16 @@ async fn proxy_protected_request( let reload_timezone = info.reload_timezone; - let resp = hyper::client::Client::new() - .request(request) - .map_err(Error::from) - .map_ok(|mut resp| { - resp.extensions_mut().insert(NoLogExtension()); - resp - }) - .await?; + let mut resp = match config.privileged_addr.clone() { + None => hyper::client::Client::new().request(request).await?, + Some(addr) => { + hyper::client::Client::builder() + .build(addr) + .request(request) + .await? + } + }; + resp.extensions_mut().insert(NoLogExtension()); if reload_timezone { unsafe { @@ -1024,7 +1027,7 @@ impl Formatted { let result = if api_method.protected && rpcenv.env_type == RpcEnvironmentType::PUBLIC { - proxy_protected_request(api_method, parts, body, peer).await + proxy_protected_request(config, api_method, parts, body, peer).await } else { handle_api_request(rpcenv, api_method, formatter, parts, body, uri_param).await }; @@ -1129,7 +1132,7 @@ impl Unformatted { let result = if api_method.protected && rpcenv.env_type == RpcEnvironmentType::PUBLIC { - proxy_protected_request(api_method, parts, body, peer).await + proxy_protected_request(config, api_method, parts, body, peer).await } else { handle_unformatted_api_request(rpcenv, api_method, parts, body, uri_param).await };