drop EventedFd/PolledFd helpers

And use tokio's AsyncFd correctly.

And restore SOCK_NONBLOCK on the receiver.

Signed-off-by: Wolfgang Bumiller <w.bumiller@proxmox.com>
This commit is contained in:
Wolfgang Bumiller 2022-07-18 12:13:59 +02:00
parent 7b1d2aa594
commit 7d6927b680
4 changed files with 109 additions and 162 deletions

View File

@ -1,5 +1,47 @@
use std::io;
use std::os::unix::io::{AsRawFd, RawFd};
use tokio::io::unix::AsyncFd;
use crate::tools::Fd;
pub mod cmsg;
pub mod pipe;
pub mod polled_fd;
pub mod rw_traits;
pub mod seq_packet;
pub async fn wrap_read<R, F>(async_fd: &AsyncFd<Fd>, mut call: F) -> io::Result<R>
where
F: FnMut(RawFd) -> io::Result<R>,
{
let fd = async_fd.as_raw_fd();
loop {
let mut guard = async_fd.readable().await?;
match call(fd) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
guard.clear_ready();
continue;
}
other => return other,
}
}
}
pub async fn wrap_write<R, F>(async_fd: &AsyncFd<Fd>, mut call: F) -> io::Result<R>
where
F: FnMut(RawFd) -> io::Result<R>,
{
let fd = async_fd.as_raw_fd();
loop {
let mut guard = async_fd.writable().await?;
match call(fd) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
guard.clear_ready();
continue;
}
other => return other,
}
}
}

View File

