๐Ÿฆ€ Functional Rust

975: Sparse Matrix

Difficulty: Intermediate Category: Data Structures / Linear Algebra Concept: Matrix storing only non-zero elements via a hash map from (row, col) โ†’ value Key Insight: OCaml needs a custom `Hashtbl.Make` functor to hash `(int * int)` tuple keys; Rust's `HashMap<(usize, usize), f64>` works directly because tuples implement `Hash` and `Eq` automatically โ€” one fewer module boundary to cross
// 975: Sparse Matrix
// Only store non-zero elements using HashMap<(usize,usize), f64>
// OCaml uses custom Hashtbl.Make; Rust uses std HashMap with tuple keys

use std::collections::HashMap;

pub struct SparseMatrix {
    rows: usize,
    cols: usize,
    data: HashMap<(usize, usize), f64>,
}

impl SparseMatrix {
    pub fn new(rows: usize, cols: usize) -> Self {
        SparseMatrix {
            rows,
            cols,
            data: HashMap::new(),
        }
    }

    pub fn set(&mut self, r: usize, c: usize, v: f64) {
        assert!(r < self.rows && c < self.cols, "index out of bounds");
        if v == 0.0 {
            self.data.remove(&(r, c));
        } else {
            self.data.insert((r, c), v);
        }
    }

    pub fn get(&self, r: usize, c: usize) -> f64 {
        *self.data.get(&(r, c)).unwrap_or(&0.0)
    }

    /// Number of non-zero elements
    pub fn nnz(&self) -> usize {
        self.data.len()
    }

    pub fn rows(&self) -> usize { self.rows }
    pub fn cols(&self) -> usize { self.cols }

    /// Matrix-vector multiply: result[i] = sum_j mat[i,j] * v[j]
    pub fn matvec(&self, v: &[f64]) -> Vec<f64> {
        assert_eq!(v.len(), self.cols, "vector length mismatch");
        let mut result = vec![0.0f64; self.rows];
        for (&(r, c), &val) in &self.data {
            result[r] += val * v[c];
        }
        result
    }

    /// Transpose: returns new SparseMatrix with rows/cols swapped
    pub fn transpose(&self) -> SparseMatrix {
        let mut t = SparseMatrix::new(self.cols, self.rows);
        for (&(r, c), &v) in &self.data {
            t.data.insert((c, r), v);
        }
        t
    }

    /// Element-wise add: returns new matrix
    pub fn add(&self, other: &SparseMatrix) -> SparseMatrix {
        assert_eq!(self.rows, other.rows);
        assert_eq!(self.cols, other.cols);
        let mut result = SparseMatrix::new(self.rows, self.cols);
        // Copy self
        for (&k, &v) in &self.data {
            result.data.insert(k, v);
        }
        // Add other
        for (&(r, c), &v) in &other.data {
            let entry = result.data.entry((r, c)).or_insert(0.0);
            *entry += v;
            if *entry == 0.0 {
                result.data.remove(&(r, c));
            }
        }
        result
    }

    /// Iterate non-zero entries (sorted for determinism in tests)
    pub fn entries(&self) -> Vec<((usize, usize), f64)> {
        let mut v: Vec<_> = self.data.iter().map(|(&k, &v)| (k, v)).collect();
        v.sort_by_key(|(k, _)| *k);
        v
    }
}


#[cfg(test)]
mod tests {
    use super::*;

    fn make_matrix() -> SparseMatrix {
        let mut m = SparseMatrix::new(4, 4);
        m.set(0, 0, 1.0);
        m.set(0, 2, 2.0);
        m.set(1, 1, 3.0);
        m.set(2, 0, 4.0);
        m.set(2, 3, 5.0);
        m.set(3, 3, 6.0);
        m
    }

    #[test]
    fn test_get_set() {
        let m = make_matrix();
        assert_eq!(m.nnz(), 6);
        assert_eq!(m.get(0, 0), 1.0);
        assert_eq!(m.get(0, 1), 0.0); // zero element
        assert_eq!(m.get(1, 1), 3.0);
    }

    #[test]
    fn test_set_zero_removes() {
        let mut m = make_matrix();
        m.set(1, 1, 0.0);
        assert_eq!(m.nnz(), 5);
        assert_eq!(m.get(1, 1), 0.0);
    }

    #[test]
    fn test_matvec() {
        let mut m = make_matrix();
        m.set(1, 1, 0.0); // remove entry
        let v = vec![1.0, 0.0, 1.0, 0.0];
        let result = m.matvec(&v);
        assert_eq!(result[0], 3.0); // 1*1 + 2*1
        assert_eq!(result[1], 0.0);
        assert_eq!(result[2], 4.0); // 4*1
    }

    #[test]
    fn test_transpose() {
        let m = make_matrix();
        let mt = m.transpose();
        assert_eq!(mt.get(0, 0), 1.0);
        assert_eq!(mt.get(2, 0), 2.0);
        assert_eq!(mt.get(0, 2), 4.0);
        assert_eq!(mt.get(3, 2), 5.0);
        assert_eq!(mt.get(3, 3), 6.0);
        assert_eq!(mt.nnz(), 6);
    }

    #[test]
    fn test_add() {
        let m1 = make_matrix();
        let mut m2 = SparseMatrix::new(4, 4);
        m2.set(0, 0, 1.0);
        m2.set(1, 1, -3.0); // cancels out

        let sum = m1.add(&m2);
        assert_eq!(sum.get(0, 0), 2.0); // 1+1
        assert_eq!(sum.get(1, 1), 0.0); // 3+(-3)=0, removed
        // m1 had 6 entries, m2 adds (0,0) which merges, (1,1) cancels โ†’ 5 non-zero
        assert_eq!(sum.nnz(), 5);
    }
}
(* 975: Sparse Matrix *)
(* Only store non-zero elements using a hash map *)

