// 975: Sparse Matrix
// Only store non-zero elements using HashMap<(usize,usize), f64>
// OCaml uses custom Hashtbl.Make; Rust uses std HashMap with tuple keys
use std::collections::HashMap;
pub struct SparseMatrix {
rows: usize,
cols: usize,
data: HashMap<(usize, usize), f64>,
}
impl SparseMatrix {
pub fn new(rows: usize, cols: usize) -> Self {
SparseMatrix {
rows,
cols,
data: HashMap::new(),
}
}
pub fn set(&mut self, r: usize, c: usize, v: f64) {
assert!(r < self.rows && c < self.cols, "index out of bounds");
if v == 0.0 {
self.data.remove(&(r, c));
} else {
self.data.insert((r, c), v);
}
}
pub fn get(&self, r: usize, c: usize) -> f64 {
*self.data.get(&(r, c)).unwrap_or(&0.0)
}
/// Number of non-zero elements
pub fn nnz(&self) -> usize {
self.data.len()
}
pub fn rows(&self) -> usize { self.rows }
pub fn cols(&self) -> usize { self.cols }
/// Matrix-vector multiply: result[i] = sum_j mat[i,j] * v[j]
pub fn matvec(&self, v: &[f64]) -> Vec<f64> {
assert_eq!(v.len(), self.cols, "vector length mismatch");
let mut result = vec![0.0f64; self.rows];
for (&(r, c), &val) in &self.data {
result[r] += val * v[c];
}
result
}
/// Transpose: returns new SparseMatrix with rows/cols swapped
pub fn transpose(&self) -> SparseMatrix {
let mut t = SparseMatrix::new(self.cols, self.rows);
for (&(r, c), &v) in &self.data {
t.data.insert((c, r), v);
}
t
}
/// Element-wise add: returns new matrix
pub fn add(&self, other: &SparseMatrix) -> SparseMatrix {
assert_eq!(self.rows, other.rows);
assert_eq!(self.cols, other.cols);
let mut result = SparseMatrix::new(self.rows, self.cols);
// Copy self
for (&k, &v) in &self.data {
result.data.insert(k, v);
}
// Add other
for (&(r, c), &v) in &other.data {
let entry = result.data.entry((r, c)).or_insert(0.0);
*entry += v;
if *entry == 0.0 {
result.data.remove(&(r, c));
}
}
result
}
/// Iterate non-zero entries (sorted for determinism in tests)
pub fn entries(&self) -> Vec<((usize, usize), f64)> {
let mut v: Vec<_> = self.data.iter().map(|(&k, &v)| (k, v)).collect();
v.sort_by_key(|(k, _)| *k);
v
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_matrix() -> SparseMatrix {
let mut m = SparseMatrix::new(4, 4);
m.set(0, 0, 1.0);
m.set(0, 2, 2.0);
m.set(1, 1, 3.0);
m.set(2, 0, 4.0);
m.set(2, 3, 5.0);
m.set(3, 3, 6.0);
m
}
#[test]
fn test_get_set() {
let m = make_matrix();
assert_eq!(m.nnz(), 6);
assert_eq!(m.get(0, 0), 1.0);
assert_eq!(m.get(0, 1), 0.0); // zero element
assert_eq!(m.get(1, 1), 3.0);
}
#[test]
fn test_set_zero_removes() {
let mut m = make_matrix();
m.set(1, 1, 0.0);
assert_eq!(m.nnz(), 5);
assert_eq!(m.get(1, 1), 0.0);
}
#[test]
fn test_matvec() {
let mut m = make_matrix();
m.set(1, 1, 0.0); // remove entry
let v = vec![1.0, 0.0, 1.0, 0.0];
let result = m.matvec(&v);
assert_eq!(result[0], 3.0); // 1*1 + 2*1
assert_eq!(result[1], 0.0);
assert_eq!(result[2], 4.0); // 4*1
}
#[test]
fn test_transpose() {
let m = make_matrix();
let mt = m.transpose();
assert_eq!(mt.get(0, 0), 1.0);
assert_eq!(mt.get(2, 0), 2.0);
assert_eq!(mt.get(0, 2), 4.0);
assert_eq!(mt.get(3, 2), 5.0);
assert_eq!(mt.get(3, 3), 6.0);
assert_eq!(mt.nnz(), 6);
}
#[test]
fn test_add() {
let m1 = make_matrix();
let mut m2 = SparseMatrix::new(4, 4);
m2.set(0, 0, 1.0);
m2.set(1, 1, -3.0); // cancels out
let sum = m1.add(&m2);
assert_eq!(sum.get(0, 0), 2.0); // 1+1
assert_eq!(sum.get(1, 1), 0.0); // 3+(-3)=0, removed
// m1 had 6 entries, m2 adds (0,0) which merges, (1,1) cancels โ 5 non-zero
assert_eq!(sum.nnz(), 5);
}
}
(* 975: Sparse Matrix *)
(* Only store non-zero elements using a hash map *)
module IntPair = struct
type t = int * int
let equal (r1, c1) (r2, c2) = r1 = r2 && c1 = c2
let hash (r, c) = Hashtbl.hash (r, c)
end
module PairHash = Hashtbl.Make(IntPair)
type sparse_matrix = {
rows: int;
cols: int;
data: float PairHash.t;
}
let create rows cols =
{ rows; cols; data = PairHash.create 16 }
let set m r c v =
if r < 0 || r >= m.rows || c < 0 || c >= m.cols then
failwith "index out of bounds";
if v = 0.0 then PairHash.remove m.data (r, c)
else PairHash.replace m.data (r, c) v
let get m r c =
PairHash.find_opt m.data (r, c) |> Option.value ~default:0.0
let nnz m = PairHash.length m.data
(* Matrix-vector multiply: result[i] = sum_j m[i,j] * v[j] *)
let matvec m v =
assert (Array.length v = m.cols);
let result = Array.make m.rows 0.0 in
PairHash.iter (fun (r, c) value ->
result.(r) <- result.(r) +. value *. v.(c)
) m.data;
result
(* Transpose *)
let transpose m =
let t = create m.cols m.rows in
PairHash.iter (fun (r, c) v ->
PairHash.replace t.data (c, r) v
) m.data;
t
(* Add two sparse matrices *)
let add m1 m2 =
assert (m1.rows = m2.rows && m1.cols = m2.cols);
let result = create m1.rows m1.cols in
PairHash.iter (fun k v -> PairHash.replace result.data k v) m1.data;
PairHash.iter (fun (r, c) v ->
let existing = PairHash.find_opt result.data (r, c) |> Option.value ~default:0.0 in
let sum = existing +. v in
if sum = 0.0 then PairHash.remove result.data (r, c)
else PairHash.replace result.data (r, c) sum
) m2.data;
result
let () =
let m = create 4 4 in
set m 0 0 1.0;
set m 0 2 2.0;
set m 1 1 3.0;
set m 2 0 4.0;
set m 2 3 5.0;
set m 3 3 6.0;
assert (nnz m = 6);
assert (get m 0 0 = 1.0);
assert (get m 0 1 = 0.0); (* zero element *)
assert (get m 1 1 = 3.0);
(* Setting to zero removes entry *)
set m 1 1 0.0;
assert (nnz m = 5);
assert (get m 1 1 = 0.0);
(* Matrix-vector multiply *)
let v = [| 1.0; 0.0; 1.0; 0.0 |] in
let result = matvec m v in
assert (result.(0) = 3.0); (* 1*1 + 2*1 *)
assert (result.(1) = 0.0); (* 3 was removed *)
assert (result.(2) = 4.0); (* 4*1 *)
(* Transpose *)
let mt = transpose m in
assert (get mt 0 0 = 1.0);
assert (get mt 2 0 = 2.0);
assert (get mt 0 2 = 4.0);
assert (get mt 3 2 = 5.0);
assert (get mt 3 3 = 6.0);
Printf.printf "โ All tests passed\n"