060: State Monad
Difficulty: 3 Level: Advanced Thread mutable state through pure functions without making everything `mut`.The Problem This Solves
You've written a function that needs to count something β lines parsed, IDs assigned, stack depth reached. The natural Rust instinct is `&mut counter`. Fine for one function. But what happens when you have a chain of 10 pure functions, and all of them need to read or update that counter? Option 1: pass `&mut state` to every function. Every signature gains an extra parameter. Every call site needs to thread it through. The deeper the call stack, the worse it gets. A function that logically takes one argument now takes three because it happens to call helpers that need state. Option 2: put it in a global with `Mutex`. Now you've introduced shared mutable state, potential deadlocks, and the function is no longer pure or testable in isolation. Option 3: the functional approach β make state explicit in the return type. Every function that needs state returns both its result and the new state: `fn tick(state: i32) -> (i32, i32)`. This is already valid Rust! The State monad is just a clean wrapper around this pattern so you don't have to manually thread the state through every call.// Without State monad β state threads manually through every call:
fn count_three(state: i32) -> ((i32, i32, i32), i32) {
let a = state; let state = state + 1;
let b = state; let state = state + 1;
let c = state; let state = state + 1;
((a, b, c), state)
}
// Works! But imagine 15 steps where state is used 5 times each...
// Every intermediate variable is named `state`, rebinding it each time.
// The business logic drowns in plumbing.
The State monad wraps `fn(S) -> (A, S)` in a struct and gives you `and_then` to chain these functions while passing state automatically. This exists to solve exactly that pain.
The Intuition
Imagine every step of your computation is a machine that takes the current state, does something, and outputs both a result and the updated state. You chain these machines together: one machine's output state becomes the next machine's input state.stateβ β [Machine A] β (resultA, stateβ)
stateβ β [Machine B] β (resultB, stateβ)
stateβ β [Machine C] β (resultC, stateβ)
The State monad is just a way to connect these machines without manually writing `stateβ`, `stateβ`, `stateβ` every time. You describe the chain, then "run" it by feeding in the initial state at the end.
Jargon decoded:
- State monad β a wrapper around functions of type `S -> (A, S)` with a `bind` that threads state automatically
- `get` β returns the current state as the result (state is unchanged)
- `put(s)` β replaces the current state with `s` (result is `()`)
- `modify(f)` β applies `f` to the current state (like `get` + `put`)
- `run(initial_state)` β actually execute the whole chain starting with `initial_state`
How It Works in Rust
// The core type: a wrapper around a state-threading function
struct State<S, A> {
run: Box<dyn FnOnce(S) -> (A, S)>,
}
impl<S: 'static, A: 'static> State<S, A> {
fn new(f: impl FnOnce(S) -> (A, S) + 'static) -> Self {
State { run: Box::new(f) }
}
// Execute the whole chain with an initial state
fn run(self, s: S) -> (A, S) {
(self.run)(s)
}
// Wrap a plain value β state passes through unchanged
fn pure(a: A) -> Self {
State::new(move |s| (a, s))
}
// Chain: run self, pass result + new state to f, run f
fn and_then<B: 'static>(self, f: impl FnOnce(A) -> State<S, B> + 'static) -> State<S, B> {
State::new(move |s| {
let (a, s2) = self.run(s); // run first step, get result + updated state
f(a).run(s2) // feed result + state into next step
})
}
}
// Primitives for working with state:
fn get<S: Clone + 'static>() -> State<S, S> {
// Copy the state into the result β state unchanged
State::new(|s: S| (s.clone(), s))
}
fn put<S: 'static>(new_s: S) -> State<S, ()> {
// Replace state with new_s β result is ()
State::new(move |_| ((), new_s))
}
fn modify<S: 'static>(f: impl FnOnce(S) -> S + 'static) -> State<S, ()> {
// Apply f to state β result is ()
State::new(move |s| ((), f(s)))
}
// A counter step: read current value, increment state, return old value
fn tick() -> State<i32, i32> {
get::<i32>().and_then(|n| {
put(n + 1).map(move |()| n) // increment state, return old n
})
}
// Chain three ticks β state threads automatically
fn count3() -> State<i32, (i32, i32, i32)> {
tick().and_then(|a|
tick().and_then(move |b|
tick().map(move |c| (a, b, c))))
}
// Run it:
let ((a, b, c), final_state) = count3().run(0);
// a=0, b=1, c=2, final_state=3
// Stack operations β state is Vec<i32>
fn push(x: i32) -> State<Vec<i32>, ()> {
modify(move |mut stack: Vec<i32>| { stack.push(x); stack })
}
fn pop() -> State<Vec<i32>, Option<i32>> {
State::new(|mut stack: Vec<i32>| {
let val = stack.pop();
(val, stack)
})
}
// Chain push/pop operations:
let ops = push(1)
.and_then(|()| push(2))
.and_then(|()| push(3))
.and_then(|()| pop());
let (top, remaining_stack) = ops.run(vec![]);
// top = Some(3), remaining_stack = [1, 2]
Important: Rust requires `Box<dyn FnOnce>` and `'static` bounds on captured values because the closure must be stored and called later. This is the main ergonomic cost vs. OCaml's approach.
The honest truth: For most Rust code, explicit `&mut state` is cleaner. The State monad shines when you're building a library of composable stateful operations that users can sequence however they like β like parser combinators or game engine scripting.
What This Unlocks
- Incremental ID generation: A stateful counter wrapped in State can be composed into any sequence of operations that need unique IDs without passing `&mut next_id` everywhere.
- Parser combinators: A parser is naturally `State<&str, ParseResult>` β each parser step consumes some input (mutates state) and returns a result. State monad lets you compose parsers without manually tracking the remaining input.
- Game scripting: A sequence of game actions (move, attack, pick up item) that all read/update game state can be composed as `State<GameState, ()>` operations, enabling AI behavior trees to be written as pure data.
Key Differences
| Concept | OCaml | Rust |
|---|---|---|
| Type | `type ('s, 'a) state = State of ('s -> 'a * 's)` | `struct State<S, A> { run: Box<dyn FnOnce(S) -> (A, S)> }` |
| Closure storage | First-class values, no boxing needed | Requires `Box<dyn FnOnce>` for heap allocation |
| Lifetime bounds | No `'static` requirement | Captured values in boxed closures need `'static` |
| Idiomatic? | Yes β monadic state threading is natural | Rarely β Rust usually prefers explicit `&mut state` |
| Performance | GC handles allocation | Each `and_then` allocates a new `Box` β can be expensive in hot paths |
| When to choose | Almost always over explicit threading | Only when composability > performance (parsers, DSLs) |
// Example 060: State Monad
// Thread state through computations without explicit passing
// State monad: S -> (A, S)
struct State<S, A> {
run: Box<dyn FnOnce(S) -> (A, S)>,
}
impl<S: 'static, A: 'static> State<S, A> {
fn new(f: impl FnOnce(S) -> (A, S) + 'static) -> Self {
State { run: Box::new(f) }
}
fn run(self, s: S) -> (A, S) {
(self.run)(s)
}
fn pure(a: A) -> Self {
State::new(move |s| (a, s))
}
fn and_then<B: 'static>(self, f: impl FnOnce(A) -> State<S, B> + 'static) -> State<S, B> {
State::new(move |s| {
let (a, s2) = self.run(s);
f(a).run(s2)
})
}
fn map<B: 'static>(self, f: impl FnOnce(A) -> B + 'static) -> State<S, B> {
State::new(move |s| {
let (a, s2) = self.run(s);
(f(a), s2)
})
}
}
fn get<S: Clone + 'static>() -> State<S, S> {
State::new(|s: S| (s.clone(), s))
}
fn put<S: 'static>(new_s: S) -> State<S, ()> {
State::new(move |_| ((), new_s))
}
fn modify<S: 'static>(f: impl FnOnce(S) -> S + 'static) -> State<S, ()> {
State::new(move |s| ((), f(s)))
}
// Approach 1: Counter
fn tick() -> State<i32, i32> {
get::<i32>().and_then(|n| put(n + 1).map(move |()| n))
}
fn count3() -> State<i32, (i32, i32, i32)> {
tick().and_then(|a|
tick().and_then(move |b|
tick().map(move |c| (a, b, c))))
}
// Approach 2: Explicit state threading (no State monad β idiomatic Rust)
fn count3_explicit(state: i32) -> ((i32, i32, i32), i32) {
let a = state;
let state = state + 1;
let b = state;
let state = state + 1;
let c = state;
let state = state + 1;
((a, b, c), state)
}
// Approach 3: Stack operations
fn push(x: i32) -> State<Vec<i32>, ()> {
modify(move |mut stack: Vec<i32>| { stack.push(x); stack })
}
fn pop() -> State<Vec<i32>, Option<i32>> {
State::new(|mut stack: Vec<i32>| {
let val = stack.pop();
(val, stack)
})
}
fn main() {
let ((a, b, c), final_state) = count3().run(0);
println!("count3: ({}, {}, {}), state={}", a, b, c, final_state);
let ((a, b, c), state) = count3_explicit(0);
println!("explicit: ({}, {}, {}), state={}", a, b, c, state);
let stack_ops = push(1)
.and_then(|()| push(2))
.and_then(|()| push(3))
.and_then(|()| pop())
.and_then(|a| pop().map(move |b| (a, b)));
let ((a, b), final_stack) = stack_ops.run(vec![]);
println!("stack: a={:?}, b={:?}, remaining={:?}", a, b, final_stack);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_counter() {
let (result, state) = count3().run(0);
assert_eq!(result, (0, 1, 2));
assert_eq!(state, 3);
}
#[test]
fn test_counter_nonzero_start() {
let (result, state) = count3().run(10);
assert_eq!(result, (10, 11, 12));
assert_eq!(state, 13);
}
#[test]
fn test_explicit_same_as_monadic() {
let (r1, s1) = count3().run(0);
let (r2, s2) = count3_explicit(0);
assert_eq!(r1, r2);
assert_eq!(s1, s2);
}
#[test]
fn test_stack_push_pop() {
let ops = push(10)
.and_then(|()| push(20))
.and_then(|()| pop());
let (val, stack) = ops.run(vec![]);
assert_eq!(val, Some(20));
assert_eq!(stack, vec![10]);
}
#[test]
fn test_stack_pop_empty() {
let (val, stack) = pop().run(vec![]);
assert_eq!(val, None);
assert_eq!(stack, Vec::<i32>::new());
}
#[test]
fn test_pure() {
let (val, state) = State::<i32, _>::pure(42).run(0);
assert_eq!(val, 42);
assert_eq!(state, 0);
}
}
(* Example 060: State Monad *)
(* Thread state through computations without explicit passing *)
(* State monad: 'a state = State of (state -> 'a * state) *)
type ('s, 'a) state = State of ('s -> 'a * 's)
let run_state (State f) s = f s
let return_ x = State (fun s -> (x, s))
let bind m f = State (fun s ->
let (a, s') = run_state m s in
run_state (f a) s')
let ( >>= ) = bind
let get = State (fun s -> (s, s))
let put s = State (fun _ -> ((), s))
let modify f = State (fun s -> ((), f s))
(* Approach 1: Counter *)
let tick = get >>= fun n -> put (n + 1) >>= fun () -> return_ n
let count3 =
tick >>= fun a ->
tick >>= fun b ->
tick >>= fun c ->
return_ (a, b, c)
(* Approach 2: Stack operations *)
let push x = modify (fun stack -> x :: stack)
let pop = get >>= fun stack ->
match stack with
| [] -> return_ None
| x :: rest -> put rest >>= fun () -> return_ (Some x)
let stack_ops =
push 1 >>= fun () ->
push 2 >>= fun () ->
push 3 >>= fun () ->
pop >>= fun a ->
pop >>= fun b ->
return_ (a, b)
(* Approach 3: Label generator *)
let fresh_label prefix =
get >>= fun n ->
put (n + 1) >>= fun () ->
return_ (Printf.sprintf "%s_%d" prefix n)
let gen_labels =
fresh_label "var" >>= fun a ->
fresh_label "tmp" >>= fun b ->
fresh_label "var" >>= fun c ->
return_ [a; b; c]
let () =
let (result, final_state) = run_state count3 0 in
assert (result = (0, 1, 2));
assert (final_state = 3);
let ((a, b), final_stack) = run_state stack_ops [] in
assert (a = Some 3);
assert (b = Some 2);
assert (final_stack = [1]);
let (labels, _) = run_state gen_labels 0 in
assert (labels = ["var_0"; "tmp_1"; "var_2"]);
Printf.printf "β All tests passed\n"
π Detailed Comparison
Comparison: State Monad
State Type
OCaml:
πͺ Show OCaml equivalent
type ('s, 'a) state = State of ('s -> 'a * 's)
let run_state (State f) s = f s
Rust:
struct State<S, A> {
run: Box<dyn FnOnce(S) -> (A, S)>,
}Bind / and_then
OCaml:
πͺ Show OCaml equivalent
let bind m f = State (fun s ->
let (a, s') = run_state m s in
run_state (f a) s')
Rust:
fn and_then<B>(self, f: impl FnOnce(A) -> State<S, B> + 'static) -> State<S, B> {
State::new(move |s| {
let (a, s2) = self.run(s);
f(a).run(s2)
})
}Counter Example
OCaml:
πͺ Show OCaml equivalent
let tick = get >>= fun n -> put (n + 1) >>= fun () -> return_ n
let (result, _) = run_state (tick >>= fun a -> tick >>= fun b -> return_ (a, b)) 0
(* result = (0, 1) *)
Rust (idiomatic β no monad needed):
fn count3_explicit(state: i32) -> ((i32, i32, i32), i32) {
let a = state; let state = state + 1;
let b = state; let state = state + 1;
let c = state; let state = state + 1;
((a, b, c), state)
}