πŸ¦€ Functional Rust

801. Prim's Algorithm: Minimum Spanning Tree

Difficulty: 4 Level: Advanced Grow a minimum spanning tree greedily by always adding the cheapest edge crossing the cut β€” O(E log V) with a binary heap.

The Problem This Solves

A minimum spanning tree connects all nodes in a weighted graph with the minimum total edge cost while forming no cycles. Network infrastructure design uses MSTs constantly: laying cables between cities, designing circuit board connections, clustering data points by minimum connection cost. Prim's algorithm is the go-to MST method for dense graphs (many edges per node), while Kruskal's (example 802) is preferred for sparse graphs. Beyond pure infrastructure, MST structure appears in approximation algorithms: the 2-approximation for the Travelling Salesman Problem builds an MST then doubles its edges; hierarchical clustering uses single-linkage dendrogram construction which is equivalent to MST.

The Intuition

Maintain a "frontier" of nodes in the MST. At each step, pick the cheapest edge from the MST to any not-yet-included node, add that node, and update edge costs for its neighbours. The `key[v]` array tracks the minimum edge weight connecting v to the current MST. A min-heap on `(key[v], v)` efficiently finds the cheapest next addition. Rust's `BinaryHeap` is a max-heap by default, so we wrap entries in `Reverse<(weight, node)>` to simulate a min-heap β€” the standard Rust idiom for Dijkstra and Prim. OCaml might use a priority queue module; Rust uses the standard library heap with `Reverse`.

How It Works in Rust

use std::collections::BinaryHeap;
use std::cmp::Reverse;

// O(E log V) time, O(V + E) space
// adj: Vec<Vec<(usize, i64)>> β€” adjacency list of (neighbour, weight)
fn prim(adj: &[Vec<(usize, i64)>]) -> (i64, Vec<(usize, usize, i64)>) {
 let n = adj.len();
 let mut key    = vec![i64::MAX; n];  // cheapest edge weight to MST
 let mut parent = vec![usize::MAX; n]; // which MST node connects to this
 let mut in_mst = vec![false; n];
 let mut heap   = BinaryHeap::new();

 key[0] = 0;
 heap.push(Reverse((0i64, 0usize)));  // (weight, node) β€” Reverse for min-heap

 let mut total = 0i64;
 let mut mst   = Vec::new();

 while let Some(Reverse((w, u))) = heap.pop() {
     if in_mst[u] { continue; }     // stale entry: skip
     in_mst[u] = true;
     if parent[u] != usize::MAX {
         total += w;
         mst.push((parent[u], u, w));
     }
     for &(v, wv) in &adj[u] {
         if !in_mst[v] && wv < key[v] {
             key[v]    = wv;
             parent[v] = u;
             heap.push(Reverse((wv, v)));  // may push stale entries
         }
     }
 }
 (total, mst)
}
The heap may contain stale entries β€” old `(key[v], v)` pairs pushed before a better edge was found. The `if in_mst[u] { continue }` guard skips them. This lazy-deletion approach is simpler than a decrease-key heap and works well in practice.

What This Unlocks

Key Differences

ConceptOCamlRust
Min-heapPriority queue module or sorted list`BinaryHeap<Reverse<(i64, usize)>>`
Stale entry handlingSeparate `visited` check`if in_mst[u] { continue }`
Adjacency list`list array` or `Hashtbl``Vec<Vec<(usize, i64)>>`
Infinity key`max_int``i64::MAX`
MST edgesAccumulate in list`Vec<(usize, usize, i64)>` grown with `push`
// Prim's MST β€” O(E log V) with BinaryHeap<Reverse<>>
use std::collections::BinaryHeap;
use std::cmp::Reverse;

fn prim(adj: &[Vec<(usize, i64)>]) -> (i64, Vec<(usize, usize, i64)>) {
    let n = adj.len();
    let mut key    = vec![i64::MAX; n];
    let mut parent = vec![usize::MAX; n];
    let mut in_mst = vec![false; n];
    let mut heap   = BinaryHeap::new();

    key[0] = 0;
    heap.push(Reverse((0i64, 0usize)));

    let mut total = 0i64;
    let mut mst   = Vec::new();

    while let Some(Reverse((w, u))) = heap.pop() {
        if in_mst[u] { continue; }
        in_mst[u] = true;
        if parent[u] != usize::MAX {
            total += w;
            mst.push((parent[u], u, w));
        }
        for &(v, wv) in &adj[u] {
            if !in_mst[v] && wv < key[v] {
                key[v]    = wv;
                parent[v] = u;
                heap.push(Reverse((wv, v)));
            }
        }
    }
    (total, mst)
}

fn main() {
    let n = 5;
    let mut adj: Vec<Vec<(usize, i64)>> = vec![vec![]; n];
    let mut add = |u: usize, v: usize, w: i64| {
        adj[u].push((v, w));
        adj[v].push((u, w));
    };
    add(0, 1, 2); add(0, 3, 6);
    add(1, 2, 3); add(1, 3, 8); add(1, 4, 5);
    add(2, 4, 7);
    add(3, 4, 9);

    let (total, mst) = prim(&adj);
    println!("MST total weight: {total}");
    for (u, v, w) in &mst {
        println!("  edge {u}-{v}  weight={w}");
    }
}
(* Prim's MST β€” O(VΒ²) with linear scan (clean, readable) *)

let prim adj n =
  let key    = Array.make n max_int in
  let parent = Array.make n (-1) in
  let inMST  = Array.make n false in
  key.(0) <- 0;
  let total = ref 0 in
  let mst   = ref [] in
  for _ = 0 to n - 1 do
    (* Find unvisited vertex with minimum key *)
    let u = ref (-1) in
    for v = 0 to n - 1 do
      if not inMST.(v) && key.(v) < max_int then
        if !u = -1 || key.(v) < key.(!u) then u := v
    done;
    if !u >= 0 then begin
      inMST.(!u) <- true;
      if parent.(!u) >= 0 then begin
        total := !total + key.(!u);
        mst   := (!u, parent.(!u), key.(!u)) :: !mst
      end;
      List.iter (fun (w, v) ->
        if not inMST.(v) && w < key.(v) then begin
          key.(v)    <- w;
          parent.(v) <- !u
        end
      ) adj.(!u)
    end
  done;
  (!total, List.rev !mst)

let () =
  (* Undirected graph as adjacency list: adj.(u) = [(weight, v); ...] *)
  let adj = Array.make 5 [] in
  let add_edge u v w =
    adj.(u) <- (w, v) :: adj.(u);
    adj.(v) <- (w, u) :: adj.(v)
  in
  add_edge 0 1 2; add_edge 0 3 6;
  add_edge 1 2 3; add_edge 1 3 8; add_edge 1 4 5;
  add_edge 2 4 7;
  add_edge 3 4 9;
  let (total, mst) = prim adj 5 in
  Printf.printf "MST total weight: %d\n" total;
  List.iter (fun (u, v, w) ->
    Printf.printf "  edge %d-%d  weight=%d\n" v u w
  ) mst