๐Ÿฆ€ Functional Rust

976: Matrix Multiply

Difficulty: Intermediate Category: Linear Algebra / Algorithms Concept: Naive O(nยณ) matrix multiplication, cache-friendly transpose variant, and Strassen's 7-multiply 2x2 algorithm Key Insight: OCaml's list-of-lists is elegant but slow (poor cache locality); both languages implement the same triple loop for arrays/Vecs โ€” Rust's `Vec<Vec<f64>>` and OCaml's `float array array` behave identically in terms of memory layout and performance
// 976: Matrix Multiplication
// Naive O(nยณ) and Strassen 2x2 demo
// OCaml: list-of-lists (functional) + arrays; Rust: Vec<Vec<f64>>

// Approach 1: Vec<Vec<f64>> naive multiply
pub fn mat_multiply(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
    let n = a.len();
    let m = b[0].len();
    let k = b.len();
    assert_eq!(a[0].len(), k, "dimension mismatch");

    let mut result = vec![vec![0.0f64; m]; n];
    for i in 0..n {
        for j in 0..m {
            for l in 0..k {
                result[i][j] += a[i][l] * b[l][j];
            }
        }
    }
    result
}

// Transpose a matrix
pub fn transpose(m: &[Vec<f64>]) -> Vec<Vec<f64>> {
    if m.is_empty() { return vec![]; }
    let rows = m.len();
    let cols = m[0].len();
    let mut t = vec![vec![0.0f64; rows]; cols];
    for i in 0..rows {
        for j in 0..cols {
            t[j][i] = m[i][j];
        }
    }
    t
}

// Approach 2: Dot-product style (cache-friendly via transpose)
pub fn mat_multiply_transposed(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
    let n = a.len();
    let m = b[0].len();
    let bt = transpose(b);

    let mut result = vec![vec![0.0f64; m]; n];
    for i in 0..n {
        for j in 0..m {
            result[i][j] = a[i].iter().zip(&bt[j]).map(|(x, y)| x * y).sum();
        }
    }
    result
}

// Approach 3: Strassen 2x2 (demonstrates the 7-multiply algorithm)
// Real Strassen: recursively split into n/2 x n/2 blocks
pub fn strassen_2x2(a: &[[f64; 2]; 2], b: &[[f64; 2]; 2]) -> [[f64; 2]; 2] {
    let (a11, a12, a21, a22) = (a[0][0], a[0][1], a[1][0], a[1][1]);
    let (b11, b12, b21, b22) = (b[0][0], b[0][1], b[1][0], b[1][1]);

    let m1 = (a11 + a22) * (b11 + b22);
    let m2 = (a21 + a22) * b11;
    let m3 = a11 * (b12 - b22);
    let m4 = a22 * (b21 - b11);
    let m5 = (a11 + a12) * b22;
    let m6 = (a21 - a11) * (b11 + b12);
    let m7 = (a12 - a22) * (b21 + b22);

    [
        [m1 + m4 - m5 + m7, m3 + m5],
        [m2 + m4, m1 - m2 + m3 + m6],
    ]
}

