๐Ÿฆ€ Functional Rust

838: Interval Tree for Stabbing Queries

Difficulty: 4 Level: Advanced Answer "which intervals contain point x?" in O(log n + k) per query after O(n log n) preprocessing โ€” the centroid-decomposition approach.

The Problem This Solves

Given a set of n intervals [lo, hi] and a query point x, find all intervals that contain x (i.e., lo โ‰ค x โ‰ค hi). This is the "stabbing query" โ€” which intervals are "stabbed" by the vertical line x = q? Stabbing queries appear in genomics (find all genes overlapping position p on a chromosome), database range queries (find all events active at time t), rendering pipelines (find all bounding boxes containing a pixel), and scheduling systems (find all tasks running at a given moment). A sorted list approach costs O(n) per query; an interval tree answers in O(log n + k) where k is the number of reported intervals. This example uses the median-based interval tree: each node stores its median point, and intervals that span the median in two sorted orders (by lo and by hi). Intervals that don't span the node's median are recursed into the left or right subtree.

The Intuition

At each tree node, pick the median point m of all interval endpoints. Split intervals into three groups: At each node, store spanning intervals twice: Query: at each node, if x < m check the hi-sorted list (stop when hi < x), if x > m check the lo-sorted list (stop when lo > x), if x = m report all. Then recurse into the appropriate subtree. O(n log n) construction. O(log n + k) query โ€” the tree depth is O(log n) and each node visit reports some intervals or terminates early.

How It Works in Rust

#[derive(Debug, Clone)]
struct Interval { lo: i64, hi: i64, id: usize }

enum IntervalTree {
 Leaf,
 Node {
     median: i64,
     by_lo: Vec<Interval>, // spanning intervals sorted by lo (ascending)
     by_hi: Vec<Interval>, // spanning intervals sorted by hi (descending)
     left:  Box<IntervalTree>,
     right: Box<IntervalTree>,
 }
}

impl IntervalTree {
 fn build(mut intervals: Vec<Interval>) -> Self {
     if intervals.is_empty() { return IntervalTree::Leaf; }

     // Collect all endpoints, find median
     let mut endpoints: Vec<i64> = intervals.iter()
         .flat_map(|iv| [iv.lo, iv.hi])
         .collect();
     endpoints.sort_unstable();
     let median = endpoints[endpoints.len() / 2];

     // Partition intervals
     let mut left_ivs  = vec![];
     let mut right_ivs = vec![];
     let mut spanning  = vec![];

     for iv in intervals.drain(..) {
         if iv.hi < median      { left_ivs.push(iv); }
         else if iv.lo > median { right_ivs.push(iv); }
         else                   { spanning.push(iv); }
     }

     let mut by_lo = spanning.clone();
     let mut by_hi = spanning;
     by_lo.sort_unstable_by_key(|iv| iv.lo);
     by_hi.sort_unstable_by_key(|iv| std::cmp::Reverse(iv.hi));

     IntervalTree::Node {
         median,
         by_lo,
         by_hi,
         left:  Box::new(IntervalTree::build(left_ivs)),
         right: Box::new(IntervalTree::build(right_ivs)),
     }
 }

 fn stab(&self, x: i64, result: &mut Vec<usize>) {
     match self {
         IntervalTree::Leaf => {}
         IntervalTree::Node { median, by_lo, by_hi, left, right } => {
             if x < *median {
                 // Check intervals sorted descending by hi: stop when hi < x
                 for iv in by_hi {
                     if iv.hi < x { break; }
                     if iv.lo <= x { result.push(iv.id); }
                 }
                 left.stab(x, result);
             } else if x > *median {
                 // Check intervals sorted ascending by lo: stop when lo > x
                 for iv in by_lo {
                     if iv.lo > x { break; }
                     if iv.hi >= x { result.push(iv.id); }
                 }
                 right.stab(x, result);
             } else {
                 // x == median: all spanning intervals match
                 for iv in by_lo { result.push(iv.id); }
                 left.stab(x, result);
                 right.stab(x, result);
             }
         }
     }
 }
}
`std::cmp::Reverse` wraps a value to reverse its sort order โ€” `sort_unstable_by_key(|iv| Reverse(iv.hi))` sorts descending by hi without writing a custom comparator. The `drain(..)` idiom moves all elements out of `intervals`, consuming the vector efficiently. This avoids cloning during the partition step.

What This Unlocks

Key Differences

