rest-server: Refactor AcceptBuilder
, provide support for optional TLS
The new public function `accept_tls_optional()` is added, which accepts both plain TCP streams and TCP streams running TLS. Plain TCP streams are sent along via a separate channel in order to clearly distinguish between "secure" and "insecure" connections. Furthermore, instead of `AcceptBuilder` itself holding a reference to an `SslAcceptor`, its public functions now take the acceptor as an argument. The public functions' names are changed to distinguish between their functionality in a more explicit manner: * `accept()` --> `accept_tls()` * NEW --> `accept_tls_optional()` Signed-off-by: Max Carrara <m.carrara@proxmox.com> Tested-by: Lukas Wagner <l.wagner@proxmox.com> Reviewed-by: Lukas Wagner <l.wagner@proxmox.com> Signed-off-by: Wolfgang Bumiller <w.bumiller@proxmox.com>
This commit is contained in:
parent
8eff15b0b0
commit
57b4c4624b
@ -8,15 +8,16 @@ use std::pin::Pin;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use anyhow::Error;
|
||||
use anyhow::{format_err, Context as _, Error};
|
||||
use futures::FutureExt;
|
||||
use hyper::server::accept;
|
||||
use openssl::ec::{EcGroup, EcKey};
|
||||
use openssl::nid::Nid;
|
||||
use openssl::pkey::{PKey, Private};
|
||||
use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod};
|
||||
use openssl::x509::X509;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_openssl::SslStream;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
@ -133,10 +134,14 @@ impl TlsAcceptorBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rate-limited-stream")]
|
||||
type ClientStreamResult = Pin<Box<SslStream<RateLimitedStream<TcpStream>>>>;
|
||||
#[cfg(not(feature = "rate-limited-stream"))]
|
||||
type ClientStreamResult = Pin<Box<SslStream<TcpStream>>>;
|
||||
type InsecureClientStream = TcpStream;
|
||||
#[cfg(feature = "rate-limited-stream")]
|
||||
type InsecureClientStream = RateLimitedStream<TcpStream>;
|
||||
|
||||
type InsecureClientStreamResult = Pin<Box<InsecureClientStream>>;
|
||||
|
||||
type ClientStreamResult = Pin<Box<SslStream<InsecureClientStream>>>;
|
||||
|
||||
#[cfg(feature = "rate-limited-stream")]
|
||||
type LookupRateLimiter = dyn Fn(std::net::SocketAddr) -> (Option<SharedRateLimit>, Option<SharedRateLimit>)
|
||||
@ -145,7 +150,6 @@ type LookupRateLimiter = dyn Fn(std::net::SocketAddr) -> (Option<SharedRateLimit
|
||||
+ 'static;
|
||||
|
||||
pub struct AcceptBuilder {
|
||||
acceptor: Arc<Mutex<SslAcceptor>>,
|
||||
debug: bool,
|
||||
tcp_keepalive_time: u32,
|
||||
max_pending_accepts: usize,
|
||||
@ -154,16 +158,9 @@ pub struct AcceptBuilder {
|
||||
lookup_rate_limiter: Option<Arc<LookupRateLimiter>>,
|
||||
}
|
||||
|
||||
impl AcceptBuilder {
|
||||
pub fn new() -> Result<Self, Error> {
|
||||
Ok(Self::with_acceptor(Arc::new(Mutex::new(
|
||||
TlsAcceptorBuilder::new().build()?,
|
||||
))))
|
||||
}
|
||||
|
||||
pub fn with_acceptor(acceptor: Arc<Mutex<SslAcceptor>>) -> Self {
|
||||
impl Default for AcceptBuilder {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
acceptor,
|
||||
debug: false,
|
||||
tcp_keepalive_time: 120,
|
||||
max_pending_accepts: 1024,
|
||||
@ -172,6 +169,12 @@ impl AcceptBuilder {
|
||||
lookup_rate_limiter: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AcceptBuilder {
|
||||
pub fn new() -> Self {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
pub fn debug(mut self, debug: bool) -> Self {
|
||||
self.debug = debug;
|
||||
@ -193,114 +196,312 @@ impl AcceptBuilder {
|
||||
self.lookup_rate_limiter = Some(lookup_rate_limiter);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub fn accept(
|
||||
impl AcceptBuilder {
|
||||
pub fn accept_tls(
|
||||
self,
|
||||
listener: TcpListener,
|
||||
) -> impl hyper::server::accept::Accept<Conn = ClientStreamResult, Error = Error> {
|
||||
let (sender, receiver) = tokio::sync::mpsc::channel(self.max_pending_accepts);
|
||||
acceptor: Arc<Mutex<SslAcceptor>>,
|
||||
) -> impl accept::Accept<Conn = ClientStreamResult, Error = Error> {
|
||||
let (secure_sender, secure_receiver) = mpsc::channel(self.max_pending_accepts);
|
||||
|
||||
tokio::spawn(self.accept_connections(listener, sender));
|
||||
tokio::spawn(self.accept_connections(listener, acceptor, secure_sender.into()));
|
||||
|
||||
//receiver
|
||||
hyper::server::accept::from_stream(ReceiverStream::new(receiver))
|
||||
accept::from_stream(ReceiverStream::new(secure_receiver))
|
||||
}
|
||||
|
||||
pub fn accept_tls_optional(
|
||||
self,
|
||||
listener: TcpListener,
|
||||
acceptor: Arc<Mutex<SslAcceptor>>,
|
||||
) -> (
|
||||
impl accept::Accept<Conn = ClientStreamResult, Error = Error>,
|
||||
impl accept::Accept<Conn = InsecureClientStreamResult, Error = Error>,
|
||||
) {
|
||||
let (secure_sender, secure_receiver) = mpsc::channel(self.max_pending_accepts);
|
||||
let (insecure_sender, insecure_receiver) = mpsc::channel(self.max_pending_accepts);
|
||||
|
||||
tokio::spawn(self.accept_connections(
|
||||
listener,
|
||||
acceptor,
|
||||
(secure_sender, insecure_sender).into(),
|
||||
));
|
||||
|
||||
(
|
||||
accept::from_stream(ReceiverStream::new(secure_receiver)),
|
||||
accept::from_stream(ReceiverStream::new(insecure_receiver)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
type ClientSender = mpsc::Sender<Result<ClientStreamResult, Error>>;
|
||||
type InsecureClientSender = mpsc::Sender<Result<InsecureClientStreamResult, Error>>;
|
||||
|
||||
enum Sender {
|
||||
Secure(ClientSender),
|
||||
SecureAndInsecure(ClientSender, InsecureClientSender),
|
||||
}
|
||||
|
||||
impl From<ClientSender> for Sender {
|
||||
fn from(sender: ClientSender) -> Self {
|
||||
Sender::Secure(sender)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(ClientSender, InsecureClientSender)> for Sender {
|
||||
fn from(senders: (ClientSender, InsecureClientSender)) -> Self {
|
||||
Sender::SecureAndInsecure(senders.0, senders.1)
|
||||
}
|
||||
}
|
||||
|
||||
impl AcceptBuilder {
|
||||
async fn accept_connections(
|
||||
self,
|
||||
listener: TcpListener,
|
||||
sender: tokio::sync::mpsc::Sender<Result<ClientStreamResult, Error>>,
|
||||
acceptor: Arc<Mutex<SslAcceptor>>,
|
||||
sender: Sender,
|
||||
) {
|
||||
let accept_counter = Arc::new(());
|
||||
let mut shutdown_future = crate::shutdown_future().fuse();
|
||||
|
||||
loop {
|
||||
let (sock, peer) = futures::select! {
|
||||
res = listener.accept().fuse() => match res {
|
||||
Ok(conn) => conn,
|
||||
let socket = futures::select! {
|
||||
res = self.try_setup_socket(&listener).fuse() => match res {
|
||||
Ok(socket) => socket,
|
||||
Err(err) => {
|
||||
eprintln!("error accepting tcp connection: {err}");
|
||||
log::error!("couldn't set up TCP socket: {err}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
_ = shutdown_future => break,
|
||||
};
|
||||
#[cfg(not(feature = "rate-limited-stream"))]
|
||||
{
|
||||
let _ = &peer;
|
||||
}
|
||||
|
||||
sock.set_nodelay(true).unwrap();
|
||||
let _ = proxmox_sys::linux::socket::set_tcp_keepalive(
|
||||
sock.as_raw_fd(),
|
||||
self.tcp_keepalive_time,
|
||||
);
|
||||
|
||||
#[cfg(feature = "rate-limited-stream")]
|
||||
let sock = match self.lookup_rate_limiter.clone() {
|
||||
Some(lookup) => {
|
||||
RateLimitedStream::with_limiter_update_cb(sock, move || lookup(peer))
|
||||
}
|
||||
None => RateLimitedStream::with_limiter(sock, None, None),
|
||||
_ = shutdown_future => break,
|
||||
};
|
||||
|
||||
let ssl = {
|
||||
// limit acceptor_guard scope
|
||||
// Acceptor can be reloaded using the command socket "reload-certificate" command
|
||||
let acceptor_guard = self.acceptor.lock().unwrap();
|
||||
|
||||
match openssl::ssl::Ssl::new(acceptor_guard.context()) {
|
||||
Ok(ssl) => ssl,
|
||||
Err(err) => {
|
||||
eprintln!("failed to create Ssl object from Acceptor context - {err}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let stream = match tokio_openssl::SslStream::new(ssl, sock) {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
eprintln!("failed to create SslStream using ssl and connection socket - {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let mut stream = Box::pin(stream);
|
||||
let sender = sender.clone();
|
||||
let acceptor = Arc::clone(&acceptor);
|
||||
let accept_counter = Arc::clone(&accept_counter);
|
||||
|
||||
if Arc::strong_count(&accept_counter) > self.max_pending_accepts {
|
||||
eprintln!("connection rejected - too many open connections");
|
||||
log::error!("connection rejected - too many open connections");
|
||||
continue;
|
||||
}
|
||||
|
||||
let accept_counter = Arc::clone(&accept_counter);
|
||||
tokio::spawn(async move {
|
||||
let accept_future =
|
||||
tokio::time::timeout(Duration::new(10, 0), stream.as_mut().accept());
|
||||
match sender {
|
||||
Sender::Secure(ref secure_sender) => {
|
||||
let accept_future = Self::do_accept_tls(
|
||||
socket,
|
||||
acceptor,
|
||||
accept_counter,
|
||||
self.debug,
|
||||
secure_sender.clone(),
|
||||
);
|
||||
|
||||
let result = accept_future.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => {
|
||||
if sender.send(Ok(stream)).await.is_err() && self.debug {
|
||||
log::error!("detect closed connection channel");
|
||||
}
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
if self.debug {
|
||||
log::error!("https handshake failed - {err}");
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
if self.debug {
|
||||
log::error!("https handshake timeout");
|
||||
}
|
||||
}
|
||||
tokio::spawn(accept_future);
|
||||
}
|
||||
Sender::SecureAndInsecure(ref secure_sender, ref insecure_sender) => {
|
||||
let accept_future = Self::do_accept_tls_optional(
|
||||
socket,
|
||||
acceptor,
|
||||
accept_counter,
|
||||
self.debug,
|
||||
secure_sender.clone(),
|
||||
insecure_sender.clone(),
|
||||
);
|
||||
|
||||
drop(accept_counter); // decrease reference count
|
||||
});
|
||||
tokio::spawn(accept_future);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_setup_socket(
|
||||
&self,
|
||||
listener: &TcpListener,
|
||||
) -> Result<InsecureClientStream, Error> {
|
||||
let (socket, peer) = match listener.accept().await {
|
||||
Ok(connection) => connection,
|
||||
Err(error) => {
|
||||
return Err(format_err!(error)).context("error while accepting tcp stream")
|
||||
}
|
||||
};
|
||||
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
.context("error while setting TCP_NODELAY on socket")?;
|
||||
|
||||
proxmox_sys::linux::socket::set_tcp_keepalive(socket.as_raw_fd(), self.tcp_keepalive_time)
|
||||
.context("error while setting SO_KEEPALIVE on socket")?;
|
||||
|
||||
#[cfg(feature = "rate-limited-stream")]
|
||||
let socket = match self.lookup_rate_limiter.clone() {
|
||||
Some(lookup) => RateLimitedStream::with_limiter_update_cb(socket, move || lookup(peer)),
|
||||
None => RateLimitedStream::with_limiter(socket, None, None),
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "rate-limited-stream"))]
|
||||
let _peer = peer;
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
async fn do_accept_tls(
|
||||
socket: InsecureClientStream,
|
||||
acceptor: Arc<Mutex<SslAcceptor>>,
|
||||
accept_counter: Arc<()>,
|
||||
debug: bool,
|
||||
secure_sender: ClientSender,
|
||||
) {
|
||||
let ssl = {
|
||||
// limit acceptor_guard scope
|
||||
// Acceptor can be reloaded using the command socket "reload-certificate" command
|
||||
let acceptor_guard = acceptor.lock().unwrap();
|
||||
|
||||
match openssl::ssl::Ssl::new(acceptor_guard.context()) {
|
||||
Ok(ssl) => ssl,
|
||||
Err(err) => {
|
||||
log::error!("failed to create Ssl object from Acceptor context - {err}");
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let secure_stream = match tokio_openssl::SslStream::new(ssl, socket) {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
log::error!("failed to create SslStream using ssl and connection socket - {err}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut secure_stream = Box::pin(secure_stream);
|
||||
|
||||
let accept_future =
|
||||
tokio::time::timeout(Duration::new(10, 0), secure_stream.as_mut().accept());
|
||||
|
||||
let result = accept_future.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => {
|
||||
if secure_sender.send(Ok(secure_stream)).await.is_err() && debug {
|
||||
log::error!("detected closed connection channel");
|
||||
}
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
if debug {
|
||||
log::error!("https handshake failed - {err}");
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
if debug {
|
||||
log::error!("https handshake timeout");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
drop(accept_counter); // decrease reference count
|
||||
}
|
||||
|
||||
async fn do_accept_tls_optional(
|
||||
socket: InsecureClientStream,
|
||||
acceptor: Arc<Mutex<SslAcceptor>>,
|
||||
accept_counter: Arc<()>,
|
||||
debug: bool,
|
||||
secure_sender: ClientSender,
|
||||
insecure_sender: InsecureClientSender,
|
||||
) {
|
||||
let client_initiates_handshake = {
|
||||
#[cfg(feature = "rate-limited-stream")]
|
||||
let socket = socket.inner();
|
||||
|
||||
#[cfg(not(feature = "rate-limited-stream"))]
|
||||
let socket = &socket;
|
||||
|
||||
match Self::wait_for_client_tls_handshake(socket).await {
|
||||
Ok(initiates_handshake) => initiates_handshake,
|
||||
Err(err) => {
|
||||
log::error!("error checking for TLS handshake: {err}");
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if !client_initiates_handshake {
|
||||
let insecure_stream = Box::pin(socket);
|
||||
|
||||
if insecure_sender.send(Ok(insecure_stream)).await.is_err() && debug {
|
||||
log::error!("detected closed connection channel")
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
Self::do_accept_tls(socket, acceptor, accept_counter, debug, secure_sender).await
|
||||
}
|
||||
|
||||
async fn wait_for_client_tls_handshake(incoming_stream: &TcpStream) -> Result<bool, Error> {
|
||||
const MS_TIMEOUT: u64 = 1000;
|
||||
const BYTES_BUF_SIZE: usize = 128;
|
||||
|
||||
let mut buf = [0; BYTES_BUF_SIZE];
|
||||
let mut last_peek_size = 0;
|
||||
|
||||
let future = async {
|
||||
loop {
|
||||
let peek_size = incoming_stream
|
||||
.peek(&mut buf)
|
||||
.await
|
||||
.context("couldn't peek into incoming tcp stream")?;
|
||||
|
||||
if contains_tls_handshake_fragment(&buf) {
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
// No more new data came in
|
||||
if peek_size == last_peek_size {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
last_peek_size = peek_size;
|
||||
|
||||
// explicitly yield to event loop; this future otherwise blocks ad infinitum
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
};
|
||||
|
||||
tokio::time::timeout(Duration::from_millis(MS_TIMEOUT), future)
|
||||
.await
|
||||
.unwrap_or(Ok(false))
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks whether an [SSL 3.0 / TLS plaintext fragment][0] being part of a
|
||||
/// SSL / TLS handshake is contained in the given buffer.
|
||||
///
|
||||
/// Such a fragment might look as follows:
|
||||
/// ```ignore
|
||||
/// [0x16, 0x3, 0x1, 0x02, 0x00, ...]
|
||||
/// // | | | |_____|
|
||||
/// // | | | \__ content length interpreted as u16
|
||||
/// // | | | must not exceed 0x4000 (2^14) bytes
|
||||
/// // | | |
|
||||
/// // | | \__ any minor version
|
||||
/// // | |
|
||||
/// // | \__ major version 3
|
||||
/// // |
|
||||
/// // \__ content type is handshake(22)
|
||||
/// ```
|
||||
///
|
||||
/// If a slice like this is detected at the beginning of the given buffer,
|
||||
/// a TLS handshake is most definitely being made.
|
||||
///
|
||||
/// [0]: https://datatracker.ietf.org/doc/html/rfc6101#section-5.2
|
||||
#[inline]
|
||||
fn contains_tls_handshake_fragment(buf: &[u8]) -> bool {
|
||||
const SLICE_LENGTH: usize = 5;
|
||||
const CONTENT_SIZE: u16 = 1 << 14; // max length of a TLS plaintext fragment
|
||||
|
||||
if buf.len() < SLICE_LENGTH {
|
||||
return false;
|
||||
}
|
||||
|
||||
buf[0] == 0x16 && buf[1] == 0x3 && (((buf[3] as u16) << 8) + buf[4] as u16) <= CONTENT_SIZE
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user