diff --git a/proxmox-router/src/lib.rs b/proxmox-router/src/lib.rs index b0b67bcb..da2f018f 100644 --- a/proxmox-router/src/lib.rs +++ b/proxmox-router/src/lib.rs @@ -29,3 +29,5 @@ pub use serializable_return::SerializableReturn; // make list_subdirs_api_method! work without an explicit proxmox-schema dependency: #[doc(hidden)] pub use proxmox_schema::ObjectSchema as ListSubdirsObjectSchema; + +pub mod stream; diff --git a/proxmox-router/src/router.rs b/proxmox-router/src/router.rs index 4585f503..4491e59b 100644 --- a/proxmox-router/src/router.rs +++ b/proxmox-router/src/router.rs @@ -78,24 +78,6 @@ pub type SerializingApiHandlerFn = &'static (dyn Fn( + Sync + 'static); -#[derive(Serialize)] -#[serde(rename_all = "kebab-case")] -enum RecordEntry<'a> { - /// A successful record. - Data(&'a Value), - /// An error entry. - Error(Value), -} - -impl<'a> From<&'a Result> for RecordEntry<'a> { - fn from(res: &'a Result) -> Self { - match res { - Ok(value) => Self::Data(value), - Err(err) => Self::Error(err.to_string().into()), - } - } -} - /// A record for a streaming API call. This contains a `Result` and allows formatting /// as a `json-seq` formatted string. /// @@ -105,21 +87,39 @@ impl<'a> From<&'a Result> for RecordEntry<'a> { /// API. pub struct Record { // direct access is only for the CLI code - pub(crate) data: Result, + pub(crate) data: crate::stream::Record, } impl Record { /// Create a new successful record from a serializeable element. - pub fn new(data: &T) -> Self { + pub fn new(data: T) -> Self { Self { - data: Ok(serde_json::to_value(data).expect("failed to create json string")), + data: crate::stream::Record::Data( + serde_json::to_value(data).expect("failed to create json string"), + ), } } /// Create a new error record from an error value. pub fn error>(error: E) -> Self { Self { - data: Err(error.into()), + data: crate::stream::Record::Error(error.into().to_string().into()), + } + } + + /// Create a new error record from an error message. + pub fn error_msg(msg: T) -> Self { + Self { + data: crate::stream::Record::Error(msg.to_string().into()), + } + } + + /// Create a new structured error record from an error value. + pub fn error_value(error: T) -> Self { + Self { + data: crate::stream::Record::Error( + serde_json::to_value(error).expect("failed to create json string"), + ), } } @@ -145,13 +145,26 @@ impl Record { // Don't return special objects that can fail to serialize. // As for "normal" data - we don't expect spurious errors, otherwise they could also happen // when serializing *errors*... - serde_json::to_writer(&mut data, &RecordEntry::from(&self.data)) - .expect("failed to create JSON record"); + serde_json::to_writer(&mut data, &self.data).expect("failed to create JSON record"); data.push(b'\n'); data } } +impl From> for Record +where + T: Serialize, +{ + fn from(data: crate::stream::Record) -> Self { + match data { + crate::stream::Record::Data(data) => Self::new(data), + crate::stream::Record::Error(err) => Self { + data: crate::stream::Record::Error(err), + }, + } + } +} + /// A synchronous API handler returns an [`Iterator`] over items which should be serialized. /// /// ``` @@ -194,7 +207,7 @@ impl SyncStream { pub fn try_collect(self) -> Result { let mut acc = Vec::new(); for i in self.inner { - acc.push(i.data?); + acc.push(i.data.into_result()?); } Ok(Value::Array(acc)) } @@ -339,7 +352,7 @@ impl Stream { let mut acc = Vec::new(); while let Some(i) = self.inner.next().await { - acc.push(i.data?); + acc.push(i.data.into_result()?); } Ok(Value::Array(acc)) } diff --git a/proxmox-router/src/stream.rs b/proxmox-router/src/stream.rs new file mode 100644 index 00000000..49b1fba4 --- /dev/null +++ b/proxmox-router/src/stream.rs @@ -0,0 +1,339 @@ +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{ready, Poll}; + +use anyhow::{format_err, Context as _, Error}; +use futures::io::{AsyncBufRead, AsyncBufReadExt, AsyncRead, BufReader}; +use hyper::body::{Body, Bytes}; +use serde::{Deserialize, Serialize}; + +pub struct Records +where + R: Send + Sync, +{ + inner: RecordsInner, +} + +impl Records { + /// Create a *new buffered reader* for to cerate a record stream from an [`AsyncRead`]. + /// Note: If the underlying type already implements [`AsyncBufRead`], use [`Records::::from`] + /// isntead! + pub fn new(reader: T) -> Records> + where + T: AsyncRead + Send + Sync + Unpin + 'static, + { + BufReader::new(reader).into() + } +} + +impl Records +where + R: AsyncBufRead + Send + Sync + Unpin + 'static, +{ + pub fn json(self) -> JsonRecords + where + T: for<'de> Deserialize<'de>, + { + self.inner.into() + } +} + +impl Records { + pub fn from_body(body: Body) -> Self { + Self::from(BodyBufReader::from(body)) + } +} + +impl From for Records +where + R: AsyncBufRead + Send + Sync + Unpin + 'static, +{ + fn from(reader: R) -> Self { + Self { + inner: reader.into(), + } + } +} + +enum RecordsInner { + New(R), + Reading(Pin, R)>>> + Send + Sync>>), + Done, +} + +impl From for RecordsInner +where + R: AsyncBufRead + Send + Sync + Unpin + 'static, +{ + fn from(reader: R) -> Self { + Self::New(reader) + } +} + +impl futures::Stream for RecordsInner +where + R: AsyncBufRead + Send + Sync + Unpin + 'static, +{ + type Item = io::Result>; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + loop { + return match std::mem::replace(&mut *self, Self::Done) { + Self::New(mut reader) => { + let fut = Box::pin(async move { + let mut linebuf = Vec::new(); + loop { + if reader.read_until(b'\x1E', &mut linebuf).await? == 0 { + return Ok(None); + } + linebuf.pop(); // pop off the record separator + if linebuf.is_empty() { + continue; + } + return Ok(Some((linebuf, reader))); + } + }); + *self = Self::Reading(fut); + continue; + } + Self::Reading(mut fut) => match fut.as_mut().poll(cx) { + Poll::Ready(Ok(None)) => Poll::Ready(None), + Poll::Ready(Ok(Some((data, reader)))) => { + *self = Self::New(reader); + Poll::Ready(Some(Ok(data))) + } + Poll::Ready(Err(err)) => { + *self = Self::Done; + Poll::Ready(Some(Err(err))) + } + Poll::Pending => { + *self = Self::Reading(fut); + Poll::Pending + } + }, + Self::Done => Poll::Ready(None), + }; + } + } +} + +pub struct JsonRecords +where + R: Send + Sync, +{ + inner: JsonRecordsInner, +} + +impl JsonRecords +where + R: Send + Sync, +{ + pub fn from_vec(list: Vec) -> Self { + Self { + inner: JsonRecordsInner::Fixed(list.into_iter()), + } + } +} + +enum JsonRecordsInner { + Stream(RecordsInner), + Fixed(std::vec::IntoIter), +} + +impl From for JsonRecords +where + R: AsyncBufRead + Send + Sync + Unpin + 'static, + T: for<'de> Deserialize<'de>, +{ + fn from(reader: R) -> Self { + Self::from(RecordsInner::from(reader)) + } +} + +impl JsonRecords +where + T: for<'de> Deserialize<'de>, +{ + pub fn from_body(body: Body) -> Self { + Self::from(BodyBufReader::from(body)) + } +} + +impl From> for JsonRecords +where + R: AsyncBufRead + Send + Sync + Unpin + 'static, + T: for<'de> Deserialize<'de>, +{ + fn from(inner: RecordsInner) -> Self { + Self { + inner: JsonRecordsInner::Stream(inner), + } + } +} + +impl futures::Stream for JsonRecords +where + R: AsyncBufRead + Send + Sync + Unpin + 'static, + T: Unpin + for<'de> Deserialize<'de>, +{ + type Item = Result, Error>; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll, Error>>> { + let this = match &mut self.get_mut().inner { + JsonRecordsInner::Stream(this) => this, + JsonRecordsInner::Fixed(iter) => { + return Poll::Ready(iter.next().map(|item| Ok(Record::Data(item)))); + } + }; + + loop { + match ready!(Pin::new(&mut *this).poll_next(cx)) { + None => return Poll::Ready(None), + Some(Err(err)) => return Poll::Ready(Some(Err(err.into()))), + Some(Ok(data)) => { + let data = std::str::from_utf8(&data) + .map_err(|_| format_err!("non-utf8 json data in record element"))? + .trim(); + if data.is_empty() { + continue; + } + return Poll::Ready(Some( + serde_json::from_str(data) + .with_context(|| format!("bad json in record element: {data:?}")), + )); + } + } + } + } +} + +/// An adapter to turn a [`hyper::Body`] into an `AsyncRead`/`AsyncBufRead` for use with the +/// [`Records`]` type. +pub struct BodyBufReader { + reader: Option, + buf_at: Option<(Bytes, usize)>, +} + +impl BodyBufReader { + pub fn records(self) -> Records { + self.into() + } + + pub fn json_records(self) -> JsonRecords + where + T: for<'de> Deserialize<'de>, + { + self.into() + } + + pub fn new(body: Body) -> Self { + Self { + reader: Some(body), + buf_at: None, + } + } +} + +impl From for BodyBufReader { + fn from(body: Body) -> Self { + Self::new(body) + } +} + +impl AsyncRead for BodyBufReader { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut [u8], + ) -> Poll> { + use std::io::Read; + let mut current_data = ready!(self.as_mut().poll_fill_buf(cx))?; + let nread = current_data.read(buf)?; + self.consume(nread); + Poll::Ready(Ok(nread)) + } +} + +impl AsyncBufRead for BodyBufReader { + fn poll_fill_buf( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + use hyper::body::HttpBody; + + let Self { + ref mut reader, + ref mut buf_at, + } = Pin::into_inner(self); + loop { + // If we currently have a buffer, use it: + if let Some((buf, at)) = buf_at { + return Poll::Ready(Ok(&buf[*at..])); + }; + + let result = match reader { + None => return Poll::Ready(Ok(&[])), + Some(reader) => ready!(Pin::new(reader).poll_data(cx)), + }; + + match result { + Some(Ok(bytes)) => { + *buf_at = Some((bytes, 0)); + } + Some(Err(err)) => { + *reader = None; + return Poll::Ready(Err(io::Error::other(err))); + } + None => { + *reader = None; + return Poll::Ready(Ok(&[])); + } + } + } + } + + fn consume(mut self: Pin<&mut Self>, amt: usize) { + if let Some((buf, at)) = self.buf_at.as_mut() { + *at = (*at + amt).min(buf.len()); + if *at == buf.len() { + self.buf_at = None; + } + } + } +} + +/// Streamed JSON records can contain either "data" or an error. +/// +/// Errors can be a simple string or structured data. +/// +/// For convenience, an [`into_result()`](Record::into_result) method is provided to turn the +/// record into a regular `Result`, where the error is converted to either a message or, for +/// structured errors, a json representation. +#[derive(Deserialize, Serialize)] +#[serde(rename_all = "kebab-case")] +pub enum Record { + /// A successful record. + Data(T), + /// An error entry. + Error(serde_json::Value), +} + +impl Record { + pub fn into_result(self) -> Result { + match self { + Self::Data(data) => Ok(data), + Self::Error(serde_json::Value::String(s)) => Err(Error::msg(s)), + Self::Error(other) => match serde_json::to_string(&other) { + Ok(s) => Err(Error::msg(s)), + Err(err) => Err(Error::from(err)), + }, + } + } +}