πŸ¦€ Functional Rust

720: Cache-Friendly Iteration and Data Access Patterns

Difficulty: 3 Level: Expert Write iteration patterns the CPU's prefetcher can follow β€” sequential, predictable, and branch-minimal.

The Problem This Solves

Modern CPUs execute instructions in ~1 ns, but fetching from main memory takes 60–200 ns. The L1/L2/L3 caches bridge this gap β€” but only if your access pattern gives the hardware prefetcher a chance to load data before you need it. Unpredictable access patterns (random indices, pointer chasing through linked structures, column-major access on row-major matrices) defeat the prefetcher and cause the CPU to stall waiting for memory. The problem is invisible in code review. A 1024Γ—1024 matrix sum looks identical whether you iterate row-major or column-major, but the performance difference is 10–50Γ— on real hardware. Rust's zero-cost iterator model doesn't fix cache behaviour β€” sequential iterators are fast because they're sequential, not because they're idiomatic. Understanding why row-major access is fast (and column-major is slow) is essential for any performance-critical Rust code.

The Intuition

A CPU cache line is 64 bytes. When you load one element, the hardware fetches the surrounding 64 bytes speculatively, betting you'll want the neighbours next. This bet pays off when you iterate sequentially β€” the next 15 `f32`s are already in cache. It fails catastrophically when you jump by `cols * sizeof(f32)` between accesses β€” each jump lands in a cold cache line, costing a full memory round-trip. Rule of thumb: access memory in the order it was laid out. For a row-major matrix (`data[r * cols + c]`), iterate row by row, not column by column. For a struct, process the same field across all instances (SoA layout) rather than all fields of one instance at a time (AoS layout).

How It Works in Rust

pub struct Matrix {
 data: Vec<f32>,  // row-major: [row0col0, row0col1, ..., row1col0, ...]
 rows: usize, cols: usize,
}

// ── Fast: sequential access β€” prefetcher loads ahead ─────────────────────
pub fn sum_row_major(m: &Matrix) -> f64 {
 // Single linear scan β€” cache line used 100%
 m.data.iter().map(|&v| v as f64).sum()
}

// ── Slow: stride of `cols` between accesses β€” cold cache lines ───────────
pub fn sum_col_major(m: &Matrix) -> f64 {
 let mut acc = 0.0f64;
 for c in 0..m.cols {
     for r in 0..m.rows {
         // Each access jumps m.cols * 4 bytes β€” cache miss per element
         acc += m.get(r, c) as f64;
     }
 }
 acc
}

// ── Tiled transpose: access both matrices in cache-friendly blocks ────────
pub fn transpose_tiled(src: &Matrix, dst: &mut Matrix, tile: usize) {
 for row_tile in (0..src.rows).step_by(tile) {
     for col_tile in (0..src.cols).step_by(tile) {
         for r in row_tile..(row_tile + tile).min(src.rows) {
             for c in col_tile..(col_tile + tile).min(src.cols) {
                 // Both reads and writes stay within a tile that fits in L1 cache
                 dst.set(c, r, src.get(r, c));
             }
         }
     }
 }
}
Tiling works by choosing a block size (e.g., 64 elements) that fits in L1 cache. You complete all work on that block before moving to the next, so every cache line loaded is used fully before eviction.

What This Unlocks

Key Differences