fn main() {
    let a = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
    let b = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
    let c = mat_multiply(&a, &b);
    println!("naive: {:?}", c);

    let c2 = mat_multiply_transposed(&a, &b);
    println!("transposed: {:?}", c2);

    let a2 = [[1.0, 2.0], [3.0, 4.0]];
    let b2 = [[5.0, 6.0], [7.0, 8.0]];
    let c3 = strassen_2x2(&a2, &b2);
    println!("strassen: {:?}", c3);
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_2x2_multiply() {
        let a = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
        let b = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
        let c = mat_multiply(&a, &b);
        assert_eq!(c[0][0], 19.0);
        assert_eq!(c[0][1], 22.0);
        assert_eq!(c[1][0], 43.0);
        assert_eq!(c[1][1], 50.0);
    }

    #[test]
    fn test_non_square() {
        let m23 = vec![
            vec![1.0, 2.0, 3.0],
            vec![4.0, 5.0, 6.0],
        ];
        let m32 = vec![
            vec![7.0, 8.0],
            vec![9.0, 10.0],
            vec![11.0, 12.0],
        ];
        let result = mat_multiply(&m23, &m32);
        assert_eq!(result[0][0], 58.0);
        assert_eq!(result[0][1], 64.0);
        assert_eq!(result[1][0], 139.0);
        assert_eq!(result[1][1], 154.0);
    }

    #[test]
    fn test_transposed_matches_naive() {
        let a = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
        let b = vec![
            vec![7.0, 8.0],
            vec![9.0, 10.0],
            vec![11.0, 12.0],
        ];
        let naive = mat_multiply(&a, &b);
        let transposed = mat_multiply_transposed(&a, &b);
        assert_eq!(naive, transposed);
    }

    #[test]
    fn test_strassen_2x2() {
        let a = [[1.0, 2.0], [3.0, 4.0]];
        let b = [[5.0, 6.0], [7.0, 8.0]];
        let c = strassen_2x2(&a, &b);
        assert_eq!(c[0][0], 19.0);
        assert_eq!(c[0][1], 22.0);
        assert_eq!(c[1][0], 43.0);
        assert_eq!(c[1][1], 50.0);
    }

    #[test]
    fn test_identity() {
        let a = vec![vec![3.0, 4.0], vec![5.0, 6.0]];
        let identity = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
        let result = mat_multiply(&a, &identity);
        assert_eq!(result, a);
    }
}
(* 976: Matrix Multiplication *)
(* Naive O(n^3) multiplication. Note: Strassen is O(n^2.807) *)
(* OCaml: list of lists (functional style) and array of arrays *)

(* Approach 1: List-of-lists (functional, educational) *)

let mat_rows m = List.length m
let mat_cols m = match m with [] -> 0 | row :: _ -> List.length row

let dot_product xs ys =
  List.fold_left2 (fun acc x y -> acc +. x *. y) 0.0 xs ys

let transpose_list m =
  let cols = mat_cols m in
  List.init cols (fun c ->
    List.map (fun row -> List.nth row c) m
  )

let multiply_lists a b =
  let bt = transpose_list b in
  List.map (fun row_a ->
    List.map (fun col_b ->
      dot_product row_a col_b
    ) bt
  ) a

(* Approach 2: Array-of-arrays (imperative, practical) *)

let mat_multiply a b =
  let n = Array.length a in
  let m = Array.length b.(0) in
  let k = Array.length b in
  let result = Array.make_matrix n m 0.0 in
  for i = 0 to n - 1 do
    for j = 0 to m - 1 do
      let s = ref 0.0 in
      for l = 0 to k - 1 do
        s := !s +. a.(i).(l) *. b.(l).(j)
      done;
      result.(i).(j) <- !s
    done
  done;
  result

(* Strassen is O(n^2.807): divide 2x2 blocks, use 7 multiplications instead of 8 *)
(* For n x n matrices: split into n/2 x n/2 blocks, apply recursively *)
(* Real implementation requires padding to power-of-2 sizes *)
(* Here we implement for 2x2 as a demonstration *)

let strassen_2x2 a b =
  let a11 = a.(0).(0) and a12 = a.(0).(1) in
  let a21 = a.(1).(0) and a22 = a.(1).(1) in
  let b11 = b.(0).(0) and b12 = b.(0).(1) in
  let b21 = b.(1).(0) and b22 = b.(1).(1) in
  let m1 = (a11 +. a22) *. (b11 +. b22) in
  let m2 = (a21 +. a22) *. b11 in
  let m3 = a11 *. (b12 -. b22) in
  let m4 = a22 *. (b21 -. b11) in
  let m5 = (a11 +. a12) *. b22 in
  let m6 = (a21 -. a11) *. (b11 +. b12) in
  let m7 = (a12 -. a22) *. (b21 +. b22) in
  [| [| m1 +. m4 -. m5 +. m7;  m3 +. m5 |];
     [| m2 +. m4;               m1 -. m2 +. m3 +. m6 |] |]

