๐Ÿฆ€ Functional Rust

248: Kan Extensions

Difficulty: 5 Level: Master The most universal constructions in category theory โ€” every concept is a Kan extension.

The Problem This Solves

You have a list `[1, 2, 3]` and you apply `.map(f).map(g)`. That's two passes over the data. The Yoneda lemma tells you these can be fused into one pass: `.map(|x| g(f(x)))`. But more powerfully, Kan extensions generalize this fusion principle to any functor, not just lists. The codensity monad (the right Kan extension of a functor along itself) gives you the same thing for monadic bind chains. A sequence of left-nested `flat_map` calls on a list is O(nยฒ) because each intermediate list is built and then immediately consumed. Representing the computation as a codensity continuation makes it O(n) โ€” the intermediate lists are never materialized. This matters in practice whenever you build query engines, effect systems, or parser combinators where operations are composed before being run.

The Intuition

Kan extensions answer the question: "I have a functor F from category C to E, and another functor K from C to D. I want to extend F along K to get a functor from D to E. What's the best approximation?" There are two dual answers: The Yoneda lemma as a special case: When `K = Id` (the identity functor), `Ran Id F โ‰… F`. In types: `forall r. (a -> r) -> F(r)` is naturally isomorphic to `F(a)`. This means any functor `F` can be represented as a continuation โ€” and mapping over the continuation fuses automatically. The codensity monad: `Ran F F` is a monad whenever `F` is a functor. For `F = Vec`, the codensity monad represents list computations as continuations. Multiple `flat_map` calls compose as continuation composition โ€” O(1) per bind โ€” and the list is only materialized once at the end via `lower()`. Think of it like `StringBuilder` vs string concatenation: naive string concatenation is O(nยฒ) because each `+` copies. `StringBuilder` buffers all the operations and concatenates once. Codensity is the monadic version of this pattern.

How It Works in Rust

The codensity monad for `Vec`:
pub struct Codensity<A: 'static> {
 // Represents: forall r. (A -> Vec<r>) -> Vec<r>
 // Stored as a one-shot closure (FnOnce) specialized to R = A
 run: Box<dyn FnOnce(&dyn Fn(A) -> Vec<A>) -> Vec<A>>,
}

impl<A: Clone + 'static> Codensity<A> {
 // Lift: wrap a Vec<A> as a lazy continuation
 pub fn lift(xs: Vec<A>) -> Self {
     Codensity { run: Box::new(move |k| xs.into_iter().flat_map(k).collect()) }
 }

 // Lower: run the continuation with identity to get the Vec back
 pub fn lower(self) -> Vec<A> {
     (self.run)(&|x| vec![x])
 }
}

// Bind composes continuations โ€” NO intermediate lists
pub fn bind_codensity<A: Clone + 'static>(
 m: Codensity<A>,
 f: impl Fn(A) -> Codensity<A> + 'static,
) -> Codensity<A> {
 Codensity {
     run: Box::new(move |k| {
         (m.run)(&|a| {
             let inner = f(a);
             (inner.run)(k)  // compose: k is passed through, not applied twice
         })
     }),
 }
}
Yoneda round-trip (right Kan extension along identity):
// to_ran converts Vec<A> to its Yoneda representation
fn to_ran<A: Clone>(xs: Vec<A>) -> impl Fn(&dyn Fn(A) -> A) -> Vec<A> {
 move |k| xs.iter().cloned().map(k).collect()
}

// from_ran recovers Vec<A> by applying the identity function
fn from_ran<A: Clone>(ran: impl Fn(&dyn Fn(A) -> A) -> Vec<A>) -> Vec<A> {
 ran(&|x| x)
}

// Map fusion: applying two functions inside Ran = one pass, not two
let fused = to_ran(vec![1, 2, 3, 4, 5])(&|x| x * 2 + 1);
// Equivalent to .map(|x| x*2).map(|x| x+1) but fused

What This Unlocks

Key Differences

ConceptOCamlRust
Codensity type`type 'a codensity = { run: 'r. ('a -> 'r list) -> 'r list }``struct Codensity<A> { run: Box<dyn FnOnce(...)> }` โ€” no rank-2 types, specialized
Rank-2 polymorphismFirst-class via `type 'r.` syntaxNot stable; workaround: specialize R or use trait objects
Yoneda `forall r`Handled by the type system nativelyMust fix R or use `Box<dyn Fn>` with trait objects
Bind fusionAutomatic via continuation compositionSame structure; requires `FnOnce` + `Box` for single-use
Left Kan (Lan)`type ('f, 'a) lan = Lan : ('f 'b -> 'a) * 'b -> ('f, 'a) lan`Needs existential types; achievable with trait objects
/// Kan Extensions: the most universal constructions in category theory.
///
/// "All concepts are Kan extensions." โ€” Saunders Mac Lane
///
/// Given functors K: C -> D and F: C -> E, the Kan extensions extend F along K:
///
///   Right Kan extension: Ran K F
///     Ran_K F (d) = โˆ€c. D(d, Kc) -> Fc   (natural in d)
///
///   Left Kan extension: Lan K F
///     Lan_K F (d) = โˆƒc. D(Kc, d) ร— Fc   (natural in d)
///
/// โ”€โ”€ Codensity Monad โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
/// The right Kan extension of F along itself: Ran_F F
///   Ran_F F (a) = โˆ€r. (a -> F r) -> F r
///
/// When F = [], this is a monad on lists that improves asymptotic complexity
/// of left-nested binds (from quadratic to linear) โ€” the "codensity optimization."
///
/// The Yoneda lemma is Ran_Id F(a) โ‰… F(a) (right Kan ext along identity).

