๐Ÿฆ€ Functional Rust

138: Type Witnesses / GADT Encoding

Difficulty: โญโญโญ Level: Advanced Use phantom type parameters as "witnesses" to prove to the compiler that a value has a specific type, enabling type-safe expression trees and typed heterogeneous maps.

The Problem This Solves

Suppose you're building an expression evaluator. You have `IntLit`, `BoolLit`, `Add` (for ints), `Eq` (for ints, returns bool), and `If` (condition must be bool, branches must match). In an untyped `enum Expr`, nothing stops you from writing `Add(BoolLit(true), IntLit(5))` or `If(IntLit(1), ...)` โ€” both are valid enum values. Type errors are only caught at eval time with panics. What you want is an `IntExpr` that can only be constructed from int sub-expressions, and a `BoolExpr` that can only be constructed from bool sub-expressions. In Haskell/OCaml, GADTs (Generalized Algebraic Data Types) solve this natively: the constructors carry type information. Rust doesn't have GADTs, but you can simulate the key properties using phantom type parameters. The same pattern applies to typed maps: a `TypedMap` where each key carries a type parameter, so `map.get(&age_key)` returns `Option<&i32>` and `map.get(&name_key)` returns `Option<&String>`, without any unsafe casts.

The Intuition

A type witness is a phantom type parameter that "witnesses" โ€” proves โ€” some type-level fact about a value. `struct TypedExpr<T>` wraps an untyped `Expr` with a phantom `T`. The constructors enforce the right `T`: `fn int_lit(n: i32) -> TypedExpr<i32>` can only return an int expression. `fn eq(a: TypedExpr<i32>, b: TypedExpr<i32>) -> TypedExpr<bool>` takes two int expressions and produces a bool expression. The types flow through the tree. The compiler now enforces: you can't call `add(bool_lit(true), int_lit(5))` because `add` requires two `TypedExpr<i32>` and `bool_lit` returns `TypedExpr<bool>`. Type errors in your expression language become compile-time errors in your Rust host code.

How It Works in Rust

use std::marker::PhantomData;

// The untyped core โ€” for evaluation
enum Expr {
 IntLit(i32),
 BoolLit(bool),
 Add(Box<Expr>, Box<Expr>),
 Eq(Box<Expr>, Box<Expr>),
 If(Box<Expr>, Box<Expr>, Box<Expr>),
}

// Typed wrapper โ€” phantom T witnesses the expression's type
struct TypedExpr<T> {
 inner: Expr,
 _type: PhantomData<T>,  // T is not stored, just proves a type fact
}

// Typed constructors โ€” the return type IS the type witness
fn int_lit(n: i32) -> TypedExpr<i32> {
 TypedExpr { inner: Expr::IntLit(n), _type: PhantomData }
}

fn bool_lit(b: bool) -> TypedExpr<bool> {
 TypedExpr { inner: Expr::BoolLit(b), _type: PhantomData }
}

// add: only accepts int expressions, produces int expression
fn add(a: TypedExpr<i32>, b: TypedExpr<i32>) -> TypedExpr<i32> {
 TypedExpr { inner: Expr::Add(Box::new(a.inner), Box::new(b.inner)), _type: PhantomData }
}

// eq: int expressions in, bool expression out โ€” type relationship captured
fn eq(a: TypedExpr<i32>, b: TypedExpr<i32>) -> TypedExpr<bool> {
 TypedExpr { inner: Expr::Eq(Box::new(a.inner), Box::new(b.inner)), _type: PhantomData }
}

// if_then_else: condition must be bool, branches must have matching type T
fn if_then_else<T>(cond: TypedExpr<bool>, t: TypedExpr<T>, f: TypedExpr<T>) -> TypedExpr<T> {
 TypedExpr { inner: Expr::If(Box::new(cond.inner), Box::new(t.inner), Box::new(f.inner)), _type: PhantomData }
}
Usage:
// This compiles โ€” valid expression
let e = if_then_else(
 eq(int_lit(1), int_lit(1)),  // bool condition โœ“
 int_lit(42),                  // int branch โœ“
 int_lit(0),                   // int branch โœ“ (matches)
);

// This does NOT compile โ€” type witness catches the error:
// let bad = add(bool_lit(true), int_lit(5));
// error: expected TypedExpr<i32>, found TypedExpr<bool>
Typed map โ€” keys carry type information:
use std::any::Any;

struct TypedKey<T: 'static> {
 name: String,
 _type: PhantomData<T>,  // witness: this key maps to values of type T
}

struct TypedMap { entries: Vec<(String, Box<dyn Any>)> }

impl TypedMap {
 fn insert<T: 'static>(&mut self, key: &TypedKey<T>, value: T) {
     self.entries.push((key.name.clone(), Box::new(value)));
 }

 fn get<T: 'static>(&self, key: &TypedKey<T>) -> Option<&T> {
     self.entries.iter()
         .find(|(name, _)| name == &key.name)
         .and_then(|(_, val)| val.downcast_ref::<T>())  // type-safe downcast
 }
}