ConceptOCamlRust
Recursive ADT`type tree = Leaf \Node of ...``enum IntervalTree { Leaf, Node { ... } }`
Boxed subtree`tree` (GC-managed heap)`Box<IntervalTree>` โ€” explicit heap allocation
Reverse sort`List.sort (fun a b -> compare b.hi a.hi)``sort_unstable_by_key(\iv\Reverse(iv.hi))`
Drain / consume`List.iter` (no ownership)`Vec::drain(..)` โ€” moves elements, leaves vec empty
Result accumulationReturns list via recursionMutable `&mut Vec<usize>` โ€” avoids allocation on each call
/// Interval Tree for Stabbing Queries.
///
/// Each node stores its median and intervals spanning it in two sorted orders.
/// Query answers "which intervals contain x?" in O(log n + k).

#[derive(Clone, Debug)]
struct Interval {
    lo: f64,
    hi: f64,
    id: usize,
}

struct IntervalNode {
    median: f64,
    by_lo: Vec<Interval>,  // sorted ascending by lo
    by_hi: Vec<Interval>,  // sorted descending by hi
    left: Option<Box<IntervalNode>>,
    right: Option<Box<IntervalNode>>,
}

/// Build an interval tree from a list of intervals.
fn build(ivs: Vec<Interval>) -> Option<Box<IntervalNode>> {
    if ivs.is_empty() { return None; }

    // Compute median of all endpoints
    let mut endpoints: Vec<f64> = ivs.iter().flat_map(|iv| [iv.lo, iv.hi]).collect();
    endpoints.sort_by(|a, b| a.partial_cmp(b).unwrap());
    let median = endpoints[endpoints.len() / 2];

    let mut spanning = Vec::new();
    let mut left_ivs = Vec::new();
    let mut right_ivs = Vec::new();

    for iv in ivs {
        if iv.hi < median {
            left_ivs.push(iv);
        } else if iv.lo > median {
            right_ivs.push(iv);
        } else {
            spanning.push(iv);
        }
    }

    let mut by_lo = spanning.clone();
    let mut by_hi = spanning;
    by_lo.sort_by(|a, b| a.lo.partial_cmp(&b.lo).unwrap());
    by_hi.sort_by(|a, b| b.hi.partial_cmp(&a.hi).unwrap());

    Some(Box::new(IntervalNode {
        median,
        by_lo,
        by_hi,
        left: build(left_ivs),
        right: build(right_ivs),
    }))
}

/// Stabbing query: return all intervals containing x.
fn query(x: f64, node: &Option<Box<IntervalNode>>, results: &mut Vec<usize>) {
    let Some(n) = node else { return; };

    if x <= n.median {
        // Scan by_lo until lo > x (early termination)
        for iv in &n.by_lo {
            if iv.lo > x { break; }
            results.push(iv.id);
        }
        query(x, &n.left, results);
    } else {
        // Scan by_hi (desc) until hi < x
        for iv in &n.by_hi {
            if iv.hi < x { break; }
            results.push(iv.id);
        }
        query(x, &n.right, results);
    }
}

struct IntervalTree {
    root: Option<Box<IntervalNode>>,
}

impl IntervalTree {
    fn new(intervals: Vec<Interval>) -> Self {
        Self { root: build(intervals) }
    }

    fn stab(&self, x: f64) -> Vec<usize> {
        let mut results = Vec::new();
        query(x, &self.root, &mut results);
        results.sort();
        results
    }
}

