/// Cofree Comonad.
///
/// Cofree f a = a ร f (Cofree f a)
///
/// For f = Vec (list of children), this gives a rose tree:
/// Cofree Vec a = Rose Tree with `a` at each node.
///
/// The Cofree comonad is the "most general" comonad for a functor f.
/// It annotates every node in an f-structure with a label.
///
/// Comonad operations:
/// extract: get the root label
/// extend f: replace every node's label with f applied to the subtree there
/// duplicate: replace every node's label with the subtree rooted there
/// Rose tree = Cofree over Vec.
/// Each node has a value and a list of child subtrees.
#[derive(Debug, Clone)]
pub struct Rose<A> {
pub value: A,
pub children: Vec<Rose<A>>,
}
impl<A: Clone + 'static> Rose<A> {
pub fn leaf(value: A) -> Self {
Rose { value, children: vec![] }
}
pub fn node(value: A, children: Vec<Rose<A>>) -> Self {
Rose { value, children }
}
/// Comonad: extract = the root value.
pub fn extract(&self) -> A {
self.value.clone()
}
/// Functor: map over all node values.
pub fn fmap<B: Clone + 'static>(&self, f: &impl Fn(A) -> B) -> Rose<B> {
Rose {
value: f(self.value.clone()),
children: self.children.iter().map(|c| c.fmap(f)).collect(),
}
}
/// Comonad: extend.
/// Each node is replaced by f applied to the subtree rooted there.
pub fn extend<B: Clone + 'static>(&self, f: &impl Fn(&Rose<A>) -> B) -> Rose<B> {
Rose {
value: f(self),
children: self.children.iter().map(|c| c.extend(f)).collect(),
}
}
/// Comonad: duplicate.
/// Replace each node's value with the subtree rooted there.
pub fn duplicate(&self) -> Rose<Rose<A>> {
Rose {
value: self.clone(),
children: self.children.iter().map(|c| c.duplicate()).collect(),
}
}
/// Fold: reduce the tree bottom-up.
pub fn fold<B: Clone>(&self, f: impl Fn(A, Vec<B>) -> B + Clone) -> B {
let child_results: Vec<B> = self.children.iter().map(|c| c.fold(f.clone())).collect();
f(self.value.clone(), child_results)
}
/// Size: total number of nodes.
pub fn size(&self) -> usize {
self.fold(|_, child_sizes| 1 + child_sizes.iter().sum::<usize>())
}
/// Depth: maximum depth (root = 1).
pub fn depth(&self) -> usize {
self.fold(|_, child_depths| 1 + child_depths.iter().copied().max().unwrap_or(0))
}
/// Sum of all values (for numeric trees).
pub fn collect_values(&self) -> Vec<A> {
let mut result = vec![self.value.clone()];
for child in &self.children {
result.extend(child.collect_values());
}
result
}
}
impl Rose<i32> {
pub fn sum(&self) -> i32 {
self.fold(|v, child_sums| v + child_sums.iter().sum::<i32>())
}
}
fn main() {
println!("=== Cofree Comonad (Rose Tree) ===\n");
println!("Cofree f a = a ร f (Cofree f a)");
println!("For f = Vec: Rose tree โ each node has a value and a list of children.\n");
// Build the tree from the OCaml example:
// 1
// / \
// 2 3
// / \ / \
// 4 5 6 7
// |
// 8
let t = Rose::node(1, vec![
Rose::node(2, vec![
Rose::leaf(4),
Rose::leaf(5),
]),
Rose::node(3, vec![
Rose::leaf(6),
Rose::node(7, vec![Rose::leaf(8)]),
]),
]);
println!("Tree structure: 1 -> [2->[4,5], 3->[6, 7->[8]]]");
println!("root = {}", t.extract());
println!("size = {}", t.size());
println!("depth = {}", t.depth());
println!("sum = {}", t.sum());
// fmap: double all values
let doubled = t.fmap(&|n| n * 2);
println!("root * 2 = {}", doubled.extract());
println!("all doubled: {:?}", doubled.collect_values());
// extend: annotate each node with its subtree sum
let annotated = t.extend(&|subtree| subtree.sum());
println!("\nAnnotated with subtree sums:");
println!(" root subtree sum = {} (should be 36)", annotated.extract());
println!(" values: {:?}", annotated.collect_values());
// duplicate: each node labeled with the subtree rooted there
let duped = t.duplicate();
println!("\nAfter duplicate:");
println!(" root.value.value = {}", duped.extract().extract());
println!(" root.value.size = {}", duped.extract().size());
// Comonad law 1: extract . extend f = f
let f = |subtree: &Rose<i32>| subtree.sum();
let extended = t.extend(&f);
assert_eq!(extended.extract(), f(&t), "Law 1: extract . extend f = f");
println!("\nComonad law 1 (extract . extend f = f): {} = {} โ",
extended.extract(), f(&t));
// Build a string tree (different type)
let st = Rose::node("root", vec![
Rose::leaf("left"),
Rose::node("mid", vec![Rose::leaf("deep")]),
Rose::leaf("right"),
]);
let lengths = st.extend(&|subtree| subtree.extract().len());
println!("\nString tree annotated with label lengths: {:?}", lengths.collect_values());
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_tree() -> Rose<i32> {
Rose::node(1, vec![
Rose::node(2, vec![Rose::leaf(4), Rose::leaf(5)]),
Rose::leaf(3),
])
}
#[test]
fn test_extract() {
assert_eq!(sample_tree().extract(), 1);
}
#[test]
fn test_size() {
assert_eq!(sample_tree().size(), 5);
}
#[test]
fn test_depth() {
assert_eq!(sample_tree().depth(), 3);
}
#[test]
fn test_sum() {
assert_eq!(sample_tree().sum(), 15); // 1+2+3+4+5
}
#[test]
fn test_fmap() {
let t = sample_tree();
let doubled = t.fmap(&|n| n * 2);
assert_eq!(doubled.extract(), 2);
assert_eq!(doubled.sum(), 30);
}
#[test]
fn test_extend_law1() {
// extract . extend f = f
let t = sample_tree();
let f = |sub: &Rose<i32>| sub.sum();
assert_eq!(t.extend(&f).extract(), f(&t));
}
#[test]
fn test_leaf() {
let l = Rose::leaf(42_i32);
assert_eq!(l.size(), 1);
assert_eq!(l.depth(), 1);
assert_eq!(l.children.len(), 0);
}
}
(* Cofree comonad: annotates every node in a structure with a label.
Cofree f a = a * f (Cofree f a)
For f = [], this gives rose trees (annotated with a-values at each node) *)
(* Cofree over list = rose tree *)
type 'a rose = Rose of 'a * 'a rose list
let leaf x = Rose (x, [])
let node x children = Rose (x, children)
(* Comonad operations *)
let extract (Rose (a, _)) = a
let rec extend (Rose (a, children)) f =
Rose (f (Rose (a, children)), List.map (fun child -> extend child f) children)
let duplicate t = extend t (fun x -> x)
(* Functor: map over annotations *)
let rec fmap f (Rose (a, children)) =
Rose (f a, List.map (fmap f) children)
(* Fold: reduce the tree *)
let rec fold f (Rose (a, children)) =
f a (List.map (fold f) children)
(* Size and depth *)
let size = fold (fun _ cs -> 1 + List.fold_left ( + ) 0 cs)
let depth = fold (fun _ cs -> 1 + (List.fold_left max 0 cs))
let sum = fold (fun a cs -> a + List.fold_left ( + ) 0 cs)
let () =
let t = node 1 [
node 2 [leaf 4; leaf 5];
node 3 [leaf 6; node 7 [leaf 8]];
] in
Printf.printf "root = %d\n" (extract t);
Printf.printf "size = %d\n" (size t);
Printf.printf "depth = %d\n" (depth t);
Printf.printf "sum = %d\n" (sum t);
let doubled = fmap (fun n -> n * 2) t in
Printf.printf "root*2 = %d\n" (extract doubled);
(* extend: annotate each node with its subtree sum *)
let annotated = extend t sum in
Printf.printf "subtree sums at root = %d\n" (extract annotated)