Map and Fold on Trees
Tutorial
The Problem
Implement map_tree and fold_tree as the two fundamental higher-order operations on binary trees, then derive size, depth, sum, preorder, and inorder from fold_tree alone — without writing any additional recursion. fold_tree is the tree catamorphism: it captures the complete recursive structure of the type, replacing each Leaf with a base value and each Node with a combining function. Understanding it reveals why higher-order functions eliminate boilerplate and why functional programmers reach for folds as their first tool when reducing a data structure.
🎯 Learning Outcomes
fold_tree completely characterizes the Tree type's recursive structuresize, depth, sum, preorder, and inorder as one-liners using a single fold with no additional recursionU: Clone is required in Rust's fold_tree but not in OCaml's version, and how the garbage collector changes the picturemap_tree preserves the tree's shape while transforming values — making Tree a functor over its element type&F through recursive calls to avoid consuming the closure on the first useCode Example
pub fn map_tree<T, U, F: Fn(T) -> U>(tree: Tree<T>, f: &F) -> Tree<U> {
match tree {
Tree::Leaf => Tree::Leaf,
Tree::Node(v, l, r) => Tree::Node(
f(v),
Box::new(map_tree(*l, f)),
Box::new(map_tree(*r, f)),
),
}
}
pub fn fold_tree<T, U: Clone, F: Fn(T, U, U) -> U>(tree: Tree<T>, acc: U, f: &F) -> U {
match tree {
Tree::Leaf => acc,
Tree::Node(v, l, r) => {
let left = fold_tree(*l, acc.clone(), f);
let right = fold_tree(*r, acc, f);
f(v, left, right)
}
}
}Key Differences
U: Clone because acc must be duplicated for the two independent subtree folds; OCaml shares the same value freely since the GC tracks all references&F through every recursive call to prevent consuming the closure; OCaml closures are heap-allocated reference values and always shareablesize, depth, sum, preorder, and inorder as one-liners over fold_tree — the expressive power is identical, only the syntax differsfold_tree consumes the tree (tree: Tree<T>) transferring ownership through the recursion; OCaml's fold borrows implicitly since the GC handles memoryOCaml Approach
OCaml's fold_tree f acc is a curried two-argument let rec function using function for pattern matching. On a Leaf it returns acc; on a Node(v, l, r) it computes f v (fold_tree f acc l) (fold_tree f acc r) — calling the user's function with the node value and the results of folding both subtrees. The same acc is shared freely to both subtrees because OCaml's GC manages value lifetime. All five derived operations are single fold_tree applications: size t = fold_tree (fun _ l r -> 1 + l + r) 0 t. The elegance is in the zero additional recursion needed.
Full Source
#![allow(clippy::all)]
//! Map and Fold on Trees
//! See example.ml for OCaml reference
#[derive(Debug, Clone, PartialEq)]
pub enum Tree<T> {
Leaf,
Node(T, Box<Tree<T>>, Box<Tree<T>>),
}
/// Apply `f` to every node value, producing a new tree of the same shape.
/// Mirrors OCaml's `let rec map_tree f = function | Leaf -> Leaf | Node(v,l,r) -> Node(f v, ...)`.
pub fn map_tree<T, U, F: Fn(T) -> U>(tree: Tree<T>, f: &F) -> Tree<U> {
match tree {
Tree::Leaf => Tree::Leaf,
Tree::Node(v, l, r) => Tree::Node(
f(v),
Box::new(map_tree(*l, f)),
Box::new(map_tree(*r, f)),
),
}
}
/// Structural fold: reduce a tree to a single value by combining each node's value
/// with the results of folding its two subtrees.
/// `f(v, left_result, right_result)` — both subtrees fold with the same initial `acc`.
pub fn fold_tree<T, U, F>(tree: Tree<T>, acc: U, f: &F) -> U
where
F: Fn(T, U, U) -> U,
U: Clone,
{
match tree {
Tree::Leaf => acc,
Tree::Node(v, l, r) => {
let l_result = fold_tree(*l, acc.clone(), f);
let r_result = fold_tree(*r, acc, f);
f(v, l_result, r_result)
}
}
}
pub fn size<T>(tree: Tree<T>) -> usize {
fold_tree(tree, 0usize, &|_, l, r| 1 + l + r)
}
pub fn depth<T>(tree: Tree<T>) -> usize {
fold_tree(tree, 0usize, &|_, l, r| 1 + l.max(r))
}
pub fn sum(tree: Tree<i32>) -> i32 {
fold_tree(tree, 0i32, &|v, l, r| v + l + r)
}
pub fn preorder<T: Clone>(tree: Tree<T>) -> Vec<T> {
fold_tree(tree, vec![], &|v, l, r| {
let mut result = vec![v];
result.extend(l);
result.extend(r);
result
})
}
pub fn inorder<T: Clone>(tree: Tree<T>) -> Vec<T> {
fold_tree(tree, vec![], &|v, l, r| {
let mut result = l;
result.push(v);
result.extend(r);
result
})
}
#[cfg(test)]
mod tests {
use super::*;
use Tree::{Leaf, Node};
// 4
// / \
// 2 6
// / \
// 1 3
fn sample() -> Tree<i32> {
Node(
4,
Box::new(Node(
2,
Box::new(Node(1, Box::new(Leaf), Box::new(Leaf))),
Box::new(Node(3, Box::new(Leaf), Box::new(Leaf))),
)),
Box::new(Node(6, Box::new(Leaf), Box::new(Leaf))),
)
}
#[test]
fn test_size() {
assert_eq!(size(sample()), 5);
assert_eq!(size(Leaf::<i32>), 0);
}
#[test]
fn test_depth() {
assert_eq!(depth(sample()), 3);
assert_eq!(depth(Leaf::<i32>), 0);
}
#[test]
fn test_sum() {
assert_eq!(sum(sample()), 16); // 1+2+3+4+6
assert_eq!(sum(Leaf), 0);
}
#[test]
fn test_preorder() {
assert_eq!(preorder(sample()), vec![4, 2, 1, 3, 6]);
}
#[test]
fn test_inorder() {
assert_eq!(inorder(sample()), vec![1, 2, 3, 4, 6]);
}
#[test]
fn test_map_tree() {
let doubled = map_tree(sample(), &|v| v * 2);
assert_eq!(sum(doubled), 32); // 2+4+6+8+12
}
#[test]
fn test_map_preserves_shape() {
let t = map_tree(sample(), &|v| v.to_string());
assert_eq!(preorder(t), vec!["4", "2", "1", "3", "6"]);
}
}#[cfg(test)]
mod tests {
use super::*;
use Tree::{Leaf, Node};
// 4
// / \
// 2 6
// / \
// 1 3
fn sample() -> Tree<i32> {
Node(
4,
Box::new(Node(
2,
Box::new(Node(1, Box::new(Leaf), Box::new(Leaf))),
Box::new(Node(3, Box::new(Leaf), Box::new(Leaf))),
)),
Box::new(Node(6, Box::new(Leaf), Box::new(Leaf))),
)
}
#[test]
fn test_size() {
assert_eq!(size(sample()), 5);
assert_eq!(size(Leaf::<i32>), 0);
}
#[test]
fn test_depth() {
assert_eq!(depth(sample()), 3);
assert_eq!(depth(Leaf::<i32>), 0);
}
#[test]
fn test_sum() {
assert_eq!(sum(sample()), 16); // 1+2+3+4+6
assert_eq!(sum(Leaf), 0);
}
#[test]
fn test_preorder() {
assert_eq!(preorder(sample()), vec![4, 2, 1, 3, 6]);
}
#[test]
fn test_inorder() {
assert_eq!(inorder(sample()), vec![1, 2, 3, 4, 6]);
}
#[test]
fn test_map_tree() {
let doubled = map_tree(sample(), &|v| v * 2);
assert_eq!(sum(doubled), 32); // 2+4+6+8+12
}
#[test]
fn test_map_preserves_shape() {
let t = map_tree(sample(), &|v| v.to_string());
assert_eq!(preorder(t), vec!["4", "2", "1", "3", "6"]);
}
}
Deep Comparison
OCaml vs Rust: Map and Fold on Trees
Side-by-Side Code
OCaml
let rec map_tree f = function
| Leaf -> Leaf
| Node (v, l, r) -> Node (f v, map_tree f l, map_tree f r)
let rec fold_tree f acc = function
| Leaf -> acc
| Node (v, l, r) -> f v (fold_tree f acc l) (fold_tree f acc r)
(* Derived operations — zero additional recursion *)
let size t = fold_tree (fun _ l r -> 1 + l + r) 0 t
let depth t = fold_tree (fun _ l r -> 1 + max l r) 0 t
let sum t = fold_tree (fun v l r -> v + l + r) 0 t
let preorder t = fold_tree (fun v l r -> [v] @ l @ r) [] t
let inorder t = fold_tree (fun v l r -> l @ [v] @ r) [] t
Rust (idiomatic)
pub fn map_tree<T, U, F: Fn(T) -> U>(tree: Tree<T>, f: &F) -> Tree<U> {
match tree {
Tree::Leaf => Tree::Leaf,
Tree::Node(v, l, r) => Tree::Node(
f(v),
Box::new(map_tree(*l, f)),
Box::new(map_tree(*r, f)),
),
}
}
pub fn fold_tree<T, U: Clone, F: Fn(T, U, U) -> U>(tree: Tree<T>, acc: U, f: &F) -> U {
match tree {
Tree::Leaf => acc,
Tree::Node(v, l, r) => {
let left = fold_tree(*l, acc.clone(), f);
let right = fold_tree(*r, acc, f);
f(v, left, right)
}
}
}
Rust (functional/recursive)
// Derived from fold_tree — same pattern as OCaml, zero extra recursion
pub fn tree_size<T>(t: Tree<T>) -> usize {
fold_tree(t, 0usize, &|_, l, r| 1 + l + r)
}
pub fn tree_sum(t: Tree<i32>) -> i32 {
fold_tree(t, 0, &|v, l, r| v + l + r)
}
Type Signatures
| Concept | OCaml | Rust |
|---|---|---|
| map | val map_tree : ('a -> 'b) -> 'a tree -> 'b tree | fn map_tree<T, U, F: Fn(T) -> U>(tree: Tree<T>, f: &F) -> Tree<U> |
| fold | val fold_tree : ('a -> 'b -> 'b -> 'b) -> 'b -> 'a tree -> 'b | fn fold_tree<T, U: Clone, F: Fn(T, U, U) -> U>(tree: Tree<T>, acc: U, f: &F) -> U |
| combining fn | 'a -> 'b -> 'b -> 'b | Fn(T, U, U) -> U |
| accumulator | 'b (shared freely) | U: Clone (must be cloned for both branches) |
Key Insights
acc is a GC-managed value passed to both subtrees freely — no copying needed at the language level. In Rust, passing acc to the left subtree would consume it, leaving nothing for the right subtree. The bound U: Clone and explicit acc.clone() call make this sharing explicit.&F:** OCaml closures are heap values that can be shared by reference automatically. Rust closures are owned. Passing f: &F through recursive calls avoids consuming the closure on the first node visited — without the &, the first call would move f and subsequent calls would fail to compile.fold_tree is the tree catamorphism — it completely characterizes the tree's recursive structure. Every recursive function on Tree<T> can be expressed as a single fold_tree call. OCaml's size, depth, sum, preorder, and inorder are all one-liners. The same applies in Rust.map_tree takes ownership of the source Tree<T> and produces a Tree<U>. This is the natural functional style — no shared mutable state, no aliasing. OCaml's GC handles the old tree's memory; Rust's ownership system drops it automatically when the last reference leaves scope.map_tree makes Tree a functor over its element type — a concept from category theory that OCaml expressses through map conventions and Rust approximates through generic functions (there is no Functor trait in std).When to Use Each Style
**Use fold_tree when:** You need any aggregation over the tree (sum, count, depth, traversal list) — derive it as a one-liner rather than writing a new recursive function each time.
**Use map_tree when:** Transforming values while preserving the tree's shape — it is the structural equivalent of List.map for lists.
Exercises
filter_tree that removes nodes whose values fail a predicate, replacing each failing node with Leaf, and verify that the resulting tree has strictly fewer nodes than the originalzip_tree that takes two trees of the same shape and combines them node-by-node into a Tree<(A, B)>, returning None if the shapes differ at any pointflatten function implemented purely via fold_tree that collects all node values into a Vec<T> in preorder order, then write a test confirming its output matches the preorder function defined in this module