๐Ÿฆ€ Functional Rust

727: SIMD Portable Concepts with std::simd

Difficulty: 5 Level: Master Write vectorised code once with `std::simd` โ€” compile to AVX2, NEON, or SSE2 without per-platform `#[cfg]` guards.

The Problem This Solves

SIMD (Single Instruction, Multiple Data) lets a single CPU instruction operate on a vector of values simultaneously: add 8 floats in one instruction instead of 8. The performance gains on data-parallel workloads โ€” audio processing, image transforms, dot products, string search โ€” are typically 4โ€“16ร— over scalar code. The problem has always been portability: AVX2 on x86, NEON on ARM, SVE on ARM64, RISC-V V extension. Writing hand-optimised SIMD meant `#[cfg(target_arch = "x86_64")]` everywhere, duplicated implementations, and fragile maintenance. `std::simd` (the `portable_simd` nightly feature) solves this with architecture-independent vector types: `f32x8` (8 float32 lanes), `i32x4` (4 int32 lanes), `u8x16` (16 byte lanes). You write lane-wise arithmetic on these types โ€” `a + b` on `f32x8` adds all 8 lanes in parallel. The compiler lowers these to the best native instruction set available: AVX2 on a Skylake server, NEON on an Apple Silicon Mac, SSE2 on an older x86. The mental model shift is from "operate on element N, then N+1, then N+2" to "operate on a chunk of N elements simultaneously." This is the key insight: SIMD code is inherently data-parallel. If your algorithm has loop-carried dependencies, SIMD won't help. If it's element-wise or reducible, SIMD is a multiplier.

The Intuition

Imagine a bank teller processing one transaction at a time versus a teller window that processes 8 transactions simultaneously because they're all the same type. That's SIMD. The `f32x8` type is a register that holds 8 floats at once. `a + b` where both are `f32x8` adds all 8 pairs in a single CPU instruction. Reductions like `.reduce_sum()` combine all 8 lanes back to a scalar. The scalar fallback in this example mirrors the SIMD API exactly โ€” `[f32; 8]` with manual loops. This makes the structure of the algorithm clear even without nightly Rust, and lets you verify correctness before enabling vectorisation.

How It Works in Rust

// On nightly, enable with: #![feature(portable_simd)]
// use std::simd::{f32x8, SimdFloat};

// Stable scalar simulation โ€” structurally identical to SIMD version.
#[derive(Clone, Copy)]
pub struct F32x8([f32; 8]);

impl F32x8 {
 pub fn splat(v: f32) -> Self { Self([v; 8]) }
 pub fn from_array(a: [f32; 8]) -> Self { Self(a) }

 // Lane-wise add โ€” compiles to VADDPS ymm on AVX2
 pub fn add(self, rhs: Self) -> Self {
     let mut r = [0.0f32; 8];
     for i in 0..8 { r[i] = self.0[i] + rhs.0[i]; }
     Self(r)
 }

 pub fn reduce_sum(self) -> f32 { self.0.iter().sum() }
}

// Dot product over slices โ€” processes 8 elements per iteration.
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
 let mut acc = F32x8::splat(0.0);
 let chunks = a.len() / 8;
 for i in 0..chunks {
     let va = F32x8::from_array(a[i*8..i*8+8].try_into().unwrap());
     let vb = F32x8::from_array(b[i*8..i*8+8].try_into().unwrap());
     acc = acc.add(va.mul(vb));
 }
 acc.reduce_sum()
}
On nightly, replace `F32x8` with `std::simd::f32x8` and the loops with SIMD operator overloads โ€” the structure stays identical.

What This Unlocks

Key Differences

ConceptOCamlRust
SIMD accessVia C stubs or `owl-base``std::simd` (nightly) or `wide` crate
Vector typesNot in stdlib`f32x4`, `f32x8`, `i32x16`, etc.
Lane operationsNot available`a + b` on SIMD types (element-wise)
Horizontal reductionNot available`.reduce_sum()`, `.reduce_max()`
Conditional selectionNot available`Mask<i32, 4>`, `.select()`
Platform portabilityN/AOne type, best native instructions
// 727. SIMD concepts with std::simd (portable_simd)
//
// This file demonstrates the portable_simd API using stable-compatible
// scalar implementations that mirror the SIMD mental model exactly.
// On nightly Rust, replace the scalar impls with std::simd types.
//
// To enable on nightly: add `#![feature(portable_simd)]` and use
// `use std::simd::*;`
//
// The scalar versions below are structurally identical to the SIMD versions โ€”
// just swap `[f32; LANES]` for `f32xN` and loops for SIMD ops.

use std::time::Instant;

