πŸ¦€ Functional Rust

244: Comonad Laws

Difficulty: 5 Level: Master Verify the three comonad laws on the Stream comonad β€” and understand what they mean.

The Problem This Solves

Laws aren't just mathematical vanity. When you implement a comonad (or use a library that claims to provide one), the laws are contracts that let you refactor safely. If `extend extract = id`, you can remove a no-op extend without changing behavior. If the associativity law holds, you can split a complex extend into two simpler ones β€” or merge them β€” without fear. The classic abuse case: you implement a moving average over a data stream. You chain several `extend` calls. Are you applying them in the right order? Can you fuse two passes into one? The comonad laws tell you exactly what substitutions are safe. The Stream comonad is the ideal testbed because it's infinite, it has a clear notion of "current position," and every comonad law has a direct intuitive meaning in terms of time-series data.

The Intuition

The Stream comonad is an infinite lazy list with a "current position": The key insight: `map` applies `f` to each element. `extend` applies `f` to each element and all its context (everything after it). For a moving average, `f` needs to see the next 3 elements β€” that's context. `extend` provides it. The three laws (in plain English): 1. Left identity β€” `extract(extend f s) = f(s)`: If you extend with `f` then immediately extract, you get the same result as just calling `f` on the original stream. The current position is preserved. 2. Right identity β€” `extend extract s = s`: If you extend with the trivial function "just read your own head," you get back the same stream unchanged. The identity computation does nothing. 3. Associativity β€” `extend f (extend g s) = extend (f . extend g) s`: The order of chaining extends matches what you'd get by composing the computations before extending. You can fuse or split passes. `duplicate` is the "canonical" comonad operation from which `extend` can be derived: `duplicate s` produces a stream of streams, where position `i` holds the stream starting at `i`. `extend f = map f . duplicate`.

How It Works in Rust

The Stream is defined as a recursive struct with a lazy tail:
pub struct Stream<A: Clone> {
 pub head: A,
 tail: Rc<dyn Fn() -> Stream<A>>,  // lazy: only computed when accessed
}
`extend` applies `f` to every suffix lazily β€” only computes when the stream is consumed:
pub fn extend<B>(&self, f: Rc<dyn Fn(&Stream<A>) -> B>) -> Stream<B> {
 let head_val = f(self);           // apply f to current suffix
 let tail_stream = self.tail();
 let f_clone = f.clone();
 // Recursively extend the tail β€” lazy, not immediately evaluated
 Stream::new(head_val, move || tail_stream.extend(f_clone.clone()))
}
Moving average (needs 3-element context β€” only possible with `extend`, not `map`):
let avg = s.extend(Rc::new(|st: &Stream<i64>| {
 let t = st.tail();
 let tt = t.tail();
 (st.head + t.head + tt.head) / 3  // st has access to its own context
}));
Law verification (checking first N elements for the infinite stream):
// Law 1: extract(extend f s) == f(s)
fn check_law1(s: &Stream<i64>, f: Rc<..>) -> bool {
 s.extend(f.clone()).extract() == f(s)
}

// Law 2: extend extract == id (first N elements equal)
fn check_law2(s: &Stream<i64>, n: usize) -> bool {
 s.extend(Rc::new(|st: &Stream<i64>| st.extract())).take(n) == s.take(n)
}

What This Unlocks

Key Differences

ConceptOCamlRust
Infinite stream`type 'a stream = { head: 'a; tail: unit -> 'a stream }``struct Stream<A> { head: A, tail: Rc<dyn Fn() -> Stream<A>> }`
Lazy tailClosure `unit ->``Rc<dyn Fn()>` β€” shared ownership needed for recursive cloning
extendHigher-kinded via moduleMethod with `Rc<dyn Fn>` to allow cloning
Sharing tailsGC handles aliasing`Rc` for reference counting; `Arc` for threads
Comonad typeclassFirst-class typeclassImplement methods directly; no HKT in stable Rust
/// Comonad Laws β€” verified on the Stream comonad.
///
/// Three comonad laws (dual of monad laws):
///
///   1. extract . extend f = f                      (left identity)
///   2. extend extract     = id                     (right identity)
///   3. extend f . extend g = extend (f . extend g) (associativity)
///
/// The Stream comonad is a natural setting:
///   Stream<A> = infinite list with a "current head"
///   extract = head (read current value)
///   extend f = apply f to every suffix of the stream
///
/// We use `Rc` to share stream tails without cloning.