let age_key:  TypedKey<i32>    = TypedKey::new("age");
let name_key: TypedKey<String> = TypedKey::new("name");

map.get(&age_key)   // โ†’ Option<&i32>    โ€” no cast needed, type is witnessed
map.get(&name_key)  // โ†’ Option<&String> โ€” different type from same map

What This Unlocks

Key Differences

ConceptOCamlRust
GADTNative: `type _ expr = IntLit : int -> int expr \...` โ€” constructors refine the type indexSimulated: `struct TypedExpr<T>` with typed constructor functions
Type-safe eval`let rec eval : type a. a expr -> a` โ€” return type inferred from GADT indexMust eval untyped core; typed wrapper ensures only valid trees are built
Heterogeneous mapModule-based or `'a Hashtbl.t` with GADT keys`Box<dyn Any>` + `downcast_ref` gated by `TypedKey<T>` phantom
Type flowGADT refinement flows automaticallyPhantom type propagates through typed constructor signatures
//! Example 138: Type Witnesses / GADT Encoding
//!
//! Simulates OCaml's GADTs in Rust using phantom type parameters.
//! A phantom `T` on `TypedExpr<T>` *witnesses* the expression's result type
//! at compile time โ€” the compiler rejects ill-typed trees before runtime.
//!
//! Two demonstrations:
//!   1. A typed expression tree (GADT-style via smart constructors + PhantomData)
//!   2. A typed heterogeneous map (each key witnesses its value's type)

use std::any::Any;
use std::collections::HashMap;
use std::marker::PhantomData;

// โ”€โ”€ Approach 1: GADT-style typed expression tree โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
//
// OCaml: `type _ expr = IntLit : int -> int expr | Add : int expr * int expr -> int expr | ...`
// Rust:  wrap a `RawExpr` in `TypedExpr<T>`; smart constructors set T correctly.
//
// The `unreachable!` branches in eval are dead code by construction โ€”
// the phantom-type invariant guarantees only valid combinations reach `eval`.

/// Untyped inner AST โ€” private; never exposed directly.
enum RawExpr {
    IntLit(i32),
    BoolLit(bool),
    Add(Box<RawExpr>, Box<RawExpr>),
    Eq(Box<RawExpr>, Box<RawExpr>),
    If(Box<RawExpr>, Box<RawExpr>, Box<RawExpr>),
}

/// A typed expression: `T` is the *witness* for the result type.
///
/// Users can only construct values through the smart constructors below,
/// which enforce the correct `T` โ€” exactly what OCaml GADTs guarantee via
/// constructor type indices.
pub struct TypedExpr<T> {
    raw: RawExpr,
    _marker: PhantomData<T>,
}

// โ”€โ”€ Smart constructors (the only way to build `TypedExpr<T>`) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

/// An integer literal expression.
pub fn int_lit(n: i32) -> TypedExpr<i32> {
    TypedExpr {
        raw: RawExpr::IntLit(n),
        _marker: PhantomData,
    }
}

/// A boolean literal expression.
pub fn bool_lit(b: bool) -> TypedExpr<bool> {
    TypedExpr {
        raw: RawExpr::BoolLit(b),
        _marker: PhantomData,
    }
}

/// Addition of two integer expressions โ€” both arguments *must* be `i32`.
pub fn add(a: TypedExpr<i32>, b: TypedExpr<i32>) -> TypedExpr<i32> {
    TypedExpr {
        raw: RawExpr::Add(Box::new(a.raw), Box::new(b.raw)),
        _marker: PhantomData,
    }
}

/// Equality test on two integer expressions โ€” result is `bool`.
pub fn eq_expr(a: TypedExpr<i32>, b: TypedExpr<i32>) -> TypedExpr<bool> {
    TypedExpr {
        raw: RawExpr::Eq(Box::new(a.raw), Box::new(b.raw)),
        _marker: PhantomData,
    }
}

/// Conditional: condition must be `bool`, branches must share the same type `T`.
pub fn if_expr<T>(
    cond: TypedExpr<bool>,
    then_branch: TypedExpr<T>,
    else_branch: TypedExpr<T>,
) -> TypedExpr<T> {
    TypedExpr {
        raw: RawExpr::If(
            Box::new(cond.raw),
            Box::new(then_branch.raw),
            Box::new(else_branch.raw),
        ),
        _marker: PhantomData,
    }
}

// โ”€โ”€ Evaluation โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
//
// We provide concrete `eval` impls for each supported result type rather than
// a generic trait-bounded impl, so the internal `RawExpr` type stays private
// and we avoid the `private_bounds` lint.

