🦀 Functional Rust

604: Monad Laws — Rust Deep Dive

Difficulty: 5 Level: Master Encode monadic bind as a Rust trait, verify the three monad laws with generic functions, and understand why the laws matter for compiler-assisted composition.

The Problem This Solves

You've used `Option::and_then` and `Result::and_then`. You've seen the three monad laws stated informally. But at the Master level, there's a harder problem: can you express the monad pattern generically in Rust's type system? The challenge is significant. Monads require what Rust calls "higher-kinded types" — the ability to abstract over type constructors like `Option<_>` or `Result<_, E>`, not just over concrete types. Rust doesn't support this natively (unlike Haskell or OCaml). Working around it reveals a lot about how Rust's type system differs from functional languages. Beyond the type system puzzle, there's a practical concern: once you write a `Monad` trait implementation, how do you prove it's correct? The compiler won't enforce the laws. You need to verify them programmatically — with concrete functions that check each law holds for specific values. This exists to solve exactly that pain: it shows you how far you can push Rust's type system toward the monad abstraction, and gives you runnable law proofs for `Option` and `Result`.

The Intuition

Think of the `Monad` trait like a contract for a production line: The three laws are quality guarantees for that contract: Law 1 — Left Identity: Starting the line with a value and immediately handing it to station `f` is the same as just running `f` directly. The startup step is transparent.
unit(a).bind(f)  ==  f(a)
Law 2 — Right Identity: If the only station is "wrap in unit", the line produces what it started with. The unit wrapper is transparent.
m.bind(unit)  ==  m
Law 3 — Associativity: Grouping doesn't matter. Running station `f` then `g` sequentially is the same as running `f` where `f` itself runs `g`. You can refactor pipelines into sub-pipelines freely.
m.bind(f).bind(g)  ==  m.bind(|x| f(x).bind(g))
Without Law 3, extracting a sub-pipeline into its own function might change the result. That would make monads unreliable for building abstractions.

How It Works in Rust

A `Monad` trait using associated types Rust can't directly express "a type constructor `M<_>` where `M<A>` and `M<B>` are related." The workaround: use an associated type `Wrapped<B>` to name the output type:
trait Monad: Sized {
 type Inner;               // the A in Option<A>
 type Wrapped<B>: Monad<Inner=B>;  // the Option<B> produced by bind
 fn unit(a: Self::Inner) -> Self;
 fn bind<B>(self, f: impl FnOnce(Self::Inner) -> Self::Wrapped<B>) -> Self::Wrapped<B>;
}
Implementing for Option
impl<A> Monad for Option<A> {
 type Inner = A;
 type Wrapped<B> = Option<B>;
 fn unit(a: A) -> Option<A> { Some(a) }
 fn bind<B>(self, f: impl FnOnce(A) -> Option<B>) -> Option<B> {
     self.and_then(f)   // bind IS and_then — no magic
 }
}
Implementing for Result
impl<A, E: Clone> Monad for Result<A, E> {
 type Inner = A;
 type Wrapped<B> = Result<B, E>;
 fn unit(a: A) -> Result<A, E> { Ok(a) }
 fn bind<B>(self, f: impl FnOnce(A) -> Result<B, E>) -> Result<B, E> {
     self.and_then(f)   // same: bind IS and_then
 }
}
Generic law verification functions
// Law 1: unit(a).bind(f) == f(a)
fn left_identity<A: Clone, B: PartialEq>(a: A, f: impl Fn(A) -> Option<B> + Clone) -> bool {
 let left  = Option::unit(a.clone()).bind(f.clone());
 let right = f(a);
 left == right
}

// Law 2: m.bind(unit) == m
fn right_identity<A: PartialEq + Clone>(m: Option<A>) -> bool {
 m.clone().bind(Option::unit) == m
}

// Law 3: m.bind(f).bind(g) == m.bind(|x| f(x).bind(g))
fn associativity<A: Clone, B: Clone, C: PartialEq>(
 m: Option<A>,
 f: impl Fn(A) -> Option<B> + Clone,
 g: impl Fn(B) -> Option<C> + Clone,
) -> bool {
 let left  = m.clone().bind(f.clone()).bind(g.clone());
 let right = m.bind(move |x| f(x).bind(g.clone()));
 left == right
}
Running the proofs
let f = |x: i32| if x > 0 { Some(x*2) } else { None };
let g = |x: i32| if x < 100 { Some(x+1) } else { None };

assert!(left_identity(5, f));            // true: Some(5).bind(f) == f(5)
assert!(right_identity(Some(5)));        // true: Some(5).bind(Some) == Some(5)
assert!(right_identity(None::<i32>));    // true: None.bind(Some) == None
assert!(associativity(Some(5), f, g));   // true: grouping doesn't matter
assert!(associativity(None, f, g));      // true: None is transparent
Simulated do-notation with `?`
// Rust's equivalent of Haskell's do-notation or OCaml's let*
fn compute(s: &str) -> Option<i32> {
 let n      = s.parse::<i32>().ok()?;         // bind step 1
 let doubled = if n > 0 { Some(n*2) } else { None }?;  // bind step 2
 Some(doubled + 1)                             // unit at the end
}
// Each ? is one monadic bind. This IS the do-notation pattern.

