๐Ÿฆ€ Functional Rust

598: Finally Tagless Style

Difficulty: 4 Level: Advanced Define a language as a trait, then run the same expression through any number of interpreters โ€” evaluator, pretty-printer, operation counter โ€” without changing the expression.

The Problem This Solves

Say you have arithmetic expressions and you want to do three things with them: evaluate them to a number, print them as a human-readable string, and count how many operations they contain. The classic approach: build an `enum Expr { Lit(i64), Add(Box<Expr>, Box<Expr>), ... }` then write three separate functions that pattern-match on it. This works โ€” until you want to add a new kind of expression. Add `Sub` to the enum, and suddenly _every_ match expression in your codebase is incomplete. The compiler will catch it, but you still have to touch every interpreter. In a small project that's fine. In a large codebase with many passes (evaluation, optimization, type-checking, code generation, pretty-printing) it becomes a maintenance burden. The other direction is worse: add a new _interpretation_ of existing expressions. Maybe you want to count expressions instead of evaluating them. With an AST enum, you write a new function from scratch, re-matching on every case. There's no way to share structure. The "finally tagless" style (the name comes from a paper by Kiselyov et al.) flips the model. Instead of a data type with constructors, you define a trait with methods. The expression is a generic function that calls those methods. Each interpreter is a struct that implements the trait with a different concrete type for the result. Adding an interpreter? New `impl`. Adding an expression form? New method in the trait (with a compile error for every impl that needs updating โ€” which is the right behavior). This pattern exists to solve exactly that pain.

The Intuition

Think of a universal remote control. You press "volume up." The remote doesn't know if you're controlling a TV, a soundbar, or an amp. It just calls `volume_up()` on whatever device is connected. The _command_ is defined once; the _behavior_ depends on the device. Tagless final works the same way. The "remote control" is the generic function โ€” it calls `lit()`, `add()`, `neg()` without knowing what they return. Each "device" is an interpreter:
trait Expr<R> {
 fn lit(n: i64) -> R;     // "make a literal number"
 fn add(l: R, r: R) -> R; // "add two things"
 fn mul(l: R, r: R) -> R; // "multiply two things"
 fn neg(x: R) -> R;       // "negate something"
}
The type parameter `R` is what the interpreter produces. Different interpreters pick different `R`: The program is written _once_ using the trait methods:
// 3*4 + (-2)
fn program<R, E: Expr<R>>() -> R {
 E::add(
     E::mul(E::lit(3), E::lit(4)),
     E::neg(E::lit(2)),
 )
}
When you call `program::<i64, EvalInterp>()`, Rust substitutes `i64` for `R` and `EvalInterp` for `E`. The whole thing compiles to direct arithmetic. No AST is ever built.

How It Works in Rust

Step 1: Define the language as a trait
trait Expr<R> {
 fn lit(n: i64) -> R;
 fn add(l: R, r: R) -> R;
 fn mul(l: R, r: R) -> R;
 fn neg(x: R) -> R;
}
This is your entire language definition. No `enum`, no `Box`, no heap. Step 2: Three interpreters
// Interpreter 1: evaluate to i64
struct EvalInterp;
impl Expr<i64> for EvalInterp {
 fn lit(n: i64)         -> i64 { n }
 fn add(l: i64, r: i64) -> i64 { l + r }
 fn mul(l: i64, r: i64) -> i64 { l * r }
 fn neg(x: i64)         -> i64 { -x }
}

// Interpreter 2: pretty-print to String
struct PrintInterp;
impl Expr<String> for PrintInterp {
 fn lit(n: i64)             -> String { format!("{}", n) }
 fn add(l: String, r: String) -> String { format!("({}+{})", l, r) }
 fn mul(l: String, r: String) -> String { format!("({}*{})", l, r) }
 fn neg(x: String)            -> String { format!("(-{})", x) }
}

// Interpreter 3: count operations (literals don't count)
struct CountInterp;
impl Expr<usize> for CountInterp {
 fn lit(_: i64)             -> usize { 0 }          // literals aren't "operations"
 fn add(l: usize, r: usize) -> usize { 1 + l + r }  // add itself is 1 op
 fn mul(l: usize, r: usize) -> usize { 1 + l + r }
 fn neg(x: usize)           -> usize { 1 + x }
}
Step 3: One expression, three runs
fn program<R, E: Expr<R>>() -> R {
 E::add(
     E::mul(E::lit(3), E::lit(4)),  // 3*4 = 12
     E::neg(E::lit(2)),             // -2
 )
}
// 3*4 + (-2) = 10

