πŸ¦€ Functional Rust

828: Fast Modular Exponentiation

Difficulty: 3 Level: Intermediate Compute a^b mod m in O(log b) by repeated squaring β€” the primitive operation behind RSA, Diffie-Hellman, Miller-Rabin, and Fermat's little theorem.

The Problem This Solves

NaΓ―ve exponentiation multiplies `a` by itself `b` times: O(b) multiplications. For cryptographic exponents (b β‰ˆ 2^2048 in RSA), this is completely infeasible. Fast exponentiation reduces this to O(log b) multiplications by exploiting the binary representation of the exponent: square when the current bit is 0, square-and-multiply when it's 1. This single algorithm appears everywhere in cryptography: RSA encryption is `m^e mod n`, RSA decryption is `c^d mod n`, Diffie-Hellman key exchange is `g^a mod p`, Miller-Rabin primality tests compute `a^d mod n`, and Fermat's little theorem gives modular inverses as `a^(p-2) mod p`. If you understand one algorithm that enables modern cryptography, this is it. The matrix variant extends the idea: `matrix^n mod m` in O(log n) matrix multiplications computes Fibonacci numbers, linear recurrences, and graph path counts in logarithmic time.

The Intuition

Binary representation of b: 10 in binary is 1010. Reading bits right-to-left: always square, multiply when bit is 1. `a^10 = a^(8+2) = a^8 Γ— a^2`. Each step: `result = result base` if the low bit is 1, then `base = base base`, then `b >>= 1`. Total: ⌊logβ‚‚ bβŒ‹ squarings and at most as many multiplications. Overflow: with 64-bit modulus, intermediate products need 128 bits. `(a as u128 * b as u128) % m as u128` is the idiomatic Rust pattern β€” hardware-supported on x86-64, no external library needed.

How It Works in Rust

// Iterative: O(log exp), O(1) space
fn pow_mod(mut base: u64, mut exp: u64, m: u64) -> u64 {
 if m == 1 { return 0; }
 let mut result = 1u64;
 base %= m;                          // Reduce base immediately
 while exp > 0 {
     if exp & 1 == 1 {              // Low bit set: multiply into result
         result = (result as u128 * base as u128 % m as u128) as u64;
     }
     base = (base as u128 * base as u128 % m as u128) as u64;  // Square
     exp >>= 1;                      // Move to next bit
 }
 result
}

// Recursive: mirrors OCaml's natural style β€” easier to verify correct
fn pow_mod_rec(base: u64, exp: u64, m: u64) -> u64 {
 if m == 1 { return 0; }
 match exp {
     0 => 1 % m,
     e if e & 1 == 1 =>             // Odd: a^n = a Γ— a^(n-1)
         (base as u128 * pow_mod_rec(base, exp - 1, m) as u128 % m as u128) as u64,
     _ => {                          // Even: a^n = (a^(n/2))Β²
         let half = pow_mod_rec(base, exp / 2, m);
         (half as u128 * half as u128 % m as u128) as u64
     }
 }
}

// Matrix exponentiation: O(kΒ³ log n) for kΓ—k matrices
// Fast Fibonacci: [F(n+1), F(n)] = [[1,1],[1,0]]^n Γ— [1, 0]
fn fib_fast(n: u64, m: u64) -> u64 {
 // [[1,1],[1,0]]^n gives Fibonacci numbers in O(log n) multiplications
 // See the Matrix struct in the source for full implementation
}
The `match` on `exp & 1` is idiomatic Rust pattern matching on a computed value β€” cleaner than a chain of `if/else`, and the compiler generates identical code.

What This Unlocks

Key Differences

ConceptOCamlRust
Iterative versionExplicit `ref` variables`let mut` β€” cleaner syntax
Recursive version`let rec pow_mod a e m = match e with``fn pow_mod_rec(...) { match exp { ... } }`
Bit test`e land 1 = 1``exp & 1 == 1`
Right shift`e lsr 1``exp >>= 1`
128-bit multiply`Int64` widening or Zarith`(a as u128 * b as u128) % m as u128` β€” native
/// Fast Modular Exponentiation β€” binary (square-and-multiply) method.
/// Computes base^exp mod m in O(log exp) multiplications.