@ -5,9 +5,9 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::unix::AsyncFd;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::io::polled_fd::PolledFd;
use crate::io::rw_traits;
use crate::tools::Fd;
@ -44,9 +44,9 @@ pub fn pipe_fds() -> io::Result<(PipeFd<rw_traits::Read>, PipeFd<rw_traits::Writ
}
/// Tokio supported pipe file descriptor. `tokio::fs::File` requires tokio's complete file system
/// feature gate, so we just use this `PolledFd` wrapper.
/// feature gate, so we just use this `AsyncFd` wrapper.
pub struct Pipe<RW> {
fd: PolledFd,
fd: AsyncFd<Fd>,
_phantom: PhantomData<RW>,
}
@ -55,7 +55,7 @@ impl<RW> TryFrom<PipeFd<RW>> for Pipe<RW> {
fn try_from(fd: PipeFd<RW>) -> io::Result<Self> {
Ok(Self {
fd: PolledFd::new(fd.into_fd())?,
fd: AsyncFd::new(fd.into_fd())?,
_phantom: PhantomData,
})
}
@ -71,7 +71,7 @@ impl<RW> AsRawFd for Pipe<RW> {
impl<RW> IntoRawFd for Pipe<RW> {
#[inline]
fn into_raw_fd(self) -> RawFd {
self.fd.into_raw_fd()
self.fd.into_inner().into_raw_fd()
}
}
@ -87,16 +87,28 @@ impl<RW: rw_traits::HasRead> AsyncRead for Pipe<RW> {
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
self.fd.wrap_read(cx, || {
let mut guard = ready!(self.fd.poll_read_ready(cx))?;
let fd = self.as_raw_fd();
let mem = buf.initialize_unfilled();
c_result!(unsafe { libc::read(fd, mem.as_mut_ptr() as *mut libc::c_void, mem.len()) })
.map(|received| {
match c_result!(unsafe { libc::read(fd, mem.as_mut_ptr() as *mut libc::c_void, mem.len()) })
{
Ok(received) => {
if received > 0 {
buf.advance(received as usize)
}
})
})
guard.retain_ready();
Poll::Ready(Ok(()))
}
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
guard.clear_ready();
Poll::Pending
}
Err(err) => {
guard.retain_ready();
Poll::Ready(Err(err))
}
}
}
}
@ -106,11 +118,24 @@ impl<RW: rw_traits::HasWrite> AsyncWrite for Pipe<RW> {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.fd.wrap_write(cx, || {
let mut guard = ready!(self.fd.poll_write_ready(cx))?;
let fd = self.as_raw_fd();
c_result!(unsafe { libc::write(fd, buf.as_ptr() as *const libc::c_void, buf.len()) })
.map(|res| res as usize)
})
match c_result!(unsafe { libc::write(fd, buf.as_ptr() as *const libc::c_void, buf.len()) })
{
Ok(res) => {
guard.retain_ready();
Poll::Ready(Ok(res as usize))
}
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
guard.clear_ready();
Poll::Pending
}
Err(err) => {
guard.retain_ready();
Poll::Ready(Err(err))
}
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {

View File

@ -1,101 +0,0 @@
use std::io;
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use std::task::{Context, Poll};
use tokio::io::unix::AsyncFd;
use crate::tools::Fd;
#[repr(transparent)]
pub struct EventedFd {
fd: Fd,
}
impl EventedFd {
#[inline]
pub fn new(fd: Fd) -> Self {
Self { fd }
}
}
impl AsRawFd for EventedFd {
#[inline]
fn as_raw_fd(&self) -> RawFd {
self.fd.as_raw_fd()
}
}
impl FromRawFd for EventedFd {
#[inline]
unsafe fn from_raw_fd(fd: RawFd) -> Self {
Self::new(unsafe { Fd::from_raw_fd(fd) })
}
}
impl IntoRawFd for EventedFd {
#[inline]
fn into_raw_fd(self) -> RawFd {
self.fd.into_raw_fd()
}
}
#[repr(transparent)]
pub struct PolledFd {
fd: AsyncFd<EventedFd>,
}
impl PolledFd {
pub fn new(fd: Fd) -> tokio::io::Result<Self> {
Ok(Self {
fd: AsyncFd::new(EventedFd::new(fd))?,
})
}
pub fn wrap_read<T>(
&self,
cx: &mut Context,
func: impl FnOnce() -> io::Result<T>,
) -> Poll<io::Result<T>> {
let mut ready_guard = ready!(self.fd.poll_read_ready(cx))?;
match func() {
Ok(out) => Poll::Ready(Ok(out)),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
ready_guard.clear_ready();
Poll::Pending
}
Err(err) => Poll::Ready(Err(err)),
}
}
pub fn wrap_write<T>(
&self,
cx: &mut Context,
func: impl FnOnce() -> io::Result<T>,
) -> Poll<io::Result<T>> {
let mut ready_guard = ready!(self.fd.poll_write_ready(cx))?;
match func() {
Ok(out) => Poll::Ready(Ok(out)),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
ready_guard.clear_ready();
Poll::Pending
}
Err(err) => Poll::Ready(Err(err)),
}
}
}
impl AsRawFd for PolledFd {
#[inline]
fn as_raw_fd(&self) -> RawFd {
self.fd.get_ref().as_raw_fd()
}
}
impl IntoRawFd for PolledFd {
#[inline]
fn into_raw_fd(self) -> RawFd {
// for the kind of resource we're managing it should always be possible to extract it from
// its driver
self.fd.into_inner().into_raw_fd()
}
}

View File

