๐Ÿฆ€ Functional Rust

177: GADT Typed Expression Evaluator

Difficulty: โญโญโญ Level: Advanced Build a complete typed expression tree where `eval` returns exactly the right Rust type for each node โ€” integers for arithmetic, booleans for comparisons โ€” enforced at compile time.

The Problem This Solves

Writing an evaluator for a mini-language hits a wall fast. If you use a flat `enum Expr` for all nodes, `eval` has to return something like `Value::Int(i64) | Value::Bool(bool)` โ€” a runtime sum type. Now adding two expressions requires a runtime check: "are both ints?" If not, you panic or propagate an error. The type system gives you nothing; it can't tell `Add(Int(3), Bool(true))` is nonsense until someone runs it. The deeper problem: this mismatches how the language actually works. In any typed language, `3 + 4` has type `Int` and `3 == 4` has type `Bool`. That's a compile-time fact, not a runtime discovery. If our evaluator's type system can't encode it, we're carrying a permanent tax of defensive checks and potential runtime panics. OCaml's GADTs solve this directly: `Add : int expr * int expr -> int expr` makes the compiler verify that both operands and the result are all `int expr`. No runtime check needed โ€” the ill-typed expression can't be constructed. Rust achieves the same through a trait-per-node approach: each node type implements `Expr` with an associated `type Value`, and generic bounds enforce that `Add<A, B>` only exists when `A: Expr<Value = i64>`.

The Intuition

A spreadsheet formula has a type: `=SUM(...)` produces a number, `=ISBLANK(...)` produces a boolean. You can't pass a boolean column into `SUM` โ€” the spreadsheet rejects it before computing. That's what we want from our expression tree. The Rust solution makes each node a separate struct: `Lit`, `Add`, `Eq`, `IfExpr`. Each implements a trait `Expr` with an associated `Value` type. `Add<A, B>` is only valid when both `A` and `B` have `Value = i64`. The compiler tracks these constraints through every generic instantiation. When you call `.eval()`, the return type is exactly the associated `Value` of that node type โ€” statically known, no boxing needed.

How It Works in Rust

// The core trait: every expression node knows its result type
trait Expr {
 type Value;
 fn eval(&self) -> Self::Value;
}

// Integer literal โ€” evaluates to i64
struct Lit(i64);
impl Expr for Lit {
 type Value = i64;
 fn eval(&self) -> i64 { self.0 }
}

// Boolean literal โ€” evaluates to bool
struct BLit(bool);
impl Expr for BLit {
 type Value = bool;
 fn eval(&self) -> bool { self.0 }
}

// Add โ€” only compiles when BOTH operands evaluate to i64
struct Add<A: Expr<Value = i64>, B: Expr<Value = i64>>(A, B);
impl<A: Expr<Value = i64>, B: Expr<Value = i64>> Expr for Add<A, B> {
 type Value = i64;
 fn eval(&self) -> i64 { self.0.eval() + self.1.eval() }
}

// Equality check โ€” takes two i64s, returns bool
struct Eq<A: Expr<Value = i64>, B: Expr<Value = i64>>(A, B);
impl<A: Expr<Value = i64>, B: Expr<Value = i64>> Expr for Eq<A, B> {
 type Value = bool;
 fn eval(&self) -> bool { self.0.eval() == self.1.eval() }
}

// Conditional โ€” condition must be bool, branches must match in type
struct IfExpr<C: Expr<Value = bool>, T: Expr, F: Expr<Value = T::Value>>(C, T, F);
impl<C: Expr<Value = bool>, T: Expr, F: Expr<Value = T::Value>> Expr for IfExpr<C, T, F> {
 type Value = T::Value;
 fn eval(&self) -> Self::Value {
     if self.0.eval() { self.1.eval() } else { self.2.eval() }
 }
}

// Usage โ€” fully static, no runtime type checks:
let e = IfExpr(BLit(true), Add(Lit(1), Lit(2)), Lit(99));
let result: i64 = e.eval(); // compiler knows this is i64

// This fails at compile time โ€” you can't add a bool to an int:
// let bad = Add(Lit(5), BLit(true));

What This Unlocks

Key Differences

