πŸ¦€ Functional Rust

060: State Monad

Difficulty: 3 Level: Advanced Thread mutable state through pure functions without making everything `mut`.

The Problem This Solves

You've written a function that needs to count something β€” lines parsed, IDs assigned, stack depth reached. The natural Rust instinct is `&mut counter`. Fine for one function. But what happens when you have a chain of 10 pure functions, and all of them need to read or update that counter? Option 1: pass `&mut state` to every function. Every signature gains an extra parameter. Every call site needs to thread it through. The deeper the call stack, the worse it gets. A function that logically takes one argument now takes three because it happens to call helpers that need state. Option 2: put it in a global with `Mutex`. Now you've introduced shared mutable state, potential deadlocks, and the function is no longer pure or testable in isolation. Option 3: the functional approach β€” make state explicit in the return type. Every function that needs state returns both its result and the new state: `fn tick(state: i32) -> (i32, i32)`. This is already valid Rust! The State monad is just a clean wrapper around this pattern so you don't have to manually thread the state through every call.
// Without State monad β€” state threads manually through every call:
fn count_three(state: i32) -> ((i32, i32, i32), i32) {
 let a = state;       let state = state + 1;
 let b = state;       let state = state + 1;
 let c = state;       let state = state + 1;
 ((a, b, c), state)
}
// Works! But imagine 15 steps where state is used 5 times each...
// Every intermediate variable is named `state`, rebinding it each time.
// The business logic drowns in plumbing.
The State monad wraps `fn(S) -> (A, S)` in a struct and gives you `and_then` to chain these functions while passing state automatically. This exists to solve exactly that pain.

The Intuition

Imagine every step of your computation is a machine that takes the current state, does something, and outputs both a result and the updated state. You chain these machines together: one machine's output state becomes the next machine's input state.
stateβ‚€ β†’ [Machine A] β†’ (resultA, state₁)
state₁ β†’ [Machine B] β†’ (resultB, stateβ‚‚)
stateβ‚‚ β†’ [Machine C] β†’ (resultC, state₃)
The State monad is just a way to connect these machines without manually writing `state₁`, `stateβ‚‚`, `state₃` every time. You describe the chain, then "run" it by feeding in the initial state at the end. Jargon decoded:

How It Works in Rust

// The core type: a wrapper around a state-threading function
struct State<S, A> {
 run: Box<dyn FnOnce(S) -> (A, S)>,
}

impl<S: 'static, A: 'static> State<S, A> {
 fn new(f: impl FnOnce(S) -> (A, S) + 'static) -> Self {
     State { run: Box::new(f) }
 }

 // Execute the whole chain with an initial state
 fn run(self, s: S) -> (A, S) {
     (self.run)(s)
 }

 // Wrap a plain value β€” state passes through unchanged
 fn pure(a: A) -> Self {
     State::new(move |s| (a, s))
 }

 // Chain: run self, pass result + new state to f, run f
 fn and_then<B: 'static>(self, f: impl FnOnce(A) -> State<S, B> + 'static) -> State<S, B> {
     State::new(move |s| {
         let (a, s2) = self.run(s);   // run first step, get result + updated state
         f(a).run(s2)                  // feed result + state into next step
     })
 }
}
// Primitives for working with state:

fn get<S: Clone + 'static>() -> State<S, S> {
 // Copy the state into the result β€” state unchanged
 State::new(|s: S| (s.clone(), s))
}

fn put<S: 'static>(new_s: S) -> State<S, ()> {
 // Replace state with new_s β€” result is ()
 State::new(move |_| ((), new_s))
}

fn modify<S: 'static>(f: impl FnOnce(S) -> S + 'static) -> State<S, ()> {
 // Apply f to state β€” result is ()
 State::new(move |s| ((), f(s)))
}
// A counter step: read current value, increment state, return old value
fn tick() -> State<i32, i32> {
 get::<i32>().and_then(|n| {
     put(n + 1).map(move |()| n)  // increment state, return old n
 })
}

// Chain three ticks β€” state threads automatically
fn count3() -> State<i32, (i32, i32, i32)> {
 tick().and_then(|a|
     tick().and_then(move |b|
         tick().map(move |c| (a, b, c))))
}

// Run it:
let ((a, b, c), final_state) = count3().run(0);
// a=0, b=1, c=2, final_state=3
// Stack operations β€” state is Vec<i32>
fn push(x: i32) -> State<Vec<i32>, ()> {
 modify(move |mut stack: Vec<i32>| { stack.push(x); stack })
}

fn pop() -> State<Vec<i32>, Option<i32>> {
 State::new(|mut stack: Vec<i32>| {
     let val = stack.pop();
     (val, stack)
 })
}

