๐Ÿฆ€ Functional Rust

831: Miller-Rabin Probabilistic Primality Test

Difficulty: 4 Level: Advanced Deterministic primality for all 64-bit integers using 12 fixed witnesses โ€” the algorithm inside RSA key generation and every serious cryptographic library.

The Problem This Solves

Trial division checks primality in O(โˆšn): for a 64-bit prime near 2^63, that's ~3 billion divisions โ€” unacceptably slow. The sieve is fast but requires O(n) memory โ€” impractical for numbers near 2^63. Miller-Rabin tests primality in O(k logยฒ n) where k is the number of witnesses, regardless of the size of n. For probabilistic Miller-Rabin, each witness independently has a โ‰ค 1/4 chance of being fooled by a composite. With 12 witnesses, a composite passing all tests has probability โ‰ค (1/4)^12 โ‰ˆ 6 ร— 10^-8. But more powerfully: with the specific deterministic witness set `{2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}`, the test is provably correct for all n < 3.3 ร— 10^24 โ€” covering every 64-bit integer with certainty. This is what production RSA implementations use. Understanding Miller-Rabin also illuminates why RSA's random prime generation works: generate a random odd number, run Miller-Rabin with a few witnesses, repeat until prime. The prime number theorem guarantees roughly 1 in ln(n) candidates near n is prime โ€” for 2048-bit primes, you need ~1400 trials on average.

The Intuition

Factor `n - 1 = 2^s ร— d` where d is odd. By Fermat's little theorem, if n is prime, then `a^(n-1) โ‰ก 1 (mod n)` for any a. More precisely: the sequence `a^d, a^(2d), a^(4d), โ€ฆ, a^(2^s ร— d)` mod n must either start at 1, or hit -1 (= n-1) somewhere before the end. Any composite n fails this condition for at least 3/4 of all bases a. Checking multiple witnesses gives exponential confidence. `trailing_zeros()` is the idiomatic way to factor out powers of 2 from `n-1` โ€” replaces the loop that tests `n % 2 == 0` repeatedly.

How It Works in Rust

fn mulmod(a: u64, b: u64, m: u64) -> u64 {
 (a as u128 * b as u128 % m as u128) as u64  // u128 prevents overflow
}

fn pow_mod(mut base: u64, mut exp: u64, m: u64) -> u64 {
 let mut result = 1u64;
 base %= m;
 while exp > 0 {
     if exp & 1 == 1 { result = mulmod(result, base, m); }
     base = mulmod(base, base, m);
     exp >>= 1;
 }
 result
}

fn miller_witness(n: u64, d: u64, s: u32, a: u64) -> bool {
 let mut x = pow_mod(a, d, n);
 if x == 1 || x == n - 1 { return true; }   // Passed trivially
 for _ in 1..s {
     x = mulmod(x, x, n);
     if x == n - 1 { return true; }           // Hit -1 in the sequence
 }
 false  // Failed: n is composite (or a is an exceptional witness)
}

pub fn is_prime(n: u64) -> bool {
 match n {
     0 | 1 => false,
     2 | 3 | 5 | 7 => true,
     _ if n % 2 == 0 || n % 3 == 0 => false,
     _ => {
         // Factor n-1 = 2^s * d with d odd
         let mut d = n - 1;
         let s = d.trailing_zeros();    // Count factors of 2
         d >>= s;

         // 12 witnesses: deterministic for all n < 3.3 ร— 10^24
         const WITNESSES: &[u64] = &[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37];
         WITNESSES.iter().all(|&a| a >= n || miller_witness(n, d, s, a))
     }
 }
}
The `a >= n` guard handles the edge case where a witness is larger than n itself (e.g., testing primality of n=5 with witness a=7).

What This Unlocks

Key Differences

ConceptOCamlRust
128-bit multiply`Int64` widening or Zarith`(a as u128 * b as u128) % m as u128`
Factor 2s from n-1Manual loop dividing by 2`d.trailing_zeros()` โ€” hardware instruction
Witness iteration`List.for_all (fun a -> ...)``WITNESSES.iter().all(&a...)`
Early exit on failureException or bool accumulator`all()` short-circuits on `false`
Guard for small witnessExplicit `if a < n` check`a >= n \\` in the closure
/// Miller-Rabin Probabilistic Primality Test.
///
/// Deterministic for all u64 using witness set {2,3,5,7,11,13,17,19,23,29,31,37}.

fn mulmod(a: u64, b: u64, m: u64) -> u64 {
    (a as u128 * b as u128 % m as u128) as u64
}

fn pow_mod(mut base: u64, mut exp: u64, m: u64) -> u64 {
    let mut result = 1u64;
    base %= m;
    while exp > 0 {
        if exp & 1 == 1 { result = mulmod(result, base, m); }
        base = mulmod(base, base, m);
        exp >>= 1;
    }
    result
}

/// Test if `a` is a Miller-Rabin witness for composite n.
/// Returns true if n is probably prime with respect to witness a.
fn miller_witness(n: u64, d: u64, s: u32, a: u64) -> bool {
    let mut x = pow_mod(a, d, n);
    if x == 1 || x == n - 1 {
        return true;
    }
    for _ in 1..s {
        x = mulmod(x, x, n);
        if x == n - 1 {
            return true;
        }
    }
    false
}