ConceptOCamlRust
Single type for all nodes`type _ expr` unified GADTSeparate struct per node (or separate int/bool enums)
Type-safe evalReturn type `a` refined per GADT constructorAssociated `type Value` on `Expr` trait, constrained by bounds
Conditional`If : bool expr 'a expr 'a expr -> 'a expr` in one constructor`IfExpr<C, T, F>` generic struct with `F: Expr<Value = T::Value>`
Pair/product types`('a * 'b) expr` naturally`PairExpr<A, B>` with `Value = (A::Value, B::Value)`
Runtime flexibilityHarder โ€” GADTs are staticCan use `Box<dyn DynExpr>` for runtime-constructed trees (with boxing cost)
// Example 177: GADT Typed Expression Evaluator
// Only well-typed expressions can be constructed

use std::fmt;

// === Approach 1: Trait-based typed expression tree ===
// Each node type is a separate struct; the trait ensures type safety

trait Expr: fmt::Debug {
    type Value;
    fn eval(&self) -> Self::Value;
    fn to_expr_string(&self) -> String;
}

#[derive(Debug)]
struct Lit(i64);

#[derive(Debug)]
struct BLit(bool);

#[derive(Debug)]
struct Add<A: Expr<Value = i64>, B: Expr<Value = i64>>(A, B);

#[derive(Debug)]
struct Mul<A: Expr<Value = i64>, B: Expr<Value = i64>>(A, B);

#[derive(Debug)]
struct Eq<A: Expr<Value = i64>, B: Expr<Value = i64>>(A, B);

#[derive(Debug)]
struct And<A: Expr<Value = bool>, B: Expr<Value = bool>>(A, B);

#[derive(Debug)]
struct Not<A: Expr<Value = bool>>(A);

#[derive(Debug)]
struct IfExpr<C: Expr<Value = bool>, T: Expr, F: Expr<Value = T::Value>>(C, T, F);

impl Expr for Lit {
    type Value = i64;
    fn eval(&self) -> i64 { self.0 }
    fn to_expr_string(&self) -> String { self.0.to_string() }
}

impl Expr for BLit {
    type Value = bool;
    fn eval(&self) -> bool { self.0 }
    fn to_expr_string(&self) -> String { self.0.to_string() }
}

impl<A: Expr<Value = i64>, B: Expr<Value = i64>> Expr for Add<A, B> {
    type Value = i64;
    fn eval(&self) -> i64 { self.0.eval() + self.1.eval() }
    fn to_expr_string(&self) -> String {
        format!("({} + {})", self.0.to_expr_string(), self.1.to_expr_string())
    }
}

impl<A: Expr<Value = i64>, B: Expr<Value = i64>> Expr for Mul<A, B> {
    type Value = i64;
    fn eval(&self) -> i64 { self.0.eval() * self.1.eval() }
    fn to_expr_string(&self) -> String {
        format!("({} * {})", self.0.to_expr_string(), self.1.to_expr_string())
    }
}

impl<A: Expr<Value = i64>, B: Expr<Value = i64>> Expr for Eq<A, B> {
    type Value = bool;
    fn eval(&self) -> bool { self.0.eval() == self.1.eval() }
    fn to_expr_string(&self) -> String {
        format!("({} = {})", self.0.to_expr_string(), self.1.to_expr_string())
    }
}

impl<A: Expr<Value = bool>, B: Expr<Value = bool>> Expr for And<A, B> {
    type Value = bool;
    fn eval(&self) -> bool { self.0.eval() && self.1.eval() }
    fn to_expr_string(&self) -> String {
        format!("({} && {})", self.0.to_expr_string(), self.1.to_expr_string())
    }
}

impl<A: Expr<Value = bool>> Expr for Not<A> {
    type Value = bool;
    fn eval(&self) -> bool { !self.0.eval() }
    fn to_expr_string(&self) -> String {
        format!("not({})", self.0.to_expr_string())
    }
}

impl<C: Expr<Value = bool>, T: Expr, F: Expr<Value = T::Value>> Expr for IfExpr<C, T, F> {
    type Value = T::Value;
    fn eval(&self) -> T::Value { if self.0.eval() { self.1.eval() } else { self.2.eval() } }
    fn to_expr_string(&self) -> String {
        format!("if {} then {} else {}", self.0.to_expr_string(), self.1.to_expr_string(), self.2.to_expr_string())
    }
}

