๐Ÿฆ€ Functional Rust

335: Waker and Context

Difficulty: 5 Level: Master How the executor knows to re-poll a future โ€” `cx.waker().wake()` schedules a re-poll.

The Problem This Solves

The async executor's job is to poll futures until they complete. But polling continuously (busy-waiting) wastes 100% of a CPU core. The executor needs a way to be notified when a future that previously returned `Poll::Pending` is now ready to make progress. That notification mechanism is the `Waker`. When your future returns `Pending`, it stores the waker from the `Context` (`cx.waker().clone()`). When the external event happens โ€” the timer fires, the I/O is ready, the channel receives a value โ€” code somewhere calls `waker.wake()`, which tells the executor: "re-poll this future now." The executor then schedules the future for another poll, and this time it returns `Poll::Ready`. Without `Waker`, every async runtime would have to busy-poll all pending futures. With it, the runtime can park idle and only wake up tasks that have actual work to do.

The Intuition

It's like a JavaScript `Promise` resolver. When you create a Promise:
new Promise((resolve, reject) => {
setTimeout(() => resolve(42), 100);
});
The executor doesn't know the promise is ready until `resolve` is called. In Rust, `waker.wake()` plays the role of `resolve` โ€” it's the signal from outside that tells the executor "this future is ready to continue." `Context` is just the carrier for the `Waker`. Every time a future is polled, the executor passes a fresh `Context` containing the waker for that task.

How It Works in Rust

struct ExternalFuture { state: Arc<Mutex<SharedState>> }

impl Future for ExternalFuture {
 type Output = i32;
 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<i32> {
     let mut s = self.state.lock().unwrap();
     if let Some(v) = s.value {
         Poll::Ready(v)        // value is available โ€” done
     } else {
         // Store the waker so the resolver can wake us later
         s.waker = Some(cx.waker().clone());
         Poll::Pending         // not ready โ€” suspend
     }
 }
}

struct Resolver { state: Arc<Mutex<SharedState>> }

impl Resolver {
 fn fulfill(self, value: i32) {
     let mut s = self.state.lock().unwrap();
     s.value = Some(value);
     // Wake the executor: "this future is ready to poll again"
     if let Some(w) = s.waker.take() { w.wake(); }
 }
}
The `block_on` executor in this example actually implements proper waker notifications using an `AtomicBool` โ€” it parks when pending and resumes when `wake()` fires. This is a minimal but correct executor.

What This Unlocks

Key Differences

ConceptOCamlRust
Completion signal`Lwt.wakeup resolver value` (explicit resolver)`waker.wake()` (executor re-polls the future)
Suspension`Lwt.wait ()` returns a suspended thread + resolver`Poll::Pending` + stored `Waker`
Executor notificationInternal to Lwt schedulerExplicit: caller must call `waker.wake()`
Shared state`ref` / `Mutex``Arc<Mutex<SharedState>>`
//! # Waker and Context
//!
//! How the executor knows to re-poll a future โ€” `cx.waker().wake()` schedules a re-poll.

use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};

/// Shared state between a future and its resolver.
#[derive(Default)]
pub struct SharedState<T> {
    pub value: Option<T>,
    pub waker: Option<Waker>,
}

/// A future that completes when an external source provides a value.
pub struct ExternalFuture<T> {
    state: Arc<Mutex<SharedState<T>>>,
}

impl<T: Clone> Future for ExternalFuture<T> {
    type Output = T;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
        let mut state = self.state.lock().unwrap();

        if let Some(value) = state.value.clone() {
            Poll::Ready(value)
        } else {
            // Store the waker so the resolver can wake us later
            state.waker = Some(cx.waker().clone());
            Poll::Pending
        }
    }
}

/// A resolver that can fulfill the associated future.
pub struct Resolver<T> {
    state: Arc<Mutex<SharedState<T>>>,
}

impl<T> Resolver<T> {
    /// Fulfill the future with a value, waking the executor.
    pub fn fulfill(self, value: T) {
        let mut state = self.state.lock().unwrap();
        state.value = Some(value);

        // Wake the executor so it knows to re-poll
        if let Some(waker) = state.waker.take() {
            waker.wake();
        }
    }
}