fn eval_i32(raw: &RawExpr) -> i32 {
    match raw {
        RawExpr::IntLit(n) => *n,
        RawExpr::Add(a, b) => eval_i32(a) + eval_i32(b),
        RawExpr::If(cond, t, f) => {
            if eval_bool(cond) {
                eval_i32(t)
            } else {
                eval_i32(f)
            }
        }
        _ => unreachable!("type witness invariant: not an i32 expression"),
    }
}

fn eval_bool(raw: &RawExpr) -> bool {
    match raw {
        RawExpr::BoolLit(b) => *b,
        RawExpr::Eq(a, b) => eval_i32(a) == eval_i32(b),
        RawExpr::If(cond, t, f) => {
            if eval_bool(cond) {
                eval_bool(t)
            } else {
                eval_bool(f)
            }
        }
        _ => unreachable!("type witness invariant: not a bool expression"),
    }
}

impl TypedExpr<i32> {
    /// Evaluate, returning the witnessed `i32`.
    pub fn eval(&self) -> i32 {
        eval_i32(&self.raw)
    }
}

impl TypedExpr<bool> {
    /// Evaluate, returning the witnessed `bool`.
    pub fn eval(&self) -> bool {
        eval_bool(&self.raw)
    }
}

// โ”€โ”€ Approach 2: Typed heterogeneous map โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
//
// A `TypedKey<T>` witnesses that the value stored under this key has type `T`.
// The caller sees `Option<&T>` from `get` โ€” no visible downcasting.

/// A key whose phantom `T` witnesses the type of its associated value.
pub struct TypedKey<T: 'static> {
    name: &'static str,
    _marker: PhantomData<T>,
}

impl<T: 'static> TypedKey<T> {
    /// Create a new key with a unique string name.
    /// Keys with different names are always distinct, even for the same `T`.
    pub const fn new(name: &'static str) -> Self {
        TypedKey {
            name,
            _marker: PhantomData,
        }
    }
}

/// A heterogeneous map: each value may have a different type,
/// determined at compile time by the `TypedKey<T>` used to access it.
#[derive(Default)]
pub struct TypedMap {
    inner: HashMap<&'static str, Box<dyn Any>>,
}

impl TypedMap {
    pub fn new() -> Self {
        Self::default()
    }

    /// Insert a value; the key's phantom type ensures `value: T`.
    pub fn insert<T: Any>(&mut self, key: &TypedKey<T>, value: T) {
        self.inner.insert(key.name, Box::new(value));
    }

    /// Retrieve a reference; returns `Option<&T>` โ€” type determined by the key.
    pub fn get<T: Any>(&self, key: &TypedKey<T>) -> Option<&T> {
        self.inner.get(key.name)?.downcast_ref::<T>()
    }

    /// Remove and return the value, if present.
    pub fn remove<T: Any>(&mut self, key: &TypedKey<T>) -> Option<T> {
        self.inner
            .remove(key.name)
            .and_then(|v| v.downcast::<T>().ok())
            .map(|b| *b)
    }
}

// โ”€โ”€ Tests โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

#[cfg(test)]
mod tests {
    use super::*;

    // โ”€โ”€ Expression tree โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

    #[test]
    fn test_int_literal() {
        assert_eq!(int_lit(42).eval(), 42);
        assert_eq!(int_lit(0).eval(), 0);
        assert_eq!(int_lit(-7).eval(), -7);
    }

    #[test]
    fn test_bool_literal() {
        assert!(bool_lit(true).eval());
        assert!(!bool_lit(false).eval());
    }

    #[test]
    fn test_add() {
        assert_eq!(add(int_lit(3), int_lit(4)).eval(), 7);
        assert_eq!(add(int_lit(0), int_lit(0)).eval(), 0);
        // Nested: (1 + 2) + (3 + 4)
        assert_eq!(
            add(add(int_lit(1), int_lit(2)), add(int_lit(3), int_lit(4))).eval(),
            10
        );
    }

    #[test]
    fn test_eq_expr() {
        assert!(eq_expr(int_lit(5), int_lit(5)).eval());
        assert!(!eq_expr(int_lit(5), int_lit(6)).eval());
        assert!(eq_expr(add(int_lit(1), int_lit(2)), int_lit(3)).eval());
    }

    #[test]
    fn test_if_int_branches() {
        assert_eq!(if_expr(bool_lit(true), int_lit(10), int_lit(20)).eval(), 10);
        assert_eq!(
            if_expr(bool_lit(false), int_lit(10), int_lit(20)).eval(),
            20
        );
    }

    #[test]
    fn test_if_bool_branches() {
        assert!(!if_expr(bool_lit(true), bool_lit(false), bool_lit(true)).eval());
        assert!(if_expr(bool_lit(false), bool_lit(false), bool_lit(true)).eval());
    }

