🦀 Functional Rust

829: Chinese Remainder Theorem

Difficulty: 4 Level: Advanced Reconstruct a unique integer from its residues mod several moduli — the mathematical foundation of RSA speedup, arbitrary-precision arithmetic, and range query tricks.

The Problem This Solves

The Chinese Remainder Theorem answers: given a system of congruences x ≡ r₁ (mod m₁), x ≡ r₂ (mod m₂), …, find x. For pairwise coprime moduli, the solution is unique mod M = m₁ × m₂ × … × mₖ. This is far more than a mathematical curiosity. In practice: RSA with CRT decomposes decryption into two smaller exponentiations (mod p and mod q separately) and combines them — giving a 4× speedup that every RSA implementation uses. In competitive programming, CRT appears in problems where you need to find when two cyclical events coincide, or reconstruct a value from partial observations. In computer arithmetic, multi-precision multiplication uses CRT (via NTT) to split large polynomial multiplications into smaller modular ones. The implementation here handles the general case including non-coprime moduli: when `gcd(m₁, m₂) > 1`, a solution exists only if `r₁ ≡ r₂ (mod gcd(m₁, m₂))`, and the result is mod `lcm(m₁, m₂)` rather than mod `m₁ × m₂`.

The Intuition

For two congruences: x ≡ a₁ (mod m₁) and x ≡ a₂ (mod m₂). We need x = a₁ + m₁ × t such that `a₁ + m₁ × t ≡ a₂ (mod m₂)`. Solving for t: `m₁ × t ≡ (a₂ - a₁) (mod m₂)`. This is a linear congruence — solvable iff `gcd(m₁, m₂) | (a₂ - a₁)`. Find t via Extended Euclidean; then x = a₁ + m₁ × t mod lcm(m₁, m₂). Apply this pairwise to combine all congruences via `fold`. OCaml uses i128 naturally; Rust has native `i128` and `u128` since 1.26 — no external library needed for CRT arithmetic.

How It Works in Rust

// Extended GCD: returns (g, x, y) where a*x + b*y = g
fn extended_gcd(a: i128, b: i128) -> (i128, i128, i128) {
 if b == 0 { (a, 1, 0) }
 else {
     let (g, x, y) = extended_gcd(b, a % b);
     (g, y, x - (a / b) * y)  // Standard back-substitution
 }
}

// Combine x ≡ a1 (mod m1) and x ≡ a2 (mod m2)
// Returns Some((remainder, lcm)) or None if incompatible
fn crt_combine(a1: i128, m1: i128, a2: i128, m2: i128) -> Option<(i128, i128)> {
 let (g, p, _) = extended_gcd(m1, m2);
 if (a2 - a1) % g != 0 { return None; }  // No solution if g ∤ (a2 - a1)
 let lcm = m1 / g * m2;
 let m2g = m2 / g;                         // Effective modulus for t
 let diff = ((a2 - a1) / g) % m2g;
 let x = (a1 + m1 * ((diff * p % m2g + m2g) % m2g)) % lcm;
 Some(((x + lcm) % lcm, lcm))             // Normalize to [0, lcm)
}

// Solve a full system via pairwise folding
fn crt(congruences: &[(i128, i128)]) -> Option<(i128, i128)> {
 // try_fold: stops and returns None on first incompatible pair
 congruences.iter().try_fold((0i128, 1i128), |(r, m), &(a, mi)| {
     crt_combine(r, m, a, mi)
 })
}
// Example: x ≡ 2 (mod 3), x ≡ 3 (mod 5), x ≡ 2 (mod 7) → x = 23 (mod 105)
`try_fold` is the idiomatic Rust way to fold that can short-circuit on `None` — cleaner than explicit early return with `?`.

What This Unlocks

Key Differences