module IntPair = struct
  type t = int * int
  let equal (r1, c1) (r2, c2) = r1 = r2 && c1 = c2
  let hash (r, c) = Hashtbl.hash (r, c)
end

module PairHash = Hashtbl.Make(IntPair)

type sparse_matrix = {
  rows: int;
  cols: int;
  data: float PairHash.t;
}

let create rows cols =
  { rows; cols; data = PairHash.create 16 }

let set m r c v =
  if r < 0 || r >= m.rows || c < 0 || c >= m.cols then
    failwith "index out of bounds";
  if v = 0.0 then PairHash.remove m.data (r, c)
  else PairHash.replace m.data (r, c) v

let get m r c =
  PairHash.find_opt m.data (r, c) |> Option.value ~default:0.0

let nnz m = PairHash.length m.data

(* Matrix-vector multiply: result[i] = sum_j m[i,j] * v[j] *)
let matvec m v =
  assert (Array.length v = m.cols);
  let result = Array.make m.rows 0.0 in
  PairHash.iter (fun (r, c) value ->
    result.(r) <- result.(r) +. value *. v.(c)
  ) m.data;
  result

(* Transpose *)
let transpose m =
  let t = create m.cols m.rows in
  PairHash.iter (fun (r, c) v ->
    PairHash.replace t.data (c, r) v
  ) m.data;
  t

(* Add two sparse matrices *)
let add m1 m2 =
  assert (m1.rows = m2.rows && m1.cols = m2.cols);
  let result = create m1.rows m1.cols in
  PairHash.iter (fun k v -> PairHash.replace result.data k v) m1.data;
  PairHash.iter (fun (r, c) v ->
    let existing = PairHash.find_opt result.data (r, c) |> Option.value ~default:0.0 in
    let sum = existing +. v in
    if sum = 0.0 then PairHash.remove result.data (r, c)
    else PairHash.replace result.data (r, c) sum
  ) m2.data;
  result

let () =
  let m = create 4 4 in
  set m 0 0 1.0;
  set m 0 2 2.0;
  set m 1 1 3.0;
  set m 2 0 4.0;
  set m 2 3 5.0;
  set m 3 3 6.0;

  assert (nnz m = 6);
  assert (get m 0 0 = 1.0);
  assert (get m 0 1 = 0.0);  (* zero element *)
  assert (get m 1 1 = 3.0);

  (* Setting to zero removes entry *)
  set m 1 1 0.0;
  assert (nnz m = 5);
  assert (get m 1 1 = 0.0);

  (* Matrix-vector multiply *)
  let v = [| 1.0; 0.0; 1.0; 0.0 |] in
  let result = matvec m v in
  assert (result.(0) = 3.0);  (* 1*1 + 2*1 *)
  assert (result.(1) = 0.0);  (* 3 was removed *)
  assert (result.(2) = 4.0);  (* 4*1 *)

  (* Transpose *)
  let mt = transpose m in
  assert (get mt 0 0 = 1.0);
  assert (get mt 2 0 = 2.0);
  assert (get mt 0 2 = 4.0);
  assert (get mt 3 2 = 5.0);
  assert (get mt 3 3 = 6.0);

  Printf.printf "โœ“ All tests passed\n"

๐Ÿ“Š Detailed Comparison

Sparse Matrix โ€” Comparison

Core Insight

A sparse matrix stores only non-zero entries, saving memory when most values are 0. Both languages use a hash map from `(row, col)` pairs to floats. OCaml requires a custom hashtable module (`Hashtbl.Make`) because standard `Hashtbl` needs a custom hash for tuple keys. Rust's `HashMap<(usize, usize), f64>` works out of the box โ€” tuples derive `Hash` automatically.

OCaml Approach

  • `module IntPair = struct type t = int * int; let equal ...; let hash ... end`
  • `module PairHash = Hashtbl.Make(IntPair)` โ€” functor application for typed hashtable
  • `PairHash.find_opt m.data (r,c) |> Option.value ~default:0.0`
  • `PairHash.remove` when setting to 0 (keep sparsity invariant)
  • `PairHash.iter` for matvec and transpose iteration
  • Floats compared with `= 0.0` (works for exact zero)

Rust Approach

  • `HashMap<(usize, usize), f64>` โ€” tuple key, hash derived automatically
  • `.unwrap_or(&0.0)` for zero default
  • `.remove(&(r, c))` when setting to 0.0
  • `for (&(r, c), &val) in &self.data` โ€” destructuring in for loop
  • `.entry((r,c)).or_insert(0.0)` for accumulate-or-init pattern
  • Same float-zero comparison: `v == 0.0`

Comparison Table

AspectOCamlRust
Tuple key hash`Hashtbl.Make(IntPair)` functor`HashMap<(usize,usize), f64>` (auto-Hash)
Default zero`Option.value ~default:0.0``.unwrap_or(&0.0)`
Remove zero`PairHash.remove m.data key``data.remove(&key)`
Iteration`PairHash.iter (fun (r,c) v -> ...)``for (&(r,c), &v) in &data`
Accumulate`existing +. v; replace``.entry(k).or_insert(0.0)` then `*e += v`
nnz`PairHash.length``data.len()`
Index check`failwith``assert!`