// Chain push/pop operations:
let ops = push(1)
 .and_then(|()| push(2))
 .and_then(|()| push(3))
 .and_then(|()| pop());

let (top, remaining_stack) = ops.run(vec![]);
// top = Some(3), remaining_stack = [1, 2]
Important: Rust requires `Box<dyn FnOnce>` and `'static` bounds on captured values because the closure must be stored and called later. This is the main ergonomic cost vs. OCaml's approach. The honest truth: For most Rust code, explicit `&mut state` is cleaner. The State monad shines when you're building a library of composable stateful operations that users can sequence however they like β€” like parser combinators or game engine scripting.

What This Unlocks

Key Differences

ConceptOCamlRust
Type`type ('s, 'a) state = State of ('s -> 'a * 's)``struct State<S, A> { run: Box<dyn FnOnce(S) -> (A, S)> }`
Closure storageFirst-class values, no boxing neededRequires `Box<dyn FnOnce>` for heap allocation
Lifetime boundsNo `'static` requirementCaptured values in boxed closures need `'static`
Idiomatic?Yes β€” monadic state threading is naturalRarely β€” Rust usually prefers explicit `&mut state`
PerformanceGC handles allocationEach `and_then` allocates a new `Box` β€” can be expensive in hot paths
When to chooseAlmost always over explicit threadingOnly when composability > performance (parsers, DSLs)
// Example 060: State Monad
// Thread state through computations without explicit passing

// State monad: S -> (A, S)
struct State<S, A> {
    run: Box<dyn FnOnce(S) -> (A, S)>,
}

impl<S: 'static, A: 'static> State<S, A> {
    fn new(f: impl FnOnce(S) -> (A, S) + 'static) -> Self {
        State { run: Box::new(f) }
    }

    fn run(self, s: S) -> (A, S) {
        (self.run)(s)
    }

    fn pure(a: A) -> Self {
        State::new(move |s| (a, s))
    }

    fn and_then<B: 'static>(self, f: impl FnOnce(A) -> State<S, B> + 'static) -> State<S, B> {
        State::new(move |s| {
            let (a, s2) = self.run(s);
            f(a).run(s2)
        })
    }

    fn map<B: 'static>(self, f: impl FnOnce(A) -> B + 'static) -> State<S, B> {
        State::new(move |s| {
            let (a, s2) = self.run(s);
            (f(a), s2)
        })
    }
}

fn get<S: Clone + 'static>() -> State<S, S> {
    State::new(|s: S| (s.clone(), s))
}

fn put<S: 'static>(new_s: S) -> State<S, ()> {
    State::new(move |_| ((), new_s))
}

fn modify<S: 'static>(f: impl FnOnce(S) -> S + 'static) -> State<S, ()> {
    State::new(move |s| ((), f(s)))
}

// Approach 1: Counter
fn tick() -> State<i32, i32> {
    get::<i32>().and_then(|n| put(n + 1).map(move |()| n))
}

fn count3() -> State<i32, (i32, i32, i32)> {
    tick().and_then(|a|
        tick().and_then(move |b|
            tick().map(move |c| (a, b, c))))
}

// Approach 2: Explicit state threading (no State monad β€” idiomatic Rust)
fn count3_explicit(state: i32) -> ((i32, i32, i32), i32) {
    let a = state;
    let state = state + 1;
    let b = state;
    let state = state + 1;
    let c = state;
    let state = state + 1;
    ((a, b, c), state)
}

// Approach 3: Stack operations
fn push(x: i32) -> State<Vec<i32>, ()> {
    modify(move |mut stack: Vec<i32>| { stack.push(x); stack })
}

fn pop() -> State<Vec<i32>, Option<i32>> {
    State::new(|mut stack: Vec<i32>| {
        let val = stack.pop();
        (val, stack)
    })
}

