838: Interval Tree for Stabbing Queries
Difficulty: 4 Level: Advanced Answer "which intervals contain point x?" in O(log n + k) per query after O(n log n) preprocessing โ the centroid-decomposition approach.The Problem This Solves
Given a set of n intervals [lo, hi] and a query point x, find all intervals that contain x (i.e., lo โค x โค hi). This is the "stabbing query" โ which intervals are "stabbed" by the vertical line x = q? Stabbing queries appear in genomics (find all genes overlapping position p on a chromosome), database range queries (find all events active at time t), rendering pipelines (find all bounding boxes containing a pixel), and scheduling systems (find all tasks running at a given moment). A sorted list approach costs O(n) per query; an interval tree answers in O(log n + k) where k is the number of reported intervals. This example uses the median-based interval tree: each node stores its median point, and intervals that span the median in two sorted orders (by lo and by hi). Intervals that don't span the node's median are recursed into the left or right subtree.The Intuition
At each tree node, pick the median point m of all interval endpoints. Split intervals into three groups:- Spans median: lo โค m โค hi โ store at this node
- Entirely left of m: recurse into left subtree
- Entirely right of m: recurse into right subtree
- Sorted by lo (for queries where x โฅ m: report all with lo โค x)
- Sorted by hi (for queries where x โค m: report all with hi โฅ x)
How It Works in Rust
#[derive(Debug, Clone)]
struct Interval { lo: i64, hi: i64, id: usize }
enum IntervalTree {
Leaf,
Node {
median: i64,
by_lo: Vec<Interval>, // spanning intervals sorted by lo (ascending)
by_hi: Vec<Interval>, // spanning intervals sorted by hi (descending)
left: Box<IntervalTree>,
right: Box<IntervalTree>,
}
}
impl IntervalTree {
fn build(mut intervals: Vec<Interval>) -> Self {
if intervals.is_empty() { return IntervalTree::Leaf; }
// Collect all endpoints, find median
let mut endpoints: Vec<i64> = intervals.iter()
.flat_map(|iv| [iv.lo, iv.hi])
.collect();
endpoints.sort_unstable();
let median = endpoints[endpoints.len() / 2];
// Partition intervals
let mut left_ivs = vec![];
let mut right_ivs = vec![];
let mut spanning = vec![];
for iv in intervals.drain(..) {
if iv.hi < median { left_ivs.push(iv); }
else if iv.lo > median { right_ivs.push(iv); }
else { spanning.push(iv); }
}
let mut by_lo = spanning.clone();
let mut by_hi = spanning;
by_lo.sort_unstable_by_key(|iv| iv.lo);
by_hi.sort_unstable_by_key(|iv| std::cmp::Reverse(iv.hi));
IntervalTree::Node {
median,
by_lo,
by_hi,
left: Box::new(IntervalTree::build(left_ivs)),
right: Box::new(IntervalTree::build(right_ivs)),
}
}
fn stab(&self, x: i64, result: &mut Vec<usize>) {
match self {
IntervalTree::Leaf => {}
IntervalTree::Node { median, by_lo, by_hi, left, right } => {
if x < *median {
// Check intervals sorted descending by hi: stop when hi < x
for iv in by_hi {
if iv.hi < x { break; }
if iv.lo <= x { result.push(iv.id); }
}
left.stab(x, result);
} else if x > *median {
// Check intervals sorted ascending by lo: stop when lo > x
for iv in by_lo {
if iv.lo > x { break; }
if iv.hi >= x { result.push(iv.id); }
}
right.stab(x, result);
} else {
// x == median: all spanning intervals match
for iv in by_lo { result.push(iv.id); }
left.stab(x, result);
right.stab(x, result);
}
}
}
}
}
`std::cmp::Reverse` wraps a value to reverse its sort order โ `sort_unstable_by_key(|iv| Reverse(iv.hi))` sorts descending by hi without writing a custom comparator.
The `drain(..)` idiom moves all elements out of `intervals`, consuming the vector efficiently. This avoids cloning during the partition step.
What This Unlocks
- Genomic interval queries: find all annotated features (genes, exons, repeats) overlapping a given chromosomal position โ standard in bioinformatics pipelines.
- Event scheduling: query all active tasks/reservations at a given time โ used in calendar systems, resource schedulers, and simulation engines.
- Rendering: bounding-volume hierarchy (BVH) trees extend interval trees to 2D/3D for fast ray-object intersection in ray tracing.
Key Differences
| Concept | OCaml | Rust | ||
|---|---|---|---|---|
| Recursive ADT | `type tree = Leaf \ | Node of ...` | `enum IntervalTree { Leaf, Node { ... } }` | |
| Boxed subtree | `tree` (GC-managed heap) | `Box<IntervalTree>` โ explicit heap allocation | ||
| Reverse sort | `List.sort (fun a b -> compare b.hi a.hi)` | `sort_unstable_by_key(\ | iv\ | Reverse(iv.hi))` |
| Drain / consume | `List.iter` (no ownership) | `Vec::drain(..)` โ moves elements, leaves vec empty | ||
| Result accumulation | Returns list via recursion | Mutable `&mut Vec<usize>` โ avoids allocation on each call |
/// Interval Tree for Stabbing Queries.
///
/// Each node stores its median and intervals spanning it in two sorted orders.
/// Query answers "which intervals contain x?" in O(log n + k).
#[derive(Clone, Debug)]
struct Interval {
lo: f64,
hi: f64,
id: usize,
}
struct IntervalNode {
median: f64,
by_lo: Vec<Interval>, // sorted ascending by lo
by_hi: Vec<Interval>, // sorted descending by hi
left: Option<Box<IntervalNode>>,
right: Option<Box<IntervalNode>>,
}
/// Build an interval tree from a list of intervals.
fn build(ivs: Vec<Interval>) -> Option<Box<IntervalNode>> {
if ivs.is_empty() { return None; }
// Compute median of all endpoints
let mut endpoints: Vec<f64> = ivs.iter().flat_map(|iv| [iv.lo, iv.hi]).collect();
endpoints.sort_by(|a, b| a.partial_cmp(b).unwrap());
let median = endpoints[endpoints.len() / 2];
let mut spanning = Vec::new();
let mut left_ivs = Vec::new();
let mut right_ivs = Vec::new();
for iv in ivs {
if iv.hi < median {
left_ivs.push(iv);
} else if iv.lo > median {
right_ivs.push(iv);
} else {
spanning.push(iv);
}
}
let mut by_lo = spanning.clone();
let mut by_hi = spanning;
by_lo.sort_by(|a, b| a.lo.partial_cmp(&b.lo).unwrap());
by_hi.sort_by(|a, b| b.hi.partial_cmp(&a.hi).unwrap());
Some(Box::new(IntervalNode {
median,
by_lo,
by_hi,
left: build(left_ivs),
right: build(right_ivs),
}))
}
/// Stabbing query: return all intervals containing x.
fn query(x: f64, node: &Option<Box<IntervalNode>>, results: &mut Vec<usize>) {
let Some(n) = node else { return; };
if x <= n.median {
// Scan by_lo until lo > x (early termination)
for iv in &n.by_lo {
if iv.lo > x { break; }
results.push(iv.id);
}
query(x, &n.left, results);
} else {
// Scan by_hi (desc) until hi < x
for iv in &n.by_hi {
if iv.hi < x { break; }
results.push(iv.id);
}
query(x, &n.right, results);
}
}
struct IntervalTree {
root: Option<Box<IntervalNode>>,
}
impl IntervalTree {
fn new(intervals: Vec<Interval>) -> Self {
Self { root: build(intervals) }
}
fn stab(&self, x: f64) -> Vec<usize> {
let mut results = Vec::new();
query(x, &self.root, &mut results);
results.sort();
results
}
}
fn main() {
let ivs = vec![
Interval { lo: 1.0, hi: 5.0, id: 1 },
Interval { lo: 2.0, hi: 8.0, id: 2 },
Interval { lo: 6.0, hi: 10.0, id: 3 },
Interval { lo: 3.0, hi: 7.0, id: 4 },
Interval { lo: 9.0, hi: 12.0, id: 5 },
];
let tree = IntervalTree::new(ivs);
for x in [0.0f64, 3.0, 6.5, 9.5, 15.0] {
println!("stab({x}): {:?}", tree.stab(x));
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tree() -> IntervalTree {
IntervalTree::new(vec![
Interval { lo: 1.0, hi: 5.0, id: 1 },
Interval { lo: 2.0, hi: 8.0, id: 2 },
Interval { lo: 6.0, hi: 10.0, id: 3 },
Interval { lo: 3.0, hi: 7.0, id: 4 },
Interval { lo: 9.0, hi: 12.0, id: 5 },
])
}
// Brute-force stab for verification
fn brute_stab(x: f64, ivs: &[(f64, f64, usize)]) -> Vec<usize> {
let mut r: Vec<usize> = ivs.iter()
.filter(|&&(lo, hi, _)| lo <= x && x <= hi)
.map(|&(_, _, id)| id)
.collect();
r.sort();
r
}
#[test]
fn test_stab_miss() {
assert_eq!(make_tree().stab(0.0), vec![]);
assert_eq!(make_tree().stab(15.0), vec![]);
}
#[test]
fn test_stab_hit() {
let t = make_tree();
assert_eq!(t.stab(3.0), vec![1, 2, 4]);
assert_eq!(t.stab(6.5), vec![2, 3, 4]);
assert_eq!(t.stab(9.5), vec![3, 5]);
}
#[test]
fn test_boundary() {
let t = make_tree();
assert!(t.stab(1.0).contains(&1));
assert!(t.stab(5.0).contains(&1));
assert!(!t.stab(5.1).contains(&1));
}
#[test]
fn test_matches_brute() {
let raw = vec![(1.0f64, 5.0, 1), (2.0, 8.0, 2), (6.0, 10.0, 3), (3.0, 7.0, 4), (9.0, 12.0, 5)];
let t = make_tree();
for x_int in 0..=15 {
let x = x_int as f64;
let mut tree_res = t.stab(x);
let mut brute = brute_stab(x, &raw);
tree_res.sort();
brute.sort();
assert_eq!(tree_res, brute, "mismatch at x={x}");
}
}
}
(* Interval Tree for Stabbing Queries in OCaml *)
type interval = { lo: float; hi: float; id: int }
(* Each node: median, intervals spanning median (sorted by lo and hi),
left subtree (all intervals < median), right subtree (all > median) *)
type tree =
| Empty
| Node of {
median : float;
by_lo : interval list; (* sorted ascending by lo *)
by_hi : interval list; (* sorted descending by hi *)
left : tree;
right : tree;
}
let build (intervals : interval list) : tree =
let rec build_rec ivs =
match ivs with
| [] -> Empty
| _ ->
(* Median of endpoints *)
let endpoints = List.concat_map (fun iv -> [iv.lo; iv.hi]) ivs in
let sorted_ep = List.sort compare endpoints in
let n = List.length sorted_ep in
let median = List.nth sorted_ep (n / 2) in
(* Partition: spanning, left, right *)
let spanning = List.filter (fun iv -> iv.lo <= median && iv.hi >= median) ivs in
let left_ivs = List.filter (fun iv -> iv.hi < median) ivs in
let right_ivs = List.filter (fun iv -> iv.lo > median) ivs in
let by_lo = List.sort (fun a b -> compare a.lo b.lo) spanning in
let by_hi = List.sort (fun a b -> compare b.hi a.hi) spanning in
Node {
median;
by_lo; by_hi;
left = build_rec left_ivs;
right = build_rec right_ivs;
}
in
build_rec intervals
(* Stab query: all intervals containing point x *)
let query (x : float) (tree : tree) : interval list =
let rec query_rec t acc =
match t with
| Empty -> acc
| Node { median; by_lo; by_hi; left; right } ->
let acc = if x <= median then begin
(* Scan by_lo until lo > x *)
let rec scan = function
| [] -> acc
| iv :: rest -> if iv.lo > x then acc else iv :: scan rest
in
scan by_lo
end else begin
(* Scan by_hi (desc) until hi < x *)
let rec scan = function
| [] -> acc
| iv :: rest -> if iv.hi < x then acc else iv :: scan rest
in
scan by_hi
end in
let acc = if x < median then query_rec left acc else query_rec right acc in
acc
in
query_rec tree []
let () =
let ivs = [
{lo=1.0; hi=5.0; id=1};
{lo=2.0; hi=8.0; id=2};
{lo=6.0; hi=10.0; id=3};
{lo=3.0; hi=7.0; id=4};
{lo=9.0; hi=12.0; id=5};
] in
let tree = build ivs in
let queries = [0.0; 3.0; 6.5; 9.5; 15.0] in
List.iter (fun x ->
let results = query x tree in
let ids = List.map (fun iv -> iv.id) results in
Printf.printf "stab(%.1f): ids=[%s]\n" x
(String.concat "," (List.map string_of_int (List.sort compare ids)))
) queries