// === Approach 2: Boxed dynamic dispatch for runtime-built trees ===

trait DynExprI64: fmt::Debug {
    fn eval(&self) -> i64;
}

struct DynLit(i64);
impl fmt::Debug for DynLit { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.0) } }
impl DynExprI64 for DynLit { fn eval(&self) -> i64 { self.0 } }

struct DynAdd(Box<dyn DynExprI64>, Box<dyn DynExprI64>);
impl fmt::Debug for DynAdd { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "({:?} + {:?})", self.0, self.1) } }
impl DynExprI64 for DynAdd { fn eval(&self) -> i64 { self.0.eval() + self.1.eval() } }

// === Approach 3: Enum-based with optimization pass ===

#[derive(Debug, Clone)]
enum IntExpr {
    Lit(i64),
    Add(Box<IntExpr>, Box<IntExpr>),
    Mul(Box<IntExpr>, Box<IntExpr>),
    IfB(Box<BoolExpr>, Box<IntExpr>, Box<IntExpr>),
}

#[derive(Debug, Clone)]
enum BoolExpr {
    Lit(bool),
    Eq(Box<IntExpr>, Box<IntExpr>),
    And(Box<BoolExpr>, Box<BoolExpr>),
    Not(Box<BoolExpr>),
}

impl IntExpr {
    fn eval(&self) -> i64 {
        match self {
            IntExpr::Lit(n) => *n,
            IntExpr::Add(a, b) => a.eval() + b.eval(),
            IntExpr::Mul(a, b) => a.eval() * b.eval(),
            IntExpr::IfB(c, t, f) => if c.eval() { t.eval() } else { f.eval() },
        }
    }

    fn optimize(self) -> Self {
        match self {
            IntExpr::Add(a, b) => {
                let a = a.optimize();
                let b = b.optimize();
                match (&a, &b) {
                    (IntExpr::Lit(0), _) => b,
                    (_, IntExpr::Lit(0)) => a,
                    (IntExpr::Lit(x), IntExpr::Lit(y)) => IntExpr::Lit(x + y),
                    _ => IntExpr::Add(Box::new(a), Box::new(b)),
                }
            }
            IntExpr::Mul(a, b) => {
                let a = a.optimize();
                let b = b.optimize();
                match (&a, &b) {
                    (IntExpr::Lit(0), _) | (_, IntExpr::Lit(0)) => IntExpr::Lit(0),
                    (IntExpr::Lit(1), _) => b,
                    (_, IntExpr::Lit(1)) => a,
                    (IntExpr::Lit(x), IntExpr::Lit(y)) => IntExpr::Lit(x * y),
                    _ => IntExpr::Mul(Box::new(a), Box::new(b)),
                }
            }
            other => other,
        }
    }
}

impl BoolExpr {
    fn eval(&self) -> bool {
        match self {
            BoolExpr::Lit(b) => *b,
            BoolExpr::Eq(a, b) => a.eval() == b.eval(),
            BoolExpr::And(a, b) => a.eval() && b.eval(),
            BoolExpr::Not(a) => !a.eval(),
        }
    }
}

