๐Ÿฆ€ Functional Rust

612: Zygomorphism

Difficulty: 5 Level: Master Run two mutually dependent folds in a single traversal โ€” when fold B needs fold A's result at each step.

The Problem This Solves

Sometimes you need two values from a single fold, where one depends on the other. Computing the average of a list requires both the sum and the count. Computing variance requires both the mean and the sum of squared deviations. If you run them as two separate folds, you traverse the structure twice and miss the opportunity to share intermediate results. The naive fix is to fold a tuple `(sum, count)`. That works, but it's informal. The problem is deeper: what if fold B doesn't just need the result of fold A, but needs fold A's intermediate results at each step? A zygomorphism formalizes this: algebra B receives its own result-so-far and the result of algebra A at each node. This pattern appears in compilers (type inference that depends on a size computation done simultaneously), dynamic programming (a DP fold that depends on an auxiliary DP table built in the same pass), and statistics (streaming computation of mean and variance together).

The Intuition

A zygomorphism runs two algebras simultaneously โ€” algebra A computes one value, algebra B computes another while having access to A's result at each step โ€” giving you two outputs in one traversal of the structure. The trade-off: more complex algebra signature, but you get single-pass efficiency when two computations share sub-expressions.

How It Works in Rust

// Zygomorphism over a list: fold A and fold B together
// alg_a: F<A> โ†’ A  (the "helper" fold)
// alg_b: F<(A, B)> โ†’ B  (the "main" fold, has access to A's current result)

fn zygo_list<A, B, AlgA, AlgB>(
 list: &[f64],
 init_a: A,
 init_b: B,
 alg_a: AlgA,   // A fold: computes running A value
 alg_b: AlgB,   // B fold: uses A's result at each step
) -> (A, B)
where
 A: Clone,
 AlgA: Fn(f64, A) -> A,
 AlgB: Fn(f64, A, B) -> B,  // B receives: current element, A's result, B's result
{
 let mut acc_a = init_a;
 let mut acc_b = init_b;
 for &x in list {
     let new_a = alg_a(x, acc_a.clone());
     acc_b = alg_b(x, acc_a, acc_b);  // B sees A's result BEFORE updating A
     acc_a = new_a;
 }
 (acc_a, acc_b)
}

let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];

// Compute count and sum simultaneously
let (count, sum) = zygo_list(
 &data,
 0.0_f64, 0.0_f64,
 |_, c| c + 1.0,         // alg_a: count
 |x, _c, s| s + x,       // alg_b: sum (doesn't need count, but could)
);
let mean = sum / count;

// Compute mean AND variance in one pass
// Variance fold needs running mean โ€” so fold B depends on fold A at each step
let (_, variance) = zygo_list(
 &data,
 mean,  // pre-computed mean as "A"
 0.0,
 |_, m| m,                       // alg_a: identity (mean is fixed)
 |x, m, acc| acc + (x - m).powi(2),  // alg_b: sum of squared deviations
);
println!("variance = {}", variance / count);

What This Unlocks

Key Differences

ConceptOCamlRust
Zygomorphism`zygo alg_a alg_b` (two algebra args)Function with two algebra closures
vs two catasTwo traversalsOne traversal, shared intermediate state
Algebra B signature`F<(A, B)> โ†’ B``Fn(element, A, B) -> B`
EfficiencySingle traversalSingle traversal
Classic exampleAverage (sum + count)Mean + variance together
GeneralizesParamorphism (A = original sub-structure)Same โ€” para is a special zygo
// Zygomorphism: two algebras computed simultaneously in one pass

// Generic zygo over slices
fn zygo<A, R1, R2>(
    xs: &[A],
    init1: R1,
    init2: R2,
    step: impl Fn(R1, R2, &A) -> (R1, R2),
) -> (R1, R2) {
    xs.iter().fold((init1, init2), |(r1,r2), a| step(r1, r2, a))
}

// Mean and variance in one pass (Welford's / naive)
fn mean_variance(xs: &[f64]) -> (f64, f64) {
    let n = xs.len() as f64;
    let (sum, sum_sq) = zygo(xs, 0.0_f64, 0.0_f64, |s, sq, &x| (s+x, sq+x*x));
    let mean = sum / n;
    let variance = sum_sq / n - mean*mean;
    (mean, variance)
}

// Count even and odd simultaneously
fn count_even_odd(xs: &[i32]) -> (usize, usize) {
    zygo(xs, 0usize, 0usize, |evens, odds, &x| {
        if x % 2 == 0 { (evens+1, odds) } else { (evens, odds+1) }
    })
}

// Min and max simultaneously
fn min_max(xs: &[i32]) -> Option<(i32, i32)> {
    if xs.is_empty() { return None; }
    let (mn, mx) = zygo(&xs[1..], xs[0], xs[0], |mn, mx, &x| (mn.min(x), mx.max(x)));
    Some((mn, mx))
}

// Zip with index and running sum
fn indexed_running_sum(xs: &[i32]) -> Vec<(usize, i32)> {
    let (result, _) = zygo(xs, Vec::new(), 0i32, |mut v, sum, &x| {
        let new_sum = sum + x;
        v.push((v.len(), new_sum));
        (v, new_sum)
    });
    result
}

fn main() {
    let xs = [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
    let (mean, var) = mean_variance(&xs);
    println!("mean={:.2} variance={:.2} stddev={:.2}", mean, var, var.sqrt());

    let nums = [1,2,3,4,5,6,7,8,9,10];
    let (evens, odds) = count_even_odd(&nums);
    println!("evens={} odds={}", evens, odds);

    println!("min_max: {:?}", min_max(&nums));
    println!("indexed: {:?}", indexed_running_sum(&[1,2,3,4]));
}

#[cfg(test)]
mod tests {
    use super::*;
    #[test] fn test_mean_var() {
        let (m,v) = mean_variance(&[2.0,2.0,2.0]);
        assert!((m-2.0).abs() < 1e-10);
        assert!(v.abs() < 1e-10);
    }
    #[test] fn test_even_odd() { assert_eq!(count_even_odd(&[1,2,3,4,5]), (2,3)); }
    #[test] fn test_min_max()  { assert_eq!(min_max(&[3,1,4,1,5]), Some((1,5))); }
}
(* Zygomorphism in OCaml *)
(* Compute average and variance in one pass *)

let zygo f g xs =
  (* f is the "helper" algebra, g uses f's intermediate results *)
  List.fold_left (fun (acc_f, acc_g) x ->
    let new_f = f acc_f x in
    let new_g = g acc_g acc_f x in
    (new_f, new_g)
  ) (f [] (List.hd xs), g [] [] (List.hd xs))
  (List.tl xs)

(* Mean and sum of squares in one pass *)
let mean_and_ssq xs =
  let n = float_of_int (List.length xs) in
  let (sum, ssq) = List.fold_left (fun (s,q) x -> (s+.x, q+.x*.x)) (0.,0.) xs in
  (sum /. n, ssq /. n -. (sum/.n) *. (sum/.n))

let () =
  let xs = [2.;4.;4.;4.;5.;5.;7.;9.] in
  let (mean, variance) = mean_and_ssq xs in
  Printf.printf "mean=%.2f variance=%.2f stddev=%.2f\n" mean variance (sqrt variance)