HttpsConnector: use RateLimitedStream

So that we can limit used bandwidth.

Signed-off-by: Dietmar Maurer <dietmar@proxmox.com>
This commit is contained in:
Dietmar Maurer 2021-11-03 11:22:07 +01:00
parent ded24b3f4c
commit 00ca0b7fae
2 changed files with 51 additions and 8 deletions

View File

@ -1,14 +1,14 @@
use anyhow::{bail, format_err, Error}; use anyhow::{bail, format_err, Error};
use std::os::unix::io::AsRawFd; use std::os::unix::io::AsRawFd;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::{Arc, Mutex};
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use futures::*; use futures::*;
use http::Uri; use http::Uri;
use hyper::client::HttpConnector; use hyper::client::HttpConnector;
use openssl::ssl::SslConnector; use openssl::ssl::SslConnector;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_openssl::SslStream; use tokio_openssl::SslStream;
@ -18,12 +18,16 @@ use crate::proxy_config::ProxyConfig;
use crate::tls::MaybeTlsStream; use crate::tls::MaybeTlsStream;
use crate::uri::build_authority; use crate::uri::build_authority;
use super::{RateLimiter, RateLimitedStream};
#[derive(Clone)] #[derive(Clone)]
pub struct HttpsConnector { pub struct HttpsConnector {
connector: HttpConnector, connector: HttpConnector,
ssl_connector: Arc<SslConnector>, ssl_connector: Arc<SslConnector>,
proxy: Option<ProxyConfig>, proxy: Option<ProxyConfig>,
tcp_keepalive: u32, tcp_keepalive: u32,
read_limiter: Option<Arc<Mutex<RateLimiter>>>,
write_limiter: Option<Arc<Mutex<RateLimiter>>>,
} }
impl HttpsConnector { impl HttpsConnector {
@ -38,6 +42,8 @@ impl HttpsConnector {
ssl_connector: Arc::new(ssl_connector), ssl_connector: Arc::new(ssl_connector),
proxy: None, proxy: None,
tcp_keepalive, tcp_keepalive,
read_limiter: None,
write_limiter: None,
} }
} }
@ -45,13 +51,21 @@ impl HttpsConnector {
self.proxy = Some(proxy); self.proxy = Some(proxy);
} }
async fn secure_stream( pub fn set_read_limiter(&mut self, limiter: Option<Arc<Mutex<RateLimiter>>>) {
tcp_stream: TcpStream, self.read_limiter = limiter;
}
pub fn set_write_limiter(&mut self, limiter: Option<Arc<Mutex<RateLimiter>>>) {
self.write_limiter = limiter;
}
async fn secure_stream<S: AsyncRead + AsyncWrite + Unpin>(
tcp_stream: S,
ssl_connector: &SslConnector, ssl_connector: &SslConnector,
host: &str, host: &str,
) -> Result<MaybeTlsStream<TcpStream>, Error> { ) -> Result<MaybeTlsStream<S>, Error> {
let config = ssl_connector.configure()?; let config = ssl_connector.configure()?;
let mut conn: SslStream<TcpStream> = SslStream::new(config.into_ssl(host)?, tcp_stream)?; let mut conn: SslStream<S> = SslStream::new(config.into_ssl(host)?, tcp_stream)?;
Pin::new(&mut conn).connect().await?; Pin::new(&mut conn).connect().await?;
Ok(MaybeTlsStream::Secured(conn)) Ok(MaybeTlsStream::Secured(conn))
} }
@ -107,7 +121,7 @@ impl HttpsConnector {
} }
impl hyper::service::Service<Uri> for HttpsConnector { impl hyper::service::Service<Uri> for HttpsConnector {
type Response = MaybeTlsStream<TcpStream>; type Response = MaybeTlsStream<RateLimitedStream<TcpStream>>;
type Error = Error; type Error = Error;
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
type Future = type Future =
@ -129,6 +143,9 @@ impl hyper::service::Service<Uri> for HttpsConnector {
}; };
let port = dst.port_u16().unwrap_or(if is_https { 443 } else { 80 }); let port = dst.port_u16().unwrap_or(if is_https { 443 } else { 80 });
let keepalive = self.tcp_keepalive; let keepalive = self.tcp_keepalive;
let read_limiter = self.read_limiter.clone();
let write_limiter = self.write_limiter.clone();
if let Some(ref proxy) = self.proxy { if let Some(ref proxy) = self.proxy {
let use_connect = is_https || proxy.force_connect; let use_connect = is_https || proxy.force_connect;
@ -152,12 +169,18 @@ impl hyper::service::Service<Uri> for HttpsConnector {
if use_connect { if use_connect {
async move { async move {
let mut tcp_stream = connector.call(proxy_uri).await.map_err(|err| { let tcp_stream = connector.call(proxy_uri).await.map_err(|err| {
format_err!("error connecting to {} - {}", proxy_authority, err) format_err!("error connecting to {} - {}", proxy_authority, err)
})?; })?;
let _ = set_tcp_keepalive(tcp_stream.as_raw_fd(), keepalive); let _ = set_tcp_keepalive(tcp_stream.as_raw_fd(), keepalive);
let mut tcp_stream = RateLimitedStream::with_limiter(
tcp_stream,
read_limiter,
write_limiter,
);
let mut connect_request = format!("CONNECT {0}:{1} HTTP/1.1\r\n", host, port); let mut connect_request = format!("CONNECT {0}:{1} HTTP/1.1\r\n", host, port);
if let Some(authorization) = authorization { if let Some(authorization) = authorization {
connect_request connect_request
@ -185,6 +208,12 @@ impl hyper::service::Service<Uri> for HttpsConnector {
let _ = set_tcp_keepalive(tcp_stream.as_raw_fd(), keepalive); let _ = set_tcp_keepalive(tcp_stream.as_raw_fd(), keepalive);
let tcp_stream = RateLimitedStream::with_limiter(
tcp_stream,
read_limiter,
write_limiter,
);
Ok(MaybeTlsStream::Proxied(tcp_stream)) Ok(MaybeTlsStream::Proxied(tcp_stream))
} }
.boxed() .boxed()
@ -199,6 +228,12 @@ impl hyper::service::Service<Uri> for HttpsConnector {
let _ = set_tcp_keepalive(tcp_stream.as_raw_fd(), keepalive); let _ = set_tcp_keepalive(tcp_stream.as_raw_fd(), keepalive);
let tcp_stream = RateLimitedStream::with_limiter(
tcp_stream,
read_limiter,
write_limiter,
);
if is_https { if is_https {
Self::secure_stream(tcp_stream, &ssl_connector, &host).await Self::secure_stream(tcp_stream, &ssl_connector, &host).await
} else { } else {

View File

@ -7,6 +7,7 @@ use std::io::IoSlice;
use futures::Future; use futures::Future;
use tokio::io::{ReadBuf, AsyncRead, AsyncWrite}; use tokio::io::{ReadBuf, AsyncRead, AsyncWrite};
use tokio::time::Sleep; use tokio::time::Sleep;
use hyper::client::connect::{Connection, Connected};
use std::task::{Context, Poll}; use std::task::{Context, Poll};
@ -174,3 +175,10 @@ impl <S: AsyncRead + Unpin> AsyncRead for RateLimitedStream<S> {
} }
} }
// we need this for the hyper http client
impl<S: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RateLimitedStream<S> {
fn connected(&self) -> Connected {
self.stream.connected()
}
}