use std::rc::Rc;

// ── Stream ────────────────────────────────────────────────────────────────────

/// An infinite lazy stream.
/// The tail is wrapped in Rc<dyn Fn()> to allow sharing.
#[derive(Clone)]
pub struct Stream<A: Clone> {
    pub head: A,
    tail: Rc<dyn Fn() -> Stream<A>>,
}

impl<A: Clone + 'static> Stream<A> {
    pub fn new(head: A, tail: impl Fn() -> Stream<A> + 'static) -> Self {
        Stream { head, tail: Rc::new(tail) }
    }

    pub fn tail(&self) -> Stream<A> {
        (self.tail)()
    }

    /// Comonad: extract = read the head.
    pub fn extract(&self) -> A {
        self.head.clone()
    }

    /// Comonad: extend.
    /// Apply `f` to every suffix (tail) of the stream.
    pub fn extend<B: Clone + 'static>(&self, f: Rc<dyn Fn(&Stream<A>) -> B>) -> Stream<B> {
        let head_val = f(self);
        let tail_stream = self.tail();
        let f_clone = f.clone();
        Stream::new(head_val, move || tail_stream.extend(f_clone.clone()))
    }

    /// duplicate: Stream<A> -> Stream<Stream<A>>
    /// Each position holds the stream starting there.
    pub fn duplicate(&self) -> Stream<Stream<A>> {
        let tail_stream = self.tail();
        let self_clone = self.clone();
        Stream::new(self_clone, move || tail_stream.duplicate())
    }

    /// Take the first `n` elements.
    pub fn take(&self, n: usize) -> Vec<A> {
        let mut result = Vec::with_capacity(n);
        let mut cur = self.clone();
        for _ in 0..n {
            result.push(cur.head.clone());
            cur = cur.tail();
        }
        result
    }
}

/// Infinite stream of natural numbers starting from `n`.
fn from(n: i64) -> Stream<i64> {
    Stream::new(n, move || from(n + 1))
}

/// Repeat a value forever.
fn repeat<A: Clone + 'static>(a: A) -> Stream<A> {
    let a2 = a.clone();
    Stream::new(a, move || repeat(a2.clone()))
}

// ── Law verification ──────────────────────────────────────────────────────────

/// Law 1: extract . extend f = f
/// i.e., (extend f s).extract() == f(s)
fn check_law1<A, B>(s: &Stream<A>, f: Rc<dyn Fn(&Stream<A>) -> B>) -> bool
where
    A: Clone + PartialEq + 'static,
    B: Clone + PartialEq + 'static,
{
    let extended = s.extend(f.clone());
    extended.extract() == f(s)
}

/// Law 2: extend extract = id (same head, checked elementwise for N elements)
fn check_law2(s: &Stream<i64>, n: usize) -> bool {
    let extended = s.extend(Rc::new(|st: &Stream<i64>| st.extract()));
    extended.take(n) == s.take(n)
}

/// Law 3 (approximation): extend f . extend g β‰ˆ extend (f . extend g)
/// Checked on first N elements.
fn check_law3(s: &Stream<i64>, n: usize) -> bool {
    let g: Rc<dyn Fn(&Stream<i64>) -> i64> = Rc::new(|st: &Stream<i64>| st.head * 2);
    let f: Rc<dyn Fn(&Stream<i64>) -> i64> = Rc::new(|st: &Stream<i64>| st.head + 10);

    // LHS: extend f (extend g s)
    let g2 = g.clone();
    let f2 = f.clone();
    let extended_g = s.extend(g2);
    let lhs = extended_g.extend(f2);

    // RHS: extend (f . extend g) s
    let g3 = g.clone();
    let f3 = f.clone();
    let rhs = s.extend(Rc::new(move |st: &Stream<i64>| {
        f3(&st.extend(g3.clone()))
    }));

    lhs.take(n) == rhs.take(n)
}