ConceptOCamlRust
Memory layout controlGC controls object placement`Vec<T>` is always contiguous; layout fully predictable
Row-major iteration`Array` row-by-row (default)Iterator over flat slice β€” optimal by default
Column-major penaltySame (GC can't fix it)Same β€” it's a hardware limitation, not a language one
Tiled algorithmsImplement manuallyImplement manually; `std::slice::chunks` helps
SIMD opportunityLimitedAuto-vectorised by LLVM when loop body is simple
// 720. Cache-friendly iteration and data access patterns
//
// Demonstrates row-major vs column-major access, tiled transposition,
// and how iterator chaining produces cache-sequential, auto-vectorisable code.

use std::time::Instant;

// ── Flat row-major matrix ─────────────────────────────────────────────────────

pub struct Matrix {
    data: Vec<f32>,
    rows: usize,
    cols: usize,
}

impl Matrix {
    pub fn new(rows: usize, cols: usize, init: f32) -> Self {
        Self { data: vec![init; rows * cols], rows, cols }
    }

    pub fn from_fn(rows: usize, cols: usize, f: impl Fn(usize, usize) -> f32) -> Self {
        let mut data = Vec::with_capacity(rows * cols);
        for r in 0..rows {
            for c in 0..cols {
                data.push(f(r, c));
            }
        }
        Self { data, rows, cols }
    }

    #[inline(always)]
    pub fn get(&self, r: usize, c: usize) -> f32 {
        self.data[r * self.cols + c]
    }

    #[inline(always)]
    pub fn set(&mut self, r: usize, c: usize, v: f32) {
        self.data[r * self.cols + c] = v;
    }

    pub fn as_slice(&self) -> &[f32] { &self.data }
}

// ── Access-pattern benchmarks ─────────────────────────────────────────────────

/// Row-major sum: sequential access β€” every element is adjacent in memory.
/// The CPU's hardware prefetcher predicts and loads the next cache line.
pub fn sum_row_major(m: &Matrix) -> f64 {
    m.data.iter().map(|&v| v as f64).sum()
}

/// Column-major sum: stride of `cols` f32s between accesses.
/// For a 1024-column matrix, each access is 4096 bytes apart β€” thrashes cache.
pub fn sum_col_major(m: &Matrix) -> f64 {
    let mut acc = 0.0f64;
    for c in 0..m.cols {
        for r in 0..m.rows {
            acc += m.get(r, c) as f64;
        }
    }
    acc
}

// ── Transposition ─────────────────────────────────────────────────────────────

/// Naive O(nΒ²) transpose: reads row-major from `src`, writes column-major into `dst`.
/// The write side causes cache thrashing for large N.
pub fn transpose_naive(src: &Matrix) -> Matrix {
    let mut dst = Matrix::new(src.cols, src.rows, 0.0);
    for r in 0..src.rows {
        for c in 0..src.cols {
            dst.set(c, r, src.get(r, c));
        }
    }
    dst
}

/// Tiled transpose: process TILEΓ—TILE blocks.
/// Both reads and writes stay within a working set that fits in L1/L2 cache.
pub fn transpose_tiled<const TILE: usize>(src: &Matrix) -> Matrix {
    let mut dst = Matrix::new(src.cols, src.rows, 0.0);
    for r_tile in (0..src.rows).step_by(TILE) {
        for c_tile in (0..src.cols).step_by(TILE) {
            let r_end = (r_tile + TILE).min(src.rows);
            let c_end = (c_tile + TILE).min(src.cols);
            for r in r_tile..r_end {
                for c in c_tile..c_end {
                    dst.set(c, r, src.get(r, c));
                }
            }
        }
    }
    dst
}

// ── Iterator patterns that compile to cache-friendly loops ────────────────────

/// Dot product of two slices: sequential load of both β€” maximises bandwidth.
pub fn dot(a: &[f32], b: &[f32]) -> f32 {
    a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}

/// Prefix sum (exclusive scan) β€” sequential write, auto-vectorisable.
pub fn prefix_sum(data: &[f32]) -> Vec<f32> {
    let mut out = Vec::with_capacity(data.len() + 1);
    out.push(0.0);
    let mut acc = 0.0f32;
    for &v in data {
        acc += v;
        out.push(acc);
    }
    out
}

/// Gather: non-sequential read (indices can be random) β€” cache unfriendly.
pub fn gather(data: &[f32], indices: &[usize]) -> Vec<f32> {
    indices.iter().map(|&i| data[i]).collect()
}

// ── main ──────────────────────────────────────────────────────────────────────

fn benchmark<F: Fn() -> f64>(label: &str, iters: u32, f: F) {
    let t0 = Instant::now();
    let mut sink = 0.0f64;
    for _ in 0..iters { sink += f(); }
    let elapsed = t0.elapsed();
    println!("  {label}: {:?}  (result={:.0})", elapsed, sink / iters as f64);
}

fn main() {
    const N: usize = 1024;
    const ITERS: u32 = 3;

    let m = Matrix::from_fn(N, N, |r, c| (r * N + c) as f32);

    println!("=== Sum ({N}Γ—{N} matrix, {ITERS} iterations) ===");
    benchmark("row-major", ITERS, || sum_row_major(&m));
    benchmark("col-major", ITERS, || sum_col_major(&m));

    println!("\n=== Transpose ({N}Γ—{N}) ===");
    let t0 = Instant::now();
    for _ in 0..ITERS { let _ = transpose_naive(&m); }
    println!("  naive:  {:?}", t0.elapsed());

    let t0 = Instant::now();
    for _ in 0..ITERS { let _ = transpose_tiled::<32>(&m); }
    println!("  tiled32:{:?}", t0.elapsed());

    // Verify transpose correctness
    let t = transpose_tiled::<32>(&m);
    assert_eq!(t.get(0, 0), m.get(0, 0));
    assert_eq!(t.get(1, 0), m.get(0, 1));
    println!("\n  Transpose verified βœ“");

    println!("\n=== Iterator patterns ===");
    let a: Vec<f32> = (0..1_000_000).map(|i| i as f32).collect();
    let b: Vec<f32> = a.iter().map(|&v| v * 0.5).collect();
    let t0 = Instant::now();
    let d = dot(&a, &b);
    println!("  dot(1M): {:?}  result={d:.2e}", t0.elapsed());

    let small: Vec<f32> = (0..10).map(|i| i as f32).collect();
    println!("  prefix_sum([0..9]): {:?}", prefix_sum(&small));
}

// ── Tests ─────────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    use super::*;

    fn small_matrix() -> Matrix {
        Matrix::from_fn(4, 4, |r, c| (r * 4 + c) as f32)
    }

    #[test]
    fn row_col_sum_equal() {
        let m = small_matrix();
        let row = sum_row_major(&m);
        let col = sum_col_major(&m);
        assert!((row - col).abs() < 1e-3, "{row} != {col}");
    }

    #[test]
    fn transpose_correctness() {
        let m = small_matrix();
        let t_naive = transpose_naive(&m);
        let t_tiled = transpose_tiled::<2>(&m);
        for r in 0..4 {
            for c in 0..4 {
                assert_eq!(t_naive.get(r, c), t_tiled.get(r, c));
                assert_eq!(t_naive.get(r, c), m.get(c, r));
            }
        }
    }

    #[test]
    fn dot_product() {
        let a = [1.0f32, 2.0, 3.0];
        let b = [4.0f32, 5.0, 6.0];
        assert_eq!(dot(&a, &b), 32.0); // 4+10+18
    }

    #[test]
    fn prefix_sum_correct() {
        let ps = prefix_sum(&[1.0, 2.0, 3.0, 4.0]);
        assert_eq!(ps, vec![0.0, 1.0, 3.0, 6.0, 10.0]);
    }
}
(* OCaml: Cache-friendly iteration patterns *)