fn main() {
    // Approach 1: Static typed expressions
    let e = Add(Lit(1), Mul(Lit(3), Lit(4)));
    println!("{} = {}", e.to_expr_string(), e.eval());

    let cond = IfExpr(Eq(Lit(1), Lit(1)), Lit(10), Lit(20));
    println!("{} = {}", cond.to_expr_string(), cond.eval());

    // Approach 2: Dynamic
    let d = DynAdd(Box::new(DynLit(10)), Box::new(DynLit(32)));
    println!("{:?} = {}", d, d.eval());

    // Approach 3: Optimizable
    let expr = IntExpr::Add(Box::new(IntExpr::Lit(0)), Box::new(IntExpr::Lit(5)));
    let opt = expr.optimize();
    println!("optimized: {:?} = {}", opt, opt.eval());

    println!("โœ“ All examples running");
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_static_eval() {
        assert_eq!(Lit(42).eval(), 42);
        assert_eq!(Add(Lit(1), Lit(2)).eval(), 3);
        assert_eq!(Mul(Lit(3), Lit(4)).eval(), 12);
        assert_eq!(Eq(Lit(1), Lit(1)).eval(), true);
        assert_eq!(Eq(Lit(1), Lit(2)).eval(), false);
        assert_eq!(And(BLit(true), BLit(true)).eval(), true);
        assert_eq!(Not(BLit(true)).eval(), false);
    }

    #[test]
    fn test_if_expr() {
        assert_eq!(IfExpr(BLit(true), Lit(10), Lit(20)).eval(), 10);
        assert_eq!(IfExpr(BLit(false), Lit(10), Lit(20)).eval(), 20);
    }

    #[test]
    fn test_pretty_print() {
        assert_eq!(Add(Lit(1), Lit(2)).to_expr_string(), "(1 + 2)");
        assert_eq!(Not(BLit(true)).to_expr_string(), "not(true)");
    }

    #[test]
    fn test_dynamic() {
        let d = DynAdd(Box::new(DynLit(10)), Box::new(DynLit(32)));
        assert_eq!(d.eval(), 42);
    }

    #[test]
    fn test_optimize() {
        let e = IntExpr::Add(Box::new(IntExpr::Lit(0)), Box::new(IntExpr::Lit(5)));
        assert_eq!(e.optimize().eval(), 5);

        let e = IntExpr::Mul(Box::new(IntExpr::Lit(0)), Box::new(IntExpr::Lit(999)));
        assert_eq!(e.optimize().eval(), 0);
    }
}
(* Example 177: GADT Typed Expression Evaluator *)
(* A fully typed expression language where only well-typed expressions compile *)

(* Approach 1: Complete typed expression evaluator *)
type _ expr =
  | Lit    : int -> int expr
  | BLit   : bool -> bool expr
  | Add    : int expr * int expr -> int expr
  | Mul    : int expr * int expr -> int expr
  | Eq     : int expr * int expr -> bool expr
  | And    : bool expr * bool expr -> bool expr
  | Not    : bool expr -> bool expr
  | If     : bool expr * 'a expr * 'a expr -> 'a expr
  | Pair   : 'a expr * 'b expr -> ('a * 'b) expr
  | Fst    : ('a * 'b) expr -> 'a expr
  | Snd    : ('a * 'b) expr -> 'b expr

let rec eval : type a. a expr -> a = function
  | Lit n -> n
  | BLit b -> b
  | Add (a, b) -> eval a + eval b
  | Mul (a, b) -> eval a * eval b
  | Eq (a, b) -> eval a = eval b
  | And (a, b) -> eval a && eval b
  | Not a -> not (eval a)
  | If (c, t, f) -> if eval c then eval t else eval f
  | Pair (a, b) -> (eval a, eval b)
  | Fst p -> fst (eval p)
  | Snd p -> snd (eval p)

(* Approach 2: Pretty printer that preserves type info *)
let rec to_string : type a. a expr -> string = function
  | Lit n -> string_of_int n
  | BLit b -> string_of_bool b
  | Add (a, b) -> "(" ^ to_string a ^ " + " ^ to_string b ^ ")"
  | Mul (a, b) -> "(" ^ to_string a ^ " * " ^ to_string b ^ ")"
  | Eq (a, b) -> "(" ^ to_string a ^ " = " ^ to_string b ^ ")"
  | And (a, b) -> "(" ^ to_string a ^ " && " ^ to_string b ^ ")"
  | Not a -> "not(" ^ to_string a ^ ")"
  | If (c, t, f) -> "if " ^ to_string c ^ " then " ^ to_string t ^ " else " ^ to_string f
  | Pair (a, b) -> "(" ^ to_string a ^ ", " ^ to_string b ^ ")"
  | Fst p -> "fst(" ^ to_string p ^ ")"
  | Snd p -> "snd(" ^ to_string p ^ ")"

