🦀 Functional Rust

818: Suffix Array Construction O(n log n)

Difficulty: 5 Level: Master Sort all suffixes of a string to enable O(m log n) pattern search and unlock a family of O(n) string problems via the LCP array.

The Problem This Solves

For a single pattern search, KMP or BMH suffice. But when you need to search the same text for many different patterns — a search engine index, a genome browser, a compressed archive — you want a data structure built once that answers any query fast. A suffix array is that structure: after O(n log n) construction, any pattern of length m is found with binary search in O(m log n). The LCP (Longest Common Prefix) array, built alongside the SA in O(n) via Kasai's algorithm, turns the SA into an even more powerful tool. With SA + LCP you can count distinct substrings, find the longest repeated substring, solve longest common substring for multiple strings, and compress repetitive data — all in O(n) or O(n log n). These problems are foundational in bioinformatics (genome assembly), data compression (BWT/FM-index), and search engine construction. The simpler O(n log² n) prefix-doubling approach (implemented here) is almost always fast enough in practice and far easier to understand than the linear SA-IS algorithm. Knowing when "simpler and fast enough" beats "theoretically optimal" is itself a valuable engineering judgment.

The Intuition

Prefix doubling: in round k, sort suffixes by their first 2^k characters. Represent each suffix as a pair `(rank[i], rank[i + 2^(k-1)])` — the rank of its first half and its second half. After ⌈log n⌉ rounds, the pairs uniquely identify every suffix. Each round is O(n log n) with comparison sort, giving O(n log² n) total. Binary search on the sorted array finds any pattern in O(m log n). In OCaml, `Array.sort` with a comparison closure directly implements prefix doubling. Rust's `sort_unstable_by` does the same thing with better cache behavior.

How It Works in Rust

fn build_sa(s: &[u8]) -> Vec<usize> {
 let n = s.len();
 let mut sa: Vec<usize> = (0..n).collect();          // Initial: suffixes in order
 let mut rank: Vec<i64> = s.iter().map(|&c| c as i64).collect(); // Rank by first char
 let mut tmp  = vec![0i64; n];
 let mut gap  = 1usize;

 while gap < n {
     let g = gap;
     let rank_ref = &rank;
     // Sort by (rank[i], rank[i+gap]) — O(n log n) per round
     sa.sort_unstable_by(|&i, &j| {
         let ri = rank_ref[i];
         let rj = rank_ref[j];
         if ri != rj { return ri.cmp(&rj); }
         let ri2 = if i + g < n { rank_ref[i + g] } else { -1 };
         let rj2 = if j + g < n { rank_ref[j + g] } else { -1 };
         ri2.cmp(&rj2)
     });
     // Rebuild ranks: assign 0 to first, increment when adjacent pair differs
     tmp[sa[0]] = 0;
     for i in 1..n {
         let (pi, ci) = (sa[i - 1], sa[i]);
         let same = rank[pi] == rank[ci]
             && (pi + g < n) == (ci + g < n)
             && (pi + g >= n || rank[pi + g] == rank[ci + g]);
         tmp[ci] = tmp[pi] + if same { 0 } else { 1 };
     }
     rank.copy_from_slice(&tmp);
     gap *= 2;  // Double the comparison length each round
 }
 sa
}

// Binary search on SA: O(m log n) per query
fn sa_search(s: &[u8], sa: &[usize], pattern: &[u8]) -> Vec<usize> {
 let m = pattern.len();
 // partition_point is std binary search — finds [left, right) range of matches
 let left  = sa.partition_point(|&i| &s[i..s.len().min(i + m)] < pattern);
 let right = sa.partition_point(|&i| &s[i..s.len().min(i + m)] <= pattern);
 let mut positions = sa[left..right].to_vec();
 positions.sort_unstable();  // Return in text order
 positions
}
`partition_point` (stable since Rust 1.52) is the idiomatic binary search for range queries — cleaner than a pair of `binary_search` calls.

What This Unlocks

Key Differences

ConceptOCamlRust
Suffix sort`Array.sort` with `compare` closure`sort_unstable_by` — no stability needed
Rank as pairTuple `(rank[i], rank[i+gap])`Compared inline in closure
Rank update`Array.init n (fun i -> ...)`In-place `copy_from_slice` of temp buffer
Binary searchManual or `Array.blit``slice.partition_point` for range queries
LCP constructionKasai's algorithmIdentical: rank inverse + linear scan
// Suffix Array — O(n log² n) prefix doubling + O(n) LCP (Kasai)

fn build_sa(s: &[u8]) -> Vec<usize> {
    let n = s.len();
    let mut sa: Vec<usize> = (0..n).collect();
    let mut rank: Vec<i64> = s.iter().map(|&c| c as i64).collect();
    let mut tmp  = vec![0i64; n];
    let mut gap  = 1usize;

    while gap < n {
        let g = gap;
        let rank_ref = &rank;
        sa.sort_unstable_by(|&i, &j| {
            let ri = rank_ref[i];
            let rj = rank_ref[j];
            if ri != rj { return ri.cmp(&rj); }
            let ri2 = if i + g < n { rank_ref[i + g] } else { -1 };
            let rj2 = if j + g < n { rank_ref[j + g] } else { -1 };
            ri2.cmp(&rj2)
        });
        // Update ranks
        tmp[sa[0]] = 0;
        for i in 1..n {
            let (pi, ci) = (sa[i - 1], sa[i]);
            let same = rank[pi] == rank[ci]
                && (pi + g < n) == (ci + g < n)
                && (pi + g >= n || rank[pi + g] == rank[ci + g]);
            tmp[ci] = tmp[pi] + if same { 0 } else { 1 };
        }
        rank.copy_from_slice(&tmp);
        gap *= 2;
    }
    sa
}