// โ”€โ”€ Codensity Monad (Ran F F for F = Vec) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

/// Codensity<A>: represents a list computation A as a continuation.
/// `โˆ€r. (A -> Vec<r>) -> Vec<r>`
///
/// We fix R to a concrete type via a closure stored in the struct.
/// The key: multiple binds fuse without intermediate allocations.
pub struct Codensity<A: 'static> {
    /// The continuation: given (A -> Vec<R>), produce Vec<R>.
    /// We represent this as a function on Any via Box<dyn Fn>.
    /// For simplicity, we specialise to R = A (homogeneous).
    run: Box<dyn FnOnce(&dyn Fn(A) -> Vec<A>) -> Vec<A>>,
}

impl<A: Clone + 'static> Codensity<A> {
    /// Lift: wrap a Vec<A> as a Codensity.
    pub fn lift(xs: Vec<A>) -> Self {
        Codensity {
            run: Box::new(move |k| xs.into_iter().flat_map(|x| k(x)).collect()),
        }
    }

    /// return: put a single value in codensity context.
    pub fn return_(a: A) -> Self {
        Codensity {
            run: Box::new(move |k| k(a)),
        }
    }

    /// Lower: run the continuation with the identity, recovering Vec<A>.
    pub fn lower(self) -> Vec<A> {
        (self.run)(&|x| vec![x])
    }
}

/// Bind for Codensity: this is where the magic happens.
/// Instead of building intermediate lists, we compose continuations.
pub fn bind_codensity<A: Clone + 'static>(
    m: Codensity<A>,
    f: impl Fn(A) -> Codensity<A> + 'static,
) -> Codensity<A> {
    Codensity {
        run: Box::new(move |k| {
            (m.run)(&|a: A| {
                let inner = f(a);
                (inner.run)(k)
            })
        }),
    }
}

// โ”€โ”€ Heterogeneous Codensity via closures โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

/// A more flexible Codensity that supports type changes via chaining.
/// We store the computation as a chain of boxed closures.
pub struct CodenseChain {
    steps: Vec<Box<dyn FnOnce(Vec<i64>) -> Vec<i64>>>,
    initial: Vec<i64>,
}

impl CodenseChain {
    pub fn new(initial: Vec<i64>) -> Self {
        CodenseChain { steps: vec![], initial }
    }

    pub fn bind(mut self, f: impl Fn(i64) -> Vec<i64> + 'static) -> Self {
        self.steps.push(Box::new(move |xs| xs.into_iter().flat_map(|x| f(x)).collect()));
        self
    }

    pub fn run(self) -> Vec<i64> {
        let mut current = self.initial;
        for step in self.steps {
            current = step(current);
        }
        current
    }
}

// โ”€โ”€ Right Kan Extension demonstration โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

/// Ran_Id F โ‰… F (Yoneda lemma, special case).
/// Ran_Id Vec (a) = โˆ€r. (a -> r) -> Vec<r> โ‰… Vec<a>
///
/// to_ran:   Vec<a> -> (โˆ€r. (a->r) -> Vec<r>)   [toYoneda]
/// from_ran: (โˆ€r. (a->r) -> Vec<r>) -> Vec<a>    [fromYoneda, via id]
fn to_ran<A: Clone + 'static>(xs: Vec<A>) -> impl Fn(&dyn Fn(A) -> A) -> Vec<A> {
    move |k| xs.iter().cloned().map(k).collect()
}

fn from_ran<A: Clone>(ran: impl Fn(&dyn Fn(A) -> A) -> Vec<A>) -> Vec<A> {
    ran(&|x| x) // apply the identity
}

