๐Ÿฆ€ Functional Rust

822: Burrows-Wheeler Transform

Difficulty: 4 Level: Advanced Permute a string into a form that compresses dramatically better โ€” the transformation at the heart of bzip2.

The Problem This Solves

The Burrows-Wheeler Transform (BWT) rearranges a string's characters so that identical characters cluster together, making the output far more compressible by run-length encoding or move-to-front coding. It's a reversible transformation โ€” you can recover the original string exactly from the BWT output plus a single index. BWT is the core of bzip2 compression, widely used for compressing source code, log files, and biological sequences. It also underpins the FM-index, the data structure behind modern short-read DNA aligners (BWA, Bowtie). Understanding BWT is essential for anyone working with text compression algorithms or genomic data processing. This example implements both the forward transform (string โ†’ BWT + index) and the inverse (BWT + index โ†’ original string), allowing you to verify correctness and understand the structure.

The Intuition

Forward BWT: append a sentinel character `$` (lexicographically smallest) to the string. Form all n+1 cyclic rotations. Sort them lexicographically. The BWT is the last column of this sorted matrix. The index tells you which row is the original string (the rotation starting with the real first character followed by `$` at the end). Why does this compress well? Rows with the same prefix (sorted together) tend to end with the same character โ€” because strings with similar suffixes often have similar preceding characters. This creates runs of identical characters in the last column, which RLE and move-to-front compress efficiently. Inverse BWT: from the last column `L`, compute the first column `F` by sorting. Build `T[i]` = the rank of `L[i]` among equal characters in `L`. Then follow the chain: start at `index`, next = `T[index]`, until you've recovered all characters. O(n log n) for sorting rotations. In Rust, we avoid materializing all rotations (O(nยฒ) space) by sorting indices with a custom comparator that compares rotations as slices.

How It Works in Rust

fn bwt_forward(s: &str) -> (Vec<u8>, usize) {
 let mut bytes: Vec<u8> = s.bytes().collect();
 bytes.push(b'$'); // sentinel โ€” must be lexicographically smallest
 let n = bytes.len();

 // Sort rotation indices without materializing rotations (saves O(nยฒ) space)
 let mut indices: Vec<usize> = (0..n).collect();
 indices.sort_unstable_by(|&a, &b| {
     // Compare rotation a vs rotation b as circular slices
     for k in 0..n {
         let ca = bytes[(a + k) % n];
         let cb = bytes[(b + k) % n];
         if ca != cb { return ca.cmp(&cb); }
     }
     std::cmp::Ordering::Equal
 });

 // Last column = character before each sorted rotation's start
 let last_col: Vec<u8> = indices.iter()
     .map(|&i| bytes[(i + n - 1) % n])
     .collect();

 // Original string is the rotation that starts with bytes[0]
 // and has $ at position n-1, which sorts as the row where index=0
 let orig_row = indices.iter().position(|&i| i == 0).unwrap();
 (last_col, orig_row)
}

fn bwt_inverse(last_col: &[u8], orig_row: usize) -> Vec<u8> {
 let n = last_col.len();

 // First column = sort of last column
 let mut first_col = last_col.to_vec();
 first_col.sort_unstable();

 // T[i] = where does last_col[i] map in first_col?
 // i.e. rank of last_col[i] among equal characters (stable sort order)
 let mut rank = vec![0usize; n];
 let mut count = [0usize; 256];
 for &b in last_col { count[b as usize] += 1; }

 // Prefix sum to find start positions in first_col
 let mut start = [0usize; 256];
 for i in 1..256 { start[i] = start[i-1] + count[i-1]; }

 // Assign T[i] based on occurrence rank
 let mut seen = [0usize; 256];
 let mut t = vec![0usize; n];
 for (i, &b) in last_col.iter().enumerate() {
     t[i] = start[b as usize] + seen[b as usize];
     seen[b as usize] += 1;
 }

 // Follow T chain from orig_row to recover original string
 let mut result = vec![0u8; n];
 let mut row = orig_row;
 for i in (0..n).rev() {
     result[i] = last_col[row];
     row = t[row];
 }
 // Remove sentinel $
 result.into_iter().filter(|&b| b != b'$').collect()
}
`sort_unstable_by` on indices is the key space optimization โ€” comparing rotations via modular indexing `(a + k) % n` instead of allocating `n` separate strings. The `count`/`start`/`seen` arrays replace a full sort during inversion, running in O(n + |alphabet|).

