๐Ÿฆ€ Functional Rust

245: Cofree Comonad

Difficulty: 5 Level: Master A tree where every node holds a label AND a subtree โ€” the universal comonad for any functor.

The Problem This Solves

You're building a game state machine. Each state has a label (the current game state) and a list of possible next states. You want to annotate every reachable state with derived information โ€” say, the minimum number of moves to win from there. Without the right abstraction, this requires either mutable global maps or deeply nested explicit recursion. The Cofree comonad solves this: it's a tree structure where every node has a value and a collection of children, and `extend` lets you replace every node's value with a function of its entire subtree. Want every node annotated with "sum of values in my subtree"? One `extend` call. Want every node annotated with "depth of my subtree"? Another one. More concretely: if you have a file system tree (each directory has a name and a list of child directories), `extend` on the Cofree comonad lets you annotate every directory with its total size โ€” in a single, structurally recursive pass.

The Intuition

The Cofree comonad over a functor F is the "most general" comonad you can build from F. For `F = Vec` (list of children), it's a rose tree โ€” a tree where each node has any number of children. Think of it this way: regular `fmap` gives each node a new label based on its own value. `extend` gives each node a new label based on its entire subtree โ€” the node, its children, their children, all the way down. `duplicate` is the most powerful operation: it replaces every node's label with the subtree rooted at that node. After `duplicate`, the root's label is the whole original tree. The root's first child's label is the subtree rooted at that child. And so on. You now have a "tree of trees" โ€” every possible substructure is labeled and accessible. The relationship between `extend` and `duplicate`: Why "cofree"? The Cofree comonad is the right adjoint of the forgetful functor from comonads to functors. Every comonad for functor F embeds into `Cofree F`. It's the "freest" possible comonad โ€” the one with the least extra structure.

How It Works in Rust

The rose tree as Cofree over Vec:
pub struct Rose<A> {
 pub value: A,
 pub children: Vec<Rose<A>>,
}
Comonad operations:
impl<A: Clone> Rose<A> {
 pub fn extract(&self) -> A { self.value.clone() }

 // fmap: transform each node's label independently
 pub fn fmap<B>(&self, f: &impl Fn(A) -> B) -> Rose<B> {
     Rose {
         value: f(self.value.clone()),
         children: self.children.iter().map(|c| c.fmap(f)).collect(),
     }
 }

 // extend: replace each node's label with f applied to the *subtree there*
 // f receives the whole subtree rooted at each node
 pub fn extend<B>(&self, f: &impl Fn(&Rose<A>) -> B) -> Rose<B> {
     Rose {
         value: f(self),            // f sees the entire subtree at this point
         children: self.children.iter().map(|c| c.extend(f)).collect(),
     }
 }

 // duplicate: label each node with its own subtree
 pub fn duplicate(&self) -> Rose<Rose<A>> {
     Rose {
         value: self.clone(),       // root's label = the whole tree
         children: self.children.iter().map(|c| c.duplicate()).collect(),
     }
 }
}
Annotating every node with its subtree sum:
let tree = Rose::node(1, vec![
 Rose::node(2, vec![Rose::leaf(4), Rose::leaf(5)]),
 Rose::node(3, vec![Rose::leaf(6), Rose::node(7, vec![Rose::leaf(8)])]),
]);

// Annotate each node with the sum of values in its subtree
let annotated: Rose<i32> = tree.extend(&|subtree| subtree.sum());
// annotated.extract() == 36 (sum of entire tree)
// annotated.children[0].extract() == 11 (sum of subtree 2โ†’[4,5])

What This Unlocks

Key Differences

ConceptOCamlRust
Cofree/Rose type`type 'a rose = Rose of 'a * 'a rose list``struct Rose<A> { value: A, children: Vec<Rose<A>> }`
fmap`let rec fmap f (Rose(v, cs)) = Rose(f v, List.map (fmap f) cs)`Method with `&impl Fn(A) -> B`
extendRecursive, same pattern as fmap but `f` receives whole subtreeMethod; `f` receives `&Rose<A>`
Sharing subtreesGC allows aliasing freely`duplicate` requires `Clone` โ€” full copy
Functor abstractionCan parameterize over functor F via modules`Rose<A>` is concrete; parameterizing over F needs GATs or trait objects
/// Cofree Comonad.
///
/// Cofree f a = a ร— f (Cofree f a)
///
/// For f = Vec (list of children), this gives a rose tree:
///   Cofree Vec a = Rose Tree with `a` at each node.
///
/// The Cofree comonad is the "most general" comonad for a functor f.
/// It annotates every node in an f-structure with a label.
///
/// Comonad operations:
///   extract:   get the root label
///   extend f:  replace every node's label with f applied to the subtree there
///   duplicate: replace every node's label with the subtree rooted there

