// 723. Branchless patterns: min/max without branches
//
// Demonstrates CMOV-friendly patterns, bit-twiddling tricks, and
// when to trust LLVM vs when to force branchless.
use std::time::Instant;
// โโ Branchless integer min/max โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
/// Branchless min using arithmetic right-shift mask.
/// On i64: (a - b) >> 63 produces 0 or -1 (all-bits-set).
#[inline(always)]
pub fn min_branchless(a: i64, b: i64) -> i64 {
let diff = a.wrapping_sub(b);
let mask = diff >> 63; // arithmetic shift: 0 or -1
b + (diff & mask)
}
#[inline(always)]
pub fn max_branchless(a: i64, b: i64) -> i64 {
let diff = a.wrapping_sub(b);
let mask = diff >> 63;
a - (diff & mask)
}
/// LLVM-friendly idiomatic version โ compiles to CMOV on x86-64.
/// Prefer this unless profiling shows a specific branch-prediction issue.
#[inline(always)]
pub fn min_idiomatic(a: i64, b: i64) -> i64 { a.min(b) }
#[inline(always)]
pub fn max_idiomatic(a: i64, b: i64) -> i64 { a.max(b) }
// โโ Branchless clamp โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
#[inline(always)]
pub fn clamp_branchless(x: i64, lo: i64, hi: i64) -> i64 {
min_branchless(hi, max_branchless(lo, x))
}
// โโ Branchless absolute value โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
#[inline(always)]
pub fn abs_branchless(x: i64) -> i64 {
let mask = x >> 63; // 0 if positive, -1 (0xFFFFโฆ) if negative
(x + mask) ^ mask // two's complement negate when negative
}
// โโ Branchless select โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
/// Select `a` if `cond` is true, else `b`. `cond` must be exactly 0 or 1.
#[inline(always)]
pub fn select(cond: u64, a: i64, b: i64) -> i64 {
let mask = (cond as i64).wrapping_neg(); // 0 โ 0, 1 โ -1
(a & mask) | (b & !mask)
}
// โโ Saturating arithmetic โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
/// Add with saturation (no overflow) โ uses LLVM intrinsic via std.
#[inline(always)]
pub fn sat_add_u32(a: u32, b: u32) -> u32 {
a.saturating_add(b)
}
#[inline(always)]
pub fn sat_sub_u32(a: u32, b: u32) -> u32 {
a.saturating_sub(b)
}
// โโ Bit-manipulation helpers โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
/// Round `x` up to the next power of two.
#[inline(always)]
pub fn next_power_of_two(x: u64) -> u64 {
if x <= 1 { return 1; }
1u64 << (64 - (x - 1).leading_zeros())
}
/// Count the number of set bits (population count) โ maps to POPCNT.
#[inline(always)]
pub fn popcount(x: u64) -> u32 { x.count_ones() }
/// Isolate the lowest set bit.
#[inline(always)]
pub fn lowest_set_bit(x: u64) -> u64 { x & x.wrapping_neg() }
/// Clear the lowest set bit.
#[inline(always)]
pub fn clear_lowest_bit(x: u64) -> u64 { x & (x - 1) }
// โโ Comparison without branch: bool โ integer โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
/// Convert a comparison result to 0/1 without a branch.
/// LLVM typically emits `setcc` (single instruction).
#[inline(always)]
pub fn cmp_to_int(a: i32, b: i32) -> u32 {
(a < b) as u32
}
// โโ Network-order / byte-swap โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
/// Swap byte order (host โ big-endian) โ maps to BSWAP.
#[inline(always)]
pub fn bswap32(x: u32) -> u32 { x.swap_bytes() }
// โโ main โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
fn main() {
// Correctness checks
println!("=== Correctness ===");
println!("min_branchless(3, 5) = {}", min_branchless(3, 5));
println!("min_branchless(7, 2) = {}", min_branchless(7, 2));
println!("max_branchless(3, 5) = {}", max_branchless(3, 5));
println!("abs_branchless(-42) = {}", abs_branchless(-42));
println!("clamp(-5, 0, 100) = {}", clamp_branchless(-5, 0, 100));
println!("clamp(50, 0, 100) = {}", clamp_branchless(50, 0, 100));
println!("clamp(150, 0, 100) = {}", clamp_branchless(150, 0, 100));
println!("select(1, 10, 20) = {}", select(1, 10, 20));
println!("select(0, 10, 20) = {}", select(0, 10, 20));
println!("sat_add(u32::MAX, 1) = {}", sat_add_u32(u32::MAX, 1));
println!("lowest_set_bit(0b110)= 0b{:b}", lowest_set_bit(0b110));
println!("next_pow2(100) = {}", next_power_of_two(100));
println!("popcount(0xFF) = {}", popcount(0xFF));
// Performance comparison
println!("\n=== Performance (10M iters) ===");
const N: usize = 10_000_000;
let data: Vec<i64> = (0..N as i64).map(|i| (i * 1_234_567 + 89) % 10_000).collect();
let t0 = Instant::now();
let r1: i64 = data.iter().copied().fold(i64::MAX, min_idiomatic);
let t1 = t0.elapsed();
let t0 = Instant::now();
let r2: i64 = data.iter().copied().fold(i64::MAX, min_branchless);
let t2 = t0.elapsed();
println!(" min_idiomatic: {:?} result={r1}", t1);
println!(" min_branchless: {:?} result={r2}", t2);
// Absolute value reduction
let t0 = Instant::now();
let s1: i64 = data.iter().map(|&x| x.abs()).sum();
let t3 = t0.elapsed();
let t0 = Instant::now();
let s2: i64 = data.iter().map(|&x| abs_branchless(x)).sum();
let t4 = t0.elapsed();
println!(" abs_std: {:?} result={s1}", t3);
println!(" abs_branchless: {:?} result={s2}", t4);
}
// โโ Tests โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn min_max() {
assert_eq!(min_branchless(3, 5), 3);
assert_eq!(min_branchless(7, 2), 2);
assert_eq!(min_branchless(-1, 1), -1);
assert_eq!(max_branchless(3, 5), 5);
assert_eq!(max_branchless(-1, -5), -1);
}
#[test]
fn abs_val() {
assert_eq!(abs_branchless(0), 0);
assert_eq!(abs_branchless(42), 42);
assert_eq!(abs_branchless(-42), 42);
assert_eq!(abs_branchless(i64::MIN + 1), i64::MAX);
}
#[test]
fn clamp() {
assert_eq!(clamp_branchless(-5, 0, 100), 0);
assert_eq!(clamp_branchless(50, 0, 100), 50);
assert_eq!(clamp_branchless(150, 0, 100), 100);
}
#[test]
fn select_fn() {
assert_eq!(select(1, 10, 20), 10);
assert_eq!(select(0, 10, 20), 20);
}
#[test]
fn saturating() {
assert_eq!(sat_add_u32(u32::MAX, 1), u32::MAX);
assert_eq!(sat_sub_u32(0, 1), 0);
assert_eq!(sat_add_u32(5, 3), 8);
}
#[test]
fn bit_tricks() {
assert_eq!(lowest_set_bit(0b1010), 0b0010);
assert_eq!(clear_lowest_bit(0b1010), 0b1000);
assert_eq!(next_power_of_two(100), 128);
assert_eq!(next_power_of_two(128), 128);
assert_eq!(popcount(0xFF), 8);
}
#[test]
fn agrees_with_std() {
for a in -100i64..=100 {
for b in -100i64..=100 {
assert_eq!(min_branchless(a, b), a.min(b), "min({a},{b})");
assert_eq!(max_branchless(a, b), a.max(b), "max({a},{b})");
assert_eq!(abs_branchless(a), a.abs(), "abs({a})");
}
}
}
}
(* OCaml: Branchless programming patterns *)
(* --- min/max --- *)
(* OCaml's `min` uses polymorphic comparison which can branch.
For ints, the compiler may or may not emit CMOV. *)
let min_branch a b = if a < b then a else b
let max_branch a b = if a > b then a else b
(* Branchless integer min using arithmetic *)
let min_branchless (a : int) (b : int) =
(* b + ((a - b) land ((a - b) asr 62)) โ uses arithmetic right shift *)
let diff = a - b in
b + (diff land (diff asr 62))
let max_branchless (a : int) (b : int) =
let diff = a - b in
a - (diff land (diff asr 62))
(* --- clamp --- *)
let clamp_branch lo hi x =
if x < lo then lo else if x > hi then hi else x
let clamp_branchless lo hi x =
min_branchless hi (max_branchless lo x)
(* --- absolute value --- *)
let abs_branchless (x : int) =
let mask = x asr 62 in (* all-ones if negative, all-zeros if positive *)
(x + mask) lxor mask
(* --- select without branch --- *)
(* Select a if cond else b, where cond is 0 or 1 *)
let select cond a b =
(* Branchless: (a - b) * cond + b *)
let mask = -cond in (* 0 -> 0, 1 -> -1 (all-ones) *)
(a land mask) lor (b land (lnot mask))
(* --- Benchmark --- *)
let time_it label f =
let t0 = Sys.time () in
let r = f () in
Printf.printf "%s: %.6fs result=%d\n" label (Sys.time () -. t0) r;
r
let () =
let n = 10_000_000 in
let data = Array.init n (fun i -> (i * 1234567 + 89) mod 1000) in
let _r1 = time_it "min_branch" (fun () ->
Array.fold_left (fun acc x -> min_branch acc x) max_int data) in
let _r2 = time_it "min_branchless" (fun () ->
Array.fold_left (fun acc x -> min_branchless acc x) max_int data) in
(* Basic correctness *)
assert (min_branchless 3 5 = 3);
assert (min_branchless 7 2 = 2);
assert (max_branchless 3 5 = 5);
assert (abs_branchless (-42) = 42);
assert (abs_branchless 42 = 42);
assert (clamp_branchless 0 100 (-5) = 0);
assert (clamp_branchless 0 100 150 = 100);
assert (clamp_branchless 0 100 50 = 50);
assert (select 1 10 20 = 10);
assert (select 0 10 20 = 20);
Printf.printf "All assertions passed.\n"