fn build_lcp(s: &[u8], sa: &[usize]) -> Vec<usize> {
    let n = s.len();
    let mut rank = vec![0usize; n];
    for (i, &v) in sa.iter().enumerate() { rank[v] = i; }
    let mut lcp = vec![0usize; n];
    let mut k   = 0usize;
    for i in 0..n {
        if rank[i] > 0 {
            let j = sa[rank[i] - 1];
            while i + k < n && j + k < n && s[i + k] == s[j + k] { k += 1; }
            lcp[rank[i]] = k;
            if k > 0 { k -= 1; }
        }
    }
    lcp
}

fn sa_search(s: &[u8], sa: &[usize], pattern: &[u8]) -> Vec<usize> {
    let m = pattern.len();
    let left  = sa.partition_point(|&i| &s[i..s.len().min(i + m)] < pattern);
    let right = sa.partition_point(|&i| &s[i..s.len().min(i + m)] <= pattern);
    let mut positions: Vec<usize> = sa[left..right].to_vec();
    positions.sort_unstable();
    positions
}

fn main() {
    let s   = b"banana";
    let sa  = build_sa(s);
    let lcp = build_lcp(s, &sa);

    println!("String: \"banana\"");
    println!("SA:  {:?}", sa);
    println!("LCP: {:?}", lcp);
    println!("Suffixes in order:");
    for &i in &sa {
        println!("  {i}: {:?}", std::str::from_utf8(&s[i..]).unwrap());
    }
    let positions = sa_search(s, &sa, b"an");
    println!("Search 'an': {:?}", positions);

    // Larger example
    let s2  = b"mississippi";
    let sa2 = build_sa(s2);
    println!("\n\"mississippi\" SA: {:?}", sa2);
    println!("Search 'issi': {:?}", sa_search(s2, &sa2, b"issi"));
}
(* Suffix Array — O(n log² n) prefix doubling + O(n) LCP (Kasai) *)

let build_sa s =
  let n    = String.length s in
  let sa   = Array.init n (fun i -> i) in
  let rank = Array.init n (fun i -> Char.code s.[i]) in
  let tmp  = Array.make n 0 in

  let gap = ref 1 in
  while !gap < n do
    let g = !gap in
    let cmp i j =
      let ri = rank.(i) and rj = rank.(j) in
      if ri <> rj then compare ri rj
      else
        let ri2 = if i + g < n then rank.(i + g) else (-1) in
        let rj2 = if j + g < n then rank.(j + g) else (-1) in
        compare ri2 rj2
    in
    Array.sort cmp sa;
    (* Update ranks *)
    tmp.(sa.(0)) <- 0;
    for i = 1 to n - 1 do
      tmp.(sa.(i)) <- tmp.(sa.(i-1)) + (if cmp sa.(i-1) sa.(i) = 0 then 0 else 1)
    done;
    Array.blit tmp 0 rank 0 n;
    gap := !gap * 2
  done;
  sa

(* Kasai's LCP array — O(n) *)
let build_lcp s sa =
  let n    = String.length s in
  let rank = Array.make n 0 in
  Array.iteri (fun i v -> rank.(v) <- i) sa;
  let lcp  = Array.make n 0 in
  let k    = ref 0 in
  for i = 0 to n - 1 do
    if rank.(i) > 0 then begin
      let j = sa.(rank.(i) - 1) in
      while i + !k < n && j + !k < n && s.[i + !k] = s.[j + !k] do
        incr k
      done;
      lcp.(rank.(i)) <- !k;
      if !k > 0 then decr k
    end
  done;
  lcp

let sa_search s sa pattern =
  let n = Array.length sa in
  let m = String.length pattern in
  (* Binary search: left bound *)
  let lo = ref 0 and hi = ref n in
  while !lo < !hi do
    let mid = (!lo + !hi) / 2 in
    if String.sub s sa.(mid) (min m (String.length s - sa.(mid))) < pattern then
      lo := mid + 1
    else hi := mid
  done;
  let left = !lo in
  hi := n;
  while !lo < !hi do
    let mid = (!lo + !hi) / 2 in
    let suf = String.sub s sa.(mid) (min m (String.length s - sa.(mid))) in
    if suf <= pattern then lo := mid + 1
    else hi := mid
  done;
  (* Collect positions *)
  let positions = Array.sub sa left (!lo - left) in
  Array.sort compare positions;
  Array.to_list positions

let () =
  let s  = "banana" in
  let sa = build_sa s in
  let lcp = build_lcp s sa in
  Printf.printf "String: %S\n" s;
  Printf.printf "SA:  [%s]\n" (String.concat "," (Array.to_list (Array.map string_of_int sa)));
  Printf.printf "LCP: [%s]\n" (String.concat "," (Array.to_list (Array.map string_of_int lcp)));
  Printf.printf "Suffixes in order:\n";
  Array.iter (fun i -> Printf.printf "  %d: %S\n" i (String.sub s i (String.length s - i))) sa;
  let pos = sa_search s sa "an" in
  Printf.printf "Search 'an': [%s]\n"
    (String.concat "," (List.map string_of_int pos))