    #[test]
    fn test_complex_expression() {
        // if (1 + 2 == 3) then 100 else 0  โ†’  100
        let cond = eq_expr(add(int_lit(1), int_lit(2)), int_lit(3));
        let expr = if_expr(cond, int_lit(100), int_lit(0));
        assert_eq!(expr.eval(), 100);
    }

    // โ”€โ”€ Typed map โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

    #[test]
    fn test_typed_map_insert_and_get() {
        static AGE: TypedKey<i32> = TypedKey::new("age");
        static NAME: TypedKey<String> = TypedKey::new("name");

        let mut map = TypedMap::new();
        map.insert(&AGE, 30);
        map.insert(&NAME, "Alice".to_string());

        assert_eq!(map.get(&AGE), Some(&30));
        assert_eq!(map.get(&NAME), Some(&"Alice".to_string()));
    }

    #[test]
    fn test_typed_map_missing_key() {
        static KEY: TypedKey<i32> = TypedKey::new("missing_key_138");
        let map = TypedMap::new();
        assert_eq!(map.get(&KEY), None);
    }

    #[test]
    fn test_typed_map_overwrite() {
        static KEY: TypedKey<i32> = TypedKey::new("counter_138");
        let mut map = TypedMap::new();
        map.insert(&KEY, 1);
        map.insert(&KEY, 2);
        assert_eq!(map.get(&KEY), Some(&2));
    }

    #[test]
    fn test_typed_map_remove() {
        static KEY: TypedKey<String> = TypedKey::new("greeting_138");
        let mut map = TypedMap::new();
        map.insert(&KEY, "hello".to_string());
        assert_eq!(map.remove(&KEY), Some("hello".to_string()));
        assert_eq!(map.get(&KEY), None);
    }
}
(* Example 138: Type Witnesses / GADT Encoding *)

(* Approach 1: GADT expression tree with type safety *)
type _ expr =
  | IntLit : int -> int expr
  | BoolLit : bool -> bool expr
  | Add : int expr * int expr -> int expr
  | Eq : int expr * int expr -> bool expr
  | If : bool expr * 'a expr * 'a expr -> 'a expr
  | Pair : 'a expr * 'b expr -> ('a * 'b) expr
  | Fst : ('a * 'b) expr -> 'a expr

let rec eval : type a. a expr -> a = function
  | IntLit n -> n
  | BoolLit b -> b
  | Add (a, b) -> eval a + eval b
  | Eq (a, b) -> eval a = eval b
  | If (cond, t, f) -> if eval cond then eval t else eval f
  | Pair (a, b) -> (eval a, eval b)
  | Fst p -> fst (eval p)

(* Approach 2: Type witness for safe casting *)
type (_, _) eq = Refl : ('a, 'a) eq

let cast : type a b. (a, b) eq -> a -> b = fun Refl x -> x

(* Approach 3: Typed keys for heterogeneous map *)
type _ key =
  | IntKey : string -> int key
  | StringKey : string -> string key
  | BoolKey : string -> bool key

type binding = Binding : 'a key * 'a -> binding

let get_int (IntKey _ as k) bindings =
  List.find_map (fun (Binding (k', v)) ->
    match k, k' with
    | IntKey a, IntKey b when a = b -> Some v
    | _ -> None
  ) bindings

(* Tests *)
let () =
  let e = If (Eq (IntLit 1, IntLit 1), IntLit 42, IntLit 0) in
  assert (eval e = 42);

  let e2 = Add (IntLit 10, IntLit 32) in
  assert (eval e2 = 42);

  let e3 = Fst (Pair (IntLit 1, BoolLit true)) in
  assert (eval e3 = 1);

  let x = cast Refl 42 in
  assert (x = 42);

  Printf.printf "โœ“ All tests passed\n"

๐Ÿ“Š Detailed Comparison

Comparison: Type Witnesses / GADT Encoding

OCaml

๐Ÿช Show OCaml equivalent
type _ expr =
| IntLit : int -> int expr
| Add : int expr * int expr -> int expr
| Eq : int expr * int expr -> bool expr
| If : bool expr * 'a expr * 'a expr -> 'a expr

let rec eval : type a. a expr -> a = function
| IntLit n -> n
| Add (a, b) -> eval a + eval b
| If (cond, t, f) -> if eval cond then eval t else eval f

Rust

// Typed wrapper over untyped core
struct TypedExpr<T> { inner: Expr, _type: PhantomData<T> }

fn int_lit(n: i32) -> TypedExpr<i32> { /* ... */ }
fn add(a: TypedExpr<i32>, b: TypedExpr<i32>) -> TypedExpr<i32> { /* ... */ }
fn eq(a: TypedExpr<i32>, b: TypedExpr<i32>) -> TypedExpr<bool> { /* ... */ }

// Can't accidentally: add(int_lit(1), bool_lit(true)) โ€” compile error!