๐Ÿฆ€ Functional Rust

1057: Matrix Chain Multiplication

Difficulty: Advanced Category: Dynamic Programming Concept: Find optimal parenthesization of matrix chain to minimize scalar multiplications Key Insight: This is interval DP โ€” `dp[i][j]` represents the minimum cost to multiply matrices i through j, trying every possible split point k between them.
// 1057: Matrix Chain Multiplication โ€” Optimal Parenthesization

use std::collections::HashMap;

// Approach 1: Bottom-up DP
fn matrix_chain_dp(dims: &[usize]) -> usize {
    let n = dims.len() - 1;
    let mut dp = vec![vec![0usize; n]; n];
    for l in 2..=n {
        for i in 0..=(n - l) {
            let j = i + l - 1;
            dp[i][j] = usize::MAX;
            for k in i..j {
                let cost = dp[i][k] + dp[k + 1][j] + dims[i] * dims[k + 1] * dims[j + 1];
                dp[i][j] = dp[i][j].min(cost);
            }
        }
    }
    dp[0][n - 1]
}

// Approach 2: With parenthesization tracking
fn matrix_chain_parens(dims: &[usize]) -> (usize, String) {
    let n = dims.len() - 1;
    let mut dp = vec![vec![0usize; n]; n];
    let mut split = vec![vec![0usize; n]; n];
    for l in 2..=n {
        for i in 0..=(n - l) {
            let j = i + l - 1;
            dp[i][j] = usize::MAX;
            for k in i..j {
                let cost = dp[i][k] + dp[k + 1][j] + dims[i] * dims[k + 1] * dims[j + 1];
                if cost < dp[i][j] {
                    dp[i][j] = cost;
                    split[i][j] = k;
                }
            }
        }
    }
    fn build(i: usize, j: usize, split: &[Vec<usize>]) -> String {
        if i == j {
            format!("A{}", i + 1)
        } else {
            format!("({}*{})", build(i, split[i][j], split), build(split[i][j] + 1, j, split))
        }
    }
    (dp[0][n - 1], build(0, n - 1, &split))
}

// Approach 3: Recursive with memoization
fn matrix_chain_memo(dims: &[usize]) -> usize {
    fn solve(i: usize, j: usize, dims: &[usize], cache: &mut HashMap<(usize, usize), usize>) -> usize {
        if i == j { return 0; }
        if let Some(&v) = cache.get(&(i, j)) { return v; }
        let mut best = usize::MAX;
        for k in i..j {
            let cost = solve(i, k, dims, cache) + solve(k + 1, j, dims, cache)
                + dims[i] * dims[k + 1] * dims[j + 1];
            best = best.min(cost);
        }
        cache.insert((i, j), best);
        best
    }
    let mut cache = HashMap::new();
    solve(0, dims.len() - 2, dims, &mut cache)
}

fn main() {
    let dims = [30, 35, 15, 5, 10, 20, 25];
    let (cost, parens) = matrix_chain_parens(&dims);
    println!("Min cost: {}, Parenthesization: {}", cost, parens);
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_matrix_chain_dp() {
        assert_eq!(matrix_chain_dp(&[30, 35, 15, 5, 10, 20, 25]), 15125);
        assert_eq!(matrix_chain_dp(&[10, 20, 30, 40]), 18000);
    }

    #[test]
    fn test_matrix_chain_parens() {
        let (cost, parens) = matrix_chain_parens(&[30, 35, 15, 5, 10, 20, 25]);
        assert_eq!(cost, 15125);
        assert!(!parens.is_empty());
    }

    #[test]
    fn test_matrix_chain_memo() {
        assert_eq!(matrix_chain_memo(&[30, 35, 15, 5, 10, 20, 25]), 15125);
        assert_eq!(matrix_chain_memo(&[10, 20, 30, 40]), 18000);
    }
}
(* 1057: Matrix Chain Multiplication โ€” Optimal Parenthesization *)