fn main() {
    println!("=== Kan Extensions ===\n");
    println!("Ran K F (d) = โˆ€c. D(d, Kc) -> F(c)   [right Kan extension]");
    println!("Lan K F (d) = โˆƒc. D(Kc, d) ร— F(c)    [left Kan extension]\n");
    println!("Yoneda lemma = Ran_Id F โ‰… F");
    println!("Codensity monad = Ran_F F\n");

    // โ”€โ”€ Codensity monad โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    println!("โ”€โ”€ Codensity Monad (Ran Vec Vec) โ”€โ”€\n");

    let program = bind_codensity(
        Codensity::lift(vec![1_i64, 2, 3]),
        |x| bind_codensity(
            Codensity::lift(vec![10_i64, 20]),
            move |y| Codensity::return_(x + y),
        ),
    );

    let result = program.lower();
    println!("Codensity bind [1,2,3] >>= (\\x -> [10,20] >>= (\\y -> [x+y])):");
    println!("  Result: {:?}", result);

    // Direct computation for verification
    let direct: Vec<i64> = vec![1, 2, 3].into_iter()
        .flat_map(|x| vec![10_i64, 20].into_iter().map(move |y| x + y))
        .collect();
    assert_eq!(result, direct);
    println!("  Direct: {:?}", direct);
    println!("  Equal: โœ“\n");

    // CodenseChain for chaining
    println!("โ”€โ”€ CodenseChain โ”€โ”€\n");
    let chain_result = CodenseChain::new(vec![1, 2, 3])
        .bind(|x| vec![x, x * 10])
        .bind(|x| if x > 5 { vec![x] } else { vec![] })
        .run();
    println!("chain [1,2,3] >>= [x,x*10] >>= filter(>5): {:?}", chain_result);

    // โ”€โ”€ Yoneda = Ran_Id โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    println!("\nโ”€โ”€ Yoneda = Ran_Id F (Right Kan ext along Identity) โ”€โ”€\n");

    let original = vec![1_i64, 2, 3, 4, 5];
    let ran = to_ran(original.clone());

    // Apply a transformation inside the Ran representation
    let doubled = ran(&|x| x * 2);
    println!("to_ran [1..5] applied to (*2): {:?}", doubled);

    // Round-trip via identity
    let recovered = from_ran(to_ran(original.clone()));
    assert_eq!(recovered, original);
    println!("from_ran(to_ran(xs)) = xs: {:?} โœ“", recovered);

    // Fuse maps via Ran
    let fused: Vec<i64> = to_ran(original.clone())(&|x| x * 2 + 1);
    let direct2: Vec<i64> = original.iter().map(|&x| x * 2 + 1).collect();
    assert_eq!(fused, direct2);
    println!("Ran fusion: {:?} = {:?} โœ“", fused, direct2);

    println!();
    println!("Key insight: Kan extensions are the 'best approximation' of a functor");
    println!("that doesn't exist along a given path in the category diagram.");
    println!("Everything โ€” limits, colimits, adjunctions โ€” can be expressed as Kan extensions.");
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_codensity_lift_lower_roundtrip() {
        let xs = vec![1_i64, 2, 3];
        assert_eq!(Codensity::lift(xs.clone()).lower(), xs);
    }

    #[test]
    fn test_codensity_return() {
        assert_eq!(Codensity::return_(42_i64).lower(), vec![42]);
    }

    #[test]
    fn test_codensity_bind() {
        let result = bind_codensity(
            Codensity::lift(vec![1_i64, 2]),
            |x| Codensity::lift(vec![x * 10, x * 100]),
        ).lower();
        assert_eq!(result, vec![10, 100, 20, 200]);
    }

    #[test]
    fn test_ran_id_roundtrip() {
        let xs = vec![7_i64, 8, 9];
        assert_eq!(from_ran(to_ran(xs.clone())), xs);
    }

    #[test]
    fn test_ran_map() {
        let ran = to_ran(vec![1_i64, 2, 3]);
        let result = ran(&|x| x + 100);
        assert_eq!(result, vec![101, 102, 103]);
    }
}
(* Kan extensions: the most general way to extend functors.
   Right Kan extension: Ran K F a = forall b. (a -> K b) -> F b
   Left Kan extension:  Lan K F a = exists b. K b * (F b -> a)
   Every concept in category theory is a Kan extension! *)

(* Right Kan extension (codensity) of the identity functor *)
(* Ran Id F a = forall b. (a -> b) -> F b
   This is the Yoneda lemma! *)

(* Codensity monad: Ran F F
   type 'a codensity = { run : 'b. ('a -> F 'b) -> F 'b } *)
type 'a codensity = { run : 'b. ('a -> 'b list) -> 'b list }

let return_c x = { run = (fun k -> k x) }

let bind_c m f = { run = (fun k -> m.run (fun a -> (f a).run k)) }

(* Lift a list computation *)
let lift_c lst = { run = (fun k -> List.concat_map k lst) }

(* Lower back to list *)
let lower_c c = c.run (fun x -> [x])

let () =
  (* Codensity improves asymptotic complexity of left-nested binds *)
  let program =
    bind_c (lift_c [1; 2; 3]) (fun x ->
    bind_c (lift_c [10; 20])  (fun y ->
    return_c (x + y)))
  in
  let result = lower_c program in
  Printf.printf "codensity: [%s]\n"
    (result |> List.map string_of_int |> String.concat ";");

  (* Verify same as direct computation *)
  let direct = List.concat_map (fun x ->
    List.map (fun y -> x + y) [10; 20]) [1; 2; 3] in
  assert (result = direct);
  Printf.printf "Kan extension (codensity) = direct computation\n"