// Continuation-Passing Style (CPS) โ Transform Functions to Pass Results
//
// Instead of returning a value, a CPS function takes an extra argument `k`
// (the "continuation") and calls k(result) instead of returning.
//
// Benefits:
// * Always tail-recursive โ no stack growth for deep recursion
// * First-class control flow (early exit, backtracking)
// * Foundation for coroutines, async, and continuations
use std::rc::Rc;
// โโ Factorial โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
/// Direct style โ not tail recursive (stack grows with n)
fn factorial_direct(n: u64) -> u64 {
if n <= 1 { 1 } else { n * factorial_direct(n - 1) }
}
/// CPS style โ always tail recursive.
/// We use Rc<dyn Fn> so continuations can be shared (cloned) in closures.
fn factorial_cps(n: u64, k: Rc<dyn Fn(u64) -> u64>) -> u64 {
if n <= 1 {
k(1)
} else {
let k2 = k.clone();
factorial_cps(n - 1, Rc::new(move |result| k2(n * result)))
}
}
// โโ Fibonacci โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
/// Direct style
fn fibonacci_direct(n: u64) -> u64 {
if n <= 1 { n } else { fibonacci_direct(n - 1) + fibonacci_direct(n - 2) }
}
/// CPS style โ Rc lets the continuation be shared between both branches
fn fibonacci_cps(n: u64, k: Rc<dyn Fn(u64) -> u64>) -> u64 {
if n <= 1 {
k(n)
} else {
let k2 = k.clone();
fibonacci_cps(
n - 1,
Rc::new(move |a| {
let k3 = k2.clone();
fibonacci_cps(n - 2, Rc::new(move |b| k3(a + b)))
}),
)
}
}
// โโ Sum of a list โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
fn sum_list_cps(list: &[i64], acc: i64, k: &dyn Fn(i64) -> i64) -> i64 {
if list.is_empty() {
k(acc)
} else {
sum_list_cps(&list[1..], acc + list[0], k)
}
}
// โโ Early exit: find first matching element โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
/// CPS-style find: calls `found` on the first match, `not_found` if none.
/// Two continuations = two possible outcomes.
fn find_cps<T: Copy>(
pred: &dyn Fn(T) -> bool,
list: &[T],
found: &dyn Fn(T) -> Option<T>,
not_found: &dyn Fn() -> Option<T>,
) -> Option<T> {
if list.is_empty() {
not_found()
} else if pred(list[0]) {
found(list[0])
} else {
find_cps(pred, &list[1..], found, not_found)
}
}
// โโ Map a list in CPS โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
fn map_cps<A: Copy, B: Clone>(
list: &[A],
f: &dyn Fn(A) -> B,
k: &dyn Fn(Vec<B>) -> Vec<B>,
) -> Vec<B> {
k(list.iter().map(|&x| f(x)).collect())
}
// โโ CPS transform: "lift" any function into CPS โโโโโโโโโโโโโโโโโโโโโโโโโโโ
fn lift_cps<A, B, R>(f: impl Fn(A) -> B, x: A, k: impl Fn(B) -> R) -> R {
k(f(x))
}
// โโ Accumulate a list in CPS (tail-recursive) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
fn fold_cps<A: Copy, B: Clone>(
list: &[A],
init: B,
f: &dyn Fn(B, A) -> B,
k: &dyn Fn(B) -> B,
) -> B {
if list.is_empty() {
k(init)
} else {
let next = f(init, list[0]);
fold_cps(&list[1..], next, f, k)
}
}
fn main() {
let id: Rc<dyn Fn(u64) -> u64> = Rc::new(|x| x);
// โโ Factorial โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
println!("5! direct = {}", factorial_direct(5));
let r = factorial_cps(5, id.clone());
println!("CPS 5! = {}", r);
let r10 = factorial_cps(10, id.clone());
println!("CPS 10! = {}", r10);
println!();
// โโ Fibonacci โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
println!("fib(10) direct = {}", fibonacci_direct(10));
let fib10 = fibonacci_cps(10, id.clone());
println!("CPS fib(10) = {}", fib10);
println!();
// โโ Sum list โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
let lst = vec![1_i64, 2, 3, 4, 5];
let sum = sum_list_cps(&lst, 0, &|x| x);
println!("sum [1..5] = {}", sum);
println!();
// โโ Find โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
let nums = [1_i64, 3, 5, 8, 9, 12];
let first_even = find_cps(
&|x| x % 2 == 0,
&nums,
&|x| Some(x),
&|| None,
);
println!("First even: {:?}", first_even);
let gt100 = find_cps(
&|x: i64| x > 100,
&nums,
&|x| Some(x),
&|| None,
);
println!("First > 100: {:?}", gt100);
println!();
// โโ Lift โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
lift_cps(|x: i32| x * x, 7, |r| println!("7ยฒ = {}", r));
// โโ Map โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
let squares = map_cps(&[1_i32, 2, 3, 4, 5], &|x| x * x, &|v| v);
println!("squares: {:?}", squares);
// โโ Fold โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
let product = fold_cps(&[1_i32, 2, 3, 4, 5], 1, &|acc, x| acc * x, &|x| x);
println!("product [1..5] = {}", product);
}
#[cfg(test)]
mod tests {
use super::*;
use std::rc::Rc;
#[test]
fn test_factorial_cps_matches_direct() {
let id: Rc<dyn Fn(u64) -> u64> = Rc::new(|x| x);
for n in 1..=10u64 {
let direct = factorial_direct(n);
let cps = factorial_cps(n, id.clone());
assert_eq!(direct, cps, "n={}", n);
}
}
#[test]
fn test_fibonacci_cps_matches_direct() {
let id: Rc<dyn Fn(u64) -> u64> = Rc::new(|x| x);
for n in 0..=10u64 {
let direct = fibonacci_direct(n);
let cps = fibonacci_cps(n, id.clone());
assert_eq!(direct, cps, "n={}", n);
}
}
#[test]
fn test_sum_list_cps() {
let result = sum_list_cps(&[1, 2, 3, 4, 5], 0, &|s| s);
assert_eq!(result, 15);
}
#[test]
fn test_find_cps_found() {
let nums = [1_i64, 3, 4, 5];
let result = find_cps(&|x| x % 2 == 0, &nums, &|x| Some(x), &|| None);
assert_eq!(result, Some(4));
}
#[test]
fn test_find_cps_not_found() {
let nums = [1_i64, 3, 5];
let result = find_cps(&|x| x % 2 == 0, &nums, &|x| Some(x), &|| None);
assert_eq!(result, None);
}
#[test]
fn test_lift_cps() {
let result = lift_cps(|x: i32| x + 1, 41, |r| r);
assert_eq!(result, 42);
}
#[test]
fn test_fold_cps_product() {
let r = fold_cps(&[1_i32, 2, 3, 4, 5], 1, &|acc, x| acc * x, &|x| x);
assert_eq!(r, 120);
}
}
(* Continuation-passing style (CPS): instead of returning a value,
pass it to a continuation function 'k'. Enables:
- Tail-call optimization for any recursion
- First-class control flow
- Building coroutines, generators *)
(* Direct style *)
let rec factorial_direct n =
if n <= 1 then 1 else n * factorial_direct (n - 1)
(* CPS style: always tail recursive *)
let factorial_cps n k =
let rec go n k =
if n <= 1 then k 1
else go (n - 1) (fun result -> k (n * result))
in go n k
(* Fibonacci in CPS *)
let fibonacci_cps n k =
let rec go n k =
if n <= 1 then k n
else
go (n - 1) (fun a ->
go (n - 2) (fun b ->
k (a + b)))
in go n k
(* Early exit: CPS allows "throwing" to a continuation *)
let find_cps pred lst found not_found =
let rec go = function
| [] -> not_found ()
| x :: xs -> if pred x then found x else go xs
in go lst
let () =
Printf.printf "5! = %d\n" (factorial_direct 5);
factorial_cps 5 (Printf.printf "CPS 5! = %d\n");
fibonacci_cps 10 (Printf.printf "CPS fib(10) = %d\n");
let lst = [1; 3; 5; 8; 9; 12] in
find_cps (fun x -> x mod 2 = 0) lst
(fun x -> Printf.printf "First even: %d\n" x)
(fun () -> Printf.printf "No evens\n")