/// Create a linked future and resolver pair.
pub fn make_future<T>() -> (ExternalFuture<T>, Resolver<T>) {
    let state = Arc::new(Mutex::new(SharedState { value: None, waker: None }));
    (
        ExternalFuture { state: Arc::clone(&state) },
        Resolver { state },
    )
}

/// A minimal executor that properly handles waker notifications.
pub fn block_on<F: Future>(fut: F) -> F::Output {
    use std::sync::atomic::{AtomicBool, Ordering};
    use std::task::{RawWaker, RawWakerVTable};

    let ready = Arc::new(AtomicBool::new(true));

    // Waker vtable that operates on Arc<AtomicBool>
    unsafe fn clone_waker(ptr: *const ()) -> RawWaker {
        let arc = Arc::from_raw(ptr as *const AtomicBool);
        let cloned = Arc::clone(&arc);
        std::mem::forget(arc); // don't drop the original
        RawWaker::new(Arc::into_raw(cloned) as *const (), &VTABLE)
    }

    unsafe fn wake(ptr: *const ()) {
        let arc = Arc::from_raw(ptr as *const AtomicBool);
        arc.store(true, Ordering::Release);
        // arc is dropped here (consumed)
    }

    unsafe fn wake_by_ref(ptr: *const ()) {
        let arc = Arc::from_raw(ptr as *const AtomicBool);
        arc.store(true, Ordering::Release);
        std::mem::forget(arc); // don't drop
    }

    unsafe fn drop_waker(ptr: *const ()) {
        drop(Arc::from_raw(ptr as *const AtomicBool));
    }

    static VTABLE: RawWakerVTable = RawWakerVTable::new(clone_waker, wake, wake_by_ref, drop_waker);

    let raw = RawWaker::new(Arc::into_raw(Arc::clone(&ready)) as *const (), &VTABLE);
    let waker = unsafe { Waker::from_raw(raw) };
    let mut cx = Context::from_waker(&waker);
    let mut fut = Box::pin(fut);

    loop {
        if let Poll::Ready(value) = fut.as_mut().poll(&mut cx) {
            return value;
        }
        // Wait until woken
        while !ready.swap(false, Ordering::AcqRel) {
            std::hint::spin_loop();
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::thread;
    use std::time::Duration;

    #[test]
    fn test_external_future() {
        let (fut, resolver) = make_future::<i32>();

        thread::spawn(move || {
            thread::sleep(Duration::from_millis(5));
            resolver.fulfill(42);
        });

        assert_eq!(block_on(fut), 42);
    }

    #[test]
    fn test_immediate_fulfill() {
        let (fut, resolver) = make_future::<String>();
        resolver.fulfill("hello".to_string());
        assert_eq!(block_on(fut), "hello");
    }

    #[test]
    fn test_multiple_futures() {
        let (fut1, res1) = make_future::<i32>();
        let (fut2, res2) = make_future::<i32>();

        thread::spawn(move || {
            thread::sleep(Duration::from_millis(5));
            res1.fulfill(1);
        });
        thread::spawn(move || {
            thread::sleep(Duration::from_millis(3));
            res2.fulfill(2);
        });

        assert_eq!(block_on(fut1), 1);
        assert_eq!(block_on(fut2), 2);
    }

    #[test]
    fn test_shared_state_default() {
        let state: SharedState<i32> = SharedState::default();
        assert!(state.value.is_none());
        assert!(state.waker.is_none());
    }
}
(* OCaml: manual continuation-passing *)

type 'a promise = Fulfilled of 'a | Pending of ('a -> unit) list

let make () =
  let state = ref (Pending []) in
  let wait f = match !state with Fulfilled v -> f v | Pending ws -> state := Pending (f::ws) in
  let fulfill v = match !state with
    | Fulfilled _ -> failwith "already"
    | Pending ws -> state := Fulfilled v; List.iter (fun f -> f v) ws
  in (wait, fulfill)

let () =
  let (wait, fulfill) = make () in
  wait (fun v -> Printf.printf "Waiter1: %d\n" v);
  wait (fun v -> Printf.printf "Waiter2: %d\n" v);
  fulfill 42