(* OCaml's `float array` is unboxed (flat), but `float array array`
   is an array of pointers to individual row arrays β€” two hops per access. *)

(* --- Flat 2D matrix (row-major, cache-friendly) --- *)

type matrix = { data: float array; rows: int; cols: int }

let make_matrix rows cols init =
  { data = Array.init (rows * cols) init; rows; cols }

let get m r c = m.data.(r * m.cols + c)
let set m r c v = m.data.(r * m.cols + c) <- v

(* Row-major sum: sequential access β€” cache friendly *)
let sum_row_major m =
  Array.fold_left (+.) 0.0 m.data

(* Column-major sum: stride = cols elements between accesses β€” cache unfriendly *)
let sum_col_major m =
  let acc = ref 0.0 in
  for c = 0 to m.cols - 1 do
    for r = 0 to m.rows - 1 do
      acc := !acc +. get m r c
    done
  done;
  !acc

(* Transpose into a new matrix β€” writes in row-major order into dst *)
let transpose_naive m =
  let t = make_matrix m.cols m.rows (fun _ -> 0.0) in
  for r = 0 to m.rows - 1 do
    for c = 0 to m.cols - 1 do
      set t c r (get m r c)
    done
  done;
  t

(* Tiled transpose β€” better cache behaviour for large matrices *)
let transpose_tiled ?(tile=32) m =
  let t = make_matrix m.cols m.rows (fun _ -> 0.0) in
  let r = ref 0 in
  while !r < m.rows do
    let c = ref 0 in
    while !c < m.cols do
      for rr = !r to min (!r + tile - 1) (m.rows - 1) do
        for cc = !c to min (!c + tile - 1) (m.cols - 1) do
          set t cc rr (get m rr cc)
        done
      done;
      c := !c + tile
    done;
    r := !r + tile
  done;
  t

let time_it label f =
  let t0 = Sys.time () in
  let r = f () in
  Printf.printf "%s: %.6fs\n" label (Sys.time () -. t0);
  r

let () =
  let n = 1024 in
  let m = make_matrix n n (fun i -> float_of_int i) in

  let _sr = time_it "Row-major sum" (fun () -> sum_row_major m) in
  let _sc = time_it "Col-major sum" (fun () -> sum_col_major m) in
  let _ = time_it "Naive transpose" (fun () -> transpose_naive m) in
  let _ = time_it "Tiled transpose (tile=32)" (fun () -> transpose_tiled m) in
  Printf.printf "Matrix %dx%d: %d elements\n" n n (n * n)