What This Unlocks

Key Differences

ConceptOCamlRust
Rotation representationString list (allocates all rotations)Sorted indices with modular indexing โ€” O(1) space per rotation
Sorting comparator`String.compare` on explicit stringsClosure with `% n` arithmetic โ€” no allocation
Byte vs char`Bytes.t` / `Char.code``u8` directly โ€” BWT is inherently a byte operation
Inversion T-arrayFunctional fold with countersMutable `count`/`start`/`seen` arrays โ€” O(alphabet)
Sentinel characterAny lexicographically minimal char`b'$'` โ€” explicit byte literal
/// Burrows-Wheeler Transform (BWT)
///
/// Forward: append '$', sort all cyclic rotations, take last column.
/// Inverse: use LF-mapping to recover original string in O(n).

/// Returns (transformed string, index of original rotation in sorted order).
fn bwt(input: &str) -> (String, usize) {
    let mut s = input.to_string();
    s.push('$');
    let bytes = s.as_bytes();
    let n = bytes.len();

    // Sort rotation indices lexicographically
    let mut indices: Vec<usize> = (0..n).collect();
    indices.sort_unstable_by(|&a, &b| {
        for k in 0..n {
            let ca = bytes[(a + k) % n];
            let cb = bytes[(b + k) % n];
            match ca.cmp(&cb) {
                std::cmp::Ordering::Equal => continue,
                other => return other,
            }
        }
        std::cmp::Ordering::Equal
    });

    // Last column: character immediately before each sorted rotation
    let transformed: String = indices
        .iter()
        .map(|&i| bytes[(i + n - 1) % n] as char)
        .collect();

    // Row where the original string (rotation starting at 0) appears
    let original_row = indices.iter().position(|&i| i == 0).unwrap();

    (transformed, original_row)
}

/// Inverse BWT using the LF (last-to-first) mapping.
fn ibwt(bwt_str: &str, original_row: usize) -> String {
    let l: Vec<u8> = bwt_str.bytes().collect();
    let n = l.len();

    // First column F = sorted L
    let mut f = l.clone();
    f.sort_unstable();

    // rank[i] = how many times l[i] occurred before position i in L
    let mut rank = vec![0usize; n];
    let mut count = [0usize; 256];
    for (i, &c) in l.iter().enumerate() {
        rank[i] = count[c as usize];
        count[c as usize] += 1;
    }

    // first_occ[c] = first position of c in F
    let mut first_occ = [0usize; 256];
    let mut seen = [false; 256];
    for (i, &c) in f.iter().enumerate() {
        if !seen[c as usize] {
            first_occ[c as usize] = i;
            seen[c as usize] = true;
        }
    }

    // Follow LF-mapping n-1 times to recover the original + '$'
    let mut result = Vec::with_capacity(n - 1);
    let mut row = original_row;
    for _ in 0..n - 1 {
        let c = l[row];
        result.push(c);
        row = first_occ[c as usize] + rank[row];
    }

    // Reverse and remove trailing '$'
    result.reverse();
    // Strip the '$' which is at the end
    let s = String::from_utf8(result).unwrap();
    s.trim_end_matches('$').to_string()
}

