๐Ÿฆ€ Functional Rust

723: Branchless Programming

Difficulty: 3 Level: Expert Eliminate branch mispredictions by computing results arithmetically instead of conditionally.

The Problem This Solves

Modern CPUs execute instructions speculatively โ€” they guess the branch outcome and start executing the predicted path before the condition is resolved. When the guess is wrong (a branch misprediction), the CPU must flush the pipeline and restart, costing 10โ€“20 clock cycles per mispredict. For a tight loop processing unpredictably-distributed data (random min/max, clamp with arbitrary inputs), mispredictions dominate runtime. Branchless code replaces conditional jumps with arithmetic or bitwise operations that always produce the correct result in constant time. There is no branch to mispredict. The trade-off: the CPU must execute both computation paths and combine them, versus speculatively executing one. For data where the branch is truly unpredictable (50/50 distribution), branchless wins. For data where one branch dominates (99% of values are positive), the branch predictor wins and branchless loses.

The Intuition

A conditional `if a < b { a } else { b }` generates a conditional jump (`jl`, `jg`) in machine code. A branchless version computes a bitmask from the comparison and combines both values arithmetically โ€” no jump, no speculation, no pipeline flush. The key insight: arithmetic right-shift on signed integers sign-extends the most significant bit to all bits. `(a - b) >> 63` produces `0` if `a >= b` and `-1` (all-bits-set / `0xFFFF...`) if `a < b`. AND-ing with a value either selects it (mask = -1) or zeroes it (mask = 0). This is how branchless min/max works. The more important insight: LLVM already does this for you. `a.min(b)` on integers compiles to a `cmov` (conditional move) instruction on x86-64 โ€” no branch, no misprediction risk. Write idiomatic Rust first, measure, and only reach for explicit branchless tricks if profiling reveals a hot branch that LLVM isn't optimising.

How It Works in Rust

// โ”€โ”€ Explicit branchless (pedagogical) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
#[inline(always)]
pub fn min_branchless(a: i64, b: i64) -> i64 {
 let diff = a.wrapping_sub(b);
 let mask = diff >> 63; // 0 or all-bits-set (0xFFFF..FFFF)
 b + (diff & mask)      // b + (a - b) if a < b, else b + 0
}

// โ”€โ”€ LLVM-idiomatic (prefer this โ€” compiles to CMOV) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
#[inline(always)]
pub fn min_idiomatic(a: i64, b: i64) -> i64 { a.min(b) }

// โ”€โ”€ Branchless absolute value โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
#[inline(always)]
pub fn abs_branchless(x: i64) -> i64 {
 let mask = x >> 63;   // 0 if positive, -1 if negative
 (x + mask) ^ mask     // two's complement negate when negative
}

// โ”€โ”€ Branchless select: choose a or b based on a boolean โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
/// `cond` must be exactly 0 or 1 (not a general bool cast).
#[inline(always)]
pub fn select(cond: u64, a: i64, b: i64) -> i64 {
 let mask = (cond as i64).wrapping_neg(); // 0 โ†’ 0, 1 โ†’ -1
 (a & mask) | (b & !mask)
}

// โ”€โ”€ When to profile before optimising โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
// Branchy: fast if data is predictable (99% positive values)
fn sum_positive_branchy(data: &[i64]) -> i64 {
 data.iter().filter(|&&x| x > 0).sum()
}
// Branchless: fast if data is unpredictable (random signs)
fn sum_positive_branchless(data: &[i64]) -> i64 {
 data.iter().map(|&x| x & -(x > 0) as i64).sum()
}

What This Unlocks

Key Differences

ConceptOCamlRust
Integer min/max`min a b` (may branch)`a.min(b)` โ†’ `CMOV` (branchless in release)
Explicit bit mask`Int64.logand / asr``wrapping_sub`, `>>`, `&` on primitives
Absolute value`abs x``x.abs()` (branchless in LLVM)
Constant-time selectManual bitmaskManual bitmask (no stdlib primitive)
Compiler optimisationReasonably goodLLVM produces CMOV for simple if/else on integers
// 723. Branchless patterns: min/max without branches
//
// Demonstrates CMOV-friendly patterns, bit-twiddling tricks, and
// when to trust LLVM vs when to force branchless.

use std::time::Instant;

// โ”€โ”€ Branchless integer min/max โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

/// Branchless min using arithmetic right-shift mask.
/// On i64: (a - b) >> 63 produces 0 or -1 (all-bits-set).
#[inline(always)]
pub fn min_branchless(a: i64, b: i64) -> i64 {
    let diff = a.wrapping_sub(b);
    let mask = diff >> 63; // arithmetic shift: 0 or -1
    b + (diff & mask)
}

#[inline(always)]
pub fn max_branchless(a: i64, b: i64) -> i64 {
    let diff = a.wrapping_sub(b);
    let mask = diff >> 63;
    a - (diff & mask)
}

