// Example 215: Recursion Schemes Intro — Separating Recursion from Logic
// === Direct recursion: recursion mixed with logic ===
#[derive(Debug, Clone)]
enum Expr {
Lit(i64),
Add(Box<Expr>, Box<Expr>),
Mul(Box<Expr>, Box<Expr>),
}
// Approach 1: Every function reimplements the same traversal
fn eval(e: &Expr) -> i64 {
match e {
Expr::Lit(n) => *n,
Expr::Add(a, b) => eval(a) + eval(b),
Expr::Mul(a, b) => eval(a) * eval(b),
}
}
fn show(e: &Expr) -> String {
match e {
Expr::Lit(n) => n.to_string(),
Expr::Add(a, b) => format!("({} + {})", show(a), show(b)),
Expr::Mul(a, b) => format!("({} * {})", show(a), show(b)),
}
}
fn depth(e: &Expr) -> usize {
match e {
Expr::Lit(_) => 0,
Expr::Add(a, b) | Expr::Mul(a, b) => 1 + depth(a).max(depth(b)),
}
}
// Approach 2: Factor out recursion with a non-recursive functor
#[derive(Debug, Clone)]
enum ExprF<A> {
LitF(i64),
AddF(A, A),
MulF(A, A),
}
impl<A> ExprF<A> {
fn map<B>(self, f: impl Fn(A) -> B) -> ExprF<B> {
match self {
ExprF::LitF(n) => ExprF::LitF(n),
ExprF::AddF(a, b) => ExprF::AddF(f(a), f(b)),
ExprF::MulF(a, b) => ExprF::MulF(f(a), f(b)),
}
}
fn map_ref<B>(&self, f: impl Fn(&A) -> B) -> ExprF<B> {
match self {
ExprF::LitF(n) => ExprF::LitF(*n),
ExprF::AddF(a, b) => ExprF::AddF(f(a), f(b)),
ExprF::MulF(a, b) => ExprF::MulF(f(a), f(b)),
}
}
}
// Fix point
#[derive(Debug, Clone)]
struct Fix(Box<ExprF<Fix>>);
impl Fix {
fn unfix(&self) -> &ExprF<Fix> { &self.0 }
}
// cata: the universal fold
fn cata<A>(alg: &dyn Fn(ExprF<A>) -> A, fix: &Fix) -> A {
alg(fix.unfix().map_ref(|child| cata(alg, child)))
}
// Algebras — NO recursion, just one layer of logic
fn eval_alg(e: ExprF<i64>) -> i64 {
match e {
ExprF::LitF(n) => n,
ExprF::AddF(a, b) => a + b,
ExprF::MulF(a, b) => a * b,
}
}
fn show_alg(e: ExprF<String>) -> String {
match e {
ExprF::LitF(n) => n.to_string(),
ExprF::AddF(a, b) => format!("({} + {})", a, b),
ExprF::MulF(a, b) => format!("({} * {})", a, b),
}
}
fn depth_alg(e: ExprF<usize>) -> usize {
match e {
ExprF::LitF(_) => 0,
ExprF::AddF(a, b) | ExprF::MulF(a, b) => 1 + a.max(b),
}
}
// Approach 3: Helpers to build Fix expressions
fn lit(n: i64) -> Fix { Fix(Box::new(ExprF::LitF(n))) }
fn add(a: Fix, b: Fix) -> Fix { Fix(Box::new(ExprF::AddF(a, b))) }
fn mul(a: Fix, b: Fix) -> Fix { Fix(Box::new(ExprF::MulF(a, b))) }
fn main() {
// Direct recursion
let e1 = Expr::Add(
Box::new(Expr::Lit(1)),
Box::new(Expr::Mul(Box::new(Expr::Lit(2)), Box::new(Expr::Lit(3)))),
);
assert_eq!(eval(&e1), 7);
assert_eq!(show(&e1), "(1 + (2 * 3))");
assert_eq!(depth(&e1), 2);
// Catamorphism — same results, no recursion in algebras!
let e2 = add(lit(1), mul(lit(2), lit(3)));
assert_eq!(cata(&eval_alg, &e2), 7);
assert_eq!(cata(&show_alg, &e2), "(1 + (2 * 3))");
assert_eq!(cata(&depth_alg, &e2), 2);
// New algebras are trivial — just define one layer
let count_lits = |e: ExprF<usize>| match e {
ExprF::LitF(_) => 1,
ExprF::AddF(a, b) | ExprF::MulF(a, b) => a + b,
};
assert_eq!(cata(&count_lits, &e2), 3);
let count_ops = |e: ExprF<usize>| match e {
ExprF::LitF(_) => 0,
ExprF::AddF(a, b) | ExprF::MulF(a, b) => 1 + a + b,
};
assert_eq!(cata(&count_ops, &e2), 2);
println!("✓ All tests passed");
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_eval_cata() {
let e = mul(add(lit(2), lit(3)), lit(4));
assert_eq!(cata(&eval_alg, &e), 20);
}
#[test]
fn test_show_cata() {
let e = add(lit(1), lit(2));
assert_eq!(cata(&show_alg, &e), "(1 + 2)");
}
#[test]
fn test_depth_cata() {
let e = add(add(lit(1), lit(2)), lit(3));
assert_eq!(cata(&depth_alg, &e), 2);
}
#[test]
fn test_single_lit() {
assert_eq!(cata(&eval_alg, &lit(42)), 42);
assert_eq!(cata(&depth_alg, &lit(42)), 0);
}
}
(* Example 215: Recursion Schemes Intro — Separating Recursion from Logic *)
(* THE PROBLEM: Recursion is scattered through business logic.
Every function that processes a tree re-implements the same traversal.
Solution: Factor out the recursion pattern. *)
(* === A simple expression type === *)
type expr =
| Lit of int
| Add of expr * expr
| Mul of expr * expr
(* Approach 1: Direct recursion — recursion mixed with logic *)
let rec eval = function
| Lit n -> n
| Add (a, b) -> eval a + eval b
| Mul (a, b) -> eval a * eval b
let rec show = function
| Lit n -> string_of_int n
| Add (a, b) -> "(" ^ show a ^ " + " ^ show b ^ ")"
| Mul (a, b) -> "(" ^ show a ^ " * " ^ show b ^ ")"
let rec depth = function
| Lit _ -> 0
| Add (a, b) | Mul (a, b) -> 1 + max (depth a) (depth b)
(* Notice: every function has the SAME recursive structure!
The only thing that differs is the "algebra" — what to do at each node. *)
(* Approach 2: Factor out recursion with a non-recursive functor *)
type 'a expr_f =
| LitF of int
| AddF of 'a * 'a
| MulF of 'a * 'a
(* The functor's map *)
let map_expr_f (f : 'a -> 'b) : 'a expr_f -> 'b expr_f = function
| LitF n -> LitF n
| AddF (a, b) -> AddF (f a, f b)
| MulF (a, b) -> MulF (f a, f b)
(* Fix point: tie the recursive knot *)
type fix = Fix of fix expr_f
let unfix (Fix f) = f
(* cata: the universal fold *)
let rec cata (alg : 'a expr_f -> 'a) (Fix f : fix) : 'a =
alg (map_expr_f (cata alg) f)
(* Now define eval, show, depth as ALGEBRAS — no recursion! *)
let eval_alg = function
| LitF n -> n
| AddF (a, b) -> a + b
| MulF (a, b) -> a * b
let show_alg = function
| LitF n -> string_of_int n
| AddF (a, b) -> "(" ^ a ^ " + " ^ b ^ ")"
| MulF (a, b) -> "(" ^ a ^ " * " ^ b ^ ")"
let depth_alg = function
| LitF _ -> 0
| AddF (a, b) | MulF (a, b) -> 1 + max a b
(* Approach 3: Compare side by side *)
let eval2 = cata eval_alg
let show2 = cata show_alg
let depth2 = cata depth_alg
(* Helper to build fix-point expressions *)
let lit n = Fix (LitF n)
let add a b = Fix (AddF (a, b))
let mul a b = Fix (MulF (a, b))
(* === Tests === *)
let () =
(* Direct recursion *)
let e1 = Add (Lit 1, Mul (Lit 2, Lit 3)) in
assert (eval e1 = 7);
assert (show e1 = "(1 + (2 * 3))");
assert (depth e1 = 2);
(* Catamorphism — same results! *)
let e2 = add (lit 1) (mul (lit 2) (lit 3)) in
assert (eval2 e2 = 7);
assert (show2 e2 = "(1 + (2 * 3))");
assert (depth2 e2 = 2);
(* The algebra is the ONLY thing that changes *)
let count_lits = function
| LitF _ -> 1
| AddF (a, b) | MulF (a, b) -> a + b
in
assert (cata count_lits e2 = 3);
let count_ops = function
| LitF _ -> 0
| AddF (a, b) | MulF (a, b) -> 1 + a + b
in
assert (cata count_ops e2 = 2);
print_endline "✓ All tests passed"