fn main() {
    let ivs = vec![
        Interval { lo: 1.0, hi: 5.0,  id: 1 },
        Interval { lo: 2.0, hi: 8.0,  id: 2 },
        Interval { lo: 6.0, hi: 10.0, id: 3 },
        Interval { lo: 3.0, hi: 7.0,  id: 4 },
        Interval { lo: 9.0, hi: 12.0, id: 5 },
    ];
    let tree = IntervalTree::new(ivs);

    for x in [0.0f64, 3.0, 6.5, 9.5, 15.0] {
        println!("stab({x}): {:?}", tree.stab(x));
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn make_tree() -> IntervalTree {
        IntervalTree::new(vec![
            Interval { lo: 1.0, hi: 5.0,  id: 1 },
            Interval { lo: 2.0, hi: 8.0,  id: 2 },
            Interval { lo: 6.0, hi: 10.0, id: 3 },
            Interval { lo: 3.0, hi: 7.0,  id: 4 },
            Interval { lo: 9.0, hi: 12.0, id: 5 },
        ])
    }

    // Brute-force stab for verification
    fn brute_stab(x: f64, ivs: &[(f64, f64, usize)]) -> Vec<usize> {
        let mut r: Vec<usize> = ivs.iter()
            .filter(|&&(lo, hi, _)| lo <= x && x <= hi)
            .map(|&(_, _, id)| id)
            .collect();
        r.sort();
        r
    }

    #[test]
    fn test_stab_miss() {
        assert_eq!(make_tree().stab(0.0), vec![]);
        assert_eq!(make_tree().stab(15.0), vec![]);
    }

    #[test]
    fn test_stab_hit() {
        let t = make_tree();
        assert_eq!(t.stab(3.0), vec![1, 2, 4]);
        assert_eq!(t.stab(6.5), vec![2, 3, 4]);
        assert_eq!(t.stab(9.5), vec![3, 5]);
    }

    #[test]
    fn test_boundary() {
        let t = make_tree();
        assert!(t.stab(1.0).contains(&1));
        assert!(t.stab(5.0).contains(&1));
        assert!(!t.stab(5.1).contains(&1));
    }

    #[test]
    fn test_matches_brute() {
        let raw = vec![(1.0f64, 5.0, 1), (2.0, 8.0, 2), (6.0, 10.0, 3), (3.0, 7.0, 4), (9.0, 12.0, 5)];
        let t = make_tree();
        for x_int in 0..=15 {
            let x = x_int as f64;
            let mut tree_res = t.stab(x);
            let mut brute = brute_stab(x, &raw);
            tree_res.sort();
            brute.sort();
            assert_eq!(tree_res, brute, "mismatch at x={x}");
        }
    }
}
(* Interval Tree for Stabbing Queries in OCaml *)

type interval = { lo: float; hi: float; id: int }

(* Each node: median, intervals spanning median (sorted by lo and hi),
   left subtree (all intervals < median), right subtree (all > median) *)
type tree =
  | Empty
  | Node of {
      median  : float;
      by_lo   : interval list;  (* sorted ascending by lo *)
      by_hi   : interval list;  (* sorted descending by hi *)
      left    : tree;
      right   : tree;
    }

let build (intervals : interval list) : tree =
  let rec build_rec ivs =
    match ivs with
    | [] -> Empty
    | _ ->
      (* Median of endpoints *)
      let endpoints = List.concat_map (fun iv -> [iv.lo; iv.hi]) ivs in
      let sorted_ep = List.sort compare endpoints in
      let n = List.length sorted_ep in
      let median = List.nth sorted_ep (n / 2) in
      (* Partition: spanning, left, right *)
      let spanning = List.filter (fun iv -> iv.lo <= median && iv.hi >= median) ivs in
      let left_ivs = List.filter (fun iv -> iv.hi < median) ivs in
      let right_ivs = List.filter (fun iv -> iv.lo > median) ivs in
      let by_lo = List.sort (fun a b -> compare a.lo b.lo) spanning in
      let by_hi = List.sort (fun a b -> compare b.hi a.hi) spanning in
      Node {
        median;
        by_lo; by_hi;
        left = build_rec left_ivs;
        right = build_rec right_ivs;
      }
  in
  build_rec intervals

(* Stab query: all intervals containing point x *)
let query (x : float) (tree : tree) : interval list =
  let rec query_rec t acc =
    match t with
    | Empty -> acc
    | Node { median; by_lo; by_hi; left; right } ->
      let acc = if x <= median then begin
        (* Scan by_lo until lo > x *)
        let rec scan = function
          | [] -> acc
          | iv :: rest -> if iv.lo > x then acc else iv :: scan rest
        in
        scan by_lo
      end else begin
        (* Scan by_hi (desc) until hi < x *)
        let rec scan = function
          | [] -> acc
          | iv :: rest -> if iv.hi < x then acc else iv :: scan rest
        in
        scan by_hi
      end in
      let acc = if x < median then query_rec left acc else query_rec right acc in
      acc
  in
  query_rec tree []

let () =
  let ivs = [
    {lo=1.0; hi=5.0; id=1};
    {lo=2.0; hi=8.0; id=2};
    {lo=6.0; hi=10.0; id=3};
    {lo=3.0; hi=7.0; id=4};
    {lo=9.0; hi=12.0; id=5};
  ] in
  let tree = build ivs in
  let queries = [0.0; 3.0; 6.5; 9.5; 15.0] in
  List.iter (fun x ->
    let results = query x tree in
    let ids = List.map (fun iv -> iv.id) results in
    Printf.printf "stab(%.1f): ids=[%s]\n" x
      (String.concat "," (List.map string_of_int (List.sort compare ids)))
  ) queries