ExamplesBy LevelBy TopicLearning Paths
1057 Advanced

1057-matrix-chain — Matrix Chain Multiplication

Functional Programming

Tutorial

The Problem

Multiplying a sequence of matrices is associative: (AB)C = A(BC), but the computational cost varies dramatically with parenthesization. Multiplying a 10×30 matrix by a 30×5 matrix by a 5×60 matrix: (AB)C costs 10×30×5 + 10×5×60 = 4,500 + 3,000 = 7,500 operations; A(BC) costs 30×5×60 + 10×30×60 = 9,000 + 18,000 = 27,000. The optimal ordering can be 10–100× faster for large chains.

Matrix chain ordering is a classic interval DP problem and a fundamental optimization in scientific computing, neural network inference, and linear algebra libraries.

🎯 Learning Outcomes

  • • Implement matrix chain DP with dp[i][j] = minimum cost for matrices i..j
  • • Understand interval DP: fill by increasing chain length
  • • Recover the optimal parenthesization using a split table
  • • Recognize that matrix multiplication associativity enables optimization
  • • Connect to BLAS/LAPACK and deep learning frameworks that optimize compute graphs
  • Code Example

    //! 1057: Matrix Chain Multiplication — optimal parenthesization via DP.
    
    pub type Dimension = (usize, usize);
    
    /// Minimum number of scalar multiplications to compute the product of a
    /// chain of matrices with the given `(rows, cols)` dimensions.
    pub fn matrix_chain(dims: &[Dimension]) -> usize {
        let n = dims.len();
        if n < 2 {
            return 0;
        }
        let mut m = vec![vec![0usize; n]; n];
        for len in 2..=n {
            for i in 0..=n - len {
                let j = i + len - 1;
                m[i][j] = usize::MAX;
                for k in i..j {
                    let cost = m[i][k] + m[k + 1][j] + dims[i].0 * dims[k].1 * dims[j].1;
                    if cost < m[i][j] {
                        m[i][j] = cost;
                    }
                }
            }
        }
        m[0][n - 1]
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn example_from_source() {
            let dims = [(10, 100), (100, 5), (5, 50)];
            assert_eq!(matrix_chain(&dims), 7500);
        }
    
        #[test]
        fn empty_chain() {
            assert_eq!(matrix_chain(&[]), 0);
        }
    
        #[test]
        fn single_matrix() {
            assert_eq!(matrix_chain(&[(10, 20)]), 0);
        }
    
        #[test]
        fn classic_six_matrices() {
            let dims = [(30, 35), (35, 15), (15, 5), (5, 10), (10, 20), (20, 25)];
            assert_eq!(matrix_chain(&dims), 15125);
        }
    
        #[test]
        fn three_matrices() {
            let dims = [(10, 20), (20, 30), (30, 40)];
            assert_eq!(matrix_chain(&dims), 18000);
        }
    }

    Key Differences

  • **usize::MAX vs max_int**: Rust uses usize::MAX as infinity; OCaml uses max_int. Both risk overflow on addition — use careful comparison before adding.
  • Interval filling order: Both fill by increasing length l — the outer loop determines chain length before the inner loop tries split points.
  • Reconstruction: Both build a separate split table during DP and recursively decode it to produce a parenthesization string.
  • Applications: OCaml ML libraries use this optimization in tensor expression evaluation; Rust ML frameworks like candle apply similar optimizations.
  • OCaml Approach

    let matrix_chain dims =
      let n = Array.length dims - 1 in
      let dp = Array.make_matrix n n 0 in
      for l = 2 to n do
        for i = 0 to n - l do
          let j = i + l - 1 in
          dp.(i).(j) <- max_int;
          for k = i to j - 1 do
            let cost = dp.(i).(k) + dp.(k+1).(j) + dims.(i) * dims.(k+1) * dims.(j+1) in
            if cost < dp.(i).(j) then dp.(i).(j) <- cost
          done
        done
      done;
      dp.(0).(n-1)
    

    The algorithm is identical. Interval DP is a mathematical technique with a canonical implementation structure.

    Full Source

    //! 1057: Matrix Chain Multiplication — optimal parenthesization via DP.
    
    pub type Dimension = (usize, usize);
    
    /// Minimum number of scalar multiplications to compute the product of a
    /// chain of matrices with the given `(rows, cols)` dimensions.
    pub fn matrix_chain(dims: &[Dimension]) -> usize {
        let n = dims.len();
        if n < 2 {
            return 0;
        }
        let mut m = vec![vec![0usize; n]; n];
        for len in 2..=n {
            for i in 0..=n - len {
                let j = i + len - 1;
                m[i][j] = usize::MAX;
                for k in i..j {
                    let cost = m[i][k] + m[k + 1][j] + dims[i].0 * dims[k].1 * dims[j].1;
                    if cost < m[i][j] {
                        m[i][j] = cost;
                    }
                }
            }
        }
        m[0][n - 1]
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn example_from_source() {
            let dims = [(10, 100), (100, 5), (5, 50)];
            assert_eq!(matrix_chain(&dims), 7500);
        }
    
        #[test]
        fn empty_chain() {
            assert_eq!(matrix_chain(&[]), 0);
        }
    
        #[test]
        fn single_matrix() {
            assert_eq!(matrix_chain(&[(10, 20)]), 0);
        }
    
        #[test]
        fn classic_six_matrices() {
            let dims = [(30, 35), (35, 15), (15, 5), (5, 10), (10, 20), (20, 25)];
            assert_eq!(matrix_chain(&dims), 15125);
        }
    
        #[test]
        fn three_matrices() {
            let dims = [(10, 20), (20, 30), (30, 40)];
            assert_eq!(matrix_chain(&dims), 18000);
        }
    }
    ✓ Tests Rust test suite
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn example_from_source() {
            let dims = [(10, 100), (100, 5), (5, 50)];
            assert_eq!(matrix_chain(&dims), 7500);
        }
    
        #[test]
        fn empty_chain() {
            assert_eq!(matrix_chain(&[]), 0);
        }
    
        #[test]
        fn single_matrix() {
            assert_eq!(matrix_chain(&[(10, 20)]), 0);
        }
    
        #[test]
        fn classic_six_matrices() {
            let dims = [(30, 35), (35, 15), (15, 5), (5, 10), (10, 20), (20, 25)];
            assert_eq!(matrix_chain(&dims), 15125);
        }
    
        #[test]
        fn three_matrices() {
            let dims = [(10, 20), (20, 30), (30, 40)];
            assert_eq!(matrix_chain(&dims), 18000);
        }
    }

    Deep Comparison

    Matrix Chain Multiplication — Comparison

    Core Insight

    Matrix chain multiplication is the canonical interval DP problem. The key is trying every split point k in range [i, j) and taking the minimum total cost. A separate split table enables reconstructing the optimal parenthesization.

    OCaml Approach

  • Buffer for building parenthesization string recursively
  • Printf.sprintf for formatting matrix names
  • max_int as initial sentinel
  • ref cells for tracking best in inner loop
  • Rust Approach

  • format! macro for string building in recursive parenthesization
  • usize::MAX as sentinel
  • • Nested function for recursive string building
  • HashMap with tuple keys for memoization
  • Comparison Table

    AspectOCamlRust
    String buildingBuffer + Printf.sprintfformat!() macro
    Infinity sentinelmax_intusize::MAX
    2D table initArray.init n (fun _ -> Array.make n 0)vec![vec![0; n]; n]
    Split trackingParallel split arrayParallel split vec
    RecursionNatural OCaml recursionInner fn with explicit params

    Exercises

  • Add memoized top-down implementation and verify it produces the same answer as the bottom-up version.
  • Implement the reconstruction function format_chain(split: &Vec<Vec<usize>>, i: usize, j: usize, names: &[&str]) -> String that produces a parenthesized expression like "((A×B)×(C×D))".
  • Extend to weighted matrix chain where some multiplications have additional overhead (e.g., GPU memory transfer costs).
  • Open Source Repos