/// LLVM-friendly idiomatic version โ€” compiles to CMOV on x86-64.
/// Prefer this unless profiling shows a specific branch-prediction issue.
#[inline(always)]
pub fn min_idiomatic(a: i64, b: i64) -> i64 { a.min(b) }

#[inline(always)]
pub fn max_idiomatic(a: i64, b: i64) -> i64 { a.max(b) }

// โ”€โ”€ Branchless clamp โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

#[inline(always)]
pub fn clamp_branchless(x: i64, lo: i64, hi: i64) -> i64 {
    min_branchless(hi, max_branchless(lo, x))
}

// โ”€โ”€ Branchless absolute value โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

#[inline(always)]
pub fn abs_branchless(x: i64) -> i64 {
    let mask = x >> 63; // 0 if positive, -1 (0xFFFFโ€ฆ) if negative
    (x + mask) ^ mask   // two's complement negate when negative
}

// โ”€โ”€ Branchless select โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

/// Select `a` if `cond` is true, else `b`. `cond` must be exactly 0 or 1.
#[inline(always)]
pub fn select(cond: u64, a: i64, b: i64) -> i64 {
    let mask = (cond as i64).wrapping_neg(); // 0 โ†’ 0, 1 โ†’ -1
    (a & mask) | (b & !mask)
}

// โ”€โ”€ Saturating arithmetic โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

/// Add with saturation (no overflow) โ€” uses LLVM intrinsic via std.
#[inline(always)]
pub fn sat_add_u32(a: u32, b: u32) -> u32 {
    a.saturating_add(b)
}

#[inline(always)]
pub fn sat_sub_u32(a: u32, b: u32) -> u32 {
    a.saturating_sub(b)
}

// โ”€โ”€ Bit-manipulation helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

/// Round `x` up to the next power of two.
#[inline(always)]
pub fn next_power_of_two(x: u64) -> u64 {
    if x <= 1 { return 1; }
    1u64 << (64 - (x - 1).leading_zeros())
}

/// Count the number of set bits (population count) โ€” maps to POPCNT.
#[inline(always)]
pub fn popcount(x: u64) -> u32 { x.count_ones() }

/// Isolate the lowest set bit.
#[inline(always)]
pub fn lowest_set_bit(x: u64) -> u64 { x & x.wrapping_neg() }

/// Clear the lowest set bit.
#[inline(always)]
pub fn clear_lowest_bit(x: u64) -> u64 { x & (x - 1) }

// โ”€โ”€ Comparison without branch: bool โ†’ integer โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

/// Convert a comparison result to 0/1 without a branch.
/// LLVM typically emits `setcc` (single instruction).
#[inline(always)]
pub fn cmp_to_int(a: i32, b: i32) -> u32 {
    (a < b) as u32
}

// โ”€โ”€ Network-order / byte-swap โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

/// Swap byte order (host โ†” big-endian) โ€” maps to BSWAP.
#[inline(always)]
pub fn bswap32(x: u32) -> u32 { x.swap_bytes() }

// โ”€โ”€ main โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

fn main() {
    // Correctness checks
    println!("=== Correctness ===");
    println!("min_branchless(3, 5) = {}", min_branchless(3, 5));
    println!("min_branchless(7, 2) = {}", min_branchless(7, 2));
    println!("max_branchless(3, 5) = {}", max_branchless(3, 5));
    println!("abs_branchless(-42)  = {}", abs_branchless(-42));
    println!("clamp(-5, 0, 100)    = {}", clamp_branchless(-5, 0, 100));
    println!("clamp(50, 0, 100)    = {}", clamp_branchless(50, 0, 100));
    println!("clamp(150, 0, 100)   = {}", clamp_branchless(150, 0, 100));
    println!("select(1, 10, 20)    = {}", select(1, 10, 20));
    println!("select(0, 10, 20)    = {}", select(0, 10, 20));
    println!("sat_add(u32::MAX, 1) = {}", sat_add_u32(u32::MAX, 1));
    println!("lowest_set_bit(0b110)= 0b{:b}", lowest_set_bit(0b110));
    println!("next_pow2(100)       = {}", next_power_of_two(100));
    println!("popcount(0xFF)       = {}", popcount(0xFF));

    // Performance comparison
    println!("\n=== Performance (10M iters) ===");
    const N: usize = 10_000_000;
    let data: Vec<i64> = (0..N as i64).map(|i| (i * 1_234_567 + 89) % 10_000).collect();

    let t0 = Instant::now();
    let r1: i64 = data.iter().copied().fold(i64::MAX, min_idiomatic);
    let t1 = t0.elapsed();

    let t0 = Instant::now();
    let r2: i64 = data.iter().copied().fold(i64::MAX, min_branchless);
    let t2 = t0.elapsed();

    println!("  min_idiomatic:  {:?}  result={r1}", t1);
    println!("  min_branchless: {:?}  result={r2}", t2);

    // Absolute value reduction
    let t0 = Instant::now();
    let s1: i64 = data.iter().map(|&x| x.abs()).sum();
    let t3 = t0.elapsed();

    let t0 = Instant::now();
    let s2: i64 = data.iter().map(|&x| abs_branchless(x)).sum();
    let t4 = t0.elapsed();

    println!("  abs_std:        {:?}  result={s1}", t3);
    println!("  abs_branchless: {:?}  result={s2}", t4);
}

