🦀 Functional Rust

840: Divide and Conquer — Generic Recursive Framework

Difficulty: 3 Level: Intermediate Split, recurse, combine: the algorithmic pattern behind merge sort, binary search, FFT, and closest-pair — parameterize it with closures to make it reusable.

The Problem This Solves

Divide and conquer is the algorithmic pattern that turns O(n²) or O(n³) brute-force problems into O(n log n) or O(n log² n) solutions. Merge sort, FFT, Strassen matrix multiplication, closest pair of points, and many geometric algorithms all decompose into: split the problem, solve each half recursively, combine. Recognizing this pattern is the difference between a working solution and an optimal one. The generic D&C framework — parameterized by `is_base_case`, `split`, `solve_base`, `recurse`, and `combine` — makes the pattern explicit and reusable. In practice you implement each algorithm concretely (merge sort, binary search), but understanding the abstract structure helps you apply D&C to new problems and reason about complexity with the master theorem. The master theorem gives you the recurrence T(n) = a×T(n/b) + f(n): merge sort (a=2, b=2, f=O(n)) → O(n log n). Binary search (a=1, b=2, f=O(1)) → O(log n). This is the formula you apply when designing a new D&C algorithm.

The Intuition

Every D&C algorithm has the same skeleton: split the input at a midpoint (or pivot), recurse on each piece, then merge the results. The interesting work happens in the split or the merge. For merge sort, the split is trivial (cut in half), the merge is the O(n) work. For binary search, the merge is trivial (just the result from one half), the split is O(1). For FFT, both split and combine are O(n) but the constant factors are carefully chosen to cancel imaginary parts. In Rust, slices make D&C natural: `&xs[..mid]` and `&xs[mid..]` give the two halves with zero copying, and the borrow checker ensures they don't alias.

How It Works in Rust

// Merge sort: O(n log n) — trivial split, O(n) merge
fn merge_sort<T: Ord + Clone>(xs: &[T]) -> Vec<T> {
 if xs.len() <= 1 { return xs.to_vec(); }
 let mid = xs.len() / 2;
 let left  = merge_sort(&xs[..mid]);   // Recurse left half
 let right = merge_sort(&xs[mid..]);   // Recurse right half
 merge(left, right)                    // Combine: O(n)
}

fn merge<T: Ord>(a: Vec<T>, b: Vec<T>) -> Vec<T> {
 let (mut i, mut j) = (0, 0);
 let mut result = Vec::with_capacity(a.len() + b.len());
 while i < a.len() && j < b.len() {
     if a[i] <= b[j] { result.push(a[i].clone()); i += 1; }
     else             { result.push(b[j].clone()); j += 1; }
 }
 result.extend_from_slice(&a[i..]);
 result.extend_from_slice(&b[j..]);
 result
}

// Binary search: O(log n) — O(n/2) split, trivial merge
fn binary_search<T: Ord>(arr: &[T], target: &T) -> Option<usize> {
 let (mut lo, mut hi) = (0usize, arr.len());
 while lo < hi {
     let mid = lo + (hi - lo) / 2;  // Avoids overflow vs (lo + hi) / 2
     match arr[mid].cmp(target) {
         std::cmp::Ordering::Equal   => return Some(mid),
         std::cmp::Ordering::Less    => lo = mid + 1,
         std::cmp::Ordering::Greater => hi = mid,
     }
 }
 None
}

// Master theorem quick reference:
// T(n) = 2T(n/2) + O(n)   → O(n log n)   [merge sort]
// T(n) = 1T(n/2) + O(1)   → O(log n)     [binary search]
// T(n) = 2T(n/2) + O(n²)  → O(n²)        [f dominates: Case 3]
`lo + (hi - lo) / 2` is the standard overflow-safe midpoint — `(lo + hi) / 2` overflows for large indices in many languages including Rust (though Rust will panic in debug mode, making the bug visible).

What This Unlocks

Key Differences

ConceptOCamlRust
Slice splitting`Array.sub arr 0 mid` (copies)`&arr[..mid]` — zero-copy borrow
Generic type`'a array` with `compare``<T: Ord + Clone>` — explicit bounds
Higher-order D&CFunctions as arguments`Fn(...)` trait objects or generics
Merge sort resultReturns new list/arrayReturns `Vec<T>` — always allocates
Overflow-safe mid`lo + (hi - lo) / 2`Same — explicit, not compiler-magic
/// Divide and Conquer: Generic Recursive Framework.
///
/// Demonstrated with: merge sort, binary search, max subarray.
/// The pattern: split → recurse → combine.

