๐Ÿฆ€ Functional Rust

827: Modular Arithmetic

Difficulty: 4 Level: Advanced Build a type-safe `ModInt` wrapper that enforces invariants and overloads operators โ€” the foundation of all modular number theory in Rust.

The Problem This Solves

Every cryptographic primitive, every competitive programming problem with "answer mod 10^9+7", every hash function โ€” they all do arithmetic in โ„ค_m. The challenge is engineering correctness: forgetting a `% m` here, overflowing a multiply there, getting a negative result from subtraction. These bugs are silent (values stay in the right range until they don't) and hard to track down. The solution is a newtype wrapper `ModInt` that enforces `v โˆˆ [0, m)` by construction and overloads `+`, `-`, `*` to automatically stay in range. This is idiomatic Rust: encode invariants in the type system so the compiler catches violations, not your test suite. Modular inverse is the key derived operation. When m is prime, Fermat's little theorem gives `a^(-1) โ‰ก a^(p-2) (mod p)` โ€” computable with fast exponentiation. For general m, use the Extended Euclidean algorithm. Both are shown here.

The Intuition

Modular arithmetic "wraps around" at m. Addition and subtraction stay safe with a single `% m` and a correction for negative subtraction results. Multiplication requires care: two `u64` values in [0, m) can have a product up to (10^9)ยฒ โ‰ˆ 10^18, which exceeds `u64::MAX` (โ‰ˆ 1.8 ร— 10^19) โ€” barely safe for m โ‰ค 10^9+7, but not for arbitrary u64 moduli. The safe idiom: widen to `u128` for the multiply, then reduce mod m. Fermat inverse: `a ร— a^(p-2) โ‰ก a^(p-1) โ‰ก 1 (mod p)` by Fermat's little theorem. So `a^(-1) = a^(p-2) mod p`. Requires p prime. For general m, Extended Euclidean finds x such that `aร—x + mร—y = gcd(a,m) = 1` โ€” then x mod m is the inverse.

How It Works in Rust

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
struct ModInt { v: u64, m: u64 }

impl ModInt {
 fn new(v: i64, m: u64) -> Self {
     // Handle negative inputs: (v % m + m) % m normalizes to [0, m)
     let v = ((v % m as i64) + m as i64) as u64 % m;
     ModInt { v, m }
 }

 // Modular exponentiation: base^exp mod m in O(log exp)
 fn pow(self, mut exp: u64) -> Self {
     let (mut base, mut result) = (self, ModInt::new(1, self.m));
     while exp > 0 {
         if exp & 1 == 1 { result = result * base; }
         base = base * base;
         exp >>= 1;
     }
     result
 }

 // Fermat inverse: O(log m), requires m prime
 fn inv_fermat(self) -> Self {
     assert!(self.v != 0, "no inverse for 0");
     self.pow(self.m - 2)  // a^(p-2) mod p
 }

 // Extended Euclidean inverse: works for any coprime (v, m)
 fn inv(self) -> Option<Self> {
     let (g, x, _) = extended_gcd(self.v as i64, self.m as i64);
     if g != 1 { None } else { Some(ModInt::new(x, self.m)) }
 }
}

impl Mul for ModInt {
 type Output = Self;
 fn mul(self, rhs: Self) -> Self {
     // u128 widening prevents overflow: (u64::MAX)^2 > u64::MAX
     let v = (self.v as u128 * rhs.v as u128 % self.m as u128) as u64;
     ModInt { v, m: self.m }
 }
}
// Add and Sub: (a + b) % m and (a + m - b) % m
The `impl Trait for Type` pattern is Rust's answer to OCaml's lack of operator overloading โ€” `+`, `-`, `*` on `ModInt` now look like regular arithmetic in calling code.

What This Unlocks

Key Differences

ConceptOCamlRust
Wrapper typeRecord `{v: int; m: int}` or module`struct ModInt { v: u64, m: u64 }`
Operator overloadCannot overload `+`, use functions`impl Add for ModInt` โ€” full overloading
u128 multiply`Int64` or Zarith for big values`(a as u128 * b as u128) % m as u128`
Negative normalization`((v mod m) + m) mod m`Same; cast through `i64` then `u64`
Fermat inverse`pow_mod a (m-2) m` function call`self.pow(self.m - 2)` method
/// Modular Arithmetic: add, sub, mul, inverse, pow.
///
/// ModInt wraps a value in [0, modulus) and implements arithmetic operators.

use std::ops::{Add, Sub, Mul};

const MOD: u64 = 1_000_000_007;

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
struct ModInt {
    v: u64,
    m: u64,
}

impl ModInt {
    fn new(v: i64, m: u64) -> Self {
        let v = ((v % m as i64) + m as i64) as u64 % m;
        ModInt { v, m }
    }

    /// Modular exponentiation: self^exp mod m. O(log exp).
    fn pow(self, mut exp: u64) -> Self {
        let mut base = self;
        let mut result = ModInt::new(1, self.m);
        while exp > 0 {
            if exp & 1 == 1 {
                result = result * base;
            }
            base = base * base;
            exp >>= 1;
        }
        result
    }

    /// Modular inverse via Fermat's little theorem (m must be prime).
    fn inv_fermat(self) -> Self {
        assert!(self.v != 0, "no inverse for 0");
        self.pow(self.m - 2)
    }

    /// Modular inverse via Extended Euclidean (works for any coprime v, m).
    fn inv(self) -> Option<Self> {
        let (g, x, _) = extended_gcd(self.v as i64, self.m as i64);
        if g != 1 {
            None
        } else {
            Some(ModInt::new(x, self.m))
        }
    }
}

impl Add for ModInt {
    type Output = Self;
    fn add(self, rhs: Self) -> Self {
        ModInt { v: (self.v + rhs.v) % self.m, m: self.m }
    }
}

impl Sub for ModInt {
    type Output = Self;
    fn sub(self, rhs: Self) -> Self {
        ModInt { v: (self.v + self.m - rhs.v) % self.m, m: self.m }
    }
}

impl Mul for ModInt {
    type Output = Self;
    fn mul(self, rhs: Self) -> Self {
        // Widen to u128 to prevent overflow
        let v = (self.v as u128 * rhs.v as u128 % self.m as u128) as u64;
        ModInt { v, m: self.m }
    }
}

impl std::fmt::Display for ModInt {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.v)
    }
}

