🦀 Functional Rust

790: Matrix Chain Multiplication

Difficulty: 4 Level: Advanced Find the optimal parenthesization order to minimize scalar multiplications when chaining matrices together.

The Problem This Solves

Matrix multiplication is associative: `(A × B) × C = A × (B × C)`. The order of parenthesization doesn't change the result — but it dramatically changes the cost. With the wrong order, multiplying a chain of matrices can require billions of operations; with the optimal order, the same chain might need only thousands. This matters any time you compose linear transformations — 3D rendering pipelines, neural network weight matrices, compiler-generated loop transformations, and robotics kinematics chains all multiply sequences of matrices. In production, getting the order wrong by even one step can make a hot path 100× slower. Matrix chain multiplication is a classic interval DP problem: you break the chain at every possible split point `k`, recursively solve both sub-chains, and pick the split that minimizes total cost. Bottom-up DP fills a triangular table in O(n³) time and O(n²) space, then backtracks through a `split` table to reconstruct the optimal parenthesization.

The Intuition

Given matrices M₁, M₂, …, Mₙ with compatible dimensions, you want to find where to place the parentheses so the total number of scalar multiplications is minimized. Multiplying an `a×b` matrix by a `b×c` matrix costs `abc` operations. The DP state `dp[i][j]` = minimum cost to multiply matrices `i` through `j`. Try every split point `k` between `i` and `j`, then `dp[i][j] = min over k of (dp[i][k] + dp[k+1][j] + dims[i]dims[k+1]dims[j+1])`. O(n³) time, O(n²) space.

How It Works in Rust

fn matrix_chain(dims: &[usize]) -> (usize, Vec<Vec<usize>>) {
 let n = dims.len() - 1;  // number of matrices
 let mut dp    = vec![vec![0usize; n]; n];  // min cost
 let mut split = vec![vec![0usize; n]; n];  // optimal split point

 // Fill by chain length l = 2..=n
 for l in 2..=n {
     for i in 0..=(n - l) {
         let j = i + l - 1;
         dp[i][j] = usize::MAX;
         for k in i..j {
             // Cost = left sub-chain + right sub-chain + this multiplication
             let cost = dp[i][k]
                 .saturating_add(dp[k + 1][j])
                 .saturating_add(dims[i] * dims[k + 1] * dims[j + 1]);
             if cost < dp[i][j] {
                 dp[i][j] = cost;
                 split[i][j] = k;  // remember where to split
             }
         }
     }
 }
 (dp[0][n - 1], split)
}

// Reconstruct the optimal parenthesization from the split table
fn parenthesize(split: &Vec<Vec<usize>>, i: usize, j: usize) -> String {
 if i == j { format!("M{}", i + 1) }
 else {
     let k = split[i][j];
     format!("({} × {})", parenthesize(split, i, k), parenthesize(split, k + 1, j))
 }
}
Key Rust details: `saturating_add` prevents overflow when adding to `usize::MAX` (sentinel for "not yet computed"); the `split` table is a separate `Vec<Vec<usize>>` that mirrors the `dp` table structure; reconstruction is a natural recursive descent.

What This Unlocks

Key Differences

ConceptOCamlRust
2D mutable table`Array.make_matrix n n 0``vec![vec![0usize; n]; n]`
Infinity sentinel`max_int``usize::MAX` with `saturating_add`
Recursive reconstructionNatural pattern matchingExplicit recursion with index tracking
Immutable sub-resultsFunctional style with `let`Same — inner `for` builds toward `dp[i][j]`
// Matrix Chain Multiplication — bottom-up DP O(n³)
// dims[i..i+1] gives dimensions of matrix i: rows=dims[i], cols=dims[i+1]

fn matrix_chain(dims: &[usize]) -> (usize, Vec<Vec<usize>>) {
    let n = dims.len() - 1;
    let mut dp    = vec![vec![0usize; n]; n];
    let mut split = vec![vec![0usize; n]; n];

    // l = chain length
    for l in 2..=n {
        for i in 0..=(n - l) {
            let j = i + l - 1;
            dp[i][j] = usize::MAX;
            for k in i..j {
                let cost = dp[i][k]
                    .saturating_add(dp[k + 1][j])
                    .saturating_add(dims[i] * dims[k + 1] * dims[j + 1]);
                if cost < dp[i][j] {
                    dp[i][j] = cost;
                    split[i][j] = k;
                }
            }
        }
    }
    (dp[0][n - 1], split)
}

fn parenthesize(split: &Vec<Vec<usize>>, i: usize, j: usize) -> String {
    if i == j {
        format!("M{}", i + 1)
    } else {
        let k = split[i][j];
        format!("({} × {})", parenthesize(split, i, k), parenthesize(split, k + 1, j))
    }
}

fn main() {
    // 6 matrices: dims describe 7 boundary values
    let dims = vec![30, 35, 15, 5, 10, 20, 25];
    let n = dims.len() - 1;
    println!("Number of matrices: {n}");
    let (cost, split) = matrix_chain(&dims);
    println!("Minimum scalar multiplications: {cost}");
    println!("Optimal parenthesization: {}", parenthesize(&split, 0, n - 1));

    // Classic 3-matrix example
    let dims2 = vec![10, 30, 5, 60];
    let (c2, s2) = matrix_chain(&dims2);
    println!("\n3-matrix (10×30, 30×5, 5×60):");
    println!("Min cost: {c2}, Order: {}", parenthesize(&s2, 0, 2));
}
(* Matrix Chain Multiplication — top-down memoised DP
   dims.(i) = rows of matrix i, dims.(i+1) = cols of matrix i
   e.g., for 3 matrices: dims = [|p0; p1; p2; p3|] *)

let matrix_chain dims =
  let n = Array.length dims - 1 in
  (* dp.(i).(j) = min cost to multiply matrices i..j (0-indexed) *)
  let dp    = Array.make_matrix n n 0 in
  let split = Array.make_matrix n n 0 in
  (* chain length l from 2 to n *)
  for l = 2 to n do
    for i = 0 to n - l do
      let j = i + l - 1 in
      dp.(i).(j) <- max_int;
      for k = i to j - 1 do
        let cost = dp.(i).(k) + dp.(k+1).(j)
                   + dims.(i) * dims.(k+1) * dims.(j+1) in
        if cost < dp.(i).(j) then begin
          dp.(i).(j)    <- cost;
          split.(i).(j) <- k
        end
      done
    done
  done;
  (dp.(0).(n-1), split)

(* Reconstruct optimal parenthesization as a string *)
let rec parenthesize split i j =
  if i = j then
    Printf.sprintf "M%d" (i + 1)
  else
    let k = split.(i).(j) in
    Printf.sprintf "(%s × %s)"
      (parenthesize split i k)
      (parenthesize split (k+1) j)

let () =
  (* Example: 4 matrices with dims 30×35, 35×15, 15×5, 5×10 *)
  let dims = [| 30; 35; 15; 5; 10; 20; 25 |] in
  let n    = Array.length dims - 1 in
  Printf.printf "Number of matrices: %d\n" n;
  let (cost, split) = matrix_chain dims in
  Printf.printf "Minimum multiplications: %d\n" cost;
  Printf.printf "Optimal order: %s\n" (parenthesize split 0 (n-1));

  (* Small example *)
  let dims2  = [| 10; 30; 5; 60 |] in
  let (c2, s2) = matrix_chain dims2 in
  Printf.printf "\n3-matrix example (10×30, 30×5, 5×60):\n";
  Printf.printf "Min cost: %d, Order: %s\n" c2 (parenthesize s2 0 2)