fn main() {
    println!("=== Comonad Laws Verified on Stream ===\n");
    println!("Stream comonad: extract = head, extend = apply f to every suffix.\n");

    let s = from(1);
    println!("Stream from 1: {:?}", s.take(8));

    // extract
    println!("extract = {}", s.extract());

    // extend: double each element
    let doubled = s.extend(Rc::new(|st: &Stream<i64>| st.head * 2));
    println!("extend (*2): {:?}", doubled.take(5));

    // extend: moving average of 3 consecutive elements
    let avg = s.extend(Rc::new(|st: &Stream<i64>| {
        let t = st.tail();
        let tt = t.tail();
        (st.head + t.head + tt.head) / 3
    }));
    println!("3-element moving average: {:?}", avg.take(6));

    // Verify Law 1
    let f: Rc<dyn Fn(&Stream<i64>) -> i64> = Rc::new(|st: &Stream<i64>| st.head * 3);
    let law1 = check_law1(&s, f);
    println!("\nLaw 1 (extract . extend f = f):       {}", if law1 { "βœ“" } else { "βœ—" });

    // Verify Law 2
    let law2 = check_law2(&s, 10);
    println!("Law 2 (extend extract = id):           {}", if law2 { "βœ“" } else { "βœ—" });

    // Verify Law 3
    let law3 = check_law3(&s, 5);
    println!("Law 3 (extend f . extend g = extend…): {}", if law3 { "βœ“" } else { "βœ—" });

    // duplicate
    let dup = s.duplicate();
    println!("\nduplicate: head of head = {}, head of (head.tail) = {}",
        dup.head.head, dup.head.tail().head);

    // Repeat stream
    let rs = repeat(42_i32);
    println!("\nrepeat(42): {:?}", rs.take(5));

    println!();
    println!("Key insight: extend is NOT the same as map.");
    println!("extend f gives each element FULL CONTEXT (its suffix),");
    println!("map gives only the element itself.");
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_stream_extract() {
        let s = from(5);
        assert_eq!(s.extract(), 5);
    }

    #[test]
    fn test_stream_take() {
        let s = from(1);
        assert_eq!(s.take(5), vec![1, 2, 3, 4, 5]);
    }

    #[test]
    fn test_law1_extract_extend() {
        let s = from(10);
        let f: Rc<dyn Fn(&Stream<i64>) -> i64> = Rc::new(|st: &Stream<i64>| st.head + 100);
        assert!(check_law1(&s, f));
    }

    #[test]
    fn test_law2_extend_extract_id() {
        let s = from(0);
        assert!(check_law2(&s, 8));
    }

    #[test]
    fn test_law3_associativity() {
        let s = from(1);
        assert!(check_law3(&s, 5));
    }

    #[test]
    fn test_extend_double() {
        let s = from(1);
        let doubled = s.extend(Rc::new(|st: &Stream<i64>| st.head * 2));
        assert_eq!(doubled.take(4), vec![2, 4, 6, 8]);
    }
}
(* Comonad laws (dual of monad laws):
   1. extract . extend f = f
   2. extend extract = id
   3. extend f . extend g = extend (f . extend g)
   
   Verified on the Stream comonad (infinite list with focus) *)

(* Lazy stream *)
type 'a stream = Cons of 'a * (unit -> 'a stream)

let rec from n = Cons (n, fun () -> from (n + 1))
let head (Cons (x, _)) = x
let tail (Cons (_, f)) = f ()

let take n s =
  let rec go n s acc =
    if n = 0 then List.rev acc
    else let Cons (x, f) = s in go (n-1) (f ()) (x :: acc)
  in go n s []

(* Stream is a comonad *)
let extract = head

let rec extend s f =
  Cons (f s, fun () -> extend (tail s) f)

let duplicate s = extend s (fun x -> x)

(* Verify law 1: extract . extend f = f *)
let check_law1 f s =
  let lhs = extract (extend s f) in
  let rhs = f s in
  lhs = rhs

(* Verify law 2: extend extract = id (on first element) *)
let check_law2 s =
  let s' = extend s extract in
  head s = head s'

let () =
  let s = from 1 in

  let f s = head s * 2 in
  assert (check_law1 f s);
  Printf.printf "Law 1 (extract . extend f = f): holds\n";

  assert (check_law2 s);
  Printf.printf "Law 2 (extend extract = id): holds\n";

  (* extend: compute moving average *)
  let avg s = (head s + head (tail s) + head (tail (tail s))) / 3 in
  let avgs = extend s avg in
  Printf.printf "Moving avg of 1..5: [%s]\n"
    (take 5 avgs |> List.map string_of_int |> String.concat ";")