fn main() {
    let ((a, b, c), final_state) = count3().run(0);
    println!("count3: ({}, {}, {}), state={}", a, b, c, final_state);

    let ((a, b, c), state) = count3_explicit(0);
    println!("explicit: ({}, {}, {}), state={}", a, b, c, state);

    let stack_ops = push(1)
        .and_then(|()| push(2))
        .and_then(|()| push(3))
        .and_then(|()| pop())
        .and_then(|a| pop().map(move |b| (a, b)));

    let ((a, b), final_stack) = stack_ops.run(vec![]);
    println!("stack: a={:?}, b={:?}, remaining={:?}", a, b, final_stack);
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_counter() {
        let (result, state) = count3().run(0);
        assert_eq!(result, (0, 1, 2));
        assert_eq!(state, 3);
    }

    #[test]
    fn test_counter_nonzero_start() {
        let (result, state) = count3().run(10);
        assert_eq!(result, (10, 11, 12));
        assert_eq!(state, 13);
    }

    #[test]
    fn test_explicit_same_as_monadic() {
        let (r1, s1) = count3().run(0);
        let (r2, s2) = count3_explicit(0);
        assert_eq!(r1, r2);
        assert_eq!(s1, s2);
    }

    #[test]
    fn test_stack_push_pop() {
        let ops = push(10)
            .and_then(|()| push(20))
            .and_then(|()| pop());
        let (val, stack) = ops.run(vec![]);
        assert_eq!(val, Some(20));
        assert_eq!(stack, vec![10]);
    }

    #[test]
    fn test_stack_pop_empty() {
        let (val, stack) = pop().run(vec![]);
        assert_eq!(val, None);
        assert_eq!(stack, Vec::<i32>::new());
    }

    #[test]
    fn test_pure() {
        let (val, state) = State::<i32, _>::pure(42).run(0);
        assert_eq!(val, 42);
        assert_eq!(state, 0);
    }
}
(* Example 060: State Monad *)
(* Thread state through computations without explicit passing *)

(* State monad: 'a state = State of (state -> 'a * state) *)
type ('s, 'a) state = State of ('s -> 'a * 's)

let run_state (State f) s = f s
let return_ x = State (fun s -> (x, s))
let bind m f = State (fun s ->
  let (a, s') = run_state m s in
  run_state (f a) s')
let ( >>= ) = bind

let get = State (fun s -> (s, s))
let put s = State (fun _ -> ((), s))
let modify f = State (fun s -> ((), f s))

(* Approach 1: Counter *)
let tick = get >>= fun n -> put (n + 1) >>= fun () -> return_ n

let count3 =
  tick >>= fun a ->
  tick >>= fun b ->
  tick >>= fun c ->
  return_ (a, b, c)

(* Approach 2: Stack operations *)
let push x = modify (fun stack -> x :: stack)
let pop = get >>= fun stack ->
  match stack with
  | [] -> return_ None
  | x :: rest -> put rest >>= fun () -> return_ (Some x)

let stack_ops =
  push 1 >>= fun () ->
  push 2 >>= fun () ->
  push 3 >>= fun () ->
  pop >>= fun a ->
  pop >>= fun b ->
  return_ (a, b)

(* Approach 3: Label generator *)
let fresh_label prefix =
  get >>= fun n ->
  put (n + 1) >>= fun () ->
  return_ (Printf.sprintf "%s_%d" prefix n)

let gen_labels =
  fresh_label "var" >>= fun a ->
  fresh_label "tmp" >>= fun b ->
  fresh_label "var" >>= fun c ->
  return_ [a; b; c]

let () =
  let (result, final_state) = run_state count3 0 in
  assert (result = (0, 1, 2));
  assert (final_state = 3);

  let ((a, b), final_stack) = run_state stack_ops [] in
  assert (a = Some 3);
  assert (b = Some 2);
  assert (final_stack = [1]);

  let (labels, _) = run_state gen_labels 0 in
  assert (labels = ["var_0"; "tmp_1"; "var_2"]);

  Printf.printf "βœ“ All tests passed\n"

πŸ“Š Detailed Comparison

Comparison: State Monad

State Type

OCaml:

πŸͺ Show OCaml equivalent
type ('s, 'a) state = State of ('s -> 'a * 's)
let run_state (State f) s = f s

Rust:

struct State<S, A> {
 run: Box<dyn FnOnce(S) -> (A, S)>,
}

Bind / and_then

OCaml:

πŸͺ Show OCaml equivalent
let bind m f = State (fun s ->
let (a, s') = run_state m s in
run_state (f a) s')

Rust:

fn and_then<B>(self, f: impl FnOnce(A) -> State<S, B> + 'static) -> State<S, B> {
 State::new(move |s| {
     let (a, s2) = self.run(s);
     f(a).run(s2)
 })
}

Counter Example

OCaml:

πŸͺ Show OCaml equivalent
let tick = get >>= fun n -> put (n + 1) >>= fun () -> return_ n
let (result, _) = run_state (tick >>= fun a -> tick >>= fun b -> return_ (a, b)) 0
(* result = (0, 1) *)

Rust (idiomatic β€” no monad needed):

fn count3_explicit(state: i32) -> ((i32, i32, i32), i32) {
 let a = state; let state = state + 1;
 let b = state; let state = state + 1;
 let c = state; let state = state + 1;
 ((a, b, c), state)
}