/// Extended GCD: returns (gcd, x, y) where ax + by = gcd.
fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) {
    if b == 0 {
        (a, 1, 0)
    } else {
        let (g, x, y) = extended_gcd(b, a % b);
        (g, y, x - (a / b) * y)
    }
}

/// Standalone modular inverse using Extended Euclidean.
fn mod_inv(a: u64, m: u64) -> Option<u64> {
    let (g, x, _) = extended_gcd(a as i64, m as i64);
    if g != 1 {
        None
    } else {
        Some(((x % m as i64 + m as i64) as u64) % m)
    }
}

fn main() {
    let a = ModInt::new(3, 7);
    let b = ModInt::new(5, 7);
    println!("3 + 5 mod 7 = {}", a + b);  // 1
    println!("3 - 5 mod 7 = {}", a - b);  // 5
    println!("3 * 5 mod 7 = {}", a * b);  // 1

    println!("\n--- MOD = {} ---", MOD);
    let x = ModInt::new(999_999_999, MOD);
    let y = ModInt::new(9, MOD);
    println!("999_999_999 + 9 mod MOD = {}", x + y); // 1

    let two = ModInt::new(2, MOD);
    println!("2^10 mod MOD = {}", two.pow(10)); // 1024

    // Modular inverse: 3 * inv(3) โ‰ก 1 mod 7
    let three = ModInt::new(3, 7);
    let inv3 = three.inv_fermat();
    println!("\ninv(3) mod 7 = {} (expected 5)", inv3);
    println!("3 * 5 mod 7 = {}", three * inv3); // 1

    // Extended Euclidean inverse (general)
    println!("mod_inv(3, 7) = {:?}", mod_inv(3, 7)); // Some(5)
    println!("mod_inv(2, 4) = {:?}", mod_inv(2, 4)); // None (gcd=2)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_add_wrap() {
        let a = ModInt::new(6, 7);
        let b = ModInt::new(5, 7);
        assert_eq!((a + b).v, 4); // 11 % 7 = 4
    }

    #[test]
    fn test_sub_negative() {
        let a = ModInt::new(3, 10);
        let b = ModInt::new(7, 10);
        assert_eq!((a - b).v, 6); // 3-7=-4 โ†’ 6 mod 10
    }

    #[test]
    fn test_mul_overflow_safe() {
        // Large multiplication shouldn't overflow
        let a = ModInt::new(999_999_999, MOD);
        let b = ModInt::new(999_999_999, MOD);
        assert_eq!((a * b).v, (999_999_999u128 * 999_999_999 % MOD as u128) as u64);
    }

    #[test]
    fn test_pow_2_10() {
        assert_eq!(ModInt::new(2, MOD).pow(10).v, 1024);
    }

    #[test]
    fn test_inv_fermat() {
        let inv = ModInt::new(3, 7).inv_fermat();
        assert_eq!(inv.v, 5); // 3*5=15โ‰ก1 mod 7
    }

    #[test]
    fn test_inv_verify() {
        for a in 1u64..7 {
            let m = ModInt::new(a as i64, 7);
            let inv = m.inv_fermat();
            assert_eq!((m * inv).v, 1, "inv({a}) mod 7 failed");
        }
    }

    #[test]
    fn test_mod_inv_no_inverse() {
        assert_eq!(mod_inv(2, 4), None); // gcd(2,4)=2
        assert_eq!(mod_inv(6, 9), None); // gcd(6,9)=3
    }

    #[test]
    fn test_negative_input() {
        // Negative values should be normalised
        let a = ModInt::new(-1, 7);
        assert_eq!(a.v, 6);
    }
}
(* Modular Arithmetic in OCaml *)