(* Approach 1: Bottom-up DP *)
let matrix_chain_dp dims =
  let n = Array.length dims - 1 in
  let dp = Array.init n (fun _ -> Array.make n 0) in
  (* chain length l = 2..n *)
  for l = 2 to n do
    for i = 0 to n - l do
      let j = i + l - 1 in
      dp.(i).(j) <- max_int;
      for k = i to j - 1 do
        let cost = dp.(i).(k) + dp.(k + 1).(j)
                   + dims.(i) * dims.(k + 1) * dims.(j + 1) in
        if cost < dp.(i).(j) then
          dp.(i).(j) <- cost
      done
    done
  done;
  dp.(0).(n - 1)

(* Approach 2: With parenthesization tracking *)
let matrix_chain_parens dims =
  let n = Array.length dims - 1 in
  let dp = Array.init n (fun _ -> Array.make n 0) in
  let split = Array.init n (fun _ -> Array.make n 0) in
  for l = 2 to n do
    for i = 0 to n - l do
      let j = i + l - 1 in
      dp.(i).(j) <- max_int;
      for k = i to j - 1 do
        let cost = dp.(i).(k) + dp.(k + 1).(j)
                   + dims.(i) * dims.(k + 1) * dims.(j + 1) in
        if cost < dp.(i).(j) then begin
          dp.(i).(j) <- cost;
          split.(i).(j) <- k
        end
      done
    done
  done;
  let buf = Buffer.create 32 in
  let rec build i j =
    if i = j then
      Buffer.add_string buf (Printf.sprintf "A%d" (i + 1))
    else begin
      Buffer.add_char buf '(';
      build i split.(i).(j);
      Buffer.add_char buf '*';
      build (split.(i).(j) + 1) j;
      Buffer.add_char buf ')'
    end
  in
  build 0 (n - 1);
  (dp.(0).(n - 1), Buffer.contents buf)

(* Approach 3: Recursive with memoization *)
let matrix_chain_memo dims =
  let n = Array.length dims - 1 in
  let cache = Hashtbl.create 64 in
  let rec solve i j =
    if i = j then 0
    else
      match Hashtbl.find_opt cache (i, j) with
      | Some v -> v
      | None ->
        let best = ref max_int in
        for k = i to j - 1 do
          let cost = solve i k + solve (k + 1) j
                     + dims.(i) * dims.(k + 1) * dims.(j + 1) in
          if cost < !best then best := cost
        done;
        Hashtbl.add cache (i, j) !best;
        !best
  in
  solve 0 (n - 1)

let () =
  (* dims: A1=30x35, A2=35x15, A3=15x5, A4=5x10, A5=10x20, A6=20x25 *)
  let dims = [|30; 35; 15; 5; 10; 20; 25|] in
  assert (matrix_chain_dp dims = 15125);
  assert (matrix_chain_memo dims = 15125);
  let (cost, parens) = matrix_chain_parens dims in
  assert (cost = 15125);
  assert (String.length parens > 0);

  let dims2 = [|10; 20; 30; 40|] in
  assert (matrix_chain_dp dims2 = 18000);

  Printf.printf "โœ“ All tests passed\n"

๐Ÿ“Š Detailed Comparison

Matrix Chain Multiplication โ€” Comparison

Core Insight

Matrix chain multiplication is the canonical interval DP problem. The key is trying every split point k in range [i, j) and taking the minimum total cost. A separate `split` table enables reconstructing the optimal parenthesization.

OCaml Approach

  • `Buffer` for building parenthesization string recursively
  • `Printf.sprintf` for formatting matrix names
  • `max_int` as initial sentinel
  • `ref` cells for tracking best in inner loop

Rust Approach

  • `format!` macro for string building in recursive parenthesization
  • `usize::MAX` as sentinel
  • Nested function for recursive string building
  • `HashMap` with tuple keys for memoization

Comparison Table

AspectOCamlRust
String building`Buffer` + `Printf.sprintf``format!()` macro
Infinity sentinel`max_int``usize::MAX`
2D table init`Array.init n (fun _ -> Array.make n 0)``vec![vec![0; n]; n]`
Split trackingParallel `split` arrayParallel `split` vec
RecursionNatural OCaml recursionInner `fn` with explicit params