๐Ÿฆ€ Functional Rust

601: Coproduct / Sum Types

Difficulty: 5 Level: Master Model "one of these alternatives" using the categorical coproduct โ€” Rust's `enum` is a coproduct with injection and elimination morphisms.

The Problem This Solves

You often have data that can be one of several distinct alternatives: a result is either `Ok` or `Err`; an AST node is a number, variable, or operation; a network packet is a control frame or a data frame. Representing this with a struct + boolean flags or a nullable union type loses type safety โ€” the type doesn't tell you which case you're in, so the compiler can't ensure you handle all cases. Category theory gives this a precise formulation: the coproduct `A + B` is a type inhabited by either an `A` or a `B`, with injection functions `inl: A โ†’ A+B` and `inr: B โ†’ A+B`. The universal property says: for any type `C` and functions `f: Aโ†’C` and `g: Bโ†’C`, there is a unique function `either(f, g): A+B โ†’ C`. This unique function is `match`. Understanding coproducts as a categorical structure means you can reason about them algebraically: `Either<A, Void>` is isomorphic to `A`; `Either<A, B>` is isomorphic to `Either<B, A>`; `Either<A, Either<B, C>>` is isomorphic to `Either<Either<A, B>, C>`. These are the same algebraic identities as for addition.

The Intuition

The coproduct `A + B` means "exactly one of A or B, tagged so you know which" โ€” Rust's `enum` is precisely this, with enum constructors as injections and `match` as the unique elimination morphism guaranteed by the universal property. The trade-off: coproducts make adding new interpreters easy but adding new variants hard (you must update all match arms).

How It Works in Rust

// Coproduct A + B: inhabited by either A or B
enum Either<A, B> {
 Left(A),   // injection inl: A โ†’ Either<A,B>
 Right(B),  // injection inr: B โ†’ Either<A,B>
}

impl<A, B> Either<A, B> {
 // The universal property: unique morphism from two functions
 // For any C, f: Aโ†’C, g: Bโ†’C, there is a unique Either<A,B>โ†’C
 fn either<C, F, G>(self, f: F, g: G) -> C
 where F: FnOnce(A) -> C, G: FnOnce(B) -> C {
     match self {               // match IS the unique morphism
         Either::Left(a)  => f(a),
         Either::Right(b) => g(b),
     }
 }

 // Functor map over both sides
 fn map_left<C>(self, f: impl FnOnce(A) -> C) -> Either<C, B> {
     match self {
         Either::Left(a)  => Either::Left(f(a)),
         Either::Right(b) => Either::Right(b),
     }
 }
}

// Algebraic identity: Either<A, Void> โ‰… A
enum Void {}  // uninhabited โ€” no values of type Void exist

fn from_void_either<A>(e: Either<A, Void>) -> A {
 match e {
     Either::Left(a)  => a,
     Either::Right(v) => match v {},  // exhaustively handled โ€” Void has no variants
 }
}

What This Unlocks

Key Differences

ConceptOCamlRust
Coproduct A+B`type t = L of A \R of B``enum T { L(A), R(B) }`
Injection `inl``L : A -> t` constructor`T::L(a)`
Injection `inr``R : B -> t` constructor`T::R(b)`
Elimination`match``match`
Universal morphismUnique function from matchSame โ€” `match` is the canonical eliminator
Uninhabited type`type void = \` (empty)`enum Void {}`
#[derive(Debug,Clone,PartialEq)]
enum Either<A,B> { Left(A), Right(B) }

impl<A,B> Either<A,B> {
    fn inl(a: A) -> Self { Either::Left(a) }
    fn inr(b: B) -> Self { Either::Right(b) }

    // Universal property: the unique morphism
    fn either<C>(self, f: impl FnOnce(A)->C, g: impl FnOnce(B)->C) -> C {
        match self { Either::Left(a)=>f(a), Either::Right(b)=>g(b) }
    }

    fn bimap<C,D>(self, f: impl FnOnce(A)->C, g: impl FnOnce(B)->D) -> Either<C,D> {
        match self { Either::Left(a)=>Either::Left(f(a)), Either::Right(b)=>Either::Right(g(b)) }
    }

    fn is_left(&self)  -> bool { matches!(self, Either::Left(_))  }
    fn is_right(&self) -> bool { matches!(self, Either::Right(_)) }

    fn left(self)  -> Option<A> { match self { Either::Left(a)=>Some(a),  _=>None } }
    fn right(self) -> Option<B> { match self { Either::Right(b)=>Some(b), _=>None } }
}

// Partition a Vec<Either<A,B>> into (Vec<A>, Vec<B>)
fn partition_either<A,B>(items: Vec<Either<A,B>>) -> (Vec<A>, Vec<B>) {
    let (mut lefts, mut rights) = (vec![], vec![]);
    for item in items {
        match item { Either::Left(a)=>lefts.push(a), Either::Right(b)=>rights.push(b) }
    }
    (lefts, rights)
}

fn main() {
    let xs: Vec<Either<i32,String>> = vec![
        Either::inl(1), Either::inr("hello".into()),
        Either::inl(42), Either::inr("world".into()),
    ];
    for e in &xs {
        let desc = e.clone().either(|n| format!("int:{}", n), |s| format!("str:{}", s));
        println!("{}", desc);
    }
    let (ints, strs) = partition_either(xs);
    println!("ints: {:?}  strs: {:?}", ints, strs);
}

#[cfg(test)]
mod tests {
    use super::*;
    #[test] fn inl_left() { assert!(Either::<i32,&str>::inl(5).is_left()); }
    #[test] fn either_map() {
        let e: Either<i32,&str> = Either::inl(5);
        assert_eq!(e.either(|n|n*2,|_|0), 10);
    }
    #[test] fn partition() {
        let v = vec![Either::Left(1),Either::Right("a"),Either::Left(2)];
        let (l,r) = partition_either(v);
        assert_eq!(l, vec![1,2]); assert_eq!(r, vec!["a"]);
    }
}
(* Coproduct (sum type) in OCaml *)
type ('a,'b) either = Left of 'a | Right of 'b

(* Injection morphisms *)
let inl : 'a -> ('a,'b) either = fun a -> Left a
let inr : 'b -> ('a,'b) either = fun b -> Right b

(* Elimination: the universal property *)
let either : ('a -> 'c) -> ('b -> 'c) -> ('a,'b) either -> 'c =
  fun f g -> function Left a -> f a | Right b -> g b

(* Bifunctor map *)
let bimap f g = function Left a -> Left (f a) | Right b -> Right (g b)

let () =
  let xs : (int, string) either list = [Left 1; Right "hello"; Left 42; Right "world"] in
  List.iter (fun e ->
    let desc = either (fun n -> Printf.sprintf "int:%d" n)
                      (fun s -> Printf.sprintf "str:%s" s) e in
    Printf.printf "%s\n" desc
  ) xs;
  let doubled = List.map (bimap (fun n->n*2) String.uppercase_ascii) xs in
  List.iter (fun e -> Printf.printf "%s " (either string_of_int Fun.id e)) doubled;
  print_newline ()