fn main() {
    let words = ["banana", "abracadabra", "mississippi", "hello", "rust"];
    for word in &words {
        let (transformed, row) = bwt(word);
        let recovered = ibwt(&transformed, row);
        println!(
            "bwt({:?}) = {:?} (row={}), recovered = {:?}, ok={}",
            word, transformed, row, recovered, &recovered == word
        );
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn roundtrip(s: &str) -> bool {
        let (t, row) = bwt(s);
        ibwt(&t, row) == s
    }

    #[test]
    fn test_banana() {
        let (t, _) = bwt("banana");
        // Classic result: "annb$aa" โ€” last column of sorted rotations
        assert_eq!(t, "annb$aa");
    }

    #[test]
    fn test_roundtrip_banana() {
        assert!(roundtrip("banana"));
    }

    #[test]
    fn test_roundtrip_abracadabra() {
        assert!(roundtrip("abracadabra"));
    }

    #[test]
    fn test_roundtrip_mississippi() {
        assert!(roundtrip("mississippi"));
    }

    #[test]
    fn test_roundtrip_single() {
        assert!(roundtrip("a"));
    }

    #[test]
    fn test_roundtrip_hello() {
        assert!(roundtrip("hello"));
    }

    #[test]
    fn test_clustering() {
        // BWT of "mississippi" should cluster repeated characters
        let (t, _) = bwt("mississippi");
        // 'i' appears 4 times in original โ€” they should cluster in BWT
        let i_count = t.chars().filter(|&c| c == 'i').count();
        assert_eq!(i_count, 4);
    }
}
(* Burrows-Wheeler Transform in OCaml *)

(* Compare two rotations of s: rotation starting at i vs rotation starting at j *)
let compare_rotations (s : string) (n : int) (i : int) (j : int) : int =
  let rec cmp k =
    if k = n then 0
    else
      let ci = s.[(i + k) mod n] and cj = s.[(j + k) mod n] in
      if ci < cj then -1
      else if ci > cj then 1
      else cmp (k + 1)
  in
  cmp 0

(* Forward BWT: returns (transformed_string, index_of_original_row) *)
let bwt (input : string) : string * int =
  let s = input ^ "$" in
  let n = String.length s in
  (* Sort rotation indices *)
  let indices = Array.init n (fun i -> i) in
  Array.sort (compare_rotations s n) indices;
  (* Last column = character before the start of each sorted rotation *)
  let transformed = String.init n (fun i -> s.[(indices.(i) + n - 1) mod n]) in
  (* Find the row corresponding to the original string *)
  let original_row =
    let found = ref 0 in
    Array.iteri (fun row i -> if i = 0 then found := row) indices;
    !found
  in
  (transformed, original_row)

(* Inverse BWT using the LF-mapping *)
let ibwt (bwt_str : string) (original_row : int) : string =
  let n = String.length bwt_str in
  let l = Array.init n (String.get bwt_str) in
  (* First column F = sorted last column L *)
  let f = Array.copy l in
  Array.sort Char.compare f;
  (* Count occurrences of each char in F (prefix counts) *)
  (* LF-mapping: next.(i) = j where l.(i) maps to f.(j) *)
  (* Build rank array: rank.(i) = how many times l.(i) appeared before i *)
  let rank = Array.make n 0 in
  let seen = Hashtbl.create 26 in
  Array.iteri (fun i c ->
    let cnt = match Hashtbl.find_opt seen c with None -> 0 | Some v -> v in
    rank.(i) <- cnt;
    Hashtbl.replace seen c (cnt + 1)
  ) l;
  (* For each char c, first_occ.(c) = first position of c in f *)
  let first_occ = Hashtbl.create 26 in
  Array.iteri (fun i c ->
    if not (Hashtbl.mem first_occ c) then
      Hashtbl.add first_occ c i
  ) f;
  (* Recover original string by following LF-mapping n-1 times *)
  let result = Buffer.create (n - 1) in
  let row = ref original_row in
  for _ = 0 to n - 2 do
    let c = l.(!row) in
    Buffer.add_char result c;
    row := (Hashtbl.find first_occ c) + rank.(!row)
  done;
  (* The recovered string is reversed and includes '$', strip it *)
  let s = Buffer.contents result in
  let reversed = String.init (String.length s) (fun i -> s.[String.length s - 1 - i]) in
  (* Remove trailing '$' *)
  String.sub reversed 0 (String.length reversed - 1)

let () =
  let tests = ["banana"; "abracadabra"; "mississippi"; "hello"] in
  List.iter (fun s ->
    let (t, row) = bwt s in
    let recovered = ibwt t row in
    Printf.printf "BWT(%S) = %S (row=%d), inverse = %S, ok=%b\n"
      s t row recovered (recovered = s)
  ) tests