// โ”€โ”€ LANES constant โ€” matches f32x8 โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

const LANES: usize = 8;

// โ”€โ”€ Portable scalar "SIMD" types (mirrors std::simd API) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

/// Simulates `std::simd::f32x8` โ€” 8 f32 lanes.
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct F32x8([f32; LANES]);

impl F32x8 {
    pub fn splat(v: f32) -> Self { Self([v; LANES]) }
    pub fn from_array(a: [f32; LANES]) -> Self { Self(a) }
    pub fn to_array(self) -> [f32; LANES] { self.0 }

    /// Lane-wise addition โ€” compiles to VADDPS ymm on AVX2.
    pub fn add(self, rhs: Self) -> Self {
        let mut r = [0.0f32; LANES];
        for i in 0..LANES { r[i] = self.0[i] + rhs.0[i]; }
        Self(r)
    }

    /// Lane-wise multiplication โ€” compiles to VMULPS ymm on AVX2.
    pub fn mul(self, rhs: Self) -> Self {
        let mut r = [0.0f32; LANES];
        for i in 0..LANES { r[i] = self.0[i] * rhs.0[i]; }
        Self(r)
    }

    /// Fused multiply-add: self * a + b โ€” compiles to VFMADD213PS on AVX2+FMA.
    pub fn mul_add(self, a: Self, b: Self) -> Self {
        let mut r = [0.0f32; LANES];
        for i in 0..LANES { r[i] = self.0[i].mul_add(a.0[i], b.0[i]); }
        Self(r)
    }

    /// Horizontal sum reduction โ€” compiles to `vhaddps` or tree reduction.
    pub fn reduce_sum(self) -> f32 {
        self.0.iter().copied().sum()
    }

    /// Lane-wise max โ€” compiles to VMAXPS ymm.
    pub fn max(self, rhs: Self) -> Self {
        let mut r = [0.0f32; LANES];
        for i in 0..LANES { r[i] = self.0[i].max(rhs.0[i]); }
        Self(r)
    }

    /// Lane-wise min โ€” compiles to VMINPS ymm.
    pub fn min(self, rhs: Self) -> Self {
        let mut r = [0.0f32; LANES];
        for i in 0..LANES { r[i] = self.0[i].min(rhs.0[i]); }
        Self(r)
    }

    /// Mask select: choose `on_true[i]` where `mask[i] > 0`, else `on_false[i]`.
    /// Maps to `VBLENDVPS` or `VPBLENDMD` on AVX512.
    pub fn select(mask: &MaskF32x8, on_true: Self, on_false: Self) -> Self {
        let mut r = [0.0f32; LANES];
        for i in 0..LANES {
            r[i] = if mask.0[i] { on_true.0[i] } else { on_false.0[i] };
        }
        Self(r)
    }

    /// Compare greater-than, producing a mask.
    pub fn gt(self, rhs: Self) -> MaskF32x8 {
        let mut m = [false; LANES];
        for i in 0..LANES { m[i] = self.0[i] > rhs.0[i]; }
        MaskF32x8(m)
    }
}

/// A boolean mask over 8 f32 lanes. Maps to `std::simd::Mask<i32, 8>`.
#[derive(Clone, Copy, Debug)]
pub struct MaskF32x8([bool; LANES]);

impl MaskF32x8 {
    pub fn any(self) -> bool { self.0.iter().any(|&b| b) }
    pub fn all(self) -> bool { self.0.iter().all(|&b| b) }
}

// โ”€โ”€ Vectorised algorithms โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

/// Dot product using 8-wide SIMD accumulation.
/// Processes 8 elements per loop iteration.
pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
    assert_eq!(a.len(), b.len());
    let n = a.len();
    let full_chunks = n / LANES;

    let mut acc = F32x8::splat(0.0);
    for i in 0..full_chunks {
        let off = i * LANES;
        let va = F32x8::from_array(a[off..off+LANES].try_into().unwrap());
        let vb = F32x8::from_array(b[off..off+LANES].try_into().unwrap());
        // acc = va * vb + acc  (FMA)
        acc = va.mul_add(vb, acc);
    }

    // Handle remaining elements (scalar tail)
    let mut result = acc.reduce_sum();
    for i in (full_chunks * LANES)..n {
        result += a[i] * b[i];
    }
    result
}

/// Element-wise clamp using SIMD min/max.
pub fn clamp_simd(data: &mut [f32], lo: f32, hi: f32) {
    let vlo = F32x8::splat(lo);
    let vhi = F32x8::splat(hi);
    let n = data.len();
    let full = n / LANES;

    for i in 0..full {
        let off = i * LANES;
        let v = F32x8::from_array(data[off..off+LANES].try_into().unwrap());
        let clamped = v.max(vlo).min(vhi);
        data[off..off+LANES].copy_from_slice(&clamped.to_array());
    }
    // Scalar tail
    for v in &mut data[full * LANES..] {
        *v = v.clamp(lo, hi);
    }
}