(* Modular addition *)
let mod_add a b m = (a + b) mod m

(* Modular subtraction โ€” ensure non-negative *)
let mod_sub a b m = ((a - b) mod m + m) mod m

(* Modular multiplication โ€” use Int64 to avoid overflow for large values *)
let mod_mul a b m =
  Int64.(to_int (rem (mul (of_int a) (of_int b)) (of_int m)))

(* Fast modular exponentiation: a^exp mod m *)
let rec pow_mod (a : int) (exp : int) (m : int) : int =
  if exp = 0 then 1
  else if exp mod 2 = 0 then
    let half = pow_mod a (exp / 2) m in
    mod_mul half half m
  else
    mod_mul a (pow_mod a (exp - 1) m) m

(* Modular inverse via Fermat's little theorem (m must be prime) *)
let mod_inv_fermat (a : int) (p : int) : int =
  pow_mod a (p - 2) p

(* Extended Euclidean Algorithm: returns (gcd, x, y) s.t. ax + by = gcd *)
let rec extended_gcd (a : int) (b : int) : int * int * int =
  if b = 0 then (a, 1, 0)
  else
    let (g, x, y) = extended_gcd b (a mod b) in
    (g, y, x - (a / b) * y)

(* Modular inverse via Extended Euclidean (works for any coprime a, m) *)
let mod_inv (a : int) (m : int) : int option =
  let (g, x, _) = extended_gcd (((a mod m) + m) mod m) m in
  if g <> 1 then None
  else Some (((x mod m) + m) mod m)

let () =
  let m = 1_000_000_007 in
  Printf.printf "mod_add(999_999_999, 9, %d) = %d\n" m (mod_add 999_999_999 9 m);
  Printf.printf "mod_sub(3, 7, 10) = %d  (expected 6)\n" (mod_sub 3 7 10);
  Printf.printf "mod_mul(123456, 789012, %d) = %d\n" m (mod_mul 123456 789012 m);
  Printf.printf "pow_mod(2, 10, %d) = %d  (expected 1024)\n" m (pow_mod 2 10 m);
  Printf.printf "mod_inv_fermat(3, 7) = %d  (expected 5, since 3*5=15โ‰ก1 mod 7)\n"
    (mod_inv_fermat 3 7);
  (match mod_inv 3 7 with
   | Some inv -> Printf.printf "mod_inv(3, 7) = %d\n" inv
   | None -> Printf.printf "mod_inv(3, 7): no inverse\n");
  (* Verify: 3 * inv โ‰ก 1 mod 7 *)
  let inv = Option.get (mod_inv 3 7) in
  Printf.printf "3 * %d mod 7 = %d  (should be 1)\n" inv (mod_mul 3 inv 7)