/// Rose tree = Cofree over Vec.
/// Each node has a value and a list of child subtrees.
#[derive(Debug, Clone)]
pub struct Rose<A> {
    pub value: A,
    pub children: Vec<Rose<A>>,
}

impl<A: Clone + 'static> Rose<A> {
    pub fn leaf(value: A) -> Self {
        Rose { value, children: vec![] }
    }

    pub fn node(value: A, children: Vec<Rose<A>>) -> Self {
        Rose { value, children }
    }

    /// Comonad: extract = the root value.
    pub fn extract(&self) -> A {
        self.value.clone()
    }

    /// Functor: map over all node values.
    pub fn fmap<B: Clone + 'static>(&self, f: &impl Fn(A) -> B) -> Rose<B> {
        Rose {
            value: f(self.value.clone()),
            children: self.children.iter().map(|c| c.fmap(f)).collect(),
        }
    }

    /// Comonad: extend.
    /// Each node is replaced by f applied to the subtree rooted there.
    pub fn extend<B: Clone + 'static>(&self, f: &impl Fn(&Rose<A>) -> B) -> Rose<B> {
        Rose {
            value: f(self),
            children: self.children.iter().map(|c| c.extend(f)).collect(),
        }
    }

    /// Comonad: duplicate.
    /// Replace each node's value with the subtree rooted there.
    pub fn duplicate(&self) -> Rose<Rose<A>> {
        Rose {
            value: self.clone(),
            children: self.children.iter().map(|c| c.duplicate()).collect(),
        }
    }

    /// Fold: reduce the tree bottom-up.
    pub fn fold<B: Clone>(&self, f: impl Fn(A, Vec<B>) -> B + Clone) -> B {
        let child_results: Vec<B> = self.children.iter().map(|c| c.fold(f.clone())).collect();
        f(self.value.clone(), child_results)
    }

    /// Size: total number of nodes.
    pub fn size(&self) -> usize {
        self.fold(|_, child_sizes| 1 + child_sizes.iter().sum::<usize>())
    }

    /// Depth: maximum depth (root = 1).
    pub fn depth(&self) -> usize {
        self.fold(|_, child_depths| 1 + child_depths.iter().copied().max().unwrap_or(0))
    }

    /// Sum of all values (for numeric trees).
    pub fn collect_values(&self) -> Vec<A> {
        let mut result = vec![self.value.clone()];
        for child in &self.children {
            result.extend(child.collect_values());
        }
        result
    }
}

impl Rose<i32> {
    pub fn sum(&self) -> i32 {
        self.fold(|v, child_sums| v + child_sums.iter().sum::<i32>())
    }
}

fn main() {
    println!("=== Cofree Comonad (Rose Tree) ===\n");
    println!("Cofree f a = a ร— f (Cofree f a)");
    println!("For f = Vec: Rose tree โ€” each node has a value and a list of children.\n");

    // Build the tree from the OCaml example:
    //      1
    //    /   \
    //   2     3
    //  / \   / \
    // 4   5 6   7
    //           |
    //           8
    let t = Rose::node(1, vec![
        Rose::node(2, vec![
            Rose::leaf(4),
            Rose::leaf(5),
        ]),
        Rose::node(3, vec![
            Rose::leaf(6),
            Rose::node(7, vec![Rose::leaf(8)]),
        ]),
    ]);

    println!("Tree structure: 1 -> [2->[4,5], 3->[6, 7->[8]]]");
    println!("root   = {}", t.extract());
    println!("size   = {}", t.size());
    println!("depth  = {}", t.depth());
    println!("sum    = {}", t.sum());

    // fmap: double all values
    let doubled = t.fmap(&|n| n * 2);
    println!("root * 2 = {}", doubled.extract());
    println!("all doubled: {:?}", doubled.collect_values());

    // extend: annotate each node with its subtree sum
    let annotated = t.extend(&|subtree| subtree.sum());
    println!("\nAnnotated with subtree sums:");
    println!("  root subtree sum = {} (should be 36)", annotated.extract());
    println!("  values: {:?}", annotated.collect_values());

    // duplicate: each node labeled with the subtree rooted there
    let duped = t.duplicate();
    println!("\nAfter duplicate:");
    println!("  root.value.value = {}", duped.extract().extract());
    println!("  root.value.size  = {}", duped.extract().size());

    // Comonad law 1: extract . extend f = f
    let f = |subtree: &Rose<i32>| subtree.sum();
    let extended = t.extend(&f);
    assert_eq!(extended.extract(), f(&t), "Law 1: extract . extend f = f");
    println!("\nComonad law 1 (extract . extend f = f): {} = {} โœ“",
        extended.extract(), f(&t));

    // Build a string tree (different type)
    let st = Rose::node("root", vec![
        Rose::leaf("left"),
        Rose::node("mid", vec![Rose::leaf("deep")]),
        Rose::leaf("right"),
    ]);
    let lengths = st.extend(&|subtree| subtree.extract().len());
    println!("\nString tree annotated with label lengths: {:?}", lengths.collect_values());
}