/// ReLU activation: max(x, 0) โ€” common in neural networks.
pub fn relu_simd(data: &mut [f32]) {
    clamp_simd(data, 0.0, f32::INFINITY);
}

/// Horizontal sum of a large slice using 8-wide vectors.
pub fn sum_simd(data: &[f32]) -> f32 {
    let full = data.len() / LANES;
    let mut acc = F32x8::splat(0.0);
    for i in 0..full {
        let off = i * LANES;
        let v = F32x8::from_array(data[off..off+LANES].try_into().unwrap());
        acc = acc.add(v);
    }
    let mut result = acc.reduce_sum();
    for &v in &data[full * LANES..] { result += v; }
    result
}

// โ”€โ”€ Scalar reference implementations โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

pub fn dot_scalar(a: &[f32], b: &[f32]) -> f32 {
    a.iter().zip(b).map(|(&x, &y)| x * y).sum()
}

// โ”€โ”€ main โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

fn main() {
    // --- F32x8 lane operations ---
    let a = F32x8::from_array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
    let b = F32x8::from_array([8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]);

    println!("=== F32x8 lane operations ===");
    println!("  a + b = {:?}", a.add(b).to_array());
    println!("  a * b = {:?}", a.mul(b).to_array());
    println!("  sum(a) = {}", a.reduce_sum());
    println!("  max(a,b) = {:?}", a.max(b).to_array());
    println!("  a > 4.5? all={} any={}", a.gt(F32x8::splat(4.5)).all(), a.gt(F32x8::splat(4.5)).any());

    // --- Dot product ---
    println!("\n=== Dot product (1M floats) ===");
    const N: usize = 1_000_000;
    let va: Vec<f32> = (0..N).map(|i| (i % 100) as f32 * 0.1).collect();
    let vb: Vec<f32> = (0..N).map(|i| ((N - i) % 100) as f32 * 0.1).collect();

    let t0 = Instant::now();
    let d_simd = dot_product_simd(&va, &vb);
    let t1 = t0.elapsed();

    let t0 = Instant::now();
    let d_scalar = dot_scalar(&va, &vb);
    let t2 = t0.elapsed();

    println!("  SIMD:   {:?}  result={d_simd:.2}", t1);
    println!("  scalar: {:?}  result={d_scalar:.2}", t2);
    assert!((d_simd - d_scalar).abs() < 10.0, "results must agree");

    // --- ReLU ---
    println!("\n=== ReLU activation ===");
    let mut data = vec![-3.0, -1.0, 0.0, 1.5, -0.5, 2.0, -2.0, 0.5];
    println!("  before: {:?}", data);
    relu_simd(&mut data);
    println!("  after:  {:?}", data);

    // --- Select ---
    println!("\n=== Conditional select ===");
    let ones  = F32x8::splat(1.0);
    let zeros = F32x8::splat(0.0);
    let mask  = a.gt(F32x8::splat(4.0));
    let sel   = F32x8::select(&mask, ones, zeros);
    println!("  a>4.0 โ†’ {:?}", sel.to_array());
}

// โ”€โ”€ Tests โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn lane_add() {
        let a = F32x8::splat(2.0);
        let b = F32x8::splat(3.0);
        assert_eq!(a.add(b).to_array(), [5.0; 8]);
    }

    #[test]
    fn reduce_sum() {
        let a = F32x8::from_array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
        assert_eq!(a.reduce_sum(), 36.0);
    }

    #[test]
    fn dot_product_matches_scalar() {
        let a: Vec<f32> = (0..64).map(|i| i as f32).collect();
        let b: Vec<f32> = (0..64).map(|i| (64 - i) as f32).collect();
        let d_simd = dot_product_simd(&a, &b);
        let d_scalar = dot_scalar(&a, &b);
        assert!((d_simd - d_scalar).abs() < 0.01);
    }

    #[test]
    fn relu_zeroes_negatives() {
        let mut v = vec![-2.0f32, -1.0, 0.0, 1.0, 2.0, -3.0, 4.0, -0.5];
        relu_simd(&mut v);
        assert_eq!(v, [0.0, 0.0, 0.0, 1.0, 2.0, 0.0, 4.0, 0.0]);
    }

    #[test]
    fn clamp_bounds() {
        let mut v = vec![-5.0f32, 0.0, 5.0, 10.0, 15.0, 3.0, -1.0, 8.0];
        clamp_simd(&mut v, 0.0, 10.0);
        for &x in &v { assert!(x >= 0.0 && x <= 10.0); }
    }

    #[test]
    fn sum_simd_correct() {
        let data: Vec<f32> = (1..=16).map(|i| i as f32).collect();
        let s = sum_simd(&data);
        assert_eq!(s, 136.0); // 1+2+โ€ฆ+16
    }

    #[test]
    fn mask_select() {
        let a = F32x8::from_array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
        let threshold = F32x8::splat(4.5);
        let mask = a.gt(threshold);
        let selected = F32x8::select(&mask, F32x8::splat(1.0), F32x8::splat(0.0));
        assert_eq!(&selected.to_array()[..4], &[0.0, 0.0, 0.0, 0.0]);
        assert_eq!(&selected.to_array()[4..], &[1.0, 1.0, 1.0, 1.0]);
    }
}
(* OCaml: SIMD concepts via scalar simulation and Bigarray
   OCaml doesn't have SIMD intrinsics in the standard library.
   We demonstrate the vectorised mental model using Bigarray and
   show what the equivalent SIMD code achieves. *)