What This Unlocks

Key Differences

ConceptOCamlRust
Generic monadType class / functor interfaceTrait with associated types (limited HKT)
Higher-kinded typesNative (`'a t`, `'a monad`)Workaround via `type Wrapped<B>`
`bind` name`>>=` / `bind``and_then` (stdlib) / `bind` (trait)
`return` / `unit``return` keyword in monad context`unit` method or `Some`/`Ok` directly
Law enforcementConvention (no type-level proof)Convention (verified via `assert!` tests)
Do-notation`let*` (OCaml 4.08+)`?` operator (for `Option`/`Result`)
trait Monad: Sized {
    type Inner;
    type Wrapped<B>: Monad<Inner=B>;
    fn unit(a: Self::Inner) -> Self;
    fn bind<B>(self, f: impl FnOnce(Self::Inner) -> Self::Wrapped<B>) -> Self::Wrapped<B>;
}

impl<A> Monad for Option<A> {
    type Inner = A;
    type Wrapped<B> = Option<B>;
    fn unit(a: A) -> Option<A> { Some(a) }
    fn bind<B>(self, f: impl FnOnce(A) -> Option<B>) -> Option<B> { self.and_then(f) }
}

impl<A,E: Clone> Monad for Result<A,E> {
    type Inner = A;
    type Wrapped<B> = Result<B,E>;
    fn unit(a: A) -> Result<A,E> { Ok(a) }
    fn bind<B>(self, f: impl FnOnce(A) -> Result<B,E>) -> Result<B,E> { self.and_then(f) }
}

// Law verification for Option
fn left_identity<A: Clone,B: PartialEq>(a: A, f: impl Fn(A) -> Option<B> + Clone) -> bool {
    let left  = Option::unit(a.clone()).bind(f.clone());
    let right = f(a);
    left == right
}

fn right_identity<A: PartialEq + Clone>(m: Option<A>) -> bool {
    let left = m.clone().bind(Option::unit);
    left == m
}

fn associativity<A: Clone, B: Clone, C: PartialEq>(
    m: Option<A>,
    f: impl Fn(A) -> Option<B> + Clone,
    g: impl Fn(B) -> Option<C> + Clone,
) -> bool {
    let left  = m.clone().bind(f.clone()).bind(g.clone());
    let right = m.bind(move |x| f(x).bind(g.clone()));
    left == right
}

fn main() {
    let f = |x: i32| if x > 0 { Some(x*2) } else { None };
    let g = |x: i32| if x < 100 { Some(x+1) } else { None };

    println!("Left identity (5, f):       {}", left_identity(5, f));
    println!("Right identity (Some(5)):   {}", right_identity(Some(5)));
    println!("Right identity (None):      {}", right_identity(None));
    println!("Associativity (Some(5),f,g):{}", associativity(Some(5), f, g));
    println!("Associativity (None,f,g):   {}", associativity(None, f, g));

    // Do-notation simulation via ? in functions
    fn compute(s: &str) -> Option<i32> {
        let n = s.parse::<i32>().ok()?;
        let doubled = if n > 0 { Some(n*2) } else { None }?;
        Some(doubled + 1)
    }
    println!("compute('5') = {:?}", compute("5"));
    println!("compute('-1')= {:?}", compute("-1"));
    println!("compute('x') = {:?}", compute("x"));
}

#[cfg(test)]
mod tests {
    use super::*;
    fn f(x: i32) -> Option<i32> { if x>0 { Some(x*2) } else { None } }
    fn g(x: i32) -> Option<i32> { if x<100 { Some(x+1) } else { None } }
    #[test] fn test_left_id()  { assert!(left_identity(5, f)); }
    #[test] fn test_right_id() { assert!(right_identity(Some(5))); assert!(right_identity::<i32>(None)); }
    #[test] fn test_assoc()    { assert!(associativity(Some(5), f, g)); }
}
(* Monad laws in OCaml *)
let (>>=) = Option.bind
let return x = Some x

(* Left identity: return a >>= f == f a *)
let left_identity a f = (return a >>= f) = f a

(* Right identity: m >>= return == m *)
let right_identity m = (m >>= return) = m

(* Associativity: (m >>= f) >>= g == m >>= (fun x -> f x >>= g) *)
let associativity m f g =
  ((m >>= f) >>= g) = (m >>= fun x -> f x >>= g)

let () =
  let f x = if x > 0 then Some (x*2) else None in
  let g x = if x < 100 then Some (x+1) else None in
  Printf.printf "left_identity(5,f): %b\n"     (left_identity 5 f);
  Printf.printf "right_identity(Some 5): %b\n" (right_identity (Some 5));
  Printf.printf "right_identity(None): %b\n"   (right_identity None);
  Printf.printf "associativity: %b\n"          (associativity (Some 5) f g)