๐Ÿฆ€ Functional Rust

195: Continuation-Passing Style

Difficulty: 3 Level: Advanced Transform any function to pass its result to a callback instead of returning โ€” unlocking early exit, async, and stack-safe recursion.

The Problem This Solves

Normal Rust functions work like this: you call them, they compute something, they return. You wait for the return. Everything stacks up โ€” literally. Each nested call adds a frame to the call stack. For most code this is fine. But three situations break this model badly: First: deep recursion. A tree with 100,000 levels, a recursive descent parser, a recursive interpreter โ€” these eat stack space call by call until Rust panics with a stack overflow. You could rewrite as a loop with an explicit stack, but that destroys the clarity of the recursive structure. Second: early exit across multiple levels. You want to abort the whole computation when you find a match, not bubble a flag up through every layer. In normal code you end up threading `Option` or `bool` through every function signature โ€” ugly and error-prone. Third: async and coroutines. How does a function "suspend" and "resume"? It has to be able to say "here's what to do with my result when I eventually produce it." That IS a continuation. CPS solves all three by making "what happens next" a first-class value you can pass, store, modify, and call. This example exists to solve exactly that pain.

The Intuition

Think of CPS as flipping functions inside out. Normal function: `compute(inputs) -> output` CPS function: `compute(inputs, callback)` โ€” callback receives the output
// Normal: returns a value
fn add(a: i32, b: i32) -> i32 { a + b }

// CPS: calls callback with the value
fn add_cps<R>(a: i32, b: i32, k: impl Fn(i32) -> R) -> R {
 k(a + b)  // "k" = continuation, the "what to do next"
}

// To use it: pass what you want done with the result
add_cps(3, 4, |sum| println!("Sum is {}", sum));  // prints "Sum is 7"
add_cps(3, 4, |sum| sum * 2);                     // returns 14
The magic: the continuation represents everything that happens after this point. When you make that explicit, you can: In Rust, you see CPS every day โ€” the `?` operator is syntactic sugar for CPS transformation on `Result`.

How It Works in Rust

Basic CPS factorial:
use std::rc::Rc;

// Rc<dyn Fn> lets continuations be shared (cloned) when both branches need the same k
fn factorial_cps(n: u64, k: Rc<dyn Fn(u64) -> u64>) -> u64 {
 if n <= 1 {
     k(1)  // base case: give 1 to the continuation
 } else {
     let k2 = k.clone();  // need to share k between the closure and recursive call
     factorial_cps(
         n - 1,
         Rc::new(move |result| k2(n * result))  // new continuation wraps the multiply
     )
 }
}

let id: Rc<dyn Fn(u64) -> u64> = Rc::new(|x| x);  // identity: "just return the result"
assert_eq!(factorial_cps(5, id), 120);
Two continuations = two outcomes (early exit):
fn find_cps<T: Copy>(
 pred: &dyn Fn(T) -> bool,
 list: &[T],
 found: &dyn Fn(T) -> Option<T>,    // continuation for success
 not_found: &dyn Fn() -> Option<T>, // continuation for failure
) -> Option<T> {
 if list.is_empty() {
     not_found()                     // exhausted: take the failure path
 } else if pred(list[0]) {
     found(list[0])                  // match: jump directly to success, skip the rest
 } else {
     find_cps(pred, &list[1..], found, not_found)
 }
}

// Find first even number โ€” stops as soon as one is found
let result = find_cps(
 &|x| x % 2 == 0,
 &[1, 3, 4, 5, 8],
 &|x| Some(x),   // found: return it
 &|| None,        // not found: return None
);
assert_eq!(result, Some(4));  // stops at 4, never looks at 5 or 8
Lifting any function into CPS (generic transformer):
fn lift_cps<A, B, R>(f: impl Fn(A) -> B, x: A, k: impl Fn(B) -> R) -> R {
 k(f(x))  // compute f(x), pass result to k
}

// Turn any normal function into CPS instantly:
lift_cps(|x: i32| x * x, 7, |r| println!("7ยฒ = {}", r));  // prints "7ยฒ = 49"
Tail-recursive fold in CPS:
fn fold_cps<A: Copy, B: Clone>(
 list: &[A],
 init: B,
 f: &dyn Fn(B, A) -> B,
 k: &dyn Fn(B) -> B,  // the final continuation: what to do with the result
) -> B {
 if list.is_empty() {
     k(init)  // done โ€” call the final continuation
 } else {
     let next = f(init, list[0]);
     fold_cps(&list[1..], next, f, k)  // tail position: no stack frame kept
 }
}

What This Unlocks

Key Differences

