๐Ÿฆ€ Functional Rust

993: Thread Pool / Work Queue

Difficulty: Intermediate Category: Async / Concurrency FP Patterns Concept: Fixed N workers consuming tasks from a shared queue Key Insight: Wrap `Receiver<Task>` in `Arc<Mutex<>>` so multiple worker threads can compete for tasks; drop the `Sender` to signal shutdown โ€” workers see `Err` from `recv()` and exit

Versions

DirectoryDescription
`std/`Standard library version using `std::sync`, `std::thread`
`tokio/`Tokio async runtime version using `tokio::sync`, `tokio::spawn`

Running

# Standard library version
cd std && cargo test

# Tokio version
cd tokio && cargo test
// 993: Thread Pool / Work Queue
// Fixed N workers consuming tasks from a shared mpsc channel

use std::sync::{mpsc, Arc, Mutex};
use std::thread;

type Task = Box<dyn FnOnce() + Send + 'static>;

struct ThreadPool {
    sender: mpsc::Sender<Task>,
    workers: Vec<thread::JoinHandle<()>>,
}

impl ThreadPool {
    fn new(size: usize) -> Self {
        assert!(size > 0);
        let (sender, receiver) = mpsc::channel::<Task>();
        // Wrap receiver in Arc<Mutex> so all workers can share it
        let receiver = Arc::new(Mutex::new(receiver));

        let workers = (0..size).map(|_| {
            let rx = Arc::clone(&receiver);
            thread::spawn(move || {
                // Each worker loops: lock, get task, unlock, run task
                loop {
                    let task = {
                        let lock = rx.lock().unwrap();
                        lock.recv() // blocks until task arrives or channel closes
                    };
                    match task {
                        Ok(f) => f(),
                        Err(_) => break, // channel closed โ†’ exit
                    }
                }
            })
        }).collect();

        ThreadPool { sender, workers }
    }

    fn execute<F: FnOnce() + Send + 'static>(&self, f: F) {
        self.sender.send(Box::new(f)).unwrap();
    }

    fn shutdown(self) {
        drop(self.sender); // close channel โ†’ workers see Err and break
        for w in self.workers { w.join().unwrap(); }
    }
}

// --- Approach 1: Submit tasks that collect results ---
fn pool_squares() -> Vec<i64> {
    let pool = ThreadPool::new(4);
    let results = Arc::new(Mutex::new(Vec::new()));

    for i in 1i64..=20 {
        let results = Arc::clone(&results);
        pool.execute(move || {
            results.lock().unwrap().push(i * i);
        });
    }

    pool.shutdown();
    let mut v = results.lock().unwrap().clone();
    v.sort();
    v
}

// --- Approach 2: Work queue with return values via channel ---
fn pool_with_results(inputs: Vec<i32>) -> Vec<i32> {
    let pool = ThreadPool::new(3);
    let (tx, rx) = mpsc::channel::<i32>();

    let n = inputs.len();
    for x in inputs {
        let tx = tx.clone();
        pool.execute(move || {
            tx.send(x * x).unwrap();
        });
    }
    drop(tx); // close sender side
    pool.shutdown();

    let mut results: Vec<i32> = rx.iter().collect();
    // Ensure we got all results (pool shutdown closed the channel)
    assert_eq!(results.len(), n);
    results.sort();
    results
}

fn main() {
    let squares = pool_squares();
    let sum: i64 = squares.iter().sum();
    println!("pool squares sum: {} ({} items)", sum, squares.len());

    let results = pool_with_results(vec![1, 2, 3, 4, 5]);
    println!("pool results: {:?}", results);
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_pool_squares_all_computed() {
        let squares = pool_squares();
        assert_eq!(squares.len(), 20);
        // Sum of i^2 for i=1..20 = 2870
        let sum: i64 = squares.iter().sum();
        assert_eq!(sum, 2870);
    }

    #[test]
    fn test_pool_with_results() {
        let results = pool_with_results(vec![1, 2, 3, 4, 5]);
        assert_eq!(results, vec![1, 4, 9, 16, 25]);
    }

    #[test]
    fn test_pool_empty_tasks() {
        let pool = ThreadPool::new(2);
        pool.shutdown(); // should not hang
    }

    #[test]
    fn test_pool_single_worker() {
        let pool = ThreadPool::new(1);
        let results = Arc::new(Mutex::new(Vec::new()));
        for i in 0..5 {
            let r = Arc::clone(&results);
            pool.execute(move || r.lock().unwrap().push(i));
        }
        pool.shutdown();
        let mut v = results.lock().unwrap().clone();
        v.sort();
        assert_eq!(v, vec![0, 1, 2, 3, 4]);
    }

    #[test]
    fn test_pool_more_tasks_than_workers() {
        let pool = ThreadPool::new(2);
        let counter = Arc::new(Mutex::new(0u32));
        for _ in 0..100 {
            let c = Arc::clone(&counter);
            pool.execute(move || *c.lock().unwrap() += 1);
        }
        pool.shutdown();
        assert_eq!(*counter.lock().unwrap(), 100);
    }
}
(* 993: Thread Pool / Work Queue *)
(* Fixed N workers consuming from a shared channel *)