/// Deterministic Miller-Rabin primality test for all u64.
pub fn is_prime(n: u64) -> bool {
    match n {
        0 | 1 => false,
        2 | 3 | 5 | 7 => true,
        _ if n % 2 == 0 || n % 3 == 0 => false,
        _ => {
            // Factor n-1 = 2^s * d where d is odd
            let mut d = n - 1;
            let s = d.trailing_zeros();
            d >>= s;

            // Deterministic witness set for all n < 3.3 ร— 10^24
            const WITNESSES: &[u64] = &[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37];
            WITNESSES.iter().all(|&a| {
                a >= n || miller_witness(n, d, s, a)
            })
        }
    }
}

/// Generate primes up to limit using Miller-Rabin (educational comparison).
fn primes_up_to(limit: u64) -> Vec<u64> {
    (2..=limit).filter(|&n| is_prime(n)).collect()
}

fn main() {
    let primes    = [2u64, 3, 5, 17, 97, 997, 1_000_003, 999_999_937];
    let composites = [4u64, 9, 15, 100, 1001, 1_000_000];

    for &n in &primes {
        println!("is_prime({n}) = {} (expected true)", is_prime(n));
    }
    for &n in &composites {
        println!("is_prime({n}) = {} (expected false)", is_prime(n));
    }

    // Carmichael number 561 = 3ร—11ร—17 โ€” fools Fermat but not Miller-Rabin
    println!("is_prime(561) = {} (Carmichael, expected false)", is_prime(561));

    // Large semiprime
    let p: u64 = 1_000_000_007;
    let q: u64 = 1_000_000_009;
    println!("is_prime(p={p}) = {}", is_prime(p));
    println!("is_prime(p*q={}) = {}", is_prime(p * q));

    let small_primes = primes_up_to(50);
    println!("Primes โ‰ค 50: {small_primes:?}");
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_small_primes() {
        let expected = [2u64, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47];
        for &p in &expected {
            assert!(is_prime(p), "{p} should be prime");
        }
    }

    #[test]
    fn test_composites() {
        for &n in &[0u64, 1, 4, 6, 8, 9, 10, 15, 25, 100] {
            assert!(!is_prime(n), "{n} should be composite");
        }
    }

    #[test]
    fn test_carmichael() {
        // 561, 1105, 1729 are Carmichael numbers
        assert!(!is_prime(561));
        assert!(!is_prime(1105));
        assert!(!is_prime(1729));
    }

    #[test]
    fn test_large_prime() {
        assert!(is_prime(999_999_937));
        assert!(is_prime(1_000_000_007));
        assert!(is_prime(1_000_000_009));
    }

    #[test]
    fn test_semiprime_not_prime() {
        assert!(!is_prime(1_000_000_007 * 3));
    }

    #[test]
    fn test_matches_sieve() {
        // Verify Miller-Rabin matches sieve for n โ‰ค 1000
        let mut sieve = vec![true; 1001];
        sieve[0] = false;
        sieve[1] = false;
        for i in 2..=31 {
            if sieve[i] {
                let mut j = i * i;
                while j <= 1000 { sieve[j] = false; j += i; }
            }
        }
        for n in 0..=1000u64 {
            assert_eq!(is_prime(n), sieve[n as usize],
                "mismatch at n={n}");
        }
    }
}
(* Miller-Rabin Primality Test in OCaml *)

(* Modular multiplication using Int64 to avoid overflow *)
let mulmod a b m =
  Int64.(to_int (rem (mul (of_int a) (of_int b)) (of_int m)))

(* Fast modular exponentiation *)
let rec pow_mod base exp m =
  if exp = 0 then 1
  else if exp land 1 = 1 then mulmod base (pow_mod base (exp-1) m) m
  else
    let half = pow_mod base (exp/2) m in
    mulmod half half m

(* Single Miller-Rabin witness test: is 'a' a witness to n's primality? *)
(* Returns true if n is probably prime w.r.t. witness a *)
let miller_witness (n : int) (d : int) (s : int) (a : int) : bool =
  let x = ref (pow_mod a d n) in
  if !x = 1 || !x = n - 1 then true
  else begin
    let composite = ref true in
    for _ = 1 to s - 1 do
      x := mulmod !x !x n;
      if !x = n - 1 then composite := false
    done;
    not !composite
  end

(* Deterministic Miller-Rabin for n < 3.2*10^18 *)
let is_prime (n : int) : bool =
  if n < 2 then false
  else if n = 2 || n = 3 || n = 5 || n = 7 then true
  else if n mod 2 = 0 || n mod 3 = 0 then false
  else begin
    (* Factor n-1 = 2^s * d, d odd *)
    let d = ref (n - 1) and s = ref 0 in
    while !d mod 2 = 0 do d := !d / 2; incr s done;
    (* Witnesses sufficient for n < 3.2*10^18 *)
    let witnesses = [2; 3; 5; 7; 11; 13; 17; 19; 23; 29; 31; 37] in
    List.for_all (fun a ->
      if a >= n then true  (* witness >= n means n is prime by definition *)
      else miller_witness n !d !s a
    ) witnesses
  end

let () =
  let primes   = [2; 3; 5; 17; 97; 997; 1_000_003; 999_999_937] in
  let composites = [4; 9; 15; 100; 1001; 1_000_000] in
  List.iter (fun n ->
    Printf.printf "is_prime(%d) = %b  (expected true)\n" n (is_prime n)
  ) primes;
  List.iter (fun n ->
    Printf.printf "is_prime(%d) = %b  (expected false)\n" n (is_prime n)
  ) composites;
  (* Carmichael number: 561 = 3*11*17, fools Fermat test *)
  Printf.printf "is_prime(561) = %b  (Carmichael, expected false)\n" (is_prime 561)