๐Ÿฆ€ Functional Rust

595: Trampoline Pattern

Difficulty: 4 Level: Advanced Turn deep recursion into a loop by returning thunks instead of recursing โ€” stack overflow prevention without TCO.

The Problem This Solves

Rust does not guarantee tail-call optimisation. A recursive function that calls itself in tail position still grows the stack. Count down from a million and you'll hit a stack overflow โ€” even if the logic is perfectly tail-recursive. Functional languages like OCaml and Haskell eliminate this with TCO or lazy evaluation. Rust makes you explicit about it. The trampoline pattern is the idiomatic solution: instead of calling the next step of the recursion, return a description of that step (a thunk โ€” a zero-argument closure). A driver loop repeatedly calls thunks until it gets a final value. The stack stays constant-depth; the heap holds the pending computation.

The Intuition

A ball on a trampoline: it hits the surface and bounces up, then comes back down. It never builds up โ€” it just keeps bouncing at the same height. The driver loop is the trampoline surface. Each recursive "call" is a bounce: you land, get sent back up (execute a thunk), land again, until you're caught (the `Done` case).

How It Works in Rust

1. Define the `Bounce` type โ€” a sum type of "finished" or "more work":
enum Bounce<T> {
    Done(T),
    More(Box<dyn FnOnce() -> Bounce<T>>),
}
2. Driver loop โ€” iteratively calls thunks, never grows the stack:
fn run<T>(mut b: Bounce<T>) -> T {
    loop {
        match b {
            Bounce::Done(v)  => return v,
            Bounce::More(th) => b = th(),
        }
    }
}
3. Stack-safe factorial โ€” return a thunk instead of recursing:
fn fact_t(n: u64, acc: u64) -> Bounce<u64> {
    if n == 0 { Bounce::Done(acc) }
    else      { Bounce::More(Box::new(move || fact_t(n - 1, n * acc))) }
}

let result = run(fact_t(1_000_000, 1)); // no stack overflow
4. Mutually recursive functions โ€” each calls the other by returning `More`:
fn even_t(n: u64) -> Bounce<bool> {
    if n == 0 { Bounce::Done(true) }
    else      { Bounce::More(Box::new(move || odd_t(n - 1))) }
}
fn odd_t(n: u64) -> Bounce<bool> {
    if n == 0 { Bounce::Done(false) }
    else      { Bounce::More(Box::new(move || even_t(n - 1))) }
}
5. Cost โ€” each step allocates a `Box`. For performance-critical paths, consider `stacker` (runtime stack growth) or iterative reformulation.

What This Unlocks

Key Differences

ConceptOCamlRust
Tail recursionTCO guaranteedNot guaranteed โ†’ use trampoline
Thunk type`unit -> 'a` (lazy)`Box<dyn FnOnce() -> Bounce<T>>`
Driver loopHandled by runtimeExplicit `run()` loop
AllocationGC closure`Box` per step
// Trampoline type
enum Bounce<T> {
    Done(T),
    More(Box<dyn FnOnce() -> Bounce<T>>),
}

fn run<T>(mut b: Bounce<T>) -> T {
    loop {
        match b {
            Bounce::Done(v)   => return v,
            Bounce::More(th)  => b = th(),
        }
    }
}

// Stack-safe factorial
fn fact_t(n: u64, acc: u64) -> Bounce<u64> {
    if n == 0 { Bounce::Done(acc) }
    else      { Bounce::More(Box::new(move || fact_t(n-1, n*acc))) }
}

// Mutually recursive even/odd โ€” stack-safe!
fn even_t(n: u64) -> Bounce<bool> {
    if n == 0 { Bounce::Done(true)  }
    else      { Bounce::More(Box::new(move || odd_t(n-1))) }
}

fn odd_t(n: u64) -> Bounce<bool> {
    if n == 0 { Bounce::Done(false) }
    else      { Bounce::More(Box::new(move || even_t(n-1))) }
}

// Count-down: would stack-overflow without trampoline at large N
fn count_t(n: u64) -> Bounce<u64> {
    if n == 0 { Bounce::Done(0) }
    else      { Bounce::More(Box::new(move || count_t(n-1))) }
}

fn main() {
    println!("5! = {}", run(fact_t(5, 1)));
    println!("20! = {}", run(fact_t(20, 1)));
    println!("even(100) = {}", run(even_t(100)));
    println!("even(101) = {}", run(even_t(101)));
    // Stack-safe at depth 100_000
    println!("count(100000) = {}", run(count_t(100_000)));
}

#[cfg(test)]
mod tests {
    use super::*;
    #[test] fn fact5()     { assert_eq!(run(fact_t(5,1)), 120); }
    #[test] fn even100()   { assert!(run(even_t(100))); }
    #[test] fn odd101()    { assert!(run(odd_t(101)));  }
    #[test] fn deep()      { assert_eq!(run(count_t(50_000)), 0); }
}
(* Trampoline in OCaml *)
type 'a bounce = Done of 'a | Bounce of (unit -> 'a bounce)

let run t =
  let rec go = function
    | Done v    -> v
    | Bounce th -> go (th ())
  in go t

let rec fact_t n acc =
  if n <= 0 then Done acc
  else Bounce (fun () -> fact_t (n-1) (n*acc))

let rec even_t n =
  if n = 0 then Done true
  else Bounce (fun () -> odd_t  (n-1))
and odd_t n =
  if n = 0 then Done false
  else Bounce (fun () -> even_t (n-1))

let () =
  Printf.printf "100! > 0: %b\n" (run (fact_t 100 1) > 0);
  Printf.printf "even 100: %b\n" (run (even_t 100));
  Printf.printf "even 101: %b\n" (run (even_t 101))