type 'a chan = { q: 'a Queue.t; m: Mutex.t; cond: Condition.t; mutable closed: bool }

let make_chan () = { q = Queue.create (); m = Mutex.create ();
                     cond = Condition.create (); closed = false }

let send c v =
  Mutex.lock c.m;
  Queue.push v c.q;
  Condition.signal c.cond;
  Mutex.unlock c.m

let close_chan c =
  Mutex.lock c.m;
  c.closed <- true;
  Condition.broadcast c.cond;
  Mutex.unlock c.m

let recv_work c =
  Mutex.lock c.m;
  while Queue.is_empty c.q && not c.closed do
    Condition.wait c.cond c.m
  done;
  let v = if Queue.is_empty c.q then None else Some (Queue.pop c.q) in
  Mutex.unlock c.m;
  v

(* --- Thread pool: spawn N workers, each pulls from shared queue --- *)

type 'a pool = {
  work_chan: ('a -> unit) chan;
  workers: Thread.t list;
}

let make_pool n =
  let work_chan = make_chan () in
  let workers = List.init n (fun _ ->
    Thread.create (fun () ->
      let rec loop () =
        match recv_work work_chan with
        | None -> ()  (* channel closed, exit *)
        | Some task -> task (); loop ()
      in
      loop ()
    ) ()
  ) in
  { work_chan; workers }

let submit pool task = send pool.work_chan task

let shutdown pool =
  close_chan pool.work_chan;
  List.iter Thread.join pool.workers

(* --- Approach 1: Submit 20 tasks to a pool of 4 workers --- *)

let () =
  let results = ref [] in
  let m = Mutex.create () in
  let pool = make_pool 4 in

  for i = 1 to 20 do
    let i = i in
    submit pool (fun () ->
      let v = i * i in
      Mutex.lock m;
      results := v :: !results;
      Mutex.unlock m
    )
  done;

  shutdown pool;

  let sorted = List.sort compare !results in
  (* 1^2..20^2 = 1,4,9,...,400 โ€” sum = 2870 *)
  let total = List.fold_left (+) 0 sorted in
  assert (List.length sorted = 20);
  assert (total = 2870);
  Printf.printf "Approach 1 (pool of 4, 20 tasks): sum=%d\n" total

let () = Printf.printf "โœ“ All tests passed\n"

๐Ÿ“Š Detailed Comparison

Thread Pool / Work Queue โ€” Comparison

Core Insight

A thread pool reuses a fixed number of threads for many tasks, avoiding thread-creation overhead. The shared queue distributes work; each worker races to grab the next task. Shutdown = close the channel.

OCaml Approach

  • `Queue` + `Mutex` + `Condition` for the work channel
  • Each worker: loop calling `recv_work` (blocks on condition variable)
  • `close_chan` sets `closed = true` + broadcasts to wake all workers
  • Workers see `None` on closed+empty channel and exit
  • Tasks are `unit -> unit` closures

Rust Approach

  • `mpsc::channel::<Task>()` where `Task = Box<dyn FnOnce() + Send>`
  • `Arc<Mutex<Receiver<Task>>>` โ€” workers compete to lock and receive
  • Drop `Sender` to close channel โ€” workers get `Err` from `recv()` and break
  • `JoinHandle` collected; `shutdown()` joins all workers
  • Rayon or tokio for production use; this is the minimal std pattern

Comparison Table

ConceptOCamlRust
Task type`unit -> unit``Box<dyn FnOnce() + Send + 'static>`
Shared queue`Queue` + `Mutex` + `Condition``mpsc::channel` + `Arc<Mutex<Rx>>`
Worker loop`while recv_work ... do task ()``loop { lock.recv().ok_or_else(break) }`
Shutdown signal`close_chan` + condition broadcastDrop `Sender` โ€” channel closes
Worker count`List.init n Thread.create``(0..n).map(spawn).collect()`
Result collection`Mutex`-protected listSeparate `mpsc::channel` or `Mutex<Vec>`
Production versionDomain pool (OCaml 5)Rayon / tokio

std vs tokio

Aspectstd versiontokio version
RuntimeOS threads via `std::thread`Async tasks on tokio runtime
Synchronization`std::sync::Mutex`, `Condvar``tokio::sync::Mutex`, channels
Channels`std::sync::mpsc` (unbounded)`tokio::sync::mpsc` (bounded, async)
BlockingThread blocks on lock/recvTask yields, runtime switches tasks
OverheadOne OS thread per taskMany tasks per thread (M:N)
Best forCPU-bound, simple concurrencyI/O-bound, high-concurrency servers