Implement a rate limiting stream (AsyncRead, AsyncWrite)

Signed-off-by: Dietmar Maurer <dietmar@proxmox.com>
This commit is contained in:
Dietmar Maurer 2021-11-03 08:48:32 +01:00
parent e848148f5c
commit c94ad247b1
3 changed files with 225 additions and 0 deletions

View File

@ -2,6 +2,12 @@
//!
//! Contains a lightweight wrapper around `hyper` with support for TLS connections.
mod rate_limiter;
pub use rate_limiter::RateLimiter;
mod rate_limited_stream;
pub use rate_limited_stream::RateLimitedStream;
mod connector;
pub use connector::HttpsConnector;

View File

@ -0,0 +1,144 @@
use std::pin::Pin;
use std::marker::Unpin;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use futures::Future;
use tokio::io::{ReadBuf, AsyncRead, AsyncWrite};
use tokio::time::Sleep;
use std::task::{Context, Poll};
use super::RateLimiter;
/// A rate limited stream using [RateLimiter]
pub struct RateLimitedStream<S> {
read_limiter: Option<Arc<Mutex<RateLimiter>>>,
read_delay: Option<Pin<Box<Sleep>>>,
write_limiter: Option<Arc<Mutex<RateLimiter>>>,
write_delay: Option<Pin<Box<Sleep>>>,
stream: S,
}
impl <S> RateLimitedStream<S> {
const MIN_DELAY: Duration = Duration::from_millis(20);
/// Creates a new instance with reads and writes limited to the same `rate`.
pub fn new(stream: S, rate: u64, bucket_size: u64) -> Self {
let now = Instant::now();
let read_limiter = Arc::new(Mutex::new(RateLimiter::with_start_time(rate, bucket_size, now)));
let write_limiter = Arc::new(Mutex::new(RateLimiter::with_start_time(rate, bucket_size, now)));
Self::with_limiter(stream, Some(read_limiter), Some(write_limiter))
}
/// Creates a new instance with specified [RateLimiters] for reads and writes.
pub fn with_limiter(
stream: S,
read_limiter: Option<Arc<Mutex<RateLimiter>>>,
write_limiter: Option<Arc<Mutex<RateLimiter>>>,
) -> Self {
Self {
read_limiter,
read_delay: None,
write_limiter,
write_delay: None,
stream,
}
}
}
impl <S: AsyncWrite + Unpin> AsyncWrite for RateLimitedStream<S> {
fn poll_write(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
buf: &[u8]
) -> Poll<Result<usize, std::io::Error>> {
let this = self.get_mut();
let is_ready = match this.write_delay {
Some(ref mut future) => {
future.as_mut().poll(ctx).is_ready()
}
None => true,
};
if !is_ready { return Poll::Pending; }
this.write_delay = None;
let result = Pin::new(&mut this.stream).poll_write(ctx, buf);
if let Some(ref write_limiter) = this.write_limiter {
if let Poll::Ready(Ok(count)) = &result {
let now = Instant::now();
let delay = write_limiter.lock().unwrap()
.register_traffic(now, *count as u64);
if delay >= Self::MIN_DELAY {
let sleep = tokio::time::sleep(delay);
this.write_delay = Some(Box::pin(sleep));
}
}
}
result
}
fn poll_flush(
self: Pin<&mut Self>,
ctx: &mut Context<'_>
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
Pin::new(&mut this.stream).poll_flush(ctx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
ctx: &mut Context<'_>
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
Pin::new(&mut this.stream).poll_shutdown(ctx)
}
}
impl <S: AsyncRead + Unpin> AsyncRead for RateLimitedStream<S> {
fn poll_read(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
let is_ready = match this.read_delay {
Some(ref mut future) => {
future.as_mut().poll(ctx).is_ready()
}
None => true,
};
if !is_ready { return Poll::Pending; }
this.read_delay = None;
let filled_len = buf.filled().len();
let result = Pin::new(&mut this.stream).poll_read(ctx, buf);
if let Some(ref read_limiter) = this.read_limiter {
if let Poll::Ready(Ok(())) = &result {
let count = buf.filled().len() - filled_len;
let now = Instant::now();
let delay = read_limiter.lock().unwrap()
.register_traffic(now, count as u64);
if delay >= Self::MIN_DELAY {
let sleep = tokio::time::sleep(delay);
this.read_delay = Some(Box::pin(sleep));
}
}
}
result
}
}

View File

@ -0,0 +1,75 @@
use std::time::{Duration, Instant};
use std::convert::TryInto;
/// Token bucket based rate limiter
pub struct RateLimiter {
rate: u64, // tokens/second
start_time: Instant,
traffic: u64, // overall traffic
bucket_size: u64,
last_update: Instant,
consumed_tokens: u64,
}
impl RateLimiter {
const NO_DELAY: Duration = Duration::from_millis(0);
/// Creates a new instance, using [Instant::now] as start time.
pub fn new(rate: u64, bucket_size: u64) -> Self {
let start_time = Instant::now();
Self::with_start_time(rate, bucket_size, start_time)
}
/// Creates a new instance with specified `rate`, `bucket_size` and `start_time`.
pub fn with_start_time(rate: u64, bucket_size: u64, start_time: Instant) -> Self {
Self {
rate,
start_time,
traffic: 0,
bucket_size,
last_update: start_time,
// start with empty bucket (all tokens consumed)
consumed_tokens: bucket_size,
}
}
/// Returns the average rate (since `start_time`)
pub fn average_rate(&self, current_time: Instant) -> f64 {
let time_diff = (current_time - self.start_time).as_secs_f64();
if time_diff <= 0.0 {
0.0
} else {
(self.traffic as f64) / time_diff
}
}
fn refill_bucket(&mut self, current_time: Instant) {
let time_diff = (current_time - self.last_update).as_nanos();
if time_diff <= 0 {
//log::error!("update_time: got negative time diff");
return;
}
self.last_update = current_time;
let allowed_traffic = ((time_diff.saturating_mul(self.rate as u128)) / 1_000_000_000)
.try_into().unwrap_or(u64::MAX);
self.consumed_tokens = self.consumed_tokens.saturating_sub(allowed_traffic);
}
/// Register traffic, returning a proposed delay to reach the expected rate.
pub fn register_traffic(&mut self, current_time: Instant, data_len: u64) -> Duration {
self.refill_bucket(current_time);
self.traffic += data_len;
self.consumed_tokens += data_len;
if self.consumed_tokens <= self.bucket_size {
return Self::NO_DELAY;
}
Duration::from_nanos((self.consumed_tokens - self.bucket_size).saturating_mul(1_000_000_000)/ self.rate)
}
}