State Monad
Tutorial
The Problem
Threading state through a sequence of functions without the State monad requires passing the state explicitly as an argument and returning it alongside the result: fn step(input: T, state: S) -> (R, S). This is error-prone and noisy. The State monad encapsulates this threading: State<S, A> represents a computation S -> (A, S) that reads and modifies state. Computations are composed without explicit state passing — the monad handles threading. This pattern appears in: compiler passes (threading symbol tables), game state machines, configuration accumulation, and embedded DSLs. It makes stateful computation composable and testable while remaining purely functional.
🎯 Learning Outcomes
State<S, A> as a wrapper around FnOnce(S) -> (A, S)get() returning current state, put(s) replacing state, modify(f) transforming statestate.then(|a| next_state) threading state through both computationsrun_state(initial) to execute the computation and get (result, final_state)FnOnce vs Fn based on state mutation needsCode Example
struct State<S, A> {
run: Box<dyn FnOnce(S) -> (A, S)>,
}Key Differences
| Aspect | Rust | OCaml |
|---|---|---|
| Type | Box<dyn FnOnce(S) -> (A, S)> | State of ('s -> 'a * 's) |
FnOnce vs Fn | Must choose based on use | fun s -> ... (always Fn) |
| Bind implementation | Complex with boxing | Clean algebraic unwrap |
get | Requires S: Clone | Same (returns clone in pure) |
| Thread safety | Send + Sync bounds needed | Not applicable (single-threaded) |
'static bound | Required for boxed closures | Not required |
OCaml Approach
OCaml represents State as type ('s, 'a) state = State of ('s -> 'a * 's). The run_state (State f) s = f s. Monadic bind: let bind (State f) k = State (fun s -> let (a, s') = f s in let State g = k a in g s'). get = State (fun s -> (s, s)), put s = State (fun _ -> ((), s)). OCaml's algebraic types make the State monad clean and readable. The ppx_let extension provides let%bind syntax for threading state.
Full Source
//! State monad in Rust — thread state through computations without explicit passing.
//!
//! The [`State`] type wraps a function `S -> (A, S)`. Combinators like
//! [`State::bind`], [`get`], [`put`], and [`modify`] let you compose stateful
//! computations, then [`run_state`] executes the pipeline against an initial state.
/// A stateful computation that, given an input state of type `S`, produces a
/// value of type `A` together with the next state.
///
/// Wraps a boxed closure so that arbitrary captures are allowed.
pub struct State<S, A> {
run: Box<dyn FnOnce(S) -> (A, S)>,
}
impl<S: 'static, A: 'static> State<S, A> {
/// Builds a `State` computation from a closure.
pub fn new<F>(f: F) -> Self
where
F: FnOnce(S) -> (A, S) + 'static,
{
State { run: Box::new(f) }
}
/// Lifts a pure value into the state monad without touching the state.
pub fn pure(a: A) -> Self {
State::new(move |s| (a, s))
}
/// Monadic bind: runs `self`, feeds its result into `f`, and continues with
/// the returned computation using the threaded state.
pub fn bind<B, F>(self, f: F) -> State<S, B>
where
B: 'static,
F: FnOnce(A) -> State<S, B> + 'static,
{
State::new(move |s| {
let (a, s1) = (self.run)(s);
(f(a).run)(s1)
})
}
/// Maps a pure function over the result, leaving state threading intact.
pub fn map<B, F>(self, f: F) -> State<S, B>
where
B: 'static,
F: FnOnce(A) -> B + 'static,
{
self.bind(move |a| State::pure(f(a)))
}
}
/// Executes a stateful computation starting from `s`, returning the final
/// value and the final state.
pub fn run_state<S: 'static, A: 'static>(m: State<S, A>, s: S) -> (A, S) {
(m.run)(s)
}
/// Reads the current state as the produced value; leaves the state unchanged.
pub fn get<S: Clone + 'static>() -> State<S, S> {
State::new(|s: S| (s.clone(), s))
}
/// Replaces the current state with `s`; produces the unit value.
pub fn put<S: 'static>(s: S) -> State<S, ()> {
State::new(move |_| ((), s))
}
/// Applies `f` to the current state, storing the result as the new state.
pub fn modify<S: 'static, F>(f: F) -> State<S, ()>
where
F: FnOnce(S) -> S + 'static,
{
State::new(move |s| ((), f(s)))
}
/// Increments a `u32` counter and returns the pre-increment value.
pub fn tick() -> State<u32, u32> {
get::<u32>().bind(|n| put(n + 1).bind(move |()| State::pure(n)))
}
/// Runs `tick` three times, collecting the three observed counter values.
pub fn count3() -> State<u32, (u32, u32, u32)> {
tick().bind(|a| {
tick().bind(move |b| tick().bind(move |c| State::pure((a, b, c))))
})
}
/// Pushes `x` onto the stack held in the state.
pub fn push<T: 'static>(x: T) -> State<Vec<T>, ()> {
modify(move |mut stack: Vec<T>| {
stack.push(x);
stack
})
}
/// Pops the top element from the stack held in the state, or `None` if empty.
pub fn pop<T: 'static>() -> State<Vec<T>, Option<T>> {
State::new(|mut stack: Vec<T>| {
let top = stack.pop();
(top, stack)
})
}
/// Example stack program: pushes 1, 2, 3, then pops twice.
pub fn stack_ops() -> State<Vec<i32>, (Option<i32>, Option<i32>)> {
push(1).bind(|()| {
push(2).bind(|()| {
push(3).bind(|()| {
pop::<i32>().bind(|a| pop::<i32>().bind(move |b| State::pure((a, b))))
})
})
})
}
/// Generates a fresh label of the form `"{prefix}_{n}"` and increments the
/// counter.
pub fn fresh_label(prefix: &'static str) -> State<u32, String> {
get::<u32>().bind(move |n| put(n + 1).bind(move |()| State::pure(format!("{}_{}", prefix, n))))
}
/// Generates three labels in sequence using a shared counter.
pub fn gen_labels() -> State<u32, Vec<String>> {
fresh_label("var").bind(|a| {
fresh_label("tmp").bind(move |b| {
fresh_label("var").bind(move |c| State::pure(vec![a, b, c]))
})
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pure_does_not_touch_state() {
let (v, s) = run_state(State::pure(42), 7u32);
assert_eq!(v, 42);
assert_eq!(s, 7);
}
#[test]
fn get_observes_current_state() {
let (v, s) = run_state(get::<u32>(), 99);
assert_eq!(v, 99);
assert_eq!(s, 99);
}
#[test]
fn put_replaces_state() {
let (_, s) = run_state(put(5u32), 0);
assert_eq!(s, 5);
}
#[test]
fn modify_applies_function() {
let (_, s) = run_state(modify(|n: u32| n * 2), 21);
assert_eq!(s, 42);
}
#[test]
fn map_transforms_result() {
let prog = get::<u32>().map(|n| n + 1);
let (v, s) = run_state(prog, 10);
assert_eq!(v, 11);
assert_eq!(s, 10);
}
#[test]
fn tick_returns_pre_increment_value() {
let (v, s) = run_state(tick(), 0);
assert_eq!(v, 0);
assert_eq!(s, 1);
}
#[test]
fn count3_threads_counter() {
let (result, final_state) = run_state(count3(), 0);
assert_eq!(result, (0, 1, 2));
assert_eq!(final_state, 3);
}
#[test]
fn stack_ops_push_then_pop() {
let ((a, b), final_stack) = run_state(stack_ops(), Vec::<i32>::new());
assert_eq!(a, Some(3));
assert_eq!(b, Some(2));
assert_eq!(final_stack, vec![1]);
}
#[test]
fn pop_empty_yields_none() {
let (v, s) = run_state(pop::<i32>(), Vec::<i32>::new());
assert_eq!(v, None);
assert_eq!(s, Vec::<i32>::new());
}
#[test]
fn gen_labels_uses_shared_counter() {
let (labels, final_state) = run_state(gen_labels(), 0);
assert_eq!(
labels,
vec![
"var_0".to_string(),
"tmp_1".to_string(),
"var_2".to_string()
]
);
assert_eq!(final_state, 3);
}
#[test]
fn bind_left_identity() {
let f = |n: u32| State::pure(n + 1);
let (v, s) = run_state(State::pure(10u32).bind(f), 0u32);
assert_eq!(v, 11);
assert_eq!(s, 0);
}
#[test]
fn bind_right_identity() {
let prog = get::<u32>().bind(State::pure);
let (v, s) = run_state(prog, 7);
assert_eq!(v, 7);
assert_eq!(s, 7);
}
}#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pure_does_not_touch_state() {
let (v, s) = run_state(State::pure(42), 7u32);
assert_eq!(v, 42);
assert_eq!(s, 7);
}
#[test]
fn get_observes_current_state() {
let (v, s) = run_state(get::<u32>(), 99);
assert_eq!(v, 99);
assert_eq!(s, 99);
}
#[test]
fn put_replaces_state() {
let (_, s) = run_state(put(5u32), 0);
assert_eq!(s, 5);
}
#[test]
fn modify_applies_function() {
let (_, s) = run_state(modify(|n: u32| n * 2), 21);
assert_eq!(s, 42);
}
#[test]
fn map_transforms_result() {
let prog = get::<u32>().map(|n| n + 1);
let (v, s) = run_state(prog, 10);
assert_eq!(v, 11);
assert_eq!(s, 10);
}
#[test]
fn tick_returns_pre_increment_value() {
let (v, s) = run_state(tick(), 0);
assert_eq!(v, 0);
assert_eq!(s, 1);
}
#[test]
fn count3_threads_counter() {
let (result, final_state) = run_state(count3(), 0);
assert_eq!(result, (0, 1, 2));
assert_eq!(final_state, 3);
}
#[test]
fn stack_ops_push_then_pop() {
let ((a, b), final_stack) = run_state(stack_ops(), Vec::<i32>::new());
assert_eq!(a, Some(3));
assert_eq!(b, Some(2));
assert_eq!(final_stack, vec![1]);
}
#[test]
fn pop_empty_yields_none() {
let (v, s) = run_state(pop::<i32>(), Vec::<i32>::new());
assert_eq!(v, None);
assert_eq!(s, Vec::<i32>::new());
}
#[test]
fn gen_labels_uses_shared_counter() {
let (labels, final_state) = run_state(gen_labels(), 0);
assert_eq!(
labels,
vec![
"var_0".to_string(),
"tmp_1".to_string(),
"var_2".to_string()
]
);
assert_eq!(final_state, 3);
}
#[test]
fn bind_left_identity() {
let f = |n: u32| State::pure(n + 1);
let (v, s) = run_state(State::pure(10u32).bind(f), 0u32);
assert_eq!(v, 11);
assert_eq!(s, 0);
}
#[test]
fn bind_right_identity() {
let prog = get::<u32>().bind(State::pure);
let (v, s) = run_state(prog, 7);
assert_eq!(v, 7);
assert_eq!(s, 7);
}
}
Deep Comparison
Comparison: State Monad
State Type
OCaml:
type ('s, 'a) state = State of ('s -> 'a * 's)
let run_state (State f) s = f s
Rust:
struct State<S, A> {
run: Box<dyn FnOnce(S) -> (A, S)>,
}
Bind / and_then
OCaml:
let bind m f = State (fun s ->
let (a, s') = run_state m s in
run_state (f a) s')
Rust:
fn and_then<B>(self, f: impl FnOnce(A) -> State<S, B> + 'static) -> State<S, B> {
State::new(move |s| {
let (a, s2) = self.run(s);
f(a).run(s2)
})
}
Counter Example
OCaml:
let tick = get >>= fun n -> put (n + 1) >>= fun () -> return_ n
let (result, _) = run_state (tick >>= fun a -> tick >>= fun b -> return_ (a, b)) 0
(* result = (0, 1) *)
Rust (idiomatic — no monad needed):
fn count3_explicit(state: i32) -> ((i32, i32, i32), i32) {
let a = state; let state = state + 1;
let b = state; let state = state + 1;
let c = state; let state = state + 1;
((a, b, c), state)
}
Exercises
State<S, A> and use it to compose get, a transform, and put into a single computation.push and pop operations as State<Vec<T>, Option<T>> computations.modify(f: S -> S) -> State<S, ()> using get and put and verify it equals State::new(|s| ((), f(s))).