/// 733: Profile-Guided Patterns โ black_box, cold/hot paths, SoA vs AoS
use std::hint::black_box;
// โโ black_box usage โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
/// Without black_box, the compiler may constant-fold this entire call.
#[inline(never)] // ensures the function appears in profiler output
fn sum_squares(n: u64) -> u64 {
(0..n).map(|i| i * i).sum()
}
// โโ Hot / Cold path annotation โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
/// Mark rare error-handling code as cold to keep it out of the hot path.
#[cold]
#[inline(never)]
fn handle_overflow(a: u64, b: u64) -> u64 {
eprintln!("overflow: {} + {}", a, b);
u64::MAX
}
fn checked_add_hot(a: u64, b: u64) -> u64 {
// Compiler infers that the success branch is hot
a.checked_add(b).unwrap_or_else(|| handle_overflow(a, b))
}
// โโ Struct-of-Arrays (SoA) vs Array-of-Structs (AoS) โโโโโโโโโโโโโโโโโโโโโโโโโ
/// AoS: poor cache use when accessing only one field
#[allow(dead_code)]
struct PointAoS { x: f32, y: f32, z: f32 }
/// SoA: excellent cache use โ each array is contiguous
struct PointsSoA {
x: Vec<f32>,
y: Vec<f32>,
z: Vec<f32>,
}
impl PointsSoA {
fn new(n: usize) -> Self {
PointsSoA {
x: (0..n).map(|i| i as f32).collect(),
y: (0..n).map(|i| i as f32 * 2.0).collect(),
z: (0..n).map(|i| i as f32 * 3.0).collect(),
}
}
/// Only touches the `x` array โ minimal cache footprint.
fn sum_x(&self) -> f32 {
self.x.iter().sum()
}
}
fn aos_sum_x(points: &[PointAoS]) -> f32 {
// Loads all 3 floats even though we only want x โ cache waste
points.iter().map(|p| p.x).sum()
}
// โโ Measurement discipline โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
fn measure_ns<F: FnOnce() -> R, R>(f: F) -> (R, u128) {
let t0 = std::time::Instant::now();
let result = f();
let elapsed = t0.elapsed().as_nanos();
(result, elapsed)
}
fn main() {
// Use black_box to prevent dead-code elimination in measurements
let (r, ns) = measure_ns(|| sum_squares(black_box(10_000)));
println!("sum_squares(10_000) = {} in {}ns", r, ns);
// Hot path
println!("checked_add(10, 20) = {}", checked_add_hot(10, 20));
println!("checked_add(MAX, 1) = {}", checked_add_hot(u64::MAX, 1));
// SoA vs AoS
let soa = PointsSoA::new(10_000);
let (sum, ns_soa) = measure_ns(|| black_box(soa.sum_x()));
println!("SoA sum_x = {:.0} in {}ns", sum, ns_soa);
let aos: Vec<PointAoS> = (0..10_000)
.map(|i| PointAoS { x: i as f32, y: i as f32 * 2.0, z: i as f32 * 3.0 })
.collect();
let (sum_a, ns_aos) = measure_ns(|| black_box(aos_sum_x(&aos)));
println!("AoS sum_x = {:.0} in {}ns", sum_a, ns_aos);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sum_squares_correct() {
// 0ยฒ + 1ยฒ + 2ยฒ + 3ยฒ = 14
assert_eq!(sum_squares(4), 14);
assert_eq!(sum_squares(0), 0);
}
#[test]
fn checked_add_no_overflow() {
assert_eq!(checked_add_hot(3, 4), 7);
}
#[test]
fn checked_add_overflow_returns_max() {
assert_eq!(checked_add_hot(u64::MAX, 1), u64::MAX);
}
#[test]
fn soa_sum_x_correct() {
let soa = PointsSoA::new(5); // x = [0,1,2,3,4]
assert_eq!(soa.sum_x(), 10.0);
}
#[test]
fn aos_sum_x_correct() {
let aos: Vec<PointAoS> = (0..5u32)
.map(|i| PointAoS { x: i as f32, y: 0.0, z: 0.0 })
.collect();
assert_eq!(aos_sum_x(&aos), 10.0);
}
}
(* 733: Profile-Guided Patterns โ OCaml *)
(* OCaml's Sys.opaque_identity is the black_box equivalent *)
let black_box x = Sys.opaque_identity x
(* Array-of-Structs (AoS) โ poor cache locality for component-wise ops *)
type point_aos = { x: float; y: float; z: float }
(* Struct-of-Arrays (SoA) โ excellent cache locality for SIMD-like ops *)
type points_soa = {
xs: float array;
ys: float array;
zs: float array;
len: int;
}
let sum_x_aos (points: point_aos array) =
Array.fold_left (fun acc p -> acc +. p.x) 0.0 points
let sum_x_soa (pts: points_soa) =
(* Only xs array is touched โ perfect for prefetcher *)
let acc = ref 0.0 in
for i = 0 to pts.len - 1 do
acc := !acc +. pts.xs.(i)
done;
!acc
let () =
let n = 1000 in
let aos = Array.init n (fun i ->
{ x = float_of_int i; y = float_of_int (i*2); z = float_of_int (i*3) }) in
let soa = {
xs = Array.init n float_of_int;
ys = Array.init n (fun i -> float_of_int (i*2));
zs = Array.init n (fun i -> float_of_int (i*3));
len = n;
} in
Printf.printf "AoS sum_x = %.0f\n" (black_box (sum_x_aos (black_box aos)));
Printf.printf "SoA sum_x = %.0f\n" (black_box (sum_x_soa (black_box soa)))