πŸ¦€ Functional Rust

802. Kruskal's MST with Union-Find

Difficulty: 4 Level: Advanced Build a minimum spanning tree by sorting edges and merging components with a Union-Find data structure β€” O(E log E).

The Problem This Solves

Kruskal's algorithm is the MST method of choice for sparse graphs. It's also the canonical application of the Union-Find (Disjoint Set Union) data structure β€” one of the most practically useful data structures in competitive programming and systems engineering. Union-Find appears in dynamic connectivity queries, network partitioning, image segmentation (connected components), and Kruskal's is just its most famous use case. The algorithm's appeal is conceptual clarity: sort all edges by weight, then greedily add each edge as long as it doesn't create a cycle. "Doesn't create a cycle" means "connects two different components" β€” which is exactly what Union-Find checks in near-O(1) time.

The Intuition

Sort all edges by weight. Walk through them cheapest first. For each edge `(u, v)`, check if u and v are already in the same component (`find(u) == find(v)`). If yes, adding this edge would create a cycle β€” skip it. If no, merge the components (`union(u, v)`) and add the edge to the MST. Stop when the MST has V-1 edges. The Union-Find with path compression and union by rank makes each operation effectively O(Ξ±(V)) β€” nearly constant. OCaml implements Union-Find with mutable `array` references; Rust uses a `struct UnionFind { parent: Vec<usize>, rank: Vec<usize> }`.

How It Works in Rust

struct UnionFind {
 parent: Vec<usize>,
 rank:   Vec<usize>,
}

impl UnionFind {
 fn new(n: usize) -> Self {
     UnionFind { parent: (0..n).collect(), rank: vec![0; n] }
 }

 fn find(&mut self, mut v: usize) -> usize {
     while self.parent[v] != v {
         self.parent[v] = self.parent[self.parent[v]]; // path halving
         v = self.parent[v];
     }
     v
 }

 fn union(&mut self, u: usize, v: usize) -> bool {
     let (pu, pv) = (self.find(u), self.find(v));
     if pu == pv { return false; }    // already same component
     // Union by rank: smaller tree hangs under larger
     match self.rank[pu].cmp(&self.rank[pv]) {
         Less    => self.parent[pu] = pv,
         Greater => self.parent[pv] = pu,
         Equal   => { self.parent[pv] = pu; self.rank[pu] += 1; }
     }
     true
 }
}

// O(E log E) dominated by sort; Union-Find ops are O(Ξ±(V)) β‰ˆ O(1)
fn kruskal(n: usize, edges: &mut Vec<(i64, usize, usize)>) -> (i64, Vec<(usize, usize, i64)>) {
 edges.sort_unstable_by_key(|&(w, _, _)| w);
 let mut uf = UnionFind::new(n);
 let mut mst = Vec::new();
 let mut total = 0i64;
 for &(w, u, v) in edges.iter() {
     if uf.union(u, v) {   // returns true if they were in different components
         total += w;
         mst.push((u, v, w));
     }
 }
 (total, mst)
}
Path halving (`parent[v] = parent[parent[v]]`) is a simpler alternative to full path compression with the same amortised complexity. `sort_unstable_by_key` is faster than `sort_by_key` when stability doesn't matter β€” and for edge weights it doesn't.

What This Unlocks

Key Differences

ConceptOCamlRust
Union-FindMutable `int array` for parent`struct UnionFind { parent: Vec<usize>, rank: Vec<usize> }`
Path compressionRecursive `find` with `parent.(v) <- root`Iterative path halving β€” no recursion, stack-safe
Union by rank`if rank.(pu) < rank.(pv)``match rank[pu].cmp(&rank[pv])` β€” exhaustive
Edge sort`List.sort` by weight`sort_unstable_by_key` β€” faster, in-place
Cycle check`find u = find v``uf.union(u, v)` returns `false` if same component
// Kruskal's MST β€” sort + Union-Find O(E log E)

struct UnionFind {
    parent: Vec<usize>,
    rank:   Vec<usize>,
}

impl UnionFind {
    fn new(n: usize) -> Self {
        UnionFind { parent: (0..n).collect(), rank: vec![0; n] }
    }

    fn find(&mut self, mut v: usize) -> usize {
        while self.parent[v] != v {
            self.parent[v] = self.parent[self.parent[v]]; // path halving
            v = self.parent[v];
        }
        v
    }

    fn union(&mut self, u: usize, v: usize) -> bool {
        let (pu, pv) = (self.find(u), self.find(v));
        if pu == pv { return false; }
        if self.rank[pu] < self.rank[pv] {
            self.parent[pu] = pv;
        } else if self.rank[pu] > self.rank[pv] {
            self.parent[pv] = pu;
        } else {
            self.parent[pv] = pu;
            self.rank[pu] += 1;
        }
        true
    }
}

fn kruskal(n: usize, edges: &mut Vec<(i64, usize, usize)>) -> (i64, Vec<(usize, usize, i64)>) {
    edges.sort_unstable_by_key(|&(w, _, _)| w);
    let mut uf    = UnionFind::new(n);
    let mut total = 0i64;
    let mut mst   = Vec::new();
    for &(w, u, v) in edges.iter() {
        if uf.union(u, v) {
            total += w;
            mst.push((u, v, w));
        }
    }
    (total, mst)
}

fn main() {
    let mut edges = vec![
        (2i64, 0, 1), (6, 0, 3),
        (3, 1, 2),    (8, 1, 3), (5, 1, 4),
        (7, 2, 4),    (9, 3, 4),
    ];
    let (total, mst) = kruskal(5, &mut edges);
    println!("MST total weight: {total}");
    for (u, v, w) in &mst {
        println!("  edge {u}-{v}  weight={w}");
    }
}
(* Kruskal's MST β€” sort + Union-Find O(E log E) *)

(* Union-Find with path compression and union by rank *)
let make_uf n =
  let parent = Array.init n (fun i -> i) in
  let rank   = Array.make n 0 in
  (parent, rank)

let rec find parent v =
  if parent.(v) = v then v
  else begin
    parent.(v) <- find parent parent.(v);  (* path compression *)
    parent.(v)
  end

let union parent rank u v =
  let pu = find parent u and pv = find parent v in
  if pu = pv then false
  else begin
    if rank.(pu) < rank.(pv) then parent.(pu) <- pv
    else if rank.(pu) > rank.(pv) then parent.(pv) <- pu
    else begin parent.(pv) <- pu; rank.(pu) <- rank.(pu) + 1 end;
    true
  end

let kruskal n edges =
  (* edges: (weight, u, v) list *)
  let sorted = List.sort (fun (w1,_,_) (w2,_,_) -> compare w1 w2) edges in
  let (parent, rank) = make_uf n in
  List.fold_left (fun (total, mst) (w, u, v) ->
    if union parent rank u v then
      (total + w, (u, v, w) :: mst)
    else
      (total, mst)
  ) (0, []) sorted

let () =
  let edges = [
    (2, 0, 1); (6, 0, 3);
    (3, 1, 2); (8, 1, 3); (5, 1, 4);
    (7, 2, 4); (9, 3, 4)
  ] in
  let (total, mst) = kruskal 5 edges in
  Printf.printf "MST total weight: %d\n" total;
  List.iter (fun (u, v, w) ->
    Printf.printf "  edge %d-%d  weight=%d\n" u v w
  ) (List.rev mst)