ExamplesBy LevelBy TopicLearning Paths
1003 Intermediate

Map and Fold on Trees

Functional Programming

Tutorial

The Problem

Implement map_tree and fold_tree as the two fundamental higher-order operations on binary trees, then derive size, depth, sum, preorder, and inorder from fold_tree alone — without writing any additional recursion. fold_tree is the tree catamorphism: it captures the complete recursive structure of the type, replacing each Leaf with a base value and each Node with a combining function. Understanding it reveals why higher-order functions eliminate boilerplate and why functional programmers reach for folds as their first tool when reducing a data structure.

🎯 Learning Outcomes

  • • What a catamorphism is and why fold_tree completely characterizes the Tree type's recursive structure
  • • How to derive size, depth, sum, preorder, and inorder as one-liners using a single fold with no additional recursion
  • • Why U: Clone is required in Rust's fold_tree but not in OCaml's version, and how the garbage collector changes the picture
  • • How map_tree preserves the tree's shape while transforming values — making Tree a functor over its element type
  • • Why closures in Rust must be passed as &F through recursive calls to avoid consuming the closure on the first use
  • Code Example

    pub fn map_tree<T, U, F: Fn(T) -> U>(tree: Tree<T>, f: &F) -> Tree<U> {
        match tree {
            Tree::Leaf => Tree::Leaf,
            Tree::Node(v, l, r) => Tree::Node(
                f(v),
                Box::new(map_tree(*l, f)),
                Box::new(map_tree(*r, f)),
            ),
        }
    }
    
    pub fn fold_tree<T, U: Clone, F: Fn(T, U, U) -> U>(tree: Tree<T>, acc: U, f: &F) -> U {
        match tree {
            Tree::Leaf => acc,
            Tree::Node(v, l, r) => {
                let left  = fold_tree(*l, acc.clone(), f);
                let right = fold_tree(*r, acc, f);
                f(v, left, right)
            }
        }
    }

    Key Differences

  • Clone requirement: Rust requires U: Clone because acc must be duplicated for the two independent subtree folds; OCaml shares the same value freely since the GC tracks all references
  • Closure ownership: Rust passes &F through every recursive call to prevent consuming the closure; OCaml closures are heap-allocated reference values and always shareable
  • Derived operations: Both languages express size, depth, sum, preorder, and inorder as one-liners over fold_tree — the expressive power is identical, only the syntax differs
  • Tree ownership: Rust's fold_tree consumes the tree (tree: Tree<T>) transferring ownership through the recursion; OCaml's fold borrows implicitly since the GC handles memory
  • OCaml Approach

    OCaml's fold_tree f acc is a curried two-argument let rec function using function for pattern matching. On a Leaf it returns acc; on a Node(v, l, r) it computes f v (fold_tree f acc l) (fold_tree f acc r) — calling the user's function with the node value and the results of folding both subtrees. The same acc is shared freely to both subtrees because OCaml's GC manages value lifetime. All five derived operations are single fold_tree applications: size t = fold_tree (fun _ l r -> 1 + l + r) 0 t. The elegance is in the zero additional recursion needed.

    Full Source

    #![allow(clippy::all)]
    //! Map and Fold on Trees
    //! See example.ml for OCaml reference
    
    #[derive(Debug, Clone, PartialEq)]
    pub enum Tree<T> {
        Leaf,
        Node(T, Box<Tree<T>>, Box<Tree<T>>),
    }
    
    /// Apply `f` to every node value, producing a new tree of the same shape.
    /// Mirrors OCaml's `let rec map_tree f = function | Leaf -> Leaf | Node(v,l,r) -> Node(f v, ...)`.
    pub fn map_tree<T, U, F: Fn(T) -> U>(tree: Tree<T>, f: &F) -> Tree<U> {
        match tree {
            Tree::Leaf => Tree::Leaf,
            Tree::Node(v, l, r) => Tree::Node(
                f(v),
                Box::new(map_tree(*l, f)),
                Box::new(map_tree(*r, f)),
            ),
        }
    }
    
    /// Structural fold: reduce a tree to a single value by combining each node's value
    /// with the results of folding its two subtrees.
    /// `f(v, left_result, right_result)` — both subtrees fold with the same initial `acc`.
    pub fn fold_tree<T, U, F>(tree: Tree<T>, acc: U, f: &F) -> U
    where
        F: Fn(T, U, U) -> U,
        U: Clone,
    {
        match tree {
            Tree::Leaf => acc,
            Tree::Node(v, l, r) => {
                let l_result = fold_tree(*l, acc.clone(), f);
                let r_result = fold_tree(*r, acc, f);
                f(v, l_result, r_result)
            }
        }
    }
    
    pub fn size<T>(tree: Tree<T>) -> usize {
        fold_tree(tree, 0usize, &|_, l, r| 1 + l + r)
    }
    
    pub fn depth<T>(tree: Tree<T>) -> usize {
        fold_tree(tree, 0usize, &|_, l, r| 1 + l.max(r))
    }
    
    pub fn sum(tree: Tree<i32>) -> i32 {
        fold_tree(tree, 0i32, &|v, l, r| v + l + r)
    }
    
    pub fn preorder<T: Clone>(tree: Tree<T>) -> Vec<T> {
        fold_tree(tree, vec![], &|v, l, r| {
            let mut result = vec![v];
            result.extend(l);
            result.extend(r);
            result
        })
    }
    
    pub fn inorder<T: Clone>(tree: Tree<T>) -> Vec<T> {
        fold_tree(tree, vec![], &|v, l, r| {
            let mut result = l;
            result.push(v);
            result.extend(r);
            result
        })
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
        use Tree::{Leaf, Node};
    
        //      4
        //     / \
        //    2   6
        //   / \
        //  1   3
        fn sample() -> Tree<i32> {
            Node(
                4,
                Box::new(Node(
                    2,
                    Box::new(Node(1, Box::new(Leaf), Box::new(Leaf))),
                    Box::new(Node(3, Box::new(Leaf), Box::new(Leaf))),
                )),
                Box::new(Node(6, Box::new(Leaf), Box::new(Leaf))),
            )
        }
    
        #[test]
        fn test_size() {
            assert_eq!(size(sample()), 5);
            assert_eq!(size(Leaf::<i32>), 0);
        }
    
        #[test]
        fn test_depth() {
            assert_eq!(depth(sample()), 3);
            assert_eq!(depth(Leaf::<i32>), 0);
        }
    
        #[test]
        fn test_sum() {
            assert_eq!(sum(sample()), 16); // 1+2+3+4+6
            assert_eq!(sum(Leaf), 0);
        }
    
        #[test]
        fn test_preorder() {
            assert_eq!(preorder(sample()), vec![4, 2, 1, 3, 6]);
        }
    
        #[test]
        fn test_inorder() {
            assert_eq!(inorder(sample()), vec![1, 2, 3, 4, 6]);
        }
    
        #[test]
        fn test_map_tree() {
            let doubled = map_tree(sample(), &|v| v * 2);
            assert_eq!(sum(doubled), 32); // 2+4+6+8+12
        }
    
        #[test]
        fn test_map_preserves_shape() {
            let t = map_tree(sample(), &|v| v.to_string());
            assert_eq!(preorder(t), vec!["4", "2", "1", "3", "6"]);
        }
    }
    ✓ Tests Rust test suite
    #[cfg(test)]
    mod tests {
        use super::*;
        use Tree::{Leaf, Node};
    
        //      4
        //     / \
        //    2   6
        //   / \
        //  1   3
        fn sample() -> Tree<i32> {
            Node(
                4,
                Box::new(Node(
                    2,
                    Box::new(Node(1, Box::new(Leaf), Box::new(Leaf))),
                    Box::new(Node(3, Box::new(Leaf), Box::new(Leaf))),
                )),
                Box::new(Node(6, Box::new(Leaf), Box::new(Leaf))),
            )
        }
    
        #[test]
        fn test_size() {
            assert_eq!(size(sample()), 5);
            assert_eq!(size(Leaf::<i32>), 0);
        }
    
        #[test]
        fn test_depth() {
            assert_eq!(depth(sample()), 3);
            assert_eq!(depth(Leaf::<i32>), 0);
        }
    
        #[test]
        fn test_sum() {
            assert_eq!(sum(sample()), 16); // 1+2+3+4+6
            assert_eq!(sum(Leaf), 0);
        }
    
        #[test]
        fn test_preorder() {
            assert_eq!(preorder(sample()), vec![4, 2, 1, 3, 6]);
        }
    
        #[test]
        fn test_inorder() {
            assert_eq!(inorder(sample()), vec![1, 2, 3, 4, 6]);
        }
    
        #[test]
        fn test_map_tree() {
            let doubled = map_tree(sample(), &|v| v * 2);
            assert_eq!(sum(doubled), 32); // 2+4+6+8+12
        }
    
        #[test]
        fn test_map_preserves_shape() {
            let t = map_tree(sample(), &|v| v.to_string());
            assert_eq!(preorder(t), vec!["4", "2", "1", "3", "6"]);
        }
    }

    Deep Comparison

    OCaml vs Rust: Map and Fold on Trees

    Side-by-Side Code

    OCaml

    let rec map_tree f = function
      | Leaf           -> Leaf
      | Node (v, l, r) -> Node (f v, map_tree f l, map_tree f r)
    
    let rec fold_tree f acc = function
      | Leaf           -> acc
      | Node (v, l, r) -> f v (fold_tree f acc l) (fold_tree f acc r)
    
    (* Derived operations — zero additional recursion *)
    let size     t = fold_tree (fun _ l r -> 1 + l + r)    0  t
    let depth    t = fold_tree (fun _ l r -> 1 + max l r)  0  t
    let sum      t = fold_tree (fun v l r -> v + l + r)    0  t
    let preorder t = fold_tree (fun v l r -> [v] @ l @ r) [] t
    let inorder  t = fold_tree (fun v l r -> l @ [v] @ r) [] t
    

    Rust (idiomatic)

    pub fn map_tree<T, U, F: Fn(T) -> U>(tree: Tree<T>, f: &F) -> Tree<U> {
        match tree {
            Tree::Leaf => Tree::Leaf,
            Tree::Node(v, l, r) => Tree::Node(
                f(v),
                Box::new(map_tree(*l, f)),
                Box::new(map_tree(*r, f)),
            ),
        }
    }
    
    pub fn fold_tree<T, U: Clone, F: Fn(T, U, U) -> U>(tree: Tree<T>, acc: U, f: &F) -> U {
        match tree {
            Tree::Leaf => acc,
            Tree::Node(v, l, r) => {
                let left  = fold_tree(*l, acc.clone(), f);
                let right = fold_tree(*r, acc, f);
                f(v, left, right)
            }
        }
    }
    

    Rust (functional/recursive)

    // Derived from fold_tree — same pattern as OCaml, zero extra recursion
    pub fn tree_size<T>(t: Tree<T>) -> usize {
        fold_tree(t, 0usize, &|_, l, r| 1 + l + r)
    }
    
    pub fn tree_sum(t: Tree<i32>) -> i32 {
        fold_tree(t, 0, &|v, l, r| v + l + r)
    }
    

    Type Signatures

    ConceptOCamlRust
    mapval map_tree : ('a -> 'b) -> 'a tree -> 'b treefn map_tree<T, U, F: Fn(T) -> U>(tree: Tree<T>, f: &F) -> Tree<U>
    foldval fold_tree : ('a -> 'b -> 'b -> 'b) -> 'b -> 'a tree -> 'bfn fold_tree<T, U: Clone, F: Fn(T, U, U) -> U>(tree: Tree<T>, acc: U, f: &F) -> U
    combining fn'a -> 'b -> 'b -> 'bFn(T, U, U) -> U
    accumulator'b (shared freely)U: Clone (must be cloned for both branches)

    Key Insights

  • Clone requirement for fold: In OCaml, acc is a GC-managed value passed to both subtrees freely — no copying needed at the language level. In Rust, passing acc to the left subtree would consume it, leaving nothing for the right subtree. The bound U: Clone and explicit acc.clone() call make this sharing explicit.
  • **Closure passed as &F:** OCaml closures are heap values that can be shared by reference automatically. Rust closures are owned. Passing f: &F through recursive calls avoids consuming the closure on the first node visited — without the &, the first call would move f and subsequent calls would fail to compile.
  • Catamorphism: fold_tree is the tree catamorphism — it completely characterizes the tree's recursive structure. Every recursive function on Tree<T> can be expressed as a single fold_tree call. OCaml's size, depth, sum, preorder, and inorder are all one-liners. The same applies in Rust.
  • Ownership through map: map_tree takes ownership of the source Tree<T> and produces a Tree<U>. This is the natural functional style — no shared mutable state, no aliasing. OCaml's GC handles the old tree's memory; Rust's ownership system drops it automatically when the last reference leaves scope.
  • Functor structure: map_tree makes Tree a functor over its element type — a concept from category theory that OCaml expressses through map conventions and Rust approximates through generic functions (there is no Functor trait in std).
  • When to Use Each Style

    **Use fold_tree when:** You need any aggregation over the tree (sum, count, depth, traversal list) — derive it as a one-liner rather than writing a new recursive function each time. **Use map_tree when:** Transforming values while preserving the tree's shape — it is the structural equivalent of List.map for lists.

    Exercises

  • Implement filter_tree that removes nodes whose values fail a predicate, replacing each failing node with Leaf, and verify that the resulting tree has strictly fewer nodes than the original
  • Implement zip_tree that takes two trees of the same shape and combines them node-by-node into a Tree<(A, B)>, returning None if the shapes differ at any point
  • Add a flatten function implemented purely via fold_tree that collects all node values into a Vec<T> in preorder order, then write a test confirming its output matches the preorder function defined in this module
  • Open Source Repos