open Bigarray

type f32vec = (float, float32_elt, c_layout) Array1.t

let make_vec n = Array1.create float32 c_layout n

(* --- Simulated SIMD operations (scalar fallback) --- *)

(* Vectorised addition: element-wise a[i] + b[i] *)
let vec_add (a : f32vec) (b : f32vec) : f32vec =
  let n = Array1.dim a in
  let c = make_vec n in
  for i = 0 to n - 1 do
    c.{i} <- a.{i} +. b.{i}
  done;
  c

(* Vectorised multiply-accumulate (FMA pattern) *)
let vec_fma (a : f32vec) (b : f32vec) (c : f32vec) : f32vec =
  let n = Array1.dim a in
  let result = make_vec n in
  for i = 0 to n - 1 do
    result.{i} <- a.{i} *. b.{i} +. c.{i}
  done;
  result

(* Horizontal sum (reduction) *)
let vec_sum (a : f32vec) : float =
  let acc = ref 0.0 in
  for i = 0 to Array1.dim a - 1 do
    acc := !acc +. a.{i}
  done;
  !acc

(* Dot product via vec_fma pattern *)
let dot_product (a : f32vec) (b : f32vec) : float =
  let n = Array1.dim a in
  let acc = ref 0.0 in
  for i = 0 to n - 1 do
    acc := !acc +. a.{i} *. b.{i}
  done;
  !acc

(* Conditional select: select a[i] if mask[i] > 0, else b[i] *)
let vec_select (mask : f32vec) (a : f32vec) (b : f32vec) : f32vec =
  let n = Array1.dim a in
  let r = make_vec n in
  for i = 0 to n - 1 do
    r.{i} <- if mask.{i} > 0.0 then a.{i} else b.{i}
  done;
  r

(* --- Demo --- *)

let fill_vec n f =
  let v = make_vec n in
  for i = 0 to n - 1 do v.{i} <- f i done;
  v

let () =
  let n = 8 in
  let a = fill_vec n (fun i -> float_of_int (i + 1)) in
  let b = fill_vec n (fun i -> float_of_int (n - i)) in

  Printf.printf "a = [%s]\n"
    (String.concat "; " (List.init n (fun i -> Printf.sprintf "%.0f" a.{i})));
  Printf.printf "b = [%s]\n"
    (String.concat "; " (List.init n (fun i -> Printf.sprintf "%.0f" b.{i})));

  let s = vec_add a b in
  Printf.printf "a+b = [%s]\n"
    (String.concat "; " (List.init n (fun i -> Printf.sprintf "%.0f" s.{i})));

  Printf.printf "dot(a,b) = %.1f\n" (dot_product a b);
  Printf.printf "sum(a)   = %.1f\n" (vec_sum a);

  let mask = fill_vec n (fun i -> if i < 4 then 1.0 else -1.0) in
  let ones  = fill_vec n (fun _ -> 1.0) in
  let zeros = fill_vec n (fun _ -> 0.0) in
  let sel = vec_select mask ones zeros in
  Printf.printf "select = [%s]\n"
    (String.concat "; " (List.init n (fun i -> Printf.sprintf "%.0f" sel.{i})));

  (* Note: in real high-perf OCaml, you'd call C SIMD kernels via FFI.
     Owl/owl-base wraps BLAS/LAPACK which uses SIMD internally. *)
  Printf.printf "OCaml SIMD note: use owl-base or C FFI for native SIMD.\n"