ConceptOCamlRust
Continuation type`'a -> 'b` (any function)`Rc<dyn Fn(A) -> B>` or `impl Fn`
Sharing continuationsAutomatic (GC)Needs `Rc` for multiple references
Stack safety with CPSYes (TCO guaranteed)No โ€” needs trampoline on top
The `?` operator`result >>= k` (bind)CPS in disguise
Recommended for prodYes (elegant in OCaml)Use iterators; CPS = educational
Each continuationStack-allocatedHeap-allocated `Box` or `Rc`
// Continuation-Passing Style (CPS) โ€” Transform Functions to Pass Results
//
// Instead of returning a value, a CPS function takes an extra argument `k`
// (the "continuation") and calls k(result) instead of returning.
//
// Benefits:
//   * Always tail-recursive โ€” no stack growth for deep recursion
//   * First-class control flow (early exit, backtracking)
//   * Foundation for coroutines, async, and continuations

use std::rc::Rc;

// โ”€โ”€ Factorial โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

/// Direct style โ€” not tail recursive (stack grows with n)
fn factorial_direct(n: u64) -> u64 {
    if n <= 1 { 1 } else { n * factorial_direct(n - 1) }
}

/// CPS style โ€” always tail recursive.
/// We use Rc<dyn Fn> so continuations can be shared (cloned) in closures.
fn factorial_cps(n: u64, k: Rc<dyn Fn(u64) -> u64>) -> u64 {
    if n <= 1 {
        k(1)
    } else {
        let k2 = k.clone();
        factorial_cps(n - 1, Rc::new(move |result| k2(n * result)))
    }
}

// โ”€โ”€ Fibonacci โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

/// Direct style
fn fibonacci_direct(n: u64) -> u64 {
    if n <= 1 { n } else { fibonacci_direct(n - 1) + fibonacci_direct(n - 2) }
}

/// CPS style โ€” Rc lets the continuation be shared between both branches
fn fibonacci_cps(n: u64, k: Rc<dyn Fn(u64) -> u64>) -> u64 {
    if n <= 1 {
        k(n)
    } else {
        let k2 = k.clone();
        fibonacci_cps(
            n - 1,
            Rc::new(move |a| {
                let k3 = k2.clone();
                fibonacci_cps(n - 2, Rc::new(move |b| k3(a + b)))
            }),
        )
    }
}

// โ”€โ”€ Sum of a list โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

fn sum_list_cps(list: &[i64], acc: i64, k: &dyn Fn(i64) -> i64) -> i64 {
    if list.is_empty() {
        k(acc)
    } else {
        sum_list_cps(&list[1..], acc + list[0], k)
    }
}

// โ”€โ”€ Early exit: find first matching element โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

/// CPS-style find: calls `found` on the first match, `not_found` if none.
/// Two continuations = two possible outcomes.
fn find_cps<T: Copy>(
    pred: &dyn Fn(T) -> bool,
    list: &[T],
    found: &dyn Fn(T) -> Option<T>,
    not_found: &dyn Fn() -> Option<T>,
) -> Option<T> {
    if list.is_empty() {
        not_found()
    } else if pred(list[0]) {
        found(list[0])
    } else {
        find_cps(pred, &list[1..], found, not_found)
    }
}

// โ”€โ”€ Map a list in CPS โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

fn map_cps<A: Copy, B: Clone>(
    list: &[A],
    f: &dyn Fn(A) -> B,
    k: &dyn Fn(Vec<B>) -> Vec<B>,
) -> Vec<B> {
    k(list.iter().map(|&x| f(x)).collect())
}

// โ”€โ”€ CPS transform: "lift" any function into CPS โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

fn lift_cps<A, B, R>(f: impl Fn(A) -> B, x: A, k: impl Fn(B) -> R) -> R {
    k(f(x))
}

// โ”€โ”€ Accumulate a list in CPS (tail-recursive) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

fn fold_cps<A: Copy, B: Clone>(
    list: &[A],
    init: B,
    f: &dyn Fn(B, A) -> B,
    k: &dyn Fn(B) -> B,
) -> B {
    if list.is_empty() {
        k(init)
    } else {
        let next = f(init, list[0]);
        fold_cps(&list[1..], next, f, k)
    }
}