/// Iterative binary exponentiation. O(log exp).
fn pow_mod(mut base: u64, mut exp: u64, m: u64) -> u64 {
    if m == 1 { return 0; }
    let mut result = 1u64;
    base %= m;
    while exp > 0 {
        if exp & 1 == 1 {
            result = (result as u128 * base as u128 % m as u128) as u64;
        }
        base = (base as u128 * base as u128 % m as u128) as u64;
        exp >>= 1;
    }
    result
}

/// Recursive binary exponentiation β€” mirrors OCaml's natural style.
fn pow_mod_rec(base: u64, exp: u64, m: u64) -> u64 {
    if m == 1 { return 0; }
    match exp {
        0 => 1 % m,
        e if e & 1 == 1 => (base as u128 * pow_mod_rec(base, exp - 1, m) as u128 % m as u128) as u64,
        _ => {
            let half = pow_mod_rec(base, exp / 2, m);
            (half as u128 * half as u128 % m as u128) as u64
        }
    }
}

/// 2Γ—2 matrix for matrix exponentiation (e.g., fast Fibonacci).
#[derive(Clone, Copy)]
struct Matrix {
    a: [[u64; 2]; 2],
}

impl Matrix {
    fn identity() -> Self {
        Matrix { a: [[1, 0], [0, 1]] }
    }

    fn mul(self, rhs: Self, m: u64) -> Self {
        let mut res = [[0u64; 2]; 2];
        for i in 0..2 {
            for j in 0..2 {
                for k in 0..2 {
                    res[i][j] = (res[i][j] + self.a[i][k] as u128 * rhs.a[k][j] as u128 % m as u128) as u64 % m;
                }
            }
        }
        Matrix { a: res }
    }

    fn pow(self, mut exp: u64, m: u64) -> Self {
        let mut result = Self::identity();
        let mut base = self;
        while exp > 0 {
            if exp & 1 == 1 {
                result = result.mul(base, m);
            }
            base = base.mul(base, m);
            exp >>= 1;
        }
        result
    }
}

/// Fibonacci F(n) mod p in O(log n) using matrix exponentiation.
fn fib_mod(n: u64, p: u64) -> u64 {
    if n == 0 { return 0; }
    let m = Matrix { a: [[1, 1], [1, 0]] };
    m.pow(n - 1, p).a[0][0]
}

fn main() {
    println!("2^10 mod 1000 = {} (expected 24)", pow_mod(2, 10, 1000));
    println!("3^100 mod 1_000_000_007 = {}", pow_mod(3, 100, 1_000_000_007));
    println!("2^1000 mod 1009 = {}", pow_mod(2, 1000, 1009));
    println!("recursive: 2^10 mod 1000 = {}", pow_mod_rec(2, 10, 1000));

    println!("\nFibonacci via matrix exponentiation:");
    println!("fib(10) = {} (expected 55)", fib_mod(10, u64::MAX));
    println!("fib(10) mod 100 = {} (expected 55)", fib_mod(10, 100));
    println!("fib(100) mod 1_000_000_007 = {}", fib_mod(100, 1_000_000_007));
    println!("fib(1_000_000) mod 1_000_000_007 = {}", fib_mod(1_000_000, 1_000_000_007));
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_pow_mod_basic() {
        assert_eq!(pow_mod(2, 10, 1000), 24);     // 1024 mod 1000
        assert_eq!(pow_mod(2, 0, 7), 1);           // anything^0 = 1
        assert_eq!(pow_mod(0, 5, 7), 0);           // 0^anything = 0
    }

    #[test]
    fn test_pow_mod_prime_modulus() {
        // Fermat's little theorem: a^(p-1) ≑ 1 (mod p)
        let p = 1_000_000_007u64;
        assert_eq!(pow_mod(3, p - 1, p), 1);
        assert_eq!(pow_mod(7, p - 1, p), 1);
    }

    #[test]
    fn test_pow_mod_m1() {
        assert_eq!(pow_mod(100, 100, 1), 0);
    }

    #[test]
    fn test_recursive_matches_iterative() {
        for base in 0..10u64 {
            for exp in 0..15u64 {
                let m = 1009u64;
                assert_eq!(pow_mod(base, exp, m), pow_mod_rec(base, exp, m),
                    "mismatch at {base}^{exp} mod {m}");
            }
        }
    }

    #[test]
    fn test_fib_mod_basic() {
        // F(0)=0, F(1)=1, F(2)=1, F(10)=55
        assert_eq!(fib_mod(0, u64::MAX), 0);
        assert_eq!(fib_mod(1, u64::MAX), 1);
        assert_eq!(fib_mod(10, u64::MAX), 55);
    }

    #[test]
    fn test_fib_mod_large() {
        // Known: F(100) mod 10^9+7 = 687995182
        assert_eq!(fib_mod(100, 1_000_000_007), 687_995_182);
    }

    #[test]
    fn test_large_exponent() {
        // 2^62 mod (2^63-1) β€” stress test large values
        let result = pow_mod(2, 62, (1u64 << 63) - 1);
        assert_eq!(result, 1u64 << 62);
    }
}
(* Fast Modular Exponentiation in OCaml *)

