custom executor

Signed-off-by: Wolfgang Bumiller <w.bumiller@proxmox.com>
This commit is contained in:
Wolfgang Bumiller 2019-10-25 12:16:21 +02:00
parent 4003b0f418
commit 64d527abe6
7 changed files with 239 additions and 403 deletions

View File

@ -9,7 +9,6 @@ authors = [
[dependencies] [dependencies]
bitflags = "1.1" bitflags = "1.1"
failure = { version = "0.1", default-features = false, features = ["std"] } failure = { version = "0.1", default-features = false, features = ["std"] }
futures-executor-preview = "0.3.0-alpha"
lazy_static = "1.3" lazy_static = "1.3"
libc = "0.2" libc = "0.2"
nix = "0.15" nix = "0.15"

236
src/executor.rs Normal file
View File

@ -0,0 +1,236 @@
use std::cell::RefCell;
use std::collections::VecDeque;
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::sync::{Arc, Condvar, Mutex, Weak};
use std::task::{Context, Poll};
use std::thread::JoinHandle;
type BoxFut = Box<dyn Future<Output = ()> + Send + 'static>;
#[derive(Clone)]
struct Task(Arc<TaskInner>);
impl Task {
fn into_raw(this: Task) -> *const TaskInner {
Arc::into_raw(this.0)
}
unsafe fn from_raw(ptr: *const TaskInner) -> Self {
Self(Arc::from_raw(ptr))
}
fn wake(self) {
if let Some(queue) = self.0.queue.upgrade() {
queue.queue(self);
}
}
fn into_raw_waker(this: Task) -> std::task::RawWaker {
std::task::RawWaker::new(
Task::into_raw(this) as *const (),
&std::task::RawWakerVTable::new(
waker_clone_fn,
waker_wake_fn,
waker_wake_by_ref_fn,
waker_drop_fn,
),
)
}
}
struct TaskInner {
future: Mutex<Option<BoxFut>>,
queue: Weak<TaskQueue>,
}
struct TaskQueue {
queue: Mutex<VecDeque<Task>>,
queue_cv: Condvar,
}
impl TaskQueue {
fn new() -> Self {
Self {
queue: Mutex::new(VecDeque::with_capacity(32)),
queue_cv: Condvar::new(),
}
}
fn new_task(self: Arc<TaskQueue>, future: BoxFut) {
let task = Task(Arc::new(TaskInner {
future: Mutex::new(Some(future)),
queue: Arc::downgrade(&self),
}));
self.queue(task);
}
fn queue(&self, task: Task) {
let mut queue = self.queue.lock().unwrap();
queue.push_back(task);
self.queue_cv.notify_one();
}
/// Blocks until a task is available
fn get_task(&self) -> Task {
let mut queue = self.queue.lock().unwrap();
loop {
if let Some(task) = queue.pop_front() {
return task;
} else {
queue = self.queue_cv.wait(queue).unwrap();
}
}
}
}
pub struct ThreadPool {
threads: Mutex<Vec<JoinHandle<()>>>,
queue: Arc<TaskQueue>,
}
impl ThreadPool {
pub fn new() -> io::Result<Self> {
let count = 2; //num_cpus()?;
let queue = Arc::new(TaskQueue::new());
let mut threads = Vec::new();
for thread_id in 0..count {
threads.push(std::thread::spawn({
let queue = Arc::clone(&queue);
move || thread_main(queue, thread_id)
}));
}
Ok(Self {
threads: Mutex::new(threads),
queue,
})
}
pub fn spawn_ok<T>(&self, future: T)
where
T: Future<Output = ()> + Send + 'static,
{
self.do_spawn(Box::new(future));
}
fn do_spawn(&self, future: BoxFut) {
Arc::clone(&self.queue).new_task(future);
}
pub fn run<R, T>(&self, future: T) -> R
where
T: Future<Output = R> + Send + 'static,
R: Send + 'static,
{
let mutex: Arc<Mutex<Option<R>>> = Arc::new(Mutex::new(None));
let cv = Arc::new(Condvar::new());
let mut guard = mutex.lock().unwrap();
self.spawn_ok({
let mutex = Arc::clone(&mutex);
let cv = Arc::clone(&cv);
async move {
let result = future.await;
*(mutex.lock().unwrap()) = Some(result);
cv.notify_all();
}
});
loop {
guard = cv.wait(guard).unwrap();
if let Some(result) = guard.take() {
return result;
}
}
}
}
thread_local! {
static CURRENT_QUEUE: RefCell<*const TaskQueue> = RefCell::new(std::ptr::null());
static CURRENT_TASK: RefCell<*const Task> = RefCell::new(std::ptr::null());
}
fn thread_main(task_queue: Arc<TaskQueue>, _thread_id: usize) {
CURRENT_QUEUE.with(|q| *q.borrow_mut() = task_queue.as_ref() as *const TaskQueue);
let local_waker = unsafe {
std::task::Waker::from_raw(std::task::RawWaker::new(
std::ptr::null(),
&std::task::RawWakerVTable::new(
local_waker_clone_fn,
local_waker_wake_fn,
local_waker_wake_fn,
local_waker_drop_fn,
),
))
};
let mut context = Context::from_waker(&local_waker);
loop {
let task: Task = task_queue.get_task();
let task: Pin<&Task> = Pin::new(&task);
let task = task.get_ref();
CURRENT_TASK.with(|c| *c.borrow_mut() = task as *const Task);
let mut task_future = task.0.future.lock().unwrap();
match task_future.take() {
Some(mut future) => {
let pin = unsafe { Pin::new_unchecked(&mut *future) };
match pin.poll(&mut context) {
Poll::Ready(()) => (), // done with that task
Poll::Pending => {
*task_future = Some(future);
}
}
}
None => eprintln!("task polled after ready"),
}
}
}
unsafe fn local_waker_clone_fn(_: *const ()) -> std::task::RawWaker {
let task: Task = CURRENT_TASK.with(|t| Task::clone(&**t.borrow()));
Task::into_raw_waker(task)
}
unsafe fn local_waker_wake_fn(_: *const ()) {
let task: Task = CURRENT_TASK.with(|t| Task::clone(&**t.borrow()));
CURRENT_QUEUE.with(|q| (**q.borrow()).queue(task));
}
unsafe fn local_waker_drop_fn(_: *const ()) {}
unsafe fn waker_clone_fn(this: *const ()) -> std::task::RawWaker {
let this = Task::from_raw(this as *const TaskInner);
let clone = this.clone();
let _ = Task::into_raw(this);
Task::into_raw_waker(clone)
}
unsafe fn waker_wake_fn(this: *const ()) {
let this = Task::from_raw(this as *const TaskInner);
this.wake();
}
unsafe fn waker_wake_by_ref_fn(this: *const ()) {
let this = Task::from_raw(this as *const TaskInner);
this.clone().wake();
let _ = Task::into_raw(this);
}
unsafe fn waker_drop_fn(this: *const ()) {
let _this = Task::from_raw(this as *const TaskInner);
}
pub fn num_cpus() -> io::Result<usize> {
let rc = unsafe { libc::sysconf(libc::_SC_NPROCESSORS_ONLN) };
if rc < 0 {
Err(io::Error::last_os_error())
} else {
Ok(rc as usize)
}
}

View File

@ -1,14 +0,0 @@
use std::io;
pub mod ring;
pub mod slot_list;
pub mod thread_pool;
pub fn num_cpus() -> io::Result<usize> {
let rc = unsafe { libc::sysconf(libc::_SC_NPROCESSORS_ONLN) };
if rc < 0 {
Err(io::Error::last_os_error())
} else {
Ok(rc as usize)
}
}

View File

@ -1,109 +0,0 @@
use std::mem::MaybeUninit;
use std::ptr;
use std::sync::atomic::{fence, AtomicBool, AtomicUsize, Ordering};
// We only perform a handful of memory read/writes in push()/pop(), so we use spin locks for
// performance reasons:
struct SpinLock(AtomicBool);
struct SpinLockGuard<'a>(&'a AtomicBool);
impl SpinLock {
const fn new() -> Self {
Self(AtomicBool::new(false))
}
fn lock(&self) -> SpinLockGuard {
while self.0.compare_and_swap(false, true, Ordering::Acquire) {
// spin
}
SpinLockGuard(&self.0)
}
}
impl Drop for SpinLockGuard<'_> {
fn drop(&mut self) {
self.0.store(false, Ordering::Release);
}
}
pub struct Ring<T> {
head: usize,
tail: usize,
mask: usize,
data: Box<[MaybeUninit<T>]>,
push_lock: SpinLock,
pop_lock: SpinLock,
}
impl<T> Ring<T> {
pub fn new(size: usize) -> Self {
if size < 2 || size.count_ones() != 1 {
panic!("Ring size must be a power of two!");
}
let mut data = Vec::with_capacity(size);
unsafe {
data.set_len(size);
}
Self {
head: 0,
tail: 0,
mask: size - 1,
data: data.into_boxed_slice(),
push_lock: SpinLock::new(),
pop_lock: SpinLock::new(),
}
}
pub fn len(&self) -> usize {
fence(Ordering::Acquire);
self.tail - self.head
}
#[inline]
fn atomic_tail(&self) -> &AtomicUsize {
unsafe { &*(&self.tail as *const usize as *const AtomicUsize) }
}
#[inline]
fn atomic_head(&self) -> &AtomicUsize {
unsafe { &*(&self.head as *const usize as *const AtomicUsize) }
}
pub fn try_push(&self, data: T) -> bool {
let _guard = self.push_lock.lock();
let tail = self.atomic_tail().load(Ordering::Acquire);
let head = self.head;
if tail - head == self.data.len() {
return false;
}
unsafe {
ptr::write(self.data[tail & self.mask].as_ptr() as *mut T, data);
}
self.atomic_tail().store(tail + 1, Ordering::Release);
true
}
pub fn try_pop(&self) -> Option<T> {
let _guard = self.pop_lock.lock();
let head = self.atomic_head().load(Ordering::Acquire);
let tail = self.tail;
if tail - head == 0 {
return None;
}
let data = unsafe { ptr::read(self.data[head & self.mask].as_ptr()) };
self.atomic_head().store(head + 1, Ordering::Release);
Some(data)
}
}