(* Approach 3: Constant folding optimizer *)
let rec optimize : type a. a expr -> a expr = function
  | Add (Lit 0, b) -> optimize b
  | Add (a, Lit 0) -> optimize a
  | Mul (Lit 0, _) -> Lit 0
  | Mul (_, Lit 0) -> Lit 0
  | Mul (Lit 1, b) -> optimize b
  | Mul (a, Lit 1) -> optimize a
  | Add (Lit a, Lit b) -> Lit (a + b)
  | Mul (Lit a, Lit b) -> Lit (a * b)
  | And (BLit true, b) -> optimize b
  | And (_, BLit false) -> BLit false
  | Not (BLit b) -> BLit (not b)
  | If (BLit true, t, _) -> optimize t
  | If (BLit false, _, f) -> optimize f
  | e -> e

let () =
  (* Test evaluation *)
  assert (eval (Lit 42) = 42);
  assert (eval (Add (Lit 1, Lit 2)) = 3);
  assert (eval (Mul (Lit 3, Lit 4)) = 12);
  assert (eval (Eq (Lit 1, Lit 1)) = true);
  assert (eval (Eq (Lit 1, Lit 2)) = false);
  assert (eval (And (BLit true, BLit true)) = true);
  assert (eval (Not (BLit true)) = false);
  assert (eval (If (BLit true, Lit 10, Lit 20)) = 10);
  assert (eval (Pair (Lit 1, BLit true)) = (1, true));
  assert (eval (Fst (Pair (Lit 1, BLit true))) = 1);
  assert (eval (Snd (Pair (Lit 1, BLit true))) = true);

  (* Test pretty printing *)
  assert (to_string (Add (Lit 1, Lit 2)) = "(1 + 2)");

  (* Test optimizer *)
  assert (eval (optimize (Add (Lit 0, Lit 5))) = 5);
  assert (eval (optimize (Mul (Lit 0, Lit 999))) = 0);
  assert (eval (optimize (If (BLit true, Lit 1, Lit 2))) = 1);

  print_endline "โœ“ All tests passed"

๐Ÿ“Š Detailed Comparison

Comparison: Example 177 โ€” GADT Typed Expression Evaluator

Type Definition

OCaml

๐Ÿช Show OCaml equivalent
type _ expr =
| Lit  : int -> int expr
| BLit : bool -> bool expr
| Add  : int expr * int expr -> int expr
| Eq   : int expr * int expr -> bool expr
| If   : bool expr * 'a expr * 'a expr -> 'a expr
| Pair : 'a expr * 'b expr -> ('a * 'b) expr
| Fst  : ('a * 'b) expr -> 'a expr

Rust

trait Expr: fmt::Debug {
 type Value;
 fn eval(&self) -> Self::Value;
}

struct Lit(i64);
struct Add<A: Expr<Value = i64>, B: Expr<Value = i64>>(A, B);
struct Eq<A: Expr<Value = i64>, B: Expr<Value = i64>>(A, B);
struct IfExpr<C: Expr<Value = bool>, T: Expr, F: Expr<Value = T::Value>>(C, T, F);

Evaluation

OCaml

๐Ÿช Show OCaml equivalent
let rec eval : type a. a expr -> a = function
| Lit n -> n
| Add (a, b) -> eval a + eval b
| Eq (a, b) -> eval a = eval b
| If (c, t, f) -> if eval c then eval t else eval f
| Pair (a, b) -> (eval a, eval b)
| Fst p -> fst (eval p)

Rust

impl Expr for Lit {
 type Value = i64;
 fn eval(&self) -> i64 { self.0 }
}

impl<A: Expr<Value = i64>, B: Expr<Value = i64>> Expr for Add<A, B> {
 type Value = i64;
 fn eval(&self) -> i64 { self.0.eval() + self.1.eval() }
}

Constant Folding

OCaml

๐Ÿช Show OCaml equivalent
let rec optimize : type a. a expr -> a expr = function
| Add (Lit 0, b) -> optimize b
| Mul (Lit 0, _) -> Lit 0
| e -> e

Rust

fn optimize(self) -> Self {
 match self {
     IntExpr::Add(a, b) => match (&a.optimize(), &b.optimize()) {
         (IntExpr::Lit(0), _) => b,
         _ => IntExpr::Add(a, b),
     },
     other => other,
 }
}