πŸ¦€ Functional Rust

066: Tree Map and Fold

Difficulty: 2 Level: Intermediate Lift `map` and `fold` from lists to binary trees β€” once you have `fold_tree`, everything else is a one-liner.

The Problem This Solves

You've used `map` and `fold` on lists. Trees are everywhere in real programs β€” ASTs, file systems, expression trees, decision trees β€” and they need the same operations. The naive approach is to write explicit recursive functions for each thing you want to compute: one for size, one for depth, one for sum, one for in-order traversal. After the third one you realize they're all the same structural pattern. `fold_tree` captures that pattern once. You pass a function that says "given the current node's value and the results from the left and right subtrees, combine them." Then size, depth, sum, and all traversals become one-liners that pass different combining functions. No more explicit recursion. This is the catamorphism pattern applied to binary trees β€” the same idea as catamorphism (example 080), but made concrete here on a familiar data structure before the abstraction is generalized.

The Intuition

A binary tree is either a `Leaf` (empty) or a `Node` containing a value, a left subtree, and a right subtree. `map_tree` transforms every node's value while keeping the tree structure. The tree shape is preserved β€” only the labels change. If you double every number in the tree, you get a tree with the same branching structure, every number doubled. `fold_tree` collapses the tree into a single value. It processes bottom-up: fold the left subtree (getting some accumulated value), fold the right subtree (getting another), then combine those results with the current node's value. The combining function `f(value, left_result, right_result)` is called once per node. For size: `f(_, l, r) = 1 + l + r` β€” count this node (1) plus left count plus right count. For depth: `f(_, l, r) = 1 + max(l, r)` β€” current level (1) plus the deeper of the two sides. For sum: `f(v, l, r) = v + l + r` β€” current value plus sums of both sides. Notice that `Leaf` returns the initial accumulator β€” for size that's 0, for sum it's also 0. This is analogous to the initial value in a list fold.

How It Works in Rust

#[derive(Debug, Clone, PartialEq)]
pub enum Tree<T> {
 Leaf,
 Node(T, Box<Tree<T>>, Box<Tree<T>>),
}

pub fn map_tree<T, U>(tree: &Tree<T>, f: &impl Fn(&T) -> U) -> Tree<U> {
 match tree {
     Tree::Leaf => Tree::Leaf,
     Tree::Node(v, l, r) => Tree::node(f(v), map_tree(l, f), map_tree(r, f)),
 }
}

pub fn fold_tree<T, A: Clone>(
 tree: &Tree<T>,
 acc: A,                          // base value for Leaf
 f: &impl Fn(&T, A, A) -> A,     // combine: (node_value, left_result, right_result)
) -> A {
 match tree {
     Tree::Leaf => acc,
     Tree::Node(v, l, r) => {
         let left  = fold_tree(l, acc.clone(), f);
         let right = fold_tree(r, acc, f);
         f(v, left, right)
     }
 }
}

// Every derived operation is a fold one-liner:
pub fn size<T>(t: &Tree<T>) -> usize  { fold_tree(t, 0, &|_, l, r| 1 + l + r) }
pub fn depth<T>(t: &Tree<T>) -> usize { fold_tree(t, 0, &|_, l, r| 1 + l.max(r)) }
pub fn sum(t: &Tree<i32>) -> i32      { fold_tree(t, 0, &|v, l, r| v + l + r) }
The `acc.clone()` in `fold_tree` is necessary because `acc` is passed to both subtrees. The left call consumes it, so we clone for the right. This is the main Rust-vs-OCaml difference: OCaml's GC handles sharing implicitly.

What This Unlocks

Key Differences

ConceptOCamlRust
Tree type`type 'a tree = Leaf \Node of 'a 'a tree 'a tree``enum Tree<T> { Leaf, Node(T, Box<Tree<T>>, Box<Tree<T>>) }`
BoxNot needed β€” GC manages recursive typesRequired: `Box<Tree<T>>` to give the enum a known size
fold accumulator sharingGC shares the initial `acc` value freelyMust `.clone()` β€” one copy for left, one for right
Currying`fold_tree ~leaf ~node tree` β€” natural 3-arg curriedClosure `&impl Fn(&T, A, A) -> A` β€” must pass explicitly
/// Map and Fold on Trees
///
/// Lifting map and fold from lists to binary trees. Once you define
/// `fold_tree`, you can express size, depth, sum, and traversals
/// without any explicit recursion β€” the fold does it all.