View File

@ -1,35 +0,0 @@
pub struct SlotList<T> {
tasks: Vec<Option<T>>,
free_slots: Vec<usize>,
}
impl<T> SlotList<T> {
pub fn new() -> Self {
Self {
tasks: Vec::new(),
free_slots: Vec::new(),
}
}
pub fn add(&mut self, data: T) -> usize {
if let Some(id) = self.free_slots.pop() {
let old = self.tasks[id].replace(data);
assert!(old.is_none());
id
} else {
let id = self.tasks.len();
self.tasks.push(Some(data));
id
}
}
pub fn remove(&mut self, id: usize) -> T {
let entry = self.tasks[id].take().unwrap();
self.free_slots.push(id);
entry
}
pub fn get(&self, id: usize) -> Option<&T> {
self.tasks[id].as_ref()
}
}

View File

@ -1,241 +0,0 @@
use std::cell::RefCell;
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::sync::{Arc, Mutex, RwLock};
use std::task::{Context, Poll};
use std::thread::JoinHandle;
use super::num_cpus;
use super::ring::Ring;
use super::slot_list::SlotList;
type BoxFut = Box<dyn Future<Output = ()> + Send + 'static>;
type TaskId = usize;
struct Task {
id: TaskId,
pool: Arc<ThreadPool>,
future: Option<(BoxFut, std::task::Waker)>,
}
pub struct ThreadPool {
inner: Arc<ThreadPoolInner>,
}
impl ThreadPool {
pub fn new() -> io::Result<Self> {
let count = num_cpus()?;
let inner = Arc::new(ThreadPoolInner {
threads: Mutex::new(Vec::new()),
tasks: RwLock::new(SlotList::new()),
overflow: RwLock::new(Vec::new()),
});
let mut threads = inner.threads.lock().unwrap();
for thread_id in 0..count {
threads.push(Thread::new(Arc::clone(&inner), thread_id));
}
drop(threads);
Ok(ThreadPool { inner })
}
pub fn spawn<T>(&self, future: T)
where
T: Future<Output = ()> + Send + 'static,
{
self.inner.spawn(Box::new(future))
}
}
struct ThreadPoolInner {
threads: Mutex<Vec<Thread>>,
tasks: RwLock<SlotList<BoxFut>>,
overflow: RwLock<Vec<TaskId>>,
}
unsafe impl Sync for ThreadPoolInner {}
impl ThreadPoolInner {
fn create_task(&self, future: BoxFut) -> TaskId {
self.tasks.write().unwrap().add(future)
}
fn spawn(&self, future: BoxFut) {
self.queue_task(self.create_task(future))
}
fn queue_task(&self, task: TaskId) {
let threads = self.threads.lock().unwrap();
let shortest = threads
.iter()
.min_by(|a, b| a.task_count().cmp(&b.task_count()))
.expect("thread pool should not be empty");
if !shortest.try_queue(task) {
drop(threads);
self.overflow.write().unwrap().push(task);
}
}
fn create_waker(self: Arc<Self>, task_id: TaskId) -> std::task::RawWaker {
let waker = Box::new(Waker {
pool: self,
task_id,
});
std::task::RawWaker::new(Box::leak(waker) as *mut Waker as *mut (), &WAKER_VTABLE)
}
}
struct Thread {
handle: JoinHandle<()>,
inner: Arc<ThreadInner>,
}
impl Thread {
fn new(pool: Arc<ThreadPoolInner>, id: usize) -> Self {
let inner = Arc::new(ThreadInner {
id,
ring: Ring::new(32),
pool,
});
let handle = std::thread::spawn({
let inner = Arc::clone(&inner);
move || inner.thread_main()
});
Thread { handle, inner }
}
fn task_count(&self) -> usize {
self.inner.task_count()
}
fn try_queue(&self, task: TaskId) -> bool {
self.inner.try_queue(task)
}
}
struct ThreadInner {
id: usize,
ring: Ring<TaskId>,
pool: Arc<ThreadPoolInner>,
}
thread_local! {
static THREAD_INNER: RefCell<*const ThreadInner> = RefCell::new(std::ptr::null());
}
impl ThreadInner {
fn thread_main(self: Arc<Self>) {
THREAD_INNER.with(|inner| {
*inner.borrow_mut() = self.as_ref() as *const Self;
});
loop {
if let Some(task_id) = self.ring.try_pop() {
self.poll_task(task_id);
}
}
}
fn poll_task(&self, task_id: TaskId) {
//let future = {
// let task = self.pool.tasks.read().unwrap().get(task_id).unwrap();
// if let Some(future) = task.future.as_ref() {
// future.as_ref() as *const (dyn Future<Output = ()> + Send + 'static)
// as *mut (dyn Future<Output = ()> + Send + 'static)
// } else {
// return;
// }
//};
let waker = unsafe {
std::task::Waker::from_raw(std::task::RawWaker::new(
task_id as *const (),
&std::task::RawWakerVTable::new(
local_waker_clone_fn,
local_waker_wake_fn,
local_waker_wake_by_ref_fn,
local_waker_drop_fn,
),
))
};
let mut context = Context::from_waker(&waker);
let future = {
self.pool
.tasks
.read()
.unwrap()
.get(task_id)
.unwrap()
.as_ref() as *const (dyn Future<Output = ()> + Send + 'static)
as *mut (dyn Future<Output = ()> + Send + 'static)
};
if let Poll::Ready(value) = unsafe { Pin::new_unchecked(&mut *future) }.poll(&mut context) {
let task = self.pool.tasks.write().unwrap().remove(task_id);
}
}
fn task_count(&self) -> usize {
self.ring.len()
}
fn try_queue(&self, task: TaskId) -> bool {
self.ring.try_push(task)
}
}
struct RefWaker<'a> {
pool: &'a ThreadPoolInner,
task_id: TaskId,
}
struct Waker {
pool: Arc<ThreadPoolInner>,
task_id: TaskId,
}
pub struct Work {}
const WAKER_VTABLE: std::task::RawWakerVTable = std::task::RawWakerVTable::new(
waker_clone_fn,
waker_wake_fn,
waker_wake_by_ref_fn,
waker_drop_fn,
);
unsafe fn waker_clone_fn(_this: *const ()) -> std::task::RawWaker {
panic!("TODO");
}
unsafe fn waker_wake_fn(_this: *const ()) {
panic!("TODO");
}
unsafe fn waker_wake_by_ref_fn(_this: *const ()) {
panic!("TODO");
}
unsafe fn waker_drop_fn(_this: *const ()) {
panic!("TODO");
}
unsafe fn local_waker_clone_fn(_this: *const ()) -> std::task::RawWaker {
panic!("TODO");
}
unsafe fn local_waker_wake_fn(_this: *const ()) {
panic!("TODO");
}
unsafe fn local_waker_wake_by_ref_fn(_this: *const ()) {
panic!("TODO");
}
unsafe fn local_waker_drop_fn(_this: *const ()) {
panic!("TODO");
}

View File

@ -23,9 +23,9 @@ pub mod tools;
use io_uring::socket::SeqPacketListener; use io_uring::socket::SeqPacketListener;
static mut EXECUTOR: *mut futures_executor::ThreadPool = std::ptr::null_mut(); static mut EXECUTOR: *mut executor::ThreadPool = std::ptr::null_mut();
pub fn executor() -> &'static futures_executor::ThreadPool { pub fn executor() -> &'static executor::ThreadPool {
unsafe { &*EXECUTOR } unsafe { &*EXECUTOR }
} }
@ -34,7 +34,7 @@ pub fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
} }
fn main() { fn main() {
let mut executor = futures_executor::ThreadPool::new().expect("spawning worker threadpool"); let mut executor = executor::ThreadPool::new().expect("spawning worker threadpool");
unsafe { unsafe {
EXECUTOR = &mut executor; EXECUTOR = &mut executor;
} }