๐Ÿฆ€ Functional Rust

187: Free Monad with State

Difficulty: โญโญโญโญโญ Level: Expert Encode mutable state as a free algebra โ€” describe `Get` and `Put` operations as data, then thread state through a pure interpreter with no actual mutation.

The Problem This Solves

Mutable state is hard to test, hard to reason about, and hard to compose. If a function reads and writes a counter, testing it requires setting up real mutable state, running the function, and inspecting side effects. You can't easily replay it, record it, or run it in a different order. The state is implicit โ€” woven into the execution itself rather than being an explicit part of the computation's description. The free monad approach makes state explicit: instead of actually mutating anything, your program describes what it wants to do. `get_state()` doesn't read state โ€” it returns a data structure that says "I want to read state here." `put_state(42)` doesn't write state โ€” it returns a data structure that says "I want to write 42 here." The actual state threading happens in a separate `run_state` interpreter that walks the description and carries the state value as a function argument. This separation means you can inspect the program structure before running it, swap out the interpreter (thread state through a log, run it in reverse, replay specific steps), and test programs by checking the description rather than observing side effects.

The Intuition

Imagine a chef who writes a recipe instead of cooking. The recipe says "read current stock of flour" and "update flour stock to X - 200g". The recipe is just text โ€” it doesn't touch the flour. A kitchen manager takes the recipe and runs it: for each step, the manager carries the current pantry state, reads or updates it as instructed, and passes the new state to the next step. The chef and the manager are fully separated. The free monad is the recipe format. `Free::Free(StateF::Get(...))` is a recipe step that says "give me the current state." `Free::Free(StateF::Put(new_s, next))` says "set state to `new_s`, then continue." The recipe is a tree of these instructions. `run_state` is the manager โ€” it carries state through the tree without the recipe knowing anything about how.

How It Works in Rust

// The "instruction" type โ€” describes what the program wants
enum StateF<S, A> {
 Get(Box<dyn FnOnce(S) -> A>),  // "give me S, I'll produce A"
 Put(S, A),                      // "set state to S, then produce A"
}

// Free monad: either a plain value, or an instruction with a continuation
enum Free<S, A> {
 Pure(A),                        // computation complete, here's the value
 Free(Box<StateF<S, Free<S, A>>>), // one instruction, rest of computation in continuation
}

// Smart constructor: "I want to read state"
fn get_state<S: Clone + 'static>() -> Free<S, S> {
 Free::Free(Box::new(StateF::Get(Box::new(|s| Free::Pure(s)))))
 //                                          ^^ continuation: receive S, wrap it
}

// Smart constructor: "I want to write state"
fn put_state<S: 'static>(s: S) -> Free<S, ()> {
 Free::Free(Box::new(StateF::Put(s, Free::Pure(()))))
}

// Monadic bind: chain two computations
fn bind<S, A, B, F>(m: Free<S, A>, f: F) -> Free<S, B>
where F: FnOnce(A) -> Free<S, B> + 'static, ... {
 match m {
     Free::Pure(x) => f(x),  // computation done, apply f to its result
     Free::Free(instr) => match *instr {
         // Push f down into the continuation โ€” build up a larger description
         StateF::Get(cont) => Free::Free(Box::new(StateF::Get(
             Box::new(move |s| bind(cont(s), f))
         ))),
         StateF::Put(s, next) => Free::Free(Box::new(StateF::Put(s, bind(next, f)))),
     },
 }
}

// Interpreter: walk the description, thread real state
fn run_state<S: Clone, A>(init: S, program: Free<S, A>) -> (A, S) {
 let mut state = init;
 let mut current = program;
 loop {
     match current {
         Free::Pure(x) => return (x, state),          // done
         Free::Free(instr) => match *instr {
             StateF::Get(cont) => current = cont(state.clone()), // pass current state to cont
             StateF::Put(new_s, next) => { state = new_s; current = next; }
         },
     }
 }
}

// Example program โ€” pure description, no actual mutation:
let program = bind(get_state::<i32>(), |n|
 bind(put_state(n + 1), |_|
     get_state::<i32>()
 )
);
let (result, final_state) = run_state(0, program);
// result = 1, final_state = 1

What This Unlocks

Key Differences

ConceptOCamlRust
GADT instructions`type 's state_f = Get : ('s -> 'a) -> 'a state_f` โ€” return type varies per constructor`enum StateF<S, A>` โ€” monomorphic; single A type per instance
BindRecursive, naturally polymorphic with locally abstract typesRequires explicit `'static` bounds due to `Box<dyn FnOnce>`
Interpreter`let rec go s = function \Pure x -> (x,s) \...` โ€” idiomatic recursionIterative loop with mutable `state` and `current` to avoid stack overflow
ErgonomicsCleaner with `let*` syntax (monadic sugar)Deeply nested `bind` calls; consider macros for real use
ApplicabilityFirst-class pattern in OCaml FP ecosystemsLess common in Rust; prefer `async/await` or `RefCell` for most cases; free monads are for DSL design
// Free Monad with State โ€” State as a Free Algebra
//
// Encode mutable state purely using a free monad.
// No mutation โ€” the state is threaded through the interpreter.
// Get and Put are instructions; the interpreter threads the state value.

// โ”€โ”€ DSL: state instructions โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

enum StateF<S, A> {
    Get(Box<dyn FnOnce(S) -> A>),
    Put(S, A),
}

// โ”€โ”€ Free monad over StateF โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

enum Free<S, A> {
    Pure(A),
    Free(Box<StateF<S, Free<S, A>>>),
}