ConceptOCamlRust
128-bit arithmetic`Zarith` library or manual `Int64``i128` / `u128` native since Rust 1.26
Extended GCD`let rec egcd a b = ...` recursiveIdentical recursive structure
Fold with early exit`List.fold_left` + exception or `Option``iter().try_fold(...)` — idiomatic
Normalize modular result`((x mod m) + m) mod m`Same; add `+ lcm` to handle negatives
General (non-coprime)Same CRT combine with GCD check`if (a2 - a1) % g != 0 { return None }`
/// Chinese Remainder Theorem (CRT).
///
/// Solves: x ≡ aᵢ (mod mᵢ) for all i.
/// Works for non-coprime moduli; returns None when no solution exists.

/// Extended GCD: returns (g, x, y) where a*x + b*y = g.
fn extended_gcd(a: i128, b: i128) -> (i128, i128, i128) {
    if b == 0 {
        (a, 1, 0)
    } else {
        let (g, x, y) = extended_gcd(b, a % b);
        (g, y, x - (a / b) * y)
    }
}

/// Combine two congruences: x ≡ a1 (mod m1) and x ≡ a2 (mod m2).
/// Returns Some((remainder, lcm)) or None if incompatible.
fn crt_combine(a1: i128, m1: i128, a2: i128, m2: i128) -> Option<(i128, i128)> {
    let (g, p, _) = extended_gcd(m1, m2);
    if (a2 - a1) % g != 0 {
        return None; // No solution
    }
    let lcm = m1 / g * m2;
    let m2g = m2 / g;
    let diff = ((a2 - a1) / g) % m2g;
    let x = (a1 + m1 * ((diff * p % m2g + m2g) % m2g)) % lcm;
    let x = (x + lcm) % lcm;
    Some((x, lcm))
}

/// Solve a system of congruences.
/// Input: slice of (remainder, modulus) pairs.
fn crt(congruences: &[(i128, i128)]) -> Option<(i128, i128)> {
    congruences.iter().try_fold((0i128, 1i128), |(r, m), &(a, mi)| {
        crt_combine(r, m, a, mi)
    })
}