let () =
  (* List multiplication *)
  let a = [[1.0; 2.0]; [3.0; 4.0]] in
  let b = [[5.0; 6.0]; [7.0; 8.0]] in
  let c = multiply_lists a b in
  assert (c = [[19.0; 22.0]; [43.0; 50.0]]);

  (* Array multiplication *)
  let a_arr = [| [| 1.0; 2.0 |]; [| 3.0; 4.0 |] |] in
  let b_arr = [| [| 5.0; 6.0 |]; [| 7.0; 8.0 |] |] in
  let c_arr = mat_multiply a_arr b_arr in
  assert (c_arr.(0).(0) = 19.0);
  assert (c_arr.(0).(1) = 22.0);
  assert (c_arr.(1).(0) = 43.0);
  assert (c_arr.(1).(1) = 50.0);

  (* Strassen 2x2 *)
  let c_s = strassen_2x2 a_arr b_arr in
  assert (c_s.(0).(0) = 19.0);
  assert (c_s.(0).(1) = 22.0);
  assert (c_s.(1).(0) = 43.0);
  assert (c_s.(1).(1) = 50.0);

  (* Non-square: 2x3 * 3x2 *)
  let m23 = [| [| 1.0; 2.0; 3.0 |]; [| 4.0; 5.0; 6.0 |] |] in
  let m32 = [| [| 7.0; 8.0 |]; [| 9.0; 10.0 |]; [| 11.0; 12.0 |] |] in
  let result = mat_multiply m23 m32 in
  assert (result.(0).(0) = 58.0);
  assert (result.(0).(1) = 64.0);
  assert (result.(1).(0) = 139.0);
  assert (result.(1).(1) = 154.0);

  Printf.printf "โœ“ All tests passed\n"

๐Ÿ“Š Detailed Comparison

Matrix Multiply โ€” Comparison

Core Insight

Matrix multiplication is O(nยณ) naive, O(n^2.807) Strassen. The triple loop `for i, j, k: C[i][j] += A[i][k] * B[k][j]` is identical in both languages. OCaml's functional list-of-lists approach is readable but slow; the array approach matches Rust's performance. Transposing B before multiplication improves cache locality (column access becomes row access).

OCaml Approach

  • `[[1.0;2.0];[3.0;4.0]]` โ€” list of float lists (functional, poor cache)
  • `List.init cols (fun c -> List.map (fun row -> List.nth row c) m)` โ€” transpose
  • `List.fold_left2` for dot product
  • `Array.make_matrix n m 0.0` for imperative approach
  • `for i = 0 to n-1 do ... done` โ€” triple nested imperative loop
  • `!s +. a.(i).(l) *. b.(l).(j)` โ€” float arithmetic with `.` suffix

Rust Approach

  • `Vec<Vec<f64>>` โ€” row-major, similar memory layout to 2D array
  • `vec![vec![0.0f64; m]; n]` โ€” initialize result matrix
  • `result[i][j] += a[i][l] * b[l][j]` โ€” clean triple loop
  • Transpose via double loop (same algorithm, no magic)
  • `.iter().zip(&bt[j]).map(|(x,y)| x*y).sum()` โ€” functional dot product
  • `[[f64; 2]; 2]` for fixed-size Strassen (stack-allocated, no allocation)

Comparison Table

AspectOCamlRust
Functional matrix`float list list``Vec<Vec<f64>>`
Init result`Array.make_matrix n m 0.0``vec![vec![0.0; m]; n]`
Element access`a.(i).(l)``a[i][l]`
Float arithmetic`+.`, `.` (explicit)`+`, `` (same operators)
Dot product`List.fold_left2``.zip().map().sum()`
Transpose`List.init cols (fun c -> List.map ...)`Double loop
Fixed 2x2`[[...]]``[[f64; 2]; 2]` (stack-allocated)
Strassen7 muls, same formula7 muls, same formula