// โ”€โ”€ Smart constructors โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

fn get_state<S: Clone + 'static>() -> Free<S, S> {
    Free::Free(Box::new(StateF::Get(Box::new(|s| Free::Pure(s)))))
}

fn put_state<S: 'static>(s: S) -> Free<S, ()> {
    Free::Free(Box::new(StateF::Put(s, Free::Pure(()))))
}

// โ”€โ”€ Monadic bind โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

fn bind<S, A, B, F>(m: Free<S, A>, f: F) -> Free<S, B>
where
    S: 'static,
    A: 'static,
    B: 'static,
    F: FnOnce(A) -> Free<S, B> + 'static,
{
    match m {
        Free::Pure(x) => f(x),
        Free::Free(instr) => match *instr {
            StateF::Get(cont) => {
                Free::Free(Box::new(StateF::Get(Box::new(move |s| bind(cont(s), f)))))
            }
            StateF::Put(s, next) => {
                Free::Free(Box::new(StateF::Put(s, bind(next, f))))
            }
        },
    }
}

// โ”€โ”€ Interpreter: thread state through the computation โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

fn run_state<S: Clone, A>(init: S, program: Free<S, A>) -> (A, S) {
    let mut state = init;
    let mut current = program;
    loop {
        match current {
            Free::Pure(x) => return (x, state),
            Free::Free(instr) => match *instr {
                StateF::Get(cont) => {
                    current = cont(state.clone());
                }
                StateF::Put(new_state, next) => {
                    state = new_state;
                    current = next;
                }
            },
        }
    }
}

// โ”€โ”€ Example program: increment counter 3 times โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

fn increment_three_times() -> Free<i32, i32> {
    bind(get_state::<i32>(), |n| {
        bind(put_state(n + 1), |_| {
            bind(get_state::<i32>(), |n| {
                bind(put_state(n + 1), |_| {
                    bind(get_state::<i32>(), |n| {
                        bind(put_state(n + 1), |_| get_state::<i32>())
                    })
                })
            })
        })
    })
}

// โ”€โ”€ Example: accumulate into state โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

fn accumulate(values: Vec<i32>) -> Free<i32, i32> {
    if values.is_empty() {
        get_state::<i32>()
    } else {
        let head = values[0];
        let tail = values[1..].to_vec();
        bind(get_state::<i32>(), move |n| {
            bind(put_state(n + head), |_| accumulate(tail))
        })
    }
}

fn main() {
    // โ”€โ”€ Increment 3 times starting from 0 โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    let program = increment_three_times();
    let (result, final_state) = run_state(0, program);
    println!("increment 3x: result={} final_state={}", result, final_state);
    assert_eq!(result, 3);
    assert_eq!(final_state, 3);

    // โ”€โ”€ Start from non-zero initial state โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    let program2 = increment_three_times();
    let (result2, final2) = run_state(10, program2);
    println!("start=10, increment 3x: result={} final={}", result2, final2);

    // โ”€โ”€ Accumulate a list โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    let acc_prog = accumulate(vec![1, 2, 3, 4, 5]);
    let (sum, _) = run_state(0, acc_prog);
    println!("sum of [1..5] = {}", sum);

    println!("State free monad works");
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_increment_three_times_from_zero() {
        let program = increment_three_times();
        let (result, final_state) = run_state(0, program);
        assert_eq!(result, 3);
        assert_eq!(final_state, 3);
    }

    #[test]
    fn test_increment_from_nonzero_initial_state() {
        let program = increment_three_times();
        let (result, final_state) = run_state(10, program);
        assert_eq!(result, 13);
        assert_eq!(final_state, 13);
    }

    #[test]
    fn test_accumulate_sum() {
        let program = accumulate(vec![1, 2, 3, 4, 5]);
        let (sum, final_state) = run_state(0, program);
        assert_eq!(sum, 15);
        assert_eq!(final_state, 15);
    }

    #[test]
    fn test_get_returns_initial_state() {
        let program: Free<i32, i32> = get_state::<i32>();
        let (result, state) = run_state(42, program);
        assert_eq!(result, 42);
        assert_eq!(state, 42);
    }
}
(* State monad encoded as a free monad.
   No mutation โ€” the state is threaded through the interpreter. *)

type 's state_f =
  | Get : ('s -> 'a) -> 'a state_f
  | Put : 's * 'a -> 'a state_f

type 'a free =
  | Pure : 'a -> 'a free
  | Free : 'a free state_f -> 'a free

let get ()   = Free (Get  (fun s -> Pure s))
let put s    = Free (Put  (s, Pure ()))

let rec bind m f = match m with
  | Pure x -> f x
  | Free (Get cont)    -> Free (Get  (fun s -> bind (cont s) f))
  | Free (Put (s, n))  -> Free (Put  (s, bind n f))

let run init program =
  let rec go s = function
    | Pure x          -> (x, s)
    | Free (Get cont) -> go s (cont s)
    | Free (Put (s', next)) -> go s' next
  in go init program

let () =
  (* Increment counter 3 times *)
  let program =
    bind (get ()) (fun n ->
    bind (put (n + 1)) (fun () ->
    bind (get ()) (fun n ->
    bind (put (n + 1)) (fun () ->
    bind (get ()) (fun n ->
    bind (put (n + 1)) (fun () ->
    bind (get ()) (fun final ->
    Pure final)))))))
  in
  let (result, final_state) = run 0 program in
  Printf.printf "result=%d state=%d\n" result final_state;
  assert (result = 3 && final_state = 3);
  Printf.printf "State free monad works\n"