#[derive(Debug, Clone, PartialEq)]
pub enum Tree<T> {
    Leaf,
    Node(T, Box<Tree<T>>, Box<Tree<T>>),
}

impl<T> Tree<T> {
    pub fn node(v: T, l: Tree<T>, r: Tree<T>) -> Self {
        Tree::Node(v, Box::new(l), Box::new(r))
    }
}

/// Map a function over every node value, producing a new tree.
pub fn map_tree<T, U>(tree: &Tree<T>, f: &impl Fn(&T) -> U) -> Tree<U> {
    match tree {
        Tree::Leaf => Tree::Leaf,
        Tree::Node(v, l, r) => Tree::node(f(v), map_tree(l, f), map_tree(r, f)),
    }
}

/// Fold (catamorphism) on a tree. The function `f` receives the node value
/// and the results of folding the left and right subtrees.
pub fn fold_tree<T, A>(
    tree: &Tree<T>,
    acc: A,
    f: &impl Fn(&T, A, A) -> A,
) -> A
where
    A: Clone,
{
    match tree {
        Tree::Leaf => acc,
        Tree::Node(v, l, r) => {
            let left = fold_tree(l, acc.clone(), f);
            let right = fold_tree(r, acc, f);
            f(v, left, right)
        }
    }
}

/// All derived via fold β€” no explicit recursion needed.
pub fn size<T>(t: &Tree<T>) -> usize {
    fold_tree(t, 0, &|_, l, r| 1 + l + r)
}

pub fn depth<T>(t: &Tree<T>) -> usize {
    fold_tree(t, 0, &|_, l, r| 1 + l.max(r))
}

pub fn sum(t: &Tree<i32>) -> i32 {
    fold_tree(t, 0, &|v, l, r| v + l + r)
}

pub fn preorder<T: Clone>(t: &Tree<T>) -> Vec<T> {
    fold_tree(t, vec![], &|v, l, r| {
        let mut result = vec![v.clone()];
        result.extend(l);
        result.extend(r);
        result
    })
}

pub fn inorder<T: Clone>(t: &Tree<T>) -> Vec<T> {
    fold_tree(t, vec![], &|v, l, r| {
        let mut result = l;
        result.push(v.clone());
        result.extend(r);
        result
    })
}

#[cfg(test)]
mod tests {
    use super::*;
    use Tree::*;

    fn sample() -> Tree<i32> {
        //      4
        //     / \
        //    2   6
        //   / \
        //  1   3
        Tree::node(4, Tree::node(2, Tree::node(1, Leaf, Leaf), Tree::node(3, Leaf, Leaf)), Tree::node(6, Leaf, Leaf))
    }

    #[test]
    fn test_size() {
        assert_eq!(size(&sample()), 5);
        assert_eq!(size::<i32>(&Leaf), 0);
    }

    #[test]
    fn test_depth() {
        assert_eq!(depth(&sample()), 3);
        assert_eq!(depth::<i32>(&Leaf), 0);
    }

    #[test]
    fn test_sum() {
        assert_eq!(sum(&sample()), 16);
        assert_eq!(sum(&Leaf), 0);
    }

    #[test]
    fn test_preorder() {
        assert_eq!(preorder(&sample()), vec![4, 2, 1, 3, 6]);
    }

    #[test]
    fn test_inorder() {
        assert_eq!(inorder(&sample()), vec![1, 2, 3, 4, 6]);
    }

    #[test]
    fn test_map_tree() {
        let doubled = map_tree(&sample(), &|v| v * 2);
        assert_eq!(sum(&doubled), 32);
        assert_eq!(preorder(&doubled), vec![8, 4, 2, 6, 12]);
    }