// โ”€โ”€ Tests โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn min_max() {
        assert_eq!(min_branchless(3, 5), 3);
        assert_eq!(min_branchless(7, 2), 2);
        assert_eq!(min_branchless(-1, 1), -1);
        assert_eq!(max_branchless(3, 5), 5);
        assert_eq!(max_branchless(-1, -5), -1);
    }

    #[test]
    fn abs_val() {
        assert_eq!(abs_branchless(0), 0);
        assert_eq!(abs_branchless(42), 42);
        assert_eq!(abs_branchless(-42), 42);
        assert_eq!(abs_branchless(i64::MIN + 1), i64::MAX);
    }

    #[test]
    fn clamp() {
        assert_eq!(clamp_branchless(-5, 0, 100), 0);
        assert_eq!(clamp_branchless(50, 0, 100), 50);
        assert_eq!(clamp_branchless(150, 0, 100), 100);
    }

    #[test]
    fn select_fn() {
        assert_eq!(select(1, 10, 20), 10);
        assert_eq!(select(0, 10, 20), 20);
    }

    #[test]
    fn saturating() {
        assert_eq!(sat_add_u32(u32::MAX, 1), u32::MAX);
        assert_eq!(sat_sub_u32(0, 1), 0);
        assert_eq!(sat_add_u32(5, 3), 8);
    }

    #[test]
    fn bit_tricks() {
        assert_eq!(lowest_set_bit(0b1010), 0b0010);
        assert_eq!(clear_lowest_bit(0b1010), 0b1000);
        assert_eq!(next_power_of_two(100), 128);
        assert_eq!(next_power_of_two(128), 128);
        assert_eq!(popcount(0xFF), 8);
    }

    #[test]
    fn agrees_with_std() {
        for a in -100i64..=100 {
            for b in -100i64..=100 {
                assert_eq!(min_branchless(a, b), a.min(b), "min({a},{b})");
                assert_eq!(max_branchless(a, b), a.max(b), "max({a},{b})");
                assert_eq!(abs_branchless(a), a.abs(), "abs({a})");
            }
        }
    }
}
(* OCaml: Branchless programming patterns *)

(* --- min/max --- *)
(* OCaml's `min` uses polymorphic comparison which can branch.
   For ints, the compiler may or may not emit CMOV. *)

let min_branch a b = if a < b then a else b
let max_branch a b = if a > b then a else b

(* Branchless integer min using arithmetic *)
let min_branchless (a : int) (b : int) =
  (* b + ((a - b) land ((a - b) asr 62)) โ€” uses arithmetic right shift *)
  let diff = a - b in
  b + (diff land (diff asr 62))

let max_branchless (a : int) (b : int) =
  let diff = a - b in
  a - (diff land (diff asr 62))

(* --- clamp --- *)
let clamp_branch lo hi x =
  if x < lo then lo else if x > hi then hi else x

let clamp_branchless lo hi x =
  min_branchless hi (max_branchless lo x)

(* --- absolute value --- *)
let abs_branchless (x : int) =
  let mask = x asr 62 in  (* all-ones if negative, all-zeros if positive *)
  (x + mask) lxor mask

(* --- select without branch --- *)
(* Select a if cond else b, where cond is 0 or 1 *)
let select cond a b =
  (* Branchless: (a - b) * cond + b *)
  let mask = -cond in   (* 0 -> 0, 1 -> -1 (all-ones) *)
  (a land mask) lor (b land (lnot mask))

(* --- Benchmark --- *)
let time_it label f =
  let t0 = Sys.time () in
  let r = f () in
  Printf.printf "%s: %.6fs result=%d\n" label (Sys.time () -. t0) r;
  r

let () =
  let n = 10_000_000 in
  let data = Array.init n (fun i -> (i * 1234567 + 89) mod 1000) in

  let _r1 = time_it "min_branch" (fun () ->
    Array.fold_left (fun acc x -> min_branch acc x) max_int data) in
  let _r2 = time_it "min_branchless" (fun () ->
    Array.fold_left (fun acc x -> min_branchless acc x) max_int data) in

  (* Basic correctness *)
  assert (min_branchless 3 5 = 3);
  assert (min_branchless 7 2 = 2);
  assert (max_branchless 3 5 = 5);
  assert (abs_branchless (-42) = 42);
  assert (abs_branchless 42 = 42);
  assert (clamp_branchless 0 100 (-5) = 0);
  assert (clamp_branchless 0 100 150 = 100);
  assert (clamp_branchless 0 100 50 = 50);
  assert (select 1 10 20 = 10);
  assert (select 0 10 20 = 20);
  Printf.printf "All assertions passed.\n"