#[cfg(test)]
mod tests {
    use super::*;

    fn sample_tree() -> Rose<i32> {
        Rose::node(1, vec![
            Rose::node(2, vec![Rose::leaf(4), Rose::leaf(5)]),
            Rose::leaf(3),
        ])
    }

    #[test]
    fn test_extract() {
        assert_eq!(sample_tree().extract(), 1);
    }

    #[test]
    fn test_size() {
        assert_eq!(sample_tree().size(), 5);
    }

    #[test]
    fn test_depth() {
        assert_eq!(sample_tree().depth(), 3);
    }

    #[test]
    fn test_sum() {
        assert_eq!(sample_tree().sum(), 15); // 1+2+3+4+5
    }

    #[test]
    fn test_fmap() {
        let t = sample_tree();
        let doubled = t.fmap(&|n| n * 2);
        assert_eq!(doubled.extract(), 2);
        assert_eq!(doubled.sum(), 30);
    }

    #[test]
    fn test_extend_law1() {
        // extract . extend f = f
        let t = sample_tree();
        let f = |sub: &Rose<i32>| sub.sum();
        assert_eq!(t.extend(&f).extract(), f(&t));
    }

    #[test]
    fn test_leaf() {
        let l = Rose::leaf(42_i32);
        assert_eq!(l.size(), 1);
        assert_eq!(l.depth(), 1);
        assert_eq!(l.children.len(), 0);
    }
}
(* Cofree comonad: annotates every node in a structure with a label.
   Cofree f a = a * f (Cofree f a)
   For f = [], this gives rose trees (annotated with a-values at each node) *)

(* Cofree over list = rose tree *)
type 'a rose = Rose of 'a * 'a rose list

let leaf x       = Rose (x, [])
let node x children = Rose (x, children)

(* Comonad operations *)
let extract (Rose (a, _)) = a

let rec extend (Rose (a, children)) f =
  Rose (f (Rose (a, children)), List.map (fun child -> extend child f) children)

let duplicate t = extend t (fun x -> x)

(* Functor: map over annotations *)
let rec fmap f (Rose (a, children)) =
  Rose (f a, List.map (fmap f) children)

(* Fold: reduce the tree *)
let rec fold f (Rose (a, children)) =
  f a (List.map (fold f) children)

(* Size and depth *)
let size  = fold (fun _ cs -> 1 + List.fold_left ( + ) 0 cs)
let depth = fold (fun _ cs -> 1 + (List.fold_left max 0 cs))
let sum   = fold (fun a cs -> a + List.fold_left ( + ) 0 cs)

let () =
  let t = node 1 [
    node 2 [leaf 4; leaf 5];
    node 3 [leaf 6; node 7 [leaf 8]];
  ] in

  Printf.printf "root   = %d\n" (extract t);
  Printf.printf "size   = %d\n" (size t);
  Printf.printf "depth  = %d\n" (depth t);
  Printf.printf "sum    = %d\n" (sum t);

  let doubled = fmap (fun n -> n * 2) t in
  Printf.printf "root*2 = %d\n" (extract doubled);

  (* extend: annotate each node with its subtree sum *)
  let annotated = extend t sum in
  Printf.printf "subtree sums at root = %d\n" (extract annotated)