    #[test]
    fn test_single_node() {
        let t = Tree::node(42, Leaf, Leaf);
        assert_eq!(size(&t), 1);
        assert_eq!(sum(&t), 42);
        assert_eq!(preorder(&t), vec![42]);
    }
}

fn main() {
    println!("{:?}", size(&sample()), 5);
    println!("{:?}", size::<i32>(&Leaf), 0);
    println!("{:?}", depth(&sample()), 3);
}
type 'a tree =
  | Leaf
  | Node of 'a * 'a tree * 'a tree

let rec map_tree f = function
  | Leaf           -> Leaf
  | Node (v, l, r) -> Node (f v, map_tree f l, map_tree f r)

let rec fold_tree f acc = function
  | Leaf           -> acc
  | Node (v, l, r) -> f v (fold_tree f acc l) (fold_tree f acc r)

let size     t = fold_tree (fun _ l r -> 1 + l + r)    0  t
let depth    t = fold_tree (fun _ l r -> 1 + max l r)  0  t
let sum      t = fold_tree (fun v l r -> v + l + r)    0  t
let preorder t = fold_tree (fun v l r -> [v] @ l @ r) [] t
let inorder  t = fold_tree (fun v l r -> l @ [v] @ r) [] t

let t =
  Node (4, Node (2, Node (1, Leaf, Leaf), Node (3, Leaf, Leaf)),
           Node (6, Leaf, Leaf))

let () =
  assert (size t = 5);
  assert (depth t = 3);
  assert (sum t = 16);
  assert (preorder t = [4; 2; 1; 3; 6]);
  assert (inorder t = [1; 2; 3; 4; 6]);
  let t2 = map_tree (fun v -> v * 2) t in
  assert (sum t2 = 32);
  print_endline "All assertions passed."

πŸ“Š Detailed Comparison

Map and Fold on Trees: OCaml vs Rust

The Core Insight

Once you can fold a tree, you can express almost any tree computation as a one-liner. This example shows how the catamorphism pattern β€” replacing constructors with functions β€” works identically in both languages, but Rust's ownership model adds friction around accumulator cloning and closure references.

OCaml Approach

OCaml's `fold_tree` takes a function `f : 'a -> 'b -> 'b -> 'b` and a base value, recursing over the tree structure. Thanks to currying, `size = fold_tree (fun _ l r -> 1 + l + r) 0` reads cleanly. The GC handles all intermediate lists created by preorder/inorder β€” `[v] @ l @ r` allocates freely. Pattern matching with `function` keyword keeps the code terse.

Rust Approach

Rust's `fold_tree` needs `A: Clone` because the accumulator must be passed to both subtrees β€” ownership can't be in two places at once. Closures are passed as `&impl Fn(...)` references to avoid ownership issues. The `vec!` macro and `extend` method replace OCaml's `@` list append. The code is slightly more verbose but makes every allocation explicit.

Side-by-Side

ConceptOCamlRust
Fold signature`('a -> 'b -> 'b -> 'b) -> 'b -> 'a tree -> 'b``(&Tree<T>, A, &impl Fn(&T, A, A) -> A) -> A`
AccumulatorPassed freely (GC)Requires `Clone` bound
List append`@` operator`extend()` method
Closure passingImplicit currying`&impl Fn(...)` reference
Derived operationsOne-liners via foldOne-liners via fold
MemoryGC handles intermediatesExplicit Vec allocation

What Rust Learners Should Notice

  • The `Clone` bound on the accumulator is the price of ownership: both subtrees need their own copy of the base case
  • `&impl Fn(...)` avoids taking ownership of the closure, so `fold_tree` can call it multiple times
  • Rust's `vec![]` + `extend` is the idiomatic way to build up collections, replacing OCaml's `@` list concatenation
  • The catamorphism pattern is universal β€” once you define fold for any data type, you unlock compositional programming
  • Intermediate Vecs in preorder/inorder are allocated on the heap; in performance-critical code, you'd use a mutable accumulator instead

Further Reading

  • [The Rust Book β€” Closures](https://doc.rust-lang.org/book/ch13-01-closures.html)
  • [OCaml Beyond Lists](https://cs3110.github.io/textbook/chapters/hop/beyond_lists.html)