println!("{}", program::<i64,    EvalInterp>());   // => 10
println!("{}", program::<String, PrintInterp>());  // => "((3*4)+(-2))"
println!("{}", program::<usize,  CountInterp>());  // => 3 (mul, add, neg)
Adding a fourth interpreter: Create a new struct, `impl Expr<YourType> for YourStruct`, done. The program function doesn't change. The other interpreters don't change. Adding a new operation (say, `sub`): Add `fn sub(l: R, r: R) -> R` to the trait. Every existing `impl` breaks at compile time โ€” which is exactly right. You want each interpreter to consciously handle subtraction.

What This Unlocks

Key Differences

ConceptOCamlRust
Language definitionModule type with `type 'a repr` (HKT)Trait with type parameter `R`
InterpreterModule satisfying the module typeStruct implementing `Expr<R>`
Multiple interpretersDifferent modulesDifferent `impl Expr<R>` for different `R`
ExtensionNew moduleNew struct + impl
Type safetyPhantom types in HKTGenerics โ€” compiler enforces it
Runtime costZero โ€” direct callsZero โ€” monomorphised by rustc
// Finally tagless: the language is a trait
trait Expr<R> {
    fn lit(n: i64)     -> R;
    fn add(l: R, r: R) -> R;
    fn mul(l: R, r: R) -> R;
    fn neg(x: R)       -> R;
}

// Interpreter 1: evaluate
struct EvalInterp;
impl Expr<i64> for EvalInterp {
    fn lit(n: i64)       -> i64 { n }
    fn add(l: i64, r: i64) -> i64 { l + r }
    fn mul(l: i64, r: i64) -> i64 { l * r }
    fn neg(x: i64)         -> i64 { -x }
}

// Interpreter 2: pretty print
struct PrintInterp;
impl Expr<String> for PrintInterp {
    fn lit(n: i64)           -> String { format!("{}", n) }
    fn add(l: String, r: String) -> String { format!("({}+{})", l, r) }
    fn mul(l: String, r: String) -> String { format!("({}*{})", l, r) }
    fn neg(x: String)            -> String { format!("(-{})", x) }
}

// Interpreter 3: count operations
struct CountInterp;
impl Expr<usize> for CountInterp {
    fn lit(_: i64)           -> usize { 0 }
    fn add(l: usize, r: usize) -> usize { 1 + l + r }
    fn mul(l: usize, r: usize) -> usize { 1 + l + r }
    fn neg(x: usize)           -> usize { 1 + x }
}

// Same expression, multiple interpretations: 3*4 + (-2)
fn program<R, E: Expr<R>>() -> R {
    E::add(
        E::mul(E::lit(3), E::lit(4)),
        E::neg(E::lit(2)),
    )
}

fn main() {
    println!("eval  = {}", program::<i64,    EvalInterp>());
    println!("print = {}", program::<String, PrintInterp>());
    println!("ops   = {}", program::<usize,  CountInterp>());
}

#[cfg(test)]
mod tests {
    use super::*;
    #[test] fn eval()  { assert_eq!(program::<i64,EvalInterp>(), 10); }
    #[test] fn print() { assert!(program::<String,PrintInterp>().contains("+")); }
    #[test] fn count() { assert_eq!(program::<usize,CountInterp>(), 3); }
}
(* Finally tagless in OCaml via module functors *)
module type EXPR = sig
  type 'a repr
  val lit : int -> int repr
  val add : int repr -> int repr -> int repr
  val mul : int repr -> int repr -> int repr
end

(* Evaluator *)
module Eval = struct
  type 'a repr = 'a
  let lit n   = n
  let add l r = l + r
  let mul l r = l * r
end

(* Printer *)
module Print = struct
  type 'a repr = string
  let lit n   = string_of_int n
  let add l r = Printf.sprintf "(%s+%s)" l r
  let mul l r = Printf.sprintf "(%s*%s)" l r
end

(* Expression: 3 * 4 + 2 *)
let prog (type a) (module E : EXPR with type 'x repr = a) =
  E.add (E.mul (E.lit 3) (E.lit 4)) (E.lit 2)

let () =
  Printf.printf "eval: %d\n"  (prog (module Eval));
  Printf.printf "print: %s\n" (prog (module Print))