From 64d527abe6278638b8434210d3dcc2bc0057b0a8 Mon Sep 17 00:00:00 2001 From: Wolfgang Bumiller Date: Fri, 25 Oct 2019 12:16:21 +0200 Subject: [PATCH] custom executor Signed-off-by: Wolfgang Bumiller --- Cargo.toml | 1 - src/executor.rs | 236 +++++++++++++++++++++++++++++++++++ src/executor/mod.rs | 14 --- src/executor/ring.rs | 109 ---------------- src/executor/slot_list.rs | 35 ------ src/executor/thread_pool.rs | 241 ------------------------------------ src/main.rs | 6 +- 7 files changed, 239 insertions(+), 403 deletions(-) create mode 100644 src/executor.rs delete mode 100644 src/executor/mod.rs delete mode 100644 src/executor/ring.rs delete mode 100644 src/executor/slot_list.rs delete mode 100644 src/executor/thread_pool.rs diff --git a/Cargo.toml b/Cargo.toml index 640cdde..6e4bc10 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,6 @@ authors = [ [dependencies] bitflags = "1.1" failure = { version = "0.1", default-features = false, features = ["std"] } -futures-executor-preview = "0.3.0-alpha" lazy_static = "1.3" libc = "0.2" nix = "0.15" diff --git a/src/executor.rs b/src/executor.rs new file mode 100644 index 0000000..d3e86c7 --- /dev/null +++ b/src/executor.rs @@ -0,0 +1,236 @@ +use std::cell::RefCell; +use std::collections::VecDeque; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::sync::{Arc, Condvar, Mutex, Weak}; +use std::task::{Context, Poll}; +use std::thread::JoinHandle; + +type BoxFut = Box + Send + 'static>; + +#[derive(Clone)] +struct Task(Arc); + +impl Task { + fn into_raw(this: Task) -> *const TaskInner { + Arc::into_raw(this.0) + } + + unsafe fn from_raw(ptr: *const TaskInner) -> Self { + Self(Arc::from_raw(ptr)) + } + + fn wake(self) { + if let Some(queue) = self.0.queue.upgrade() { + queue.queue(self); + } + } + + fn into_raw_waker(this: Task) -> std::task::RawWaker { + std::task::RawWaker::new( + Task::into_raw(this) as *const (), + &std::task::RawWakerVTable::new( + waker_clone_fn, + waker_wake_fn, + waker_wake_by_ref_fn, + waker_drop_fn, + ), + ) + } +} + +struct TaskInner { + future: Mutex>, + queue: Weak, +} + +struct TaskQueue { + queue: Mutex>, + queue_cv: Condvar, +} + +impl TaskQueue { + fn new() -> Self { + Self { + queue: Mutex::new(VecDeque::with_capacity(32)), + queue_cv: Condvar::new(), + } + } + + fn new_task(self: Arc, future: BoxFut) { + let task = Task(Arc::new(TaskInner { + future: Mutex::new(Some(future)), + queue: Arc::downgrade(&self), + })); + + self.queue(task); + } + + fn queue(&self, task: Task) { + let mut queue = self.queue.lock().unwrap(); + queue.push_back(task); + self.queue_cv.notify_one(); + } + + /// Blocks until a task is available + fn get_task(&self) -> Task { + let mut queue = self.queue.lock().unwrap(); + loop { + if let Some(task) = queue.pop_front() { + return task; + } else { + queue = self.queue_cv.wait(queue).unwrap(); + } + } + } +} + +pub struct ThreadPool { + threads: Mutex>>, + queue: Arc, +} + +impl ThreadPool { + pub fn new() -> io::Result { + let count = 2; //num_cpus()?; + + let queue = Arc::new(TaskQueue::new()); + + let mut threads = Vec::new(); + for thread_id in 0..count { + threads.push(std::thread::spawn({ + let queue = Arc::clone(&queue); + move || thread_main(queue, thread_id) + })); + } + + Ok(Self { + threads: Mutex::new(threads), + queue, + }) + } + + pub fn spawn_ok(&self, future: T) + where + T: Future + Send + 'static, + { + self.do_spawn(Box::new(future)); + } + + fn do_spawn(&self, future: BoxFut) { + Arc::clone(&self.queue).new_task(future); + } + + pub fn run(&self, future: T) -> R + where + T: Future + Send + 'static, + R: Send + 'static, + { + let mutex: Arc>> = Arc::new(Mutex::new(None)); + let cv = Arc::new(Condvar::new()); + let mut guard = mutex.lock().unwrap(); + self.spawn_ok({ + let mutex = Arc::clone(&mutex); + let cv = Arc::clone(&cv); + async move { + let result = future.await; + *(mutex.lock().unwrap()) = Some(result); + cv.notify_all(); + } + }); + loop { + guard = cv.wait(guard).unwrap(); + if let Some(result) = guard.take() { + return result; + } + } + } +} + +thread_local! { + static CURRENT_QUEUE: RefCell<*const TaskQueue> = RefCell::new(std::ptr::null()); + static CURRENT_TASK: RefCell<*const Task> = RefCell::new(std::ptr::null()); +} + +fn thread_main(task_queue: Arc, _thread_id: usize) { + CURRENT_QUEUE.with(|q| *q.borrow_mut() = task_queue.as_ref() as *const TaskQueue); + + let local_waker = unsafe { + std::task::Waker::from_raw(std::task::RawWaker::new( + std::ptr::null(), + &std::task::RawWakerVTable::new( + local_waker_clone_fn, + local_waker_wake_fn, + local_waker_wake_fn, + local_waker_drop_fn, + ), + )) + }; + + let mut context = Context::from_waker(&local_waker); + + loop { + let task: Task = task_queue.get_task(); + let task: Pin<&Task> = Pin::new(&task); + let task = task.get_ref(); + CURRENT_TASK.with(|c| *c.borrow_mut() = task as *const Task); + + let mut task_future = task.0.future.lock().unwrap(); + match task_future.take() { + Some(mut future) => { + let pin = unsafe { Pin::new_unchecked(&mut *future) }; + match pin.poll(&mut context) { + Poll::Ready(()) => (), // done with that task + Poll::Pending => { + *task_future = Some(future); + } + } + } + None => eprintln!("task polled after ready"), + } + } +} + +unsafe fn local_waker_clone_fn(_: *const ()) -> std::task::RawWaker { + let task: Task = CURRENT_TASK.with(|t| Task::clone(&**t.borrow())); + Task::into_raw_waker(task) +} + +unsafe fn local_waker_wake_fn(_: *const ()) { + let task: Task = CURRENT_TASK.with(|t| Task::clone(&**t.borrow())); + CURRENT_QUEUE.with(|q| (**q.borrow()).queue(task)); +} + +unsafe fn local_waker_drop_fn(_: *const ()) {} + +unsafe fn waker_clone_fn(this: *const ()) -> std::task::RawWaker { + let this = Task::from_raw(this as *const TaskInner); + let clone = this.clone(); + let _ = Task::into_raw(this); + Task::into_raw_waker(clone) +} + +unsafe fn waker_wake_fn(this: *const ()) { + let this = Task::from_raw(this as *const TaskInner); + this.wake(); +} + +unsafe fn waker_wake_by_ref_fn(this: *const ()) { + let this = Task::from_raw(this as *const TaskInner); + this.clone().wake(); + let _ = Task::into_raw(this); +} + +unsafe fn waker_drop_fn(this: *const ()) { + let _this = Task::from_raw(this as *const TaskInner); +} + +pub fn num_cpus() -> io::Result { + let rc = unsafe { libc::sysconf(libc::_SC_NPROCESSORS_ONLN) }; + if rc < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(rc as usize) + } +} diff --git a/src/executor/mod.rs b/src/executor/mod.rs deleted file mode 100644 index 7712e15..0000000 --- a/src/executor/mod.rs +++ /dev/null @@ -1,14 +0,0 @@ -use std::io; - -pub mod ring; -pub mod slot_list; -pub mod thread_pool; - -pub fn num_cpus() -> io::Result { - let rc = unsafe { libc::sysconf(libc::_SC_NPROCESSORS_ONLN) }; - if rc < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(rc as usize) - } -} diff --git a/src/executor/ring.rs b/src/executor/ring.rs deleted file mode 100644 index f56c0de..0000000 --- a/src/executor/ring.rs +++ /dev/null @@ -1,109 +0,0 @@ -use std::mem::MaybeUninit; -use std::ptr; -use std::sync::atomic::{fence, AtomicBool, AtomicUsize, Ordering}; - -// We only perform a handful of memory read/writes in push()/pop(), so we use spin locks for -// performance reasons: - -struct SpinLock(AtomicBool); -struct SpinLockGuard<'a>(&'a AtomicBool); - -impl SpinLock { - const fn new() -> Self { - Self(AtomicBool::new(false)) - } - - fn lock(&self) -> SpinLockGuard { - while self.0.compare_and_swap(false, true, Ordering::Acquire) { - // spin - } - SpinLockGuard(&self.0) - } -} - -impl Drop for SpinLockGuard<'_> { - fn drop(&mut self) { - self.0.store(false, Ordering::Release); - } -} - -pub struct Ring { - head: usize, - tail: usize, - mask: usize, - data: Box<[MaybeUninit]>, - push_lock: SpinLock, - pop_lock: SpinLock, -} - -impl Ring { - pub fn new(size: usize) -> Self { - if size < 2 || size.count_ones() != 1 { - panic!("Ring size must be a power of two!"); - } - - let mut data = Vec::with_capacity(size); - unsafe { - data.set_len(size); - } - - Self { - head: 0, - tail: 0, - mask: size - 1, - data: data.into_boxed_slice(), - push_lock: SpinLock::new(), - pop_lock: SpinLock::new(), - } - } - - pub fn len(&self) -> usize { - fence(Ordering::Acquire); - self.tail - self.head - } - - #[inline] - fn atomic_tail(&self) -> &AtomicUsize { - unsafe { &*(&self.tail as *const usize as *const AtomicUsize) } - } - - #[inline] - fn atomic_head(&self) -> &AtomicUsize { - unsafe { &*(&self.head as *const usize as *const AtomicUsize) } - } - - pub fn try_push(&self, data: T) -> bool { - let _guard = self.push_lock.lock(); - - let tail = self.atomic_tail().load(Ordering::Acquire); - let head = self.head; - - if tail - head == self.data.len() { - return false; - } - - unsafe { - ptr::write(self.data[tail & self.mask].as_ptr() as *mut T, data); - } - self.atomic_tail().store(tail + 1, Ordering::Release); - - true - } - - pub fn try_pop(&self) -> Option { - let _guard = self.pop_lock.lock(); - - let head = self.atomic_head().load(Ordering::Acquire); - let tail = self.tail; - - if tail - head == 0 { - return None; - } - - let data = unsafe { ptr::read(self.data[head & self.mask].as_ptr()) }; - - self.atomic_head().store(head + 1, Ordering::Release); - - Some(data) - } -} diff --git a/src/executor/slot_list.rs b/src/executor/slot_list.rs deleted file mode 100644 index ae7f084..0000000 --- a/src/executor/slot_list.rs +++ /dev/null @@ -1,35 +0,0 @@ -pub struct SlotList { - tasks: Vec>, - free_slots: Vec, -} - -impl SlotList { - pub fn new() -> Self { - Self { - tasks: Vec::new(), - free_slots: Vec::new(), - } - } - - pub fn add(&mut self, data: T) -> usize { - if let Some(id) = self.free_slots.pop() { - let old = self.tasks[id].replace(data); - assert!(old.is_none()); - id - } else { - let id = self.tasks.len(); - self.tasks.push(Some(data)); - id - } - } - - pub fn remove(&mut self, id: usize) -> T { - let entry = self.tasks[id].take().unwrap(); - self.free_slots.push(id); - entry - } - - pub fn get(&self, id: usize) -> Option<&T> { - self.tasks[id].as_ref() - } -} diff --git a/src/executor/thread_pool.rs b/src/executor/thread_pool.rs deleted file mode 100644 index 1d62220..0000000 --- a/src/executor/thread_pool.rs +++ /dev/null @@ -1,241 +0,0 @@ -use std::cell::RefCell; -use std::future::Future; -use std::io; -use std::pin::Pin; -use std::sync::{Arc, Mutex, RwLock}; -use std::task::{Context, Poll}; -use std::thread::JoinHandle; - -use super::num_cpus; -use super::ring::Ring; -use super::slot_list::SlotList; - -type BoxFut = Box + Send + 'static>; -type TaskId = usize; - -struct Task { - id: TaskId, - pool: Arc, - future: Option<(BoxFut, std::task::Waker)>, -} - -pub struct ThreadPool { - inner: Arc, -} - -impl ThreadPool { - pub fn new() -> io::Result { - let count = num_cpus()?; - - let inner = Arc::new(ThreadPoolInner { - threads: Mutex::new(Vec::new()), - tasks: RwLock::new(SlotList::new()), - overflow: RwLock::new(Vec::new()), - }); - - let mut threads = inner.threads.lock().unwrap(); - for thread_id in 0..count { - threads.push(Thread::new(Arc::clone(&inner), thread_id)); - } - drop(threads); - - Ok(ThreadPool { inner }) - } - - pub fn spawn(&self, future: T) - where - T: Future + Send + 'static, - { - self.inner.spawn(Box::new(future)) - } -} - -struct ThreadPoolInner { - threads: Mutex>, - tasks: RwLock>, - overflow: RwLock>, -} - -unsafe impl Sync for ThreadPoolInner {} - -impl ThreadPoolInner { - fn create_task(&self, future: BoxFut) -> TaskId { - self.tasks.write().unwrap().add(future) - } - - fn spawn(&self, future: BoxFut) { - self.queue_task(self.create_task(future)) - } - - fn queue_task(&self, task: TaskId) { - let threads = self.threads.lock().unwrap(); - - let shortest = threads - .iter() - .min_by(|a, b| a.task_count().cmp(&b.task_count())) - .expect("thread pool should not be empty"); - - if !shortest.try_queue(task) { - drop(threads); - self.overflow.write().unwrap().push(task); - } - } - - fn create_waker(self: Arc, task_id: TaskId) -> std::task::RawWaker { - let waker = Box::new(Waker { - pool: self, - task_id, - }); - std::task::RawWaker::new(Box::leak(waker) as *mut Waker as *mut (), &WAKER_VTABLE) - } -} - -struct Thread { - handle: JoinHandle<()>, - inner: Arc, -} - -impl Thread { - fn new(pool: Arc, id: usize) -> Self { - let inner = Arc::new(ThreadInner { - id, - ring: Ring::new(32), - pool, - }); - - let handle = std::thread::spawn({ - let inner = Arc::clone(&inner); - move || inner.thread_main() - }); - Thread { handle, inner } - } - - fn task_count(&self) -> usize { - self.inner.task_count() - } - - fn try_queue(&self, task: TaskId) -> bool { - self.inner.try_queue(task) - } -} - -struct ThreadInner { - id: usize, - ring: Ring, - pool: Arc, -} - -thread_local! { - static THREAD_INNER: RefCell<*const ThreadInner> = RefCell::new(std::ptr::null()); -} - -impl ThreadInner { - fn thread_main(self: Arc) { - THREAD_INNER.with(|inner| { - *inner.borrow_mut() = self.as_ref() as *const Self; - }); - loop { - if let Some(task_id) = self.ring.try_pop() { - self.poll_task(task_id); - } - } - } - - fn poll_task(&self, task_id: TaskId) { - //let future = { - // let task = self.pool.tasks.read().unwrap().get(task_id).unwrap(); - // if let Some(future) = task.future.as_ref() { - // future.as_ref() as *const (dyn Future + Send + 'static) - // as *mut (dyn Future + Send + 'static) - // } else { - // return; - // } - //}; - let waker = unsafe { - std::task::Waker::from_raw(std::task::RawWaker::new( - task_id as *const (), - &std::task::RawWakerVTable::new( - local_waker_clone_fn, - local_waker_wake_fn, - local_waker_wake_by_ref_fn, - local_waker_drop_fn, - ), - )) - }; - - let mut context = Context::from_waker(&waker); - - let future = { - self.pool - .tasks - .read() - .unwrap() - .get(task_id) - .unwrap() - .as_ref() as *const (dyn Future + Send + 'static) - as *mut (dyn Future + Send + 'static) - }; - if let Poll::Ready(value) = unsafe { Pin::new_unchecked(&mut *future) }.poll(&mut context) { - let task = self.pool.tasks.write().unwrap().remove(task_id); - } - } - - fn task_count(&self) -> usize { - self.ring.len() - } - - fn try_queue(&self, task: TaskId) -> bool { - self.ring.try_push(task) - } -} - -struct RefWaker<'a> { - pool: &'a ThreadPoolInner, - task_id: TaskId, -} - -struct Waker { - pool: Arc, - task_id: TaskId, -} - -pub struct Work {} - -const WAKER_VTABLE: std::task::RawWakerVTable = std::task::RawWakerVTable::new( - waker_clone_fn, - waker_wake_fn, - waker_wake_by_ref_fn, - waker_drop_fn, -); - -unsafe fn waker_clone_fn(_this: *const ()) -> std::task::RawWaker { - panic!("TODO"); -} - -unsafe fn waker_wake_fn(_this: *const ()) { - panic!("TODO"); -} - -unsafe fn waker_wake_by_ref_fn(_this: *const ()) { - panic!("TODO"); -} - -unsafe fn waker_drop_fn(_this: *const ()) { - panic!("TODO"); -} - -unsafe fn local_waker_clone_fn(_this: *const ()) -> std::task::RawWaker { - panic!("TODO"); -} - -unsafe fn local_waker_wake_fn(_this: *const ()) { - panic!("TODO"); -} - -unsafe fn local_waker_wake_by_ref_fn(_this: *const ()) { - panic!("TODO"); -} - -unsafe fn local_waker_drop_fn(_this: *const ()) { - panic!("TODO"); -} diff --git a/src/main.rs b/src/main.rs index 8556a9d..a49ef6e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,9 +23,9 @@ pub mod tools; use io_uring::socket::SeqPacketListener; -static mut EXECUTOR: *mut futures_executor::ThreadPool = std::ptr::null_mut(); +static mut EXECUTOR: *mut executor::ThreadPool = std::ptr::null_mut(); -pub fn executor() -> &'static futures_executor::ThreadPool { +pub fn executor() -> &'static executor::ThreadPool { unsafe { &*EXECUTOR } } @@ -34,7 +34,7 @@ pub fn spawn(fut: impl Future + Send + 'static) { } fn main() { - let mut executor = futures_executor::ThreadPool::new().expect("spawning worker threadpool"); + let mut executor = executor::ThreadPool::new().expect("spawning worker threadpool"); unsafe { EXECUTOR = &mut executor; }