(* Recursive binary method β€” elegant and clear *)
let rec pow_mod_rec (base : int) (exp : int) (m : int) : int =
  if exp = 0 then 1 mod m
  else if exp land 1 = 1 then
    (* Odd exponent: a^n = a * a^(n-1) *)
    let rest = pow_mod_rec base (exp - 1) m in
    Int64.(to_int (rem (mul (of_int base) (of_int rest)) (of_int m)))
  else
    (* Even exponent: a^n = (a^(n/2))^2 *)
    let half = pow_mod_rec base (exp / 2) m in
    Int64.(to_int (rem (mul (of_int half) (of_int half)) (of_int m)))

(* Iterative binary method β€” more efficient (no stack) *)
let pow_mod (base : int) (exp : int) (m : int) : int =
  let result = ref 1 in
  let base = ref (base mod m) in
  let exp = ref exp in
  while !exp > 0 do
    if !exp land 1 = 1 then
      result := Int64.(to_int (rem (mul (of_int !result) (of_int !base)) (of_int m)));
    base := Int64.(to_int (rem (mul (of_int !base) (of_int !base)) (of_int m)));
    exp := !exp lsr 1
  done;
  !result

(* Matrix exponentiation for Fibonacci β€” shows generality of the method *)
(* M^n via binary exponentiation where M is a 2x2 matrix *)
type matrix = { a: int; b: int; c: int; d: int }

let mat_mul m1 m2 modp =
  let mm a b = Int64.(to_int (rem (mul (of_int a) (of_int b)) (of_int modp))) in
  { a = (mm m1.a m2.a + mm m1.b m2.c) mod modp;
    b = (mm m1.a m2.b + mm m1.b m2.d) mod modp;
    c = (mm m1.c m2.a + mm m1.d m2.c) mod modp;
    d = (mm m1.c m2.b + mm m1.d m2.d) mod modp }

let mat_id = { a=1; b=0; c=0; d=1 }

let rec mat_pow m exp modp =
  if exp = 0 then mat_id
  else if exp land 1 = 1 then mat_mul m (mat_pow m (exp-1) modp) modp
  else let half = mat_pow m (exp/2) modp in mat_mul half half modp

(* F(n) mod p using matrix exponentiation in O(log n) *)
let fib_mod n p =
  if n = 0 then 0
  else
    let m = mat_pow { a=1; b=1; c=1; d=0 } (n-1) p in
    m.a

let () =
  Printf.printf "2^10 mod 1000 = %d  (expected 24)\n" (pow_mod 2 10 1000);
  Printf.printf "3^100 mod 1_000_000_007 = %d\n" (pow_mod 3 100 1_000_000_007);
  Printf.printf "2^1000 mod 1009 = %d\n" (pow_mod 2 1000 1009);
  Printf.printf "recursive: 2^10 mod 1000 = %d\n" (pow_mod_rec 2 10 1000);
  Printf.printf "fib(10) mod 100 = %d  (expected 55)\n" (fib_mod 10 100);
  Printf.printf "fib(100) mod 1_000_000_007 = %d\n" (fib_mod 100 1_000_000_007)