ExamplesBy LevelBy TopicLearning Paths
859 Expert

State Monad

Functional Programming

Tutorial

The Problem

Threading state through a sequence of functions without the State monad requires passing the state explicitly as an argument and returning it alongside the result: fn step(input: T, state: S) -> (R, S). This is error-prone and noisy. The State monad encapsulates this threading: State<S, A> represents a computation S -> (A, S) that reads and modifies state. Computations are composed without explicit state passing — the monad handles threading. This pattern appears in: compiler passes (threading symbol tables), game state machines, configuration accumulation, and embedded DSLs. It makes stateful computation composable and testable while remaining purely functional.

🎯 Learning Outcomes

  • • Understand State<S, A> as a wrapper around FnOnce(S) -> (A, S)
  • • Implement get() returning current state, put(s) replacing state, modify(f) transforming state
  • • Implement monadic bind: state.then(|a| next_state) threading state through both computations
  • • Use run_state(initial) to execute the computation and get (result, final_state)
  • • Recognize the tension with Rust's ownership: FnOnce vs Fn based on state mutation needs
  • Code Example

    struct State<S, A> {
        run: Box<dyn FnOnce(S) -> (A, S)>,
    }

    Key Differences

    AspectRustOCaml
    TypeBox<dyn FnOnce(S) -> (A, S)>State of ('s -> 'a * 's)
    FnOnce vs FnMust choose based on usefun s -> ... (always Fn)
    Bind implementationComplex with boxingClean algebraic unwrap
    getRequires S: CloneSame (returns clone in pure)
    Thread safetySend + Sync bounds neededNot applicable (single-threaded)
    'static boundRequired for boxed closuresNot required

    OCaml Approach

    OCaml represents State as type ('s, 'a) state = State of ('s -> 'a * 's). The run_state (State f) s = f s. Monadic bind: let bind (State f) k = State (fun s -> let (a, s') = f s in let State g = k a in g s'). get = State (fun s -> (s, s)), put s = State (fun _ -> ((), s)). OCaml's algebraic types make the State monad clean and readable. The ppx_let extension provides let%bind syntax for threading state.

    Full Source

    //! State monad in Rust — thread state through computations without explicit passing.
    //!
    //! The [`State`] type wraps a function `S -> (A, S)`. Combinators like
    //! [`State::bind`], [`get`], [`put`], and [`modify`] let you compose stateful
    //! computations, then [`run_state`] executes the pipeline against an initial state.
    
    /// A stateful computation that, given an input state of type `S`, produces a
    /// value of type `A` together with the next state.
    ///
    /// Wraps a boxed closure so that arbitrary captures are allowed.
    pub struct State<S, A> {
        run: Box<dyn FnOnce(S) -> (A, S)>,
    }
    
    impl<S: 'static, A: 'static> State<S, A> {
        /// Builds a `State` computation from a closure.
        pub fn new<F>(f: F) -> Self
        where
            F: FnOnce(S) -> (A, S) + 'static,
        {
            State { run: Box::new(f) }
        }
    
        /// Lifts a pure value into the state monad without touching the state.
        pub fn pure(a: A) -> Self {
            State::new(move |s| (a, s))
        }
    
        /// Monadic bind: runs `self`, feeds its result into `f`, and continues with
        /// the returned computation using the threaded state.
        pub fn bind<B, F>(self, f: F) -> State<S, B>
        where
            B: 'static,
            F: FnOnce(A) -> State<S, B> + 'static,
        {
            State::new(move |s| {
                let (a, s1) = (self.run)(s);
                (f(a).run)(s1)
            })
        }
    
        /// Maps a pure function over the result, leaving state threading intact.
        pub fn map<B, F>(self, f: F) -> State<S, B>
        where
            B: 'static,
            F: FnOnce(A) -> B + 'static,
        {
            self.bind(move |a| State::pure(f(a)))
        }
    }
    
    /// Executes a stateful computation starting from `s`, returning the final
    /// value and the final state.
    pub fn run_state<S: 'static, A: 'static>(m: State<S, A>, s: S) -> (A, S) {
        (m.run)(s)
    }
    
    /// Reads the current state as the produced value; leaves the state unchanged.
    pub fn get<S: Clone + 'static>() -> State<S, S> {
        State::new(|s: S| (s.clone(), s))
    }
    
    /// Replaces the current state with `s`; produces the unit value.
    pub fn put<S: 'static>(s: S) -> State<S, ()> {
        State::new(move |_| ((), s))
    }
    
    /// Applies `f` to the current state, storing the result as the new state.
    pub fn modify<S: 'static, F>(f: F) -> State<S, ()>
    where
        F: FnOnce(S) -> S + 'static,
    {
        State::new(move |s| ((), f(s)))
    }
    
    /// Increments a `u32` counter and returns the pre-increment value.
    pub fn tick() -> State<u32, u32> {
        get::<u32>().bind(|n| put(n + 1).bind(move |()| State::pure(n)))
    }
    
    /// Runs `tick` three times, collecting the three observed counter values.
    pub fn count3() -> State<u32, (u32, u32, u32)> {
        tick().bind(|a| {
            tick().bind(move |b| tick().bind(move |c| State::pure((a, b, c))))
        })
    }
    
    /// Pushes `x` onto the stack held in the state.
    pub fn push<T: 'static>(x: T) -> State<Vec<T>, ()> {
        modify(move |mut stack: Vec<T>| {
            stack.push(x);
            stack
        })
    }
    
    /// Pops the top element from the stack held in the state, or `None` if empty.
    pub fn pop<T: 'static>() -> State<Vec<T>, Option<T>> {
        State::new(|mut stack: Vec<T>| {
            let top = stack.pop();
            (top, stack)
        })
    }
    
    /// Example stack program: pushes 1, 2, 3, then pops twice.
    pub fn stack_ops() -> State<Vec<i32>, (Option<i32>, Option<i32>)> {
        push(1).bind(|()| {
            push(2).bind(|()| {
                push(3).bind(|()| {
                    pop::<i32>().bind(|a| pop::<i32>().bind(move |b| State::pure((a, b))))
                })
            })
        })
    }
    
    /// Generates a fresh label of the form `"{prefix}_{n}"` and increments the
    /// counter.
    pub fn fresh_label(prefix: &'static str) -> State<u32, String> {
        get::<u32>().bind(move |n| put(n + 1).bind(move |()| State::pure(format!("{}_{}", prefix, n))))
    }
    
    /// Generates three labels in sequence using a shared counter.
    pub fn gen_labels() -> State<u32, Vec<String>> {
        fresh_label("var").bind(|a| {
            fresh_label("tmp").bind(move |b| {
                fresh_label("var").bind(move |c| State::pure(vec![a, b, c]))
            })
        })
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn pure_does_not_touch_state() {
            let (v, s) = run_state(State::pure(42), 7u32);
            assert_eq!(v, 42);
            assert_eq!(s, 7);
        }
    
        #[test]
        fn get_observes_current_state() {
            let (v, s) = run_state(get::<u32>(), 99);
            assert_eq!(v, 99);
            assert_eq!(s, 99);
        }
    
        #[test]
        fn put_replaces_state() {
            let (_, s) = run_state(put(5u32), 0);
            assert_eq!(s, 5);
        }
    
        #[test]
        fn modify_applies_function() {
            let (_, s) = run_state(modify(|n: u32| n * 2), 21);
            assert_eq!(s, 42);
        }
    
        #[test]
        fn map_transforms_result() {
            let prog = get::<u32>().map(|n| n + 1);
            let (v, s) = run_state(prog, 10);
            assert_eq!(v, 11);
            assert_eq!(s, 10);
        }
    
        #[test]
        fn tick_returns_pre_increment_value() {
            let (v, s) = run_state(tick(), 0);
            assert_eq!(v, 0);
            assert_eq!(s, 1);
        }
    
        #[test]
        fn count3_threads_counter() {
            let (result, final_state) = run_state(count3(), 0);
            assert_eq!(result, (0, 1, 2));
            assert_eq!(final_state, 3);
        }
    
        #[test]
        fn stack_ops_push_then_pop() {
            let ((a, b), final_stack) = run_state(stack_ops(), Vec::<i32>::new());
            assert_eq!(a, Some(3));
            assert_eq!(b, Some(2));
            assert_eq!(final_stack, vec![1]);
        }
    
        #[test]
        fn pop_empty_yields_none() {
            let (v, s) = run_state(pop::<i32>(), Vec::<i32>::new());
            assert_eq!(v, None);
            assert_eq!(s, Vec::<i32>::new());
        }
    
        #[test]
        fn gen_labels_uses_shared_counter() {
            let (labels, final_state) = run_state(gen_labels(), 0);
            assert_eq!(
                labels,
                vec![
                    "var_0".to_string(),
                    "tmp_1".to_string(),
                    "var_2".to_string()
                ]
            );
            assert_eq!(final_state, 3);
        }
    
        #[test]
        fn bind_left_identity() {
            let f = |n: u32| State::pure(n + 1);
            let (v, s) = run_state(State::pure(10u32).bind(f), 0u32);
            assert_eq!(v, 11);
            assert_eq!(s, 0);
        }
    
        #[test]
        fn bind_right_identity() {
            let prog = get::<u32>().bind(State::pure);
            let (v, s) = run_state(prog, 7);
            assert_eq!(v, 7);
            assert_eq!(s, 7);
        }
    }
    ✓ Tests Rust test suite
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn pure_does_not_touch_state() {
            let (v, s) = run_state(State::pure(42), 7u32);
            assert_eq!(v, 42);
            assert_eq!(s, 7);
        }
    
        #[test]
        fn get_observes_current_state() {
            let (v, s) = run_state(get::<u32>(), 99);
            assert_eq!(v, 99);
            assert_eq!(s, 99);
        }
    
        #[test]
        fn put_replaces_state() {
            let (_, s) = run_state(put(5u32), 0);
            assert_eq!(s, 5);
        }
    
        #[test]
        fn modify_applies_function() {
            let (_, s) = run_state(modify(|n: u32| n * 2), 21);
            assert_eq!(s, 42);
        }
    
        #[test]
        fn map_transforms_result() {
            let prog = get::<u32>().map(|n| n + 1);
            let (v, s) = run_state(prog, 10);
            assert_eq!(v, 11);
            assert_eq!(s, 10);
        }
    
        #[test]
        fn tick_returns_pre_increment_value() {
            let (v, s) = run_state(tick(), 0);
            assert_eq!(v, 0);
            assert_eq!(s, 1);
        }
    
        #[test]
        fn count3_threads_counter() {
            let (result, final_state) = run_state(count3(), 0);
            assert_eq!(result, (0, 1, 2));
            assert_eq!(final_state, 3);
        }
    
        #[test]
        fn stack_ops_push_then_pop() {
            let ((a, b), final_stack) = run_state(stack_ops(), Vec::<i32>::new());
            assert_eq!(a, Some(3));
            assert_eq!(b, Some(2));
            assert_eq!(final_stack, vec![1]);
        }
    
        #[test]
        fn pop_empty_yields_none() {
            let (v, s) = run_state(pop::<i32>(), Vec::<i32>::new());
            assert_eq!(v, None);
            assert_eq!(s, Vec::<i32>::new());
        }
    
        #[test]
        fn gen_labels_uses_shared_counter() {
            let (labels, final_state) = run_state(gen_labels(), 0);
            assert_eq!(
                labels,
                vec![
                    "var_0".to_string(),
                    "tmp_1".to_string(),
                    "var_2".to_string()
                ]
            );
            assert_eq!(final_state, 3);
        }
    
        #[test]
        fn bind_left_identity() {
            let f = |n: u32| State::pure(n + 1);
            let (v, s) = run_state(State::pure(10u32).bind(f), 0u32);
            assert_eq!(v, 11);
            assert_eq!(s, 0);
        }
    
        #[test]
        fn bind_right_identity() {
            let prog = get::<u32>().bind(State::pure);
            let (v, s) = run_state(prog, 7);
            assert_eq!(v, 7);
            assert_eq!(s, 7);
        }
    }

    Deep Comparison

    Comparison: State Monad

    State Type

    OCaml:

    type ('s, 'a) state = State of ('s -> 'a * 's)
    let run_state (State f) s = f s
    

    Rust:

    struct State<S, A> {
        run: Box<dyn FnOnce(S) -> (A, S)>,
    }
    

    Bind / and_then

    OCaml:

    let bind m f = State (fun s ->
      let (a, s') = run_state m s in
      run_state (f a) s')
    

    Rust:

    fn and_then<B>(self, f: impl FnOnce(A) -> State<S, B> + 'static) -> State<S, B> {
        State::new(move |s| {
            let (a, s2) = self.run(s);
            f(a).run(s2)
        })
    }
    

    Counter Example

    OCaml:

    let tick = get >>= fun n -> put (n + 1) >>= fun () -> return_ n
    let (result, _) = run_state (tick >>= fun a -> tick >>= fun b -> return_ (a, b)) 0
    (* result = (0, 1) *)
    

    Rust (idiomatic — no monad needed):

    fn count3_explicit(state: i32) -> ((i32, i32, i32), i32) {
        let a = state; let state = state + 1;
        let b = state; let state = state + 1;
        let c = state; let state = state + 1;
        ((a, b, c), state)
    }
    

    Exercises

  • Implement monadic bind for State<S, A> and use it to compose get, a transform, and put into a single computation.
  • Implement a stack using the State monad: push and pop operations as State<Vec<T>, Option<T>> computations.
  • Use the State monad to implement a simple counter that increments and returns the new count at each step.
  • Compare the State monad approach with explicit state threading: implement the same computation both ways.
  • Implement modify(f: S -> S) -> State<S, ()> using get and put and verify it equals State::new(|s| ((), f(s))).
  • Open Source Repos