@ -1,13 +1,11 @@
use std::io::{self, IoSlice, IoSliceMut};
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::ptr;
use std::task::{Context, Poll};
use anyhow::Error;
use nix::sys::socket::{self, AddressFamily, SockFlag, SockType, SockaddrLike};
use tokio::io::unix::AsyncFd;
use crate::io::polled_fd::PolledFd;
use crate::poll_fn::poll_fn;
use crate::tools::AssertSendSync;
use crate::tools::Fd;
@ -22,7 +20,7 @@ fn seq_packet_socket(flags: SockFlag) -> nix::Result<Fd> {
}
pub struct SeqPacketListener {
fd: PolledFd,
fd: AsyncFd<Fd>,
}
impl AsRawFd for SeqPacketListener {
@ -38,14 +36,13 @@ impl SeqPacketListener {
socket::bind(fd.as_raw_fd(), address)?;
socket::listen(fd.as_raw_fd(), 16)?;
let fd = PolledFd::new(fd)?;
let fd = AsyncFd::new(fd)?;
Ok(Self { fd })
}
pub fn poll_accept(&mut self, cx: &mut Context) -> Poll<io::Result<SeqPacketSocket>> {
let fd = self.as_raw_fd();
let res = self.fd.wrap_read(cx, || {
pub async fn accept(&mut self) -> io::Result<SeqPacketSocket> {
let fd = super::wrap_read(&self.fd, |fd| {
c_result!(unsafe {
libc::accept4(
fd,
@ -54,22 +51,16 @@ impl SeqPacketListener {
libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK,
)
})
.map(|fd| unsafe { Fd::from_raw_fd(fd as RawFd) })
});
match res {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(fd)) => Poll::Ready(SeqPacketSocket::new(fd)),
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
}
}
})
.await?;
pub async fn accept(&mut self) -> io::Result<SeqPacketSocket> {
poll_fn(move |cx| self.poll_accept(cx)).await
let fd = unsafe { Fd::from_raw_fd(fd as RawFd) };
SeqPacketSocket::new(fd)
}
}
pub struct SeqPacketSocket {
fd: PolledFd,
fd: AsyncFd<Fd>,
}
impl AsRawFd for SeqPacketSocket {
@ -82,21 +73,16 @@ impl AsRawFd for SeqPacketSocket {
impl SeqPacketSocket {
pub fn new(fd: Fd) -> io::Result<Self> {
Ok(Self {
fd: PolledFd::new(fd)?,
fd: AsyncFd::new(fd)?,
})
}
pub fn poll_sendmsg(
&self,
cx: &mut Context,
msg: &AssertSendSync<libc::msghdr>,
) -> Poll<io::Result<usize>> {
let fd = self.fd.as_raw_fd();
self.fd.wrap_write(cx, || {
async fn sendmsg(&self, msg: &AssertSendSync<libc::msghdr>) -> io::Result<usize> {
let rc = super::wrap_write(&self.fd, |fd| {
c_result!(unsafe { libc::sendmsg(fd, &msg.0 as *const libc::msghdr, 0) })
.map(|rc| rc as usize)
})
.await?;
Ok(rc as usize)
}
pub async fn sendmsg_vectored(&self, iov: &[IoSlice<'_>]) -> io::Result<usize> {
@ -110,20 +96,15 @@ impl SeqPacketSocket {
msg_flags: 0,
});
poll_fn(move |cx| self.poll_sendmsg(cx, &msg)).await
self.sendmsg(&msg).await
}
pub fn poll_recvmsg(
&self,
cx: &mut Context,
msg: &mut AssertSendSync<libc::msghdr>,
) -> Poll<io::Result<usize>> {
let fd = self.fd.as_raw_fd();
self.fd.wrap_read(cx, || {
async fn recvmsg(&self, msg: &mut AssertSendSync<libc::msghdr>) -> io::Result<usize> {
let rc = super::wrap_read(&self.fd, move |fd| {
c_result!(unsafe { libc::recvmsg(fd, &mut msg.0 as *mut libc::msghdr, 0) })
.map(|rc| rc as usize)
})
.await?;
Ok(rc as usize)
}
// clippy is wrong about this one
@ -143,7 +124,7 @@ impl SeqPacketSocket {
msg_flags: libc::MSG_CMSG_CLOEXEC,
});
let data_size = poll_fn(|cx| self.poll_recvmsg(cx, &mut msg)).await?;
let data_size = self.recvmsg(&mut msg).await?;
Ok((data_size, msg.0.msg_controllen as usize))
}