fn main() {
    // Classic: x ≡ 2 (mod 3), x ≡ 3 (mod 5), x ≡ 2 (mod 7) → x = 23 mod 105
    let system = [(2, 3), (3, 5), (2, 7)];
    match crt(&system) {
        Some((x, m)) => println!("x ≡ 2(3), x ≡ 3(5), x ≡ 2(7): x = {x} (mod {m})"),
        None => println!("No solution"),
    }

    // Non-coprime moduli with solution
    match crt(&[(0, 4), (6, 10)]) {
        Some((x, m)) => println!("x ≡ 0(4), x ≡ 6(10): x = {x} (mod {m})"),
        None => println!("No solution"),
    }

    // Non-coprime moduli with no solution
    match crt(&[(1, 4), (6, 10)]) {
        Some((x, m)) => println!("x ≡ 1(4), x ≡ 6(10): x = {x} (mod {m})"),
        None => println!("x ≡ 1(4), x ≡ 6(10): No solution"),
    }

    // Two-modulus example
    match crt(&[(1, 5), (2, 7)]) {
        Some((x, m)) => println!("x ≡ 1(5), x ≡ 2(7): x = {x} (mod {m})"),
        None => println!("No solution"),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_classic_three_congruences() {
        // x ≡ 2 (mod 3), x ≡ 3 (mod 5), x ≡ 2 (mod 7)
        let (x, m) = crt(&[(2, 3), (3, 5), (2, 7)]).unwrap();
        assert_eq!(x, 23);
        assert_eq!(m, 105); // 3*5*7
        assert_eq!(x % 3, 2);
        assert_eq!(x % 5, 3);
        assert_eq!(x % 7, 2);
    }

    #[test]
    fn test_two_coprime() {
        let (x, m) = crt(&[(1, 5), (2, 7)]).unwrap();
        assert_eq!(m, 35);
        assert_eq!(x % 5, 1);
        assert_eq!(x % 7, 2);
    }

    #[test]
    fn test_non_coprime_has_solution() {
        let result = crt(&[(0, 4), (6, 10)]);
        assert!(result.is_some());
        let (x, _m) = result.unwrap();
        assert_eq!(x % 4, 0);
        assert_eq!(x % 10, 6);
    }

    #[test]
    fn test_non_coprime_no_solution() {
        // gcd(4,10)=2, but 6-1=5 is not divisible by 2
        assert!(crt(&[(1, 4), (6, 10)]).is_none());
    }

    #[test]
    fn test_single_congruence() {
        let (x, m) = crt(&[(3, 7)]).unwrap();
        assert_eq!(x, 3);
        assert_eq!(m, 7);
    }

    #[test]
    fn test_solution_uniqueness() {
        // Solution should be unique mod M
        let (x, m) = crt(&[(2, 3), (3, 5)]).unwrap();
        // Verify no other solution in [0, m)
        let others: Vec<i128> = (0..m).filter(|&t| t != x && t % 3 == 2 && t % 5 == 3).collect();
        assert!(others.is_empty(), "Multiple solutions found: {:?}", others);
    }

    #[test]
    fn test_consistency_verification() {
        let system = [(2, 3), (3, 5), (2, 7)];
        let (x, _m) = crt(&system).unwrap();
        for (a, mi) in system {
            assert_eq!(x % mi, a, "x={x} fails x ≡ {a} (mod {mi})");
        }
    }
}
(* Chinese Remainder Theorem in OCaml *)

(* Extended GCD: returns (g, x, y) where a*x + b*y = g *)
let rec extended_gcd (a : int) (b : int) : int * int * int =
  if b = 0 then (a, 1, 0)
  else
    let (g, x, y) = extended_gcd b (a mod b) in
    (g, y, x - (a / b) * y)

(* Combine two congruences: x ≡ a1 (mod m1), x ≡ a2 (mod m2) *)
(* Returns Some (remainder, modulus) or None if no solution *)
let crt_combine (a1 : int) (m1 : int) (a2 : int) (m2 : int)
    : (int * int) option =
  let (g, p, _) = extended_gcd m1 m2 in
  if (a2 - a1) mod g <> 0 then None
  else begin
    let lcm = m1 / g * m2 in
    (* x = a1 + m1 * ((a2 - a1) / g * p mod (m2/g)) *)
    let m2g = m2 / g in
    let diff = ((a2 - a1) / g) mod m2g in
    let x = (a1 + m1 * ((diff * p mod m2g + m2g) mod m2g)) mod lcm in
    let x = (x + lcm) mod lcm in
    Some (x, lcm)
  end

(* Solve a system of congruences: x ≡ a_i (mod m_i) for all i *)
let crt (congruences : (int * int) list) : (int * int) option =
  List.fold_left (fun acc (a, m) ->
    match acc with
    | None -> None
    | Some (r, lcm) -> crt_combine r lcm a m
  ) (Some (0, 1)) congruences

let () =
  (* x ≡ 2 (mod 3), x ≡ 3 (mod 5), x ≡ 2 (mod 7) → x = 23 *)
  let system = [(2, 3); (3, 5); (2, 7)] in
  (match crt system with
   | Some (x, m) -> Printf.printf "x ≡ 2(mod 3), x ≡ 3(mod 5), x ≡ 2(mod 7): x = %d (mod %d)\n" x m
   | None -> Printf.printf "No solution\n");

  (* x ≡ 0 (mod 4), x ≡ 6 (mod 10) — non-coprime moduli, has solution *)
  (match crt [(0, 4); (6, 10)] with
   | Some (x, m) -> Printf.printf "x ≡ 0(mod 4), x ≡ 6(mod 10): x = %d (mod %d)\n" x m
   | None -> Printf.printf "No solution\n");

  (* x ≡ 1 (mod 4), x ≡ 6 (mod 10) — no solution since 6-1=5 not div by gcd(4,10)=2 *)
  (match crt [(1, 4); (6, 10)] with
   | Some (x, m) -> Printf.printf "x ≡ 1(mod 4), x ≡ 6(mod 10): x = %d (mod %d)\n" x m
   | None -> Printf.printf "x ≡ 1(mod 4), x ≡ 6(mod 10): No solution\n")