🦀 Functional Rust

365: Union-Find — Disjoint Set Forest

Difficulty: 3 Level: Advanced Track which elements belong to the same group and merge groups in O(α(n)) amortized — effectively O(1) in practice.

The Problem This Solves

You're given a list of edges in a network and need to answer: "are nodes A and B connected?" After processing all edges, you also need to know how many connected components exist. A BFS/DFS from scratch answers each query in O(V+E) and doesn't update incrementally. Union-Find is built exactly for this: process edges one by one, merging sets as you go, and answer "same group?" queries in near-constant time at any point. It's the workhorse behind Kruskal's minimum spanning tree algorithm — add the cheapest edge that doesn't connect two already-connected nodes. The second class of problems is cycle detection in undirected graphs: if you try to union two nodes that already share a root, they're in the same component — you've found a cycle. This is exactly the check Kruskal's needs, and it's O(α(n)) per edge.

The Intuition

Imagine you have a collection of groups, each with a representative ("root"). To check if two people are in the same group, ask them both: "who is your group's root?" If the roots match, same group. To merge two groups, just point one root to the other. The clever part is path compression: after finding a root, update every node along the path to point directly at the root. Next time you ask the same node, the answer is immediate. Union by rank ensures the tree stays flat by always attaching the smaller tree under the larger. Together, these make the amortized cost per operation O(α(n)) — the inverse Ackermann function, which is ≤ 4 for any n you'll encounter in practice.

How It Works in Rust

struct UnionFind {
 parent: Vec<usize>,
 rank: Vec<u32>,
 count: usize, // number of distinct components
}

impl UnionFind {
 fn new(n: usize) -> Self {
     UnionFind {
         parent: (0..n).collect(), // each node is its own root
         rank: vec![0; n],
         count: n,
     }
 }

 // Find root with path compression — O(α(n)) amortized
 fn find(&mut self, x: usize) -> usize {
     if self.parent[x] != x {
         self.parent[x] = self.find(self.parent[x]); // path compression
     }
     self.parent[x]
 }

 // Union two sets — returns false if already in same set (cycle detected)
 fn union(&mut self, x: usize, y: usize) -> bool {
     let rx = self.find(x);
     let ry = self.find(y);
     if rx == ry { return false; } // already connected

     // Union by rank: attach smaller tree under larger
     match self.rank[rx].cmp(&self.rank[ry]) {
         std::cmp::Ordering::Less    => self.parent[rx] = ry,
         std::cmp::Ordering::Greater => self.parent[ry] = rx,
         std::cmp::Ordering::Equal   => {
             self.parent[ry] = rx;
             self.rank[rx] += 1;
         }
     }
     self.count -= 1;
     true
 }

 fn connected(&mut self, x: usize, y: usize) -> bool {
     self.find(x) == self.find(y)
 }

 fn components(&self) -> usize { self.count }
}

// Usage: Kruskal's MST — process edges sorted by weight
let mut uf = UnionFind::new(5); // nodes 0..4

let edges = vec![(1, 0, 1), (3, 1, 2), (2, 0, 3), (5, 3, 4)]; // (weight, u, v)
let mut sorted = edges.clone();
sorted.sort();

let mut mst_weight = 0;
for (w, u, v) in sorted {
 if uf.union(u, v) { // only add edge if it connects two components
     mst_weight += w;
 }
}
println!("MST weight: {mst_weight}");
println!("Components: {}", uf.components()); // 1 if fully connected

What This Unlocks

Key Differences

ConceptOCamlRust
Union-Findnot in stdlibcustom `Vec`-based
Find with path compressionN/AO(α(n)) amortized
Union by rankN/Aprevents degenerate O(n) trees
Cycle detectionmanual DFS`union()` returns `false`
Connected queryBFS/DFS each timeO(α(n)) via `find`
Component countmanual trackingmaintained in `count` field
struct UnionFind {
    parent: Vec<usize>,
    rank: Vec<u32>,
    size: Vec<usize>,
    components: usize,
}

impl UnionFind {
    fn new(n: usize) -> Self {
        Self {
            parent: (0..n).collect(),
            rank: vec![0; n],
            size: vec![1; n],
            components: n,
        }
    }

    fn find(&mut self, x: usize) -> usize {
        if self.parent[x] != x {
            self.parent[x] = self.find(self.parent[x]); // path compression
        }
        self.parent[x]
    }

    fn union(&mut self, x: usize, y: usize) -> bool {
        let rx = self.find(x);
        let ry = self.find(y);
        if rx == ry { return false; } // already connected
        // union by rank
        if self.rank[rx] < self.rank[ry] {
            self.parent[rx] = ry;
            self.size[ry] += self.size[rx];
        } else if self.rank[rx] > self.rank[ry] {
            self.parent[ry] = rx;
            self.size[rx] += self.size[ry];
        } else {
            self.parent[ry] = rx;
            self.size[rx] += self.size[ry];
            self.rank[rx] += 1;
        }
        self.components -= 1;
        true
    }

    fn connected(&mut self, x: usize, y: usize) -> bool { self.find(x) == self.find(y) }
    fn component_size(&mut self, x: usize) -> usize { let r = self.find(x); self.size[r] }
}

fn count_connected_components(n: usize, edges: &[(usize,usize)]) -> usize {
    let mut uf = UnionFind::new(n);
    for &(u,v) in edges { uf.union(u, v); }
    uf.components
}

fn main() {
    let mut uf = UnionFind::new(10);
    uf.union(0,1); uf.union(2,3); uf.union(0,2);
    println!("0-3 connected: {}", uf.connected(0,3));
    println!("0-4 connected: {}", uf.connected(0,4));
    println!("Component of 0 size: {}", uf.component_size(0));
    println!("Components: {}", uf.components);

    let edges = [(0,1),(1,2),(3,4)];
    println!("Connected components in 5-node graph: {}", count_connected_components(5, &edges));
}

#[cfg(test)]
mod tests {
    use super::*;
    #[test] fn basic_union_find() {
        let mut uf = UnionFind::new(5);
        assert!(!uf.connected(0,1));
        uf.union(0,1); assert!(uf.connected(0,1));
    }
    #[test] fn transitive() {
        let mut uf = UnionFind::new(4);
        uf.union(0,1); uf.union(1,2);
        assert!(uf.connected(0,2));
        assert!(!uf.connected(0,3));
    }
    #[test] fn components() {
        assert_eq!(count_connected_components(5, &[(0,1),(2,3)]), 3);
    }
}
(* OCaml: Union-Find with arrays *)

let parent = Array.init 10 (fun i -> i)
let rank   = Array.make 10 0

let rec find x =
  if parent.(x) = x then x
  else begin
    parent.(x) <- find parent.(x);  (* path compression *)
    parent.(x)
  end

let union x y =
  let rx = find x and ry = find y in
  if rx = ry then ()
  else if rank.(rx) < rank.(ry) then parent.(rx) <- ry
  else if rank.(rx) > rank.(ry) then parent.(ry) <- rx
  else begin parent.(ry) <- rx; rank.(rx) <- rank.(rx)+1 end

let () =
  union 0 1; union 2 3; union 0 2;
  Printf.printf "Connected(0,3): %b\n" (find 0 = find 3);
  Printf.printf "Connected(0,4): %b\n" (find 0 = find 4)