🦀 Functional Rust

837: Closest Pair of Points — Divide and Conquer

Difficulty: 5 Level: Master Find the two closest points in a set of n points in O(n log n) — the classical divide-and-conquer geometry algorithm.

The Problem This Solves

Given n points in the plane, find the pair with the smallest Euclidean distance. The brute-force approach checks all O(n²) pairs. The divide-and-conquer algorithm achieves O(n log n) — the same asymptotic bound as sorting — through a non-trivial strip-merging argument. Closest pair appears in clustering algorithms (initial step for single-linkage clustering), collision detection (finding nearby objects), computational biology (protein structure analysis), and as the canonical example of a geometric divide-and-conquer algorithm. Understanding it is essential for anyone working in computational geometry or teaching algorithm design. The algorithm returns the minimum distance and the pair of points achieving it. It handles degenerate cases (coincident points, collinear configurations) correctly.

The Intuition

1. Sort points by x-coordinate. 2. Split into left and right halves at the median x. 3. Recurse on each half; let δ = min(left_min, right_min). 4. The closest pair might cross the split line — but only if both points lie within δ of the split. This "strip" has width 2δ. 5. In the strip, sort by y and check each point against at most 7 others (the geometric argument: in a δ×2δ rectangle, you can pack at most 8 points with pairwise distance ≥ δ, so each point has at most 7 candidates above it within δ vertical distance). The strip-of-7 argument is the key insight that makes the algorithm O(n log n) rather than O(n²). Without it, checking all pairs in the strip would be quadratic. O(n log n) total: O(n log n) for the initial sort + O(n log n) from the recurrence T(n) = 2T(n/2) + O(n).

How It Works in Rust

fn dist_sq(a: (f64, f64), b: (f64, f64)) -> f64 {
 (a.0 - b.0).powi(2) + (a.1 - b.1).powi(2)
}

fn closest_pair(pts: &mut [(f64, f64)]) -> f64 {
 pts.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); // sort by x
 closest_rec(pts).sqrt()
}

fn closest_rec(pts: &[(f64, f64)]) -> f64 {
 let n = pts.len();
 if n <= 3 {
     // Base case: brute force among ≤3 points
     let mut min_d = f64::INFINITY;
     for i in 0..n {
         for j in i+1..n { min_d = min_d.min(dist_sq(pts[i], pts[j])); }
     }
     return min_d;
 }

 let mid = n / 2;
 let mid_x = pts[mid].0;

 // Recurse on left and right halves
 let d_left  = closest_rec(&pts[..mid]);
 let d_right = closest_rec(&pts[mid..]);
 let mut d = d_left.min(d_right); // δ² — compare squared to avoid sqrt

 // Collect strip: points within δ of the split line
 let delta = d.sqrt();
 let mut strip: Vec<(f64, f64)> = pts.iter()
     .filter(|&&p| (p.0 - mid_x).abs() < delta)
     .copied()
     .collect();

 // Sort strip by y for the 7-point check
 strip.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());

 // Check each point against at most 7 above it within δ vertical distance
 for i in 0..strip.len() {
     let mut j = i + 1;
     while j < strip.len() && (strip[j].1 - strip[i].1).powi(2) < d {
         d = d.min(dist_sq(strip[i], strip[j]));
         j += 1;
     }
 }
 d
}
Working with squared distances (`dist_sq`) avoids `sqrt` in the inner loop — `sqrt` is expensive and the ordering is preserved under squaring. Only take the final `sqrt` at the top level. `partial_cmp` is required for `f64` because `f64::NAN != f64::NAN` — Rust's `Ord` requires total order, which floats don't satisfy. If your points are guaranteed finite, `partial_cmp(...).unwrap()` is safe. The strip filter `(p.0 - mid_x).abs() < delta` uses the pre-computed `delta = d.sqrt()` from the current best distance — as d improves, fewer points enter the strip.

What This Unlocks

Key Differences

ConceptOCamlRust
Mutable sort`Array.sort` in place`pts.sort_by(...)` — in-place, `partial_cmp` for f64
Float ordering`compare` works but is slow for float`partial_cmp(...).unwrap()` — explicit, panics on NaN
Strip filter`List.filter` creates a new list`.filter().copied().collect()` — same cost, owns data
Squared distanceSame trick`dist_sq` avoids sqrt in inner loop
Recursion on slicesSubarray with index bounds`&pts[..mid]`, `&pts[mid..]` — zero-copy slice borrows
/// Closest Pair of Points — Divide and Conquer O(n log n).
///
/// Split by median x, recurse, then check the strip of width δ around the split.
/// At most 8 points in any δ×2δ box → strip check is O(n).

#[derive(Clone, Copy, Debug, PartialEq)]
struct Point {
    x: f64,
    y: f64,
}

impl Point {
    fn new(x: f64, y: f64) -> Self { Point { x, y } }
    fn dist(&self, other: &Point) -> f64 {
        let dx = self.x - other.x;
        let dy = self.y - other.y;
        (dx * dx + dy * dy).sqrt()
    }
}

/// Brute-force for small n.
fn brute_force(pts: &[Point]) -> f64 {
    let n = pts.len();
    let mut best = f64::INFINITY;
    for i in 0..n {
        for j in i + 1..n {
            best = best.min(pts[i].dist(&pts[j]));
        }
    }
    best
}

/// Check strip: points within delta of mid_x, sorted by y.
fn strip_closest(strip: &mut Vec<Point>, delta: f64) -> f64 {
    strip.sort_by(|a, b| a.y.partial_cmp(&b.y).unwrap());
    let mut best = delta;
    let n = strip.len();
    for i in 0..n {
        let mut j = i + 1;
        // At most 7 comparisons per point (geometric packing)
        while j < n && strip[j].y - strip[i].y < best {
            best = best.min(strip[i].dist(&strip[j]));
            j += 1;
        }
    }
    best
}

/// Recursive divide-and-conquer. pts_x must be sorted by x.
fn closest_rec(pts_x: &[Point]) -> f64 {
    let n = pts_x.len();
    if n <= 3 {
        return brute_force(pts_x);
    }

    let mid = n / 2;
    let mid_x = pts_x[mid].x;

    let dl = closest_rec(&pts_x[..mid]);
    let dr = closest_rec(&pts_x[mid..]);
    let mut delta = dl.min(dr);

    // Collect strip: points within delta of the dividing line
    let mut strip: Vec<Point> = pts_x
        .iter()
        .filter(|p| (p.x - mid_x).abs() < delta)
        .copied()
        .collect();

    delta = delta.min(strip_closest(&mut strip, delta));
    delta
}

/// Public API: find minimum distance among all pairs of points.
pub fn closest_pair(points: &[Point]) -> f64 {
    if points.len() < 2 { return f64::INFINITY; }
    let mut sorted = points.to_vec();
    sorted.sort_by(|a, b| a.x.partial_cmp(&b.x).unwrap());
    closest_rec(&sorted)
}

fn main() {
    let points = vec![
        Point::new(2.0, 3.0),  Point::new(12.0, 30.0),
        Point::new(40.0, 50.0), Point::new(5.0, 1.0),
        Point::new(12.0, 10.0), Point::new(3.0, 4.0),
    ];

    let d = closest_pair(&points);
    let bf = brute_force(&points);
    println!("Closest pair (D&C):    {:.6}", d);
    println!("Closest pair (brute):  {:.6}", bf);
    println!("Match: {}", (d - bf).abs() < 1e-9);

    // Edge case: two points
    let two = vec![Point::new(0.0, 0.0), Point::new(3.0, 4.0)];
    println!("Two points distance: {} (expected 5.0)", closest_pair(&two));
}

#[cfg(test)]
mod tests {
    use super::*;

    fn matches_brute(pts: Vec<Point>) {
        let d = closest_pair(&pts);
        let bf = brute_force(&pts);
        assert!((d - bf).abs() < 1e-9,
            "d&c={d} brute={bf} for {pts:?}");
    }

    #[test]
    fn test_basic() {
        matches_brute(vec![
            Point::new(2.0, 3.0), Point::new(12.0, 30.0),
            Point::new(40.0, 50.0), Point::new(5.0, 1.0),
            Point::new(12.0, 10.0), Point::new(3.0, 4.0),
        ]);
    }

    #[test]
    fn test_two_points() {
        let d = closest_pair(&[Point::new(0.0, 0.0), Point::new(3.0, 4.0)]);
        assert!((d - 5.0).abs() < 1e-9);
    }

    #[test]
    fn test_collinear() {
        let pts: Vec<Point> = (0..10).map(|i| Point::new(i as f64, 0.0)).collect();
        matches_brute(pts);
    }

    #[test]
    fn test_grid() {
        let pts: Vec<Point> = (0..5)
            .flat_map(|i| (0..5).map(move |j| Point::new(i as f64, j as f64)))
            .collect();
        matches_brute(pts);
    }

    #[test]
    fn test_clustered() {
        let mut pts = vec![
            Point::new(0.0, 0.0), Point::new(100.0, 100.0), Point::new(200.0, 200.0),
        ];
        pts.push(Point::new(0.1, 0.0)); // Very close to (0,0)
        let d = closest_pair(&pts);
        assert!((d - 0.1).abs() < 1e-9);
    }

    #[test]
    fn test_strip_case() {
        // Points near the dividing line
        let pts = vec![
            Point::new(0.0, 0.0), Point::new(1.0, 0.0),
            Point::new(0.5, 0.3), Point::new(0.5, -0.3),
        ];
        matches_brute(pts);
    }
}
(* Closest Pair of Points O(n log n) in OCaml *)

type point = { x: float; y: float }

let dist a b =
  let dx = a.x -. b.x and dy = a.y -. b.y in
  sqrt (dx *. dx +. dy *. dy)

(* Brute force for n ≤ 3 *)
let brute_force pts =
  let n = Array.length pts in
  let best = ref infinity in
  let pair = ref (pts.(0), pts.(0)) in
  for i = 0 to n - 1 do
    for j = i + 1 to n - 1 do
      let d = dist pts.(i) pts.(j) in
      if d < !best then begin best := d; pair := (pts.(i), pts.(j)) end
    done
  done;
  (!best, !pair)

(* Strip scan: check points within delta of the dividing line, sorted by y *)
let strip_closest (strip : point list) (delta : float) : float =
  let arr = Array.of_list (List.sort (fun a b -> compare a.y b.y) strip) in
  let n = Array.length arr in
  let best = ref delta in
  for i = 0 to n - 1 do
    let j = ref (i + 1) in
    while !j < n && arr.(!j).y -. arr.(i).y < !best do
      let d = dist arr.(i) arr.(!j) in
      if d < !best then best := d;
      incr j
    done
  done;
  !best

(* Divide and conquer *)
let rec closest_pair_rec (pts_x : point array) : float =
  let n = Array.length pts_x in
  if n <= 3 then fst (brute_force pts_x)
  else begin
    let mid = n / 2 in
    let mid_x = pts_x.(mid).x in
    let left = Array.sub pts_x 0 mid in
    let right = Array.sub pts_x mid (n - mid) in
    let dl = closest_pair_rec left in
    let dr = closest_pair_rec right in
    let delta = min dl dr in
    (* Build strip: points within delta of the dividing line *)
    let strip = Array.to_list pts_x
      |> List.filter (fun p -> abs_float (p.x -. mid_x) < delta) in
    min delta (strip_closest strip delta)
  end

let closest_pair (points : point list) : float =
  let pts = Array.of_list (List.sort (fun a b -> compare a.x b.x) points) in
  closest_pair_rec pts

let () =
  let points = [
    {x=2.0;y=3.0}; {x=12.0;y=30.0}; {x=40.0;y=50.0};
    {x=5.0;y=1.0};  {x=12.0;y=10.0}; {x=3.0;y=4.0};
  ] in
  let d = closest_pair points in
  Printf.printf "Closest pair distance: %.4f\n" d;
  (* Brute force verification *)
  let arr = Array.of_list points in
  let (bf, _) = brute_force arr in
  Printf.printf "Brute force distance:  %.4f\n" bf;
  Printf.printf "Match: %b\n" (abs_float (d -. bf) < 1e-9)