187: Free Monad with State
Difficulty: โญโญโญโญโญ Level: Expert Encode mutable state as a free algebra โ describe `Get` and `Put` operations as data, then thread state through a pure interpreter with no actual mutation.The Problem This Solves
Mutable state is hard to test, hard to reason about, and hard to compose. If a function reads and writes a counter, testing it requires setting up real mutable state, running the function, and inspecting side effects. You can't easily replay it, record it, or run it in a different order. The state is implicit โ woven into the execution itself rather than being an explicit part of the computation's description. The free monad approach makes state explicit: instead of actually mutating anything, your program describes what it wants to do. `get_state()` doesn't read state โ it returns a data structure that says "I want to read state here." `put_state(42)` doesn't write state โ it returns a data structure that says "I want to write 42 here." The actual state threading happens in a separate `run_state` interpreter that walks the description and carries the state value as a function argument. This separation means you can inspect the program structure before running it, swap out the interpreter (thread state through a log, run it in reverse, replay specific steps), and test programs by checking the description rather than observing side effects.The Intuition
Imagine a chef who writes a recipe instead of cooking. The recipe says "read current stock of flour" and "update flour stock to X - 200g". The recipe is just text โ it doesn't touch the flour. A kitchen manager takes the recipe and runs it: for each step, the manager carries the current pantry state, reads or updates it as instructed, and passes the new state to the next step. The chef and the manager are fully separated. The free monad is the recipe format. `Free::Free(StateF::Get(...))` is a recipe step that says "give me the current state." `Free::Free(StateF::Put(new_s, next))` says "set state to `new_s`, then continue." The recipe is a tree of these instructions. `run_state` is the manager โ it carries state through the tree without the recipe knowing anything about how.How It Works in Rust
// The "instruction" type โ describes what the program wants
enum StateF<S, A> {
Get(Box<dyn FnOnce(S) -> A>), // "give me S, I'll produce A"
Put(S, A), // "set state to S, then produce A"
}
// Free monad: either a plain value, or an instruction with a continuation
enum Free<S, A> {
Pure(A), // computation complete, here's the value
Free(Box<StateF<S, Free<S, A>>>), // one instruction, rest of computation in continuation
}
// Smart constructor: "I want to read state"
fn get_state<S: Clone + 'static>() -> Free<S, S> {
Free::Free(Box::new(StateF::Get(Box::new(|s| Free::Pure(s)))))
// ^^ continuation: receive S, wrap it
}
// Smart constructor: "I want to write state"
fn put_state<S: 'static>(s: S) -> Free<S, ()> {
Free::Free(Box::new(StateF::Put(s, Free::Pure(()))))
}
// Monadic bind: chain two computations
fn bind<S, A, B, F>(m: Free<S, A>, f: F) -> Free<S, B>
where F: FnOnce(A) -> Free<S, B> + 'static, ... {
match m {
Free::Pure(x) => f(x), // computation done, apply f to its result
Free::Free(instr) => match *instr {
// Push f down into the continuation โ build up a larger description
StateF::Get(cont) => Free::Free(Box::new(StateF::Get(
Box::new(move |s| bind(cont(s), f))
))),
StateF::Put(s, next) => Free::Free(Box::new(StateF::Put(s, bind(next, f)))),
},
}
}
// Interpreter: walk the description, thread real state
fn run_state<S: Clone, A>(init: S, program: Free<S, A>) -> (A, S) {
let mut state = init;
let mut current = program;
loop {
match current {
Free::Pure(x) => return (x, state), // done
Free::Free(instr) => match *instr {
StateF::Get(cont) => current = cont(state.clone()), // pass current state to cont
StateF::Put(new_s, next) => { state = new_s; current = next; }
},
}
}
}
// Example program โ pure description, no actual mutation:
let program = bind(get_state::<i32>(), |n|
bind(put_state(n + 1), |_|
get_state::<i32>()
)
);
let (result, final_state) = run_state(0, program);
// result = 1, final_state = 1
What This Unlocks
- Testable stateful logic โ inspect or replay the free monad tree without running the interpreter; assert on the structure of the computation, not its effects.
- Alternative interpreters โ `run_state` threads real state; a logging interpreter could record every Get/Put; a dry-run interpreter could count state mutations without applying them.
- Pure functional state โ enables reasoning with equational laws: `get >> put s >> get โก put s >> return s`. Valid because the description is data, not execution.
Key Differences
| Concept | OCaml | Rust | ||
|---|---|---|---|---|
| GADT instructions | `type 's state_f = Get : ('s -> 'a) -> 'a state_f` โ return type varies per constructor | `enum StateF<S, A>` โ monomorphic; single A type per instance | ||
| Bind | Recursive, naturally polymorphic with locally abstract types | Requires explicit `'static` bounds due to `Box<dyn FnOnce>` | ||
| Interpreter | `let rec go s = function \ | Pure x -> (x,s) \ | ...` โ idiomatic recursion | Iterative loop with mutable `state` and `current` to avoid stack overflow |
| Ergonomics | Cleaner with `let*` syntax (monadic sugar) | Deeply nested `bind` calls; consider macros for real use | ||
| Applicability | First-class pattern in OCaml FP ecosystems | Less common in Rust; prefer `async/await` or `RefCell` for most cases; free monads are for DSL design |
// Free Monad with State โ State as a Free Algebra
//
// Encode mutable state purely using a free monad.
// No mutation โ the state is threaded through the interpreter.
// Get and Put are instructions; the interpreter threads the state value.
// โโ DSL: state instructions โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
enum StateF<S, A> {
Get(Box<dyn FnOnce(S) -> A>),
Put(S, A),
}
// โโ Free monad over StateF โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
enum Free<S, A> {
Pure(A),
Free(Box<StateF<S, Free<S, A>>>),
}
// โโ Smart constructors โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
fn get_state<S: Clone + 'static>() -> Free<S, S> {
Free::Free(Box::new(StateF::Get(Box::new(|s| Free::Pure(s)))))
}
fn put_state<S: 'static>(s: S) -> Free<S, ()> {
Free::Free(Box::new(StateF::Put(s, Free::Pure(()))))
}
// โโ Monadic bind โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
fn bind<S, A, B, F>(m: Free<S, A>, f: F) -> Free<S, B>
where
S: 'static,
A: 'static,
B: 'static,
F: FnOnce(A) -> Free<S, B> + 'static,
{
match m {
Free::Pure(x) => f(x),
Free::Free(instr) => match *instr {
StateF::Get(cont) => {
Free::Free(Box::new(StateF::Get(Box::new(move |s| bind(cont(s), f)))))
}
StateF::Put(s, next) => {
Free::Free(Box::new(StateF::Put(s, bind(next, f))))
}
},
}
}
// โโ Interpreter: thread state through the computation โโโโโโโโโโโโโโโโโโโโโ
fn run_state<S: Clone, A>(init: S, program: Free<S, A>) -> (A, S) {
let mut state = init;
let mut current = program;
loop {
match current {
Free::Pure(x) => return (x, state),
Free::Free(instr) => match *instr {
StateF::Get(cont) => {
current = cont(state.clone());
}
StateF::Put(new_state, next) => {
state = new_state;
current = next;
}
},
}
}
}
// โโ Example program: increment counter 3 times โโโโโโโโโโโโโโโโโโโโโโโโโโโ
fn increment_three_times() -> Free<i32, i32> {
bind(get_state::<i32>(), |n| {
bind(put_state(n + 1), |_| {
bind(get_state::<i32>(), |n| {
bind(put_state(n + 1), |_| {
bind(get_state::<i32>(), |n| {
bind(put_state(n + 1), |_| get_state::<i32>())
})
})
})
})
})
}
// โโ Example: accumulate into state โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
fn accumulate(values: Vec<i32>) -> Free<i32, i32> {
if values.is_empty() {
get_state::<i32>()
} else {
let head = values[0];
let tail = values[1..].to_vec();
bind(get_state::<i32>(), move |n| {
bind(put_state(n + head), |_| accumulate(tail))
})
}
}
fn main() {
// โโ Increment 3 times starting from 0 โโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
let program = increment_three_times();
let (result, final_state) = run_state(0, program);
println!("increment 3x: result={} final_state={}", result, final_state);
assert_eq!(result, 3);
assert_eq!(final_state, 3);
// โโ Start from non-zero initial state โโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
let program2 = increment_three_times();
let (result2, final2) = run_state(10, program2);
println!("start=10, increment 3x: result={} final={}", result2, final2);
// โโ Accumulate a list โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
let acc_prog = accumulate(vec![1, 2, 3, 4, 5]);
let (sum, _) = run_state(0, acc_prog);
println!("sum of [1..5] = {}", sum);
println!("State free monad works");
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_increment_three_times_from_zero() {
let program = increment_three_times();
let (result, final_state) = run_state(0, program);
assert_eq!(result, 3);
assert_eq!(final_state, 3);
}
#[test]
fn test_increment_from_nonzero_initial_state() {
let program = increment_three_times();
let (result, final_state) = run_state(10, program);
assert_eq!(result, 13);
assert_eq!(final_state, 13);
}
#[test]
fn test_accumulate_sum() {
let program = accumulate(vec![1, 2, 3, 4, 5]);
let (sum, final_state) = run_state(0, program);
assert_eq!(sum, 15);
assert_eq!(final_state, 15);
}
#[test]
fn test_get_returns_initial_state() {
let program: Free<i32, i32> = get_state::<i32>();
let (result, state) = run_state(42, program);
assert_eq!(result, 42);
assert_eq!(state, 42);
}
}
(* State monad encoded as a free monad.
No mutation โ the state is threaded through the interpreter. *)
type 's state_f =
| Get : ('s -> 'a) -> 'a state_f
| Put : 's * 'a -> 'a state_f
type 'a free =
| Pure : 'a -> 'a free
| Free : 'a free state_f -> 'a free
let get () = Free (Get (fun s -> Pure s))
let put s = Free (Put (s, Pure ()))
let rec bind m f = match m with
| Pure x -> f x
| Free (Get cont) -> Free (Get (fun s -> bind (cont s) f))
| Free (Put (s, n)) -> Free (Put (s, bind n f))
let run init program =
let rec go s = function
| Pure x -> (x, s)
| Free (Get cont) -> go s (cont s)
| Free (Put (s', next)) -> go s' next
in go init program
let () =
(* Increment counter 3 times *)
let program =
bind (get ()) (fun n ->
bind (put (n + 1)) (fun () ->
bind (get ()) (fun n ->
bind (put (n + 1)) (fun () ->
bind (get ()) (fun n ->
bind (put (n + 1)) (fun () ->
bind (get ()) (fun final ->
Pure final)))))))
in
let (result, final_state) = run 0 program in
Printf.printf "result=%d state=%d\n" result final_state;
assert (result = 3 && final_state = 3);
Printf.printf "State free monad works\n"