fn main() {
    let id: Rc<dyn Fn(u64) -> u64> = Rc::new(|x| x);

    // โ”€โ”€ Factorial โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    println!("5! direct = {}", factorial_direct(5));
    let r = factorial_cps(5, id.clone());
    println!("CPS 5! = {}", r);
    let r10 = factorial_cps(10, id.clone());
    println!("CPS 10! = {}", r10);

    println!();

    // โ”€โ”€ Fibonacci โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    println!("fib(10) direct = {}", fibonacci_direct(10));
    let fib10 = fibonacci_cps(10, id.clone());
    println!("CPS fib(10) = {}", fib10);

    println!();

    // โ”€โ”€ Sum list โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    let lst = vec![1_i64, 2, 3, 4, 5];
    let sum = sum_list_cps(&lst, 0, &|x| x);
    println!("sum [1..5] = {}", sum);

    println!();

    // โ”€โ”€ Find โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    let nums = [1_i64, 3, 5, 8, 9, 12];
    let first_even = find_cps(
        &|x| x % 2 == 0,
        &nums,
        &|x| Some(x),
        &|| None,
    );
    println!("First even: {:?}", first_even);

    let gt100 = find_cps(
        &|x: i64| x > 100,
        &nums,
        &|x| Some(x),
        &|| None,
    );
    println!("First > 100: {:?}", gt100);

    println!();

    // โ”€โ”€ Lift โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    lift_cps(|x: i32| x * x, 7, |r| println!("7ยฒ = {}", r));

    // โ”€โ”€ Map โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    let squares = map_cps(&[1_i32, 2, 3, 4, 5], &|x| x * x, &|v| v);
    println!("squares: {:?}", squares);

    // โ”€โ”€ Fold โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    let product = fold_cps(&[1_i32, 2, 3, 4, 5], 1, &|acc, x| acc * x, &|x| x);
    println!("product [1..5] = {}", product);
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::rc::Rc;

    #[test]
    fn test_factorial_cps_matches_direct() {
        let id: Rc<dyn Fn(u64) -> u64> = Rc::new(|x| x);
        for n in 1..=10u64 {
            let direct = factorial_direct(n);
            let cps = factorial_cps(n, id.clone());
            assert_eq!(direct, cps, "n={}", n);
        }
    }

    #[test]
    fn test_fibonacci_cps_matches_direct() {
        let id: Rc<dyn Fn(u64) -> u64> = Rc::new(|x| x);
        for n in 0..=10u64 {
            let direct = fibonacci_direct(n);
            let cps = fibonacci_cps(n, id.clone());
            assert_eq!(direct, cps, "n={}", n);
        }
    }

    #[test]
    fn test_sum_list_cps() {
        let result = sum_list_cps(&[1, 2, 3, 4, 5], 0, &|s| s);
        assert_eq!(result, 15);
    }

    #[test]
    fn test_find_cps_found() {
        let nums = [1_i64, 3, 4, 5];
        let result = find_cps(&|x| x % 2 == 0, &nums, &|x| Some(x), &|| None);
        assert_eq!(result, Some(4));
    }

    #[test]
    fn test_find_cps_not_found() {
        let nums = [1_i64, 3, 5];
        let result = find_cps(&|x| x % 2 == 0, &nums, &|x| Some(x), &|| None);
        assert_eq!(result, None);
    }

    #[test]
    fn test_lift_cps() {
        let result = lift_cps(|x: i32| x + 1, 41, |r| r);
        assert_eq!(result, 42);
    }

    #[test]
    fn test_fold_cps_product() {
        let r = fold_cps(&[1_i32, 2, 3, 4, 5], 1, &|acc, x| acc * x, &|x| x);
        assert_eq!(r, 120);
    }
}
(* Continuation-passing style (CPS): instead of returning a value,
   pass it to a continuation function 'k'. Enables:
   - Tail-call optimization for any recursion
   - First-class control flow
   - Building coroutines, generators *)

(* Direct style *)
let rec factorial_direct n =
  if n <= 1 then 1 else n * factorial_direct (n - 1)

(* CPS style: always tail recursive *)
let factorial_cps n k =
  let rec go n k =
    if n <= 1 then k 1
    else go (n - 1) (fun result -> k (n * result))
  in go n k

(* Fibonacci in CPS *)
let fibonacci_cps n k =
  let rec go n k =
    if n <= 1 then k n
    else
      go (n - 1) (fun a ->
        go (n - 2) (fun b ->
          k (a + b)))
  in go n k

(* Early exit: CPS allows "throwing" to a continuation *)
let find_cps pred lst found not_found =
  let rec go = function
    | []     -> not_found ()
    | x :: xs -> if pred x then found x else go xs
  in go lst

let () =
  Printf.printf "5! = %d\n" (factorial_direct 5);
  factorial_cps 5 (Printf.printf "CPS 5! = %d\n");
  fibonacci_cps 10 (Printf.printf "CPS fib(10) = %d\n");

  let lst = [1; 3; 5; 8; 9; 12] in
  find_cps (fun x -> x mod 2 = 0) lst
    (fun x -> Printf.printf "First even: %d\n" x)
    (fun () -> Printf.printf "No evens\n")