/// Merge sort: O(n log n) time, O(n) space.
fn merge_sort<T: Ord + Clone>(xs: &[T]) -> Vec<T> {
    if xs.len() <= 1 {
        return xs.to_vec();
    }
    let mid = xs.len() / 2;
    let left = merge_sort(&xs[..mid]);
    let right = merge_sort(&xs[mid..]);
    merge(left, right)
}

fn merge<T: Ord>(a: Vec<T>, b: Vec<T>) -> Vec<T> {
    let (mut i, mut j) = (0, 0);
    let mut result = Vec::with_capacity(a.len() + b.len());
    while i < a.len() && j < b.len() {
        if a[i] <= b[j] {
            result.push(a[i].clone());
            i += 1;
        } else {
            result.push(b[j].clone());
            j += 1;
        }
    }
    result.extend_from_slice(&a[i..]);
    result.extend_from_slice(&b[j..]);
    result
}

/// Binary search: O(log n). Returns the index of target.
fn binary_search<T: Ord>(arr: &[T], target: &T) -> Option<usize> {
    let (mut lo, mut hi) = (0usize, arr.len());
    while lo < hi {
        let mid = lo + (hi - lo) / 2;
        match arr[mid].cmp(target) {
            std::cmp::Ordering::Equal => return Some(mid),
            std::cmp::Ordering::Less => lo = mid + 1,
            std::cmp::Ordering::Greater => hi = mid,
        }
    }
    None
}

/// Maximum crossing subarray sum (for divide step in max subarray).
fn max_crossing(arr: &[i64], lo: usize, mid: usize, hi: usize) -> i64 {
    let mut left_sum = i64::MIN;
    let mut s = 0i64;
    for i in (lo..=mid).rev() {
        s += arr[i];
        if s > left_sum { left_sum = s; }
    }
    let mut right_sum = i64::MIN;
    s = 0;
    for i in mid + 1..=hi {
        s += arr[i];
        if s > right_sum { right_sum = s; }
    }
    left_sum + right_sum
}

/// Maximum subarray sum via D&C: O(n log n).
fn max_subarray(arr: &[i64], lo: usize, hi: usize) -> i64 {
    if lo == hi { return arr[lo]; }
    let mid = lo + (hi - lo) / 2;
    let left_max = max_subarray(arr, lo, mid);
    let right_max = max_subarray(arr, mid + 1, hi);
    let cross_max = max_crossing(arr, lo, mid, hi);
    left_max.max(right_max).max(cross_max)
}

/// Generic D&C framework as a higher-order function.
fn divide_and_conquer<T, R, F, S, C>(
    problem: T,
    base_case: impl Fn(&T) -> Option<R>,
    split: impl Fn(T) -> Vec<T>,
    solve_sub: &F,
    combine: C,
) -> R
where
    F: Fn(T) -> R,
    C: Fn(Vec<R>) -> R,
    T: Clone,
{
    if let Some(result) = base_case(&problem) {
        return result;
    }
    let subproblems = split(problem);
    let sub_results: Vec<R> = subproblems.into_iter().map(|p| solve_sub(p)).collect();
    combine(sub_results)
}

fn main() {
    // Merge sort
    let xs = vec![5i32, 3, 8, 1, 9, 2, 7, 4, 6];
    let sorted = merge_sort(&xs);
    println!("merge_sort({xs:?}) = {sorted:?}");

    // Binary search
    let arr = vec![1i32, 3, 5, 7, 9, 11, 13];
    println!("binary_search(7) = {:?}", binary_search(&arr, &7));
    println!("binary_search(6) = {:?}", binary_search(&arr, &6));

    // Max subarray
    let nums = [-2i64, 1, -3, 4, -1, 2, 1, -5, 4];
    let n = nums.len();
    println!("max_subarray({nums:?}) = {} (expected 6)", max_subarray(&nums, 0, n - 1));
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_merge_sort() {
        assert_eq!(merge_sort(&[5, 3, 8, 1, 9, 2, 7, 4, 6]), vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
    }

    #[test]
    fn test_merge_sort_empty() {
        assert_eq!(merge_sort::<i32>(&[]), vec![]);
    }

    #[test]
    fn test_merge_sort_single() {
        assert_eq!(merge_sort(&[42i32]), vec![42]);
    }

    #[test]
    fn test_merge_sort_strings() {
        assert_eq!(merge_sort(&["banana", "apple", "cherry"]), vec!["apple", "banana", "cherry"]);
    }

    #[test]
    fn test_binary_search_found() {
        let arr = vec![1, 3, 5, 7, 9, 11, 13];
        assert_eq!(binary_search(&arr, &7), Some(3));
        assert_eq!(binary_search(&arr, &1), Some(0));
        assert_eq!(binary_search(&arr, &13), Some(6));
    }

    #[test]
    fn test_binary_search_not_found() {
        let arr = vec![1, 3, 5, 7, 9];
        assert_eq!(binary_search(&arr, &6), None);
        assert_eq!(binary_search(&arr, &0), None);
        assert_eq!(binary_search(&arr, &10), None);
    }

    #[test]
    fn test_max_subarray() {
        let nums = [-2i64, 1, -3, 4, -1, 2, 1, -5, 4];
        let n = nums.len();
        assert_eq!(max_subarray(&nums, 0, n - 1), 6); // [4,-1,2,1]
    }

    #[test]
    fn test_max_subarray_all_negative() {
        let nums = [-3i64, -1, -2];
        let n = nums.len();
        assert_eq!(max_subarray(&nums, 0, n - 1), -1);
    }
}
(* Divide and Conquer Framework in OCaml *)

(* Merge sort — the canonical D&C algorithm *)
let rec merge_sort (xs : 'a list) : 'a list =
  match xs with
  | [] | [_] -> xs
  | _ ->
    let n = List.length xs in
    let mid = n / 2 in
    let left = List.filteri (fun i _ -> i < mid) xs in
    let right = List.filteri (fun i _ -> i >= mid) xs in
    let sl = merge_sort left in
    let sr = merge_sort right in
    (* Merge two sorted lists *)
    let rec merge a b = match (a, b) with
      | ([], b) -> b
      | (a, []) -> a
      | (x::xs, y::ys) ->
        if x <= y then x :: merge xs b
        else y :: merge a ys
    in
    merge sl sr

(* Binary search: returns Some index or None *)
let binary_search (arr : 'a array) (target : 'a) : int option =
  let rec go lo hi =
    if lo > hi then None
    else
      let mid = (lo + hi) / 2 in
      if arr.(mid) = target then Some mid
      else if arr.(mid) < target then go (mid + 1) hi
      else go lo (mid - 1)
  in
  go 0 (Array.length arr - 1)

(* Maximum subarray sum via D&C (Kadane's is O(n) but D&C shows the pattern) *)
let max_crossing_sum (arr : int array) (lo mid hi : int) : int =
  let left_sum = ref min_int and s = ref 0 in
  for i = mid downto lo do
    s := !s + arr.(i);
    if !s > !left_sum then left_sum := !s
  done;
  let right_sum = ref min_int in
  s := 0;
  for i = mid + 1 to hi do
    s := !s + arr.(i);
    if !s > !right_sum then right_sum := !s
  done;
  !left_sum + !right_sum

let rec max_subarray (arr : int array) (lo hi : int) : int =
  if lo = hi then arr.(lo)
  else
    let mid = (lo + hi) / 2 in
    let left_max = max_subarray arr lo mid in
    let right_max = max_subarray arr (mid + 1) hi in
    let cross_max = max_crossing_sum arr lo mid hi in
    max left_max (max right_max cross_max)

let () =
  let xs = [5; 3; 8; 1; 9; 2; 7; 4; 6] in
  let sorted = merge_sort xs in
  Printf.printf "merge_sort %s = %s\n"
    (String.concat "," (List.map string_of_int xs))
    (String.concat "," (List.map string_of_int sorted));

  let arr = [| 1; 3; 5; 7; 9; 11; 13 |] in
  Printf.printf "binary_search(7) = %s\n"
    (match binary_search arr 7 with Some i -> string_of_int i | None -> "None");
  Printf.printf "binary_search(6) = %s\n"
    (match binary_search arr 6 with Some i -> string_of_int i | None -> "None");

  let nums = [| -2; 1; -3; 4; -1; 2; 1; -5; 4 |] in
  Printf.printf "max_subarray([-2,1,-3,4,-1,2,1,-5,4]) = %d (expected 6)\n"
    (max_subarray nums 0 (Array.length nums - 1))