๐Ÿฆ€ Functional Rust

965: Segment Tree

Difficulty: Intermediate Category: Data Structures / Trees Concept: Complete binary tree in an array for O(log n) range sum queries and point updates Key Insight: Both languages use the same 1-indexed tree layout (node `i` has children `2i` and `2i+1`); OCaml stores recursive helpers as module-level functions, Rust stores them as private `fn` methods โ€” the recursive structure is identical
// 965: Segment Tree for Range Sum Queries
// 1-indexed internal nodes; O(log n) point update and range sum

pub struct SegmentTree {
    n: usize,
    tree: Vec<i64>,
}

impl SegmentTree {
    pub fn new(n: usize) -> Self {
        SegmentTree {
            n,
            tree: vec![0i64; 4 * n],
        }
    }

    pub fn build(&mut self, arr: &[i64]) {
        self.build_rec(1, 0, self.n - 1, arr);
    }

    fn build_rec(&mut self, node: usize, lo: usize, hi: usize, arr: &[i64]) {
        if lo == hi {
            self.tree[node] = arr[lo];
        } else {
            let mid = (lo + hi) / 2;
            self.build_rec(2 * node, lo, mid, arr);
            self.build_rec(2 * node + 1, mid + 1, hi, arr);
            self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1];
        }
    }

    /// Point update: set position `pos` to `value`
    pub fn update(&mut self, pos: usize, value: i64) {
        self.update_rec(1, 0, self.n - 1, pos, value);
    }

    fn update_rec(&mut self, node: usize, lo: usize, hi: usize, pos: usize, value: i64) {
        if lo == hi {
            self.tree[node] = value;
        } else {
            let mid = (lo + hi) / 2;
            if pos <= mid {
                self.update_rec(2 * node, lo, mid, pos, value);
            } else {
                self.update_rec(2 * node + 1, mid + 1, hi, pos, value);
            }
            self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1];
        }
    }

    /// Range sum query [l, r] (inclusive, 0-indexed)
    pub fn query(&self, l: usize, r: usize) -> i64 {
        self.query_rec(1, 0, self.n - 1, l, r)
    }

    fn query_rec(&self, node: usize, lo: usize, hi: usize, l: usize, r: usize) -> i64 {
        if r < lo || hi < l {
            0
        } else if l <= lo && hi <= r {
            self.tree[node]
        } else {
            let mid = (lo + hi) / 2;
            self.query_rec(2 * node, lo, mid, l, r)
                + self.query_rec(2 * node + 1, mid + 1, hi, l, r)
        }
    }
}

fn main() {
    let arr = vec![1i64, 3, 5, 7, 9, 11];
    let n = arr.len();
    let mut st = SegmentTree::new(n);
    st.build(&arr);

    println!("total sum: {}", st.query(0, n - 1));
    println!("sum [0,2]: {}", st.query(0, 2));
    println!("sum [2,4]: {}", st.query(2, 4));

    st.update(2, 10);
    println!("after update arr[2]=10:");
    println!("total sum: {}", st.query(0, n - 1));
    println!("sum [0,2]: {}", st.query(0, 2));
}

#[cfg(test)]
mod tests {
    use super::*;

    fn make_tree() -> SegmentTree {
        let arr = vec![1i64, 3, 5, 7, 9, 11];
        let mut st = SegmentTree::new(arr.len());
        st.build(&arr);
        st
    }

    #[test]
    fn test_total_sum() {
        let st = make_tree();
        assert_eq!(st.query(0, 5), 36);
    }

    #[test]
    fn test_range_queries() {
        let st = make_tree();
        assert_eq!(st.query(0, 2), 9);  // 1+3+5
        assert_eq!(st.query(2, 4), 21); // 5+7+9
        assert_eq!(st.query(1, 3), 15); // 3+5+7
        assert_eq!(st.query(5, 5), 11); // single element
    }

    #[test]
    fn test_point_update() {
        let mut st = make_tree();
        st.update(2, 10); // replace 5 with 10
        assert_eq!(st.query(0, 5), 41); // 36 - 5 + 10
        assert_eq!(st.query(0, 2), 14); // 1+3+10
        assert_eq!(st.query(2, 4), 26); // 10+7+9
    }

    #[test]
    fn test_single_element() {
        let arr = vec![42i64];
        let mut st = SegmentTree::new(1);
        st.build(&arr);
        assert_eq!(st.query(0, 0), 42);
        st.update(0, 100);
        assert_eq!(st.query(0, 0), 100);
    }

    #[test]
    fn test_multiple_updates() {
        let mut st = make_tree();
        st.update(0, 0);
        st.update(5, 0);
        assert_eq!(st.query(0, 5), 24); // 0+3+5+7+9+0
    }
}
(* 965: Segment Tree for Range Sum Queries *)
(* 1-indexed internal nodes, supports point update + range query in O(log n) *)

type segment_tree = {
  n: int;
  tree: int array;  (* 1-indexed, size 4*n *)
}

let create n = { n; tree = Array.make (4 * n) 0 }

(* Build from array *)
let rec build st node lo hi arr =
  if lo = hi then
    st.tree.(node) <- arr.(lo)
  else begin
    let mid = (lo + hi) / 2 in
    build st (2*node) lo mid arr;
    build st (2*node+1) (mid+1) hi arr;
    st.tree.(node) <- st.tree.(2*node) + st.tree.(2*node+1)
  end

(* Point update: set position pos to value *)
let rec update st node lo hi pos value =
  if lo = hi then
    st.tree.(node) <- value
  else begin
    let mid = (lo + hi) / 2 in
    if pos <= mid then update st (2*node) lo mid pos value
    else update st (2*node+1) (mid+1) hi pos value;
    st.tree.(node) <- st.tree.(2*node) + st.tree.(2*node+1)
  end

(* Range sum query [l, r] *)
let rec query st node lo hi l r =
  if r < lo || hi < l then 0
  else if l <= lo && hi <= r then st.tree.(node)
  else begin
    let mid = (lo + hi) / 2 in
    query st (2*node) lo mid l r +
    query st (2*node+1) (mid+1) hi l r
  end

let st_update st pos value = update st 1 0 (st.n-1) pos value
let st_query st l r = query st 1 0 (st.n-1) l r
let st_build st arr = build st 1 0 (st.n-1) arr

let () =
  let arr = [| 1; 3; 5; 7; 9; 11 |] in
  let n = Array.length arr in
  let st = create n in
  st_build st arr;

  (* Sum of entire array *)
  assert (st_query st 0 (n-1) = 36);

  (* Range queries *)
  assert (st_query st 0 2 = 9);   (* 1+3+5 *)
  assert (st_query st 2 4 = 21);  (* 5+7+9 *)
  assert (st_query st 1 3 = 15);  (* 3+5+7 *)

  (* Point update *)
  st_update st 2 10;   (* arr[2] = 10 instead of 5 *)
  assert (st_query st 0 (n-1) = 41);  (* 36 - 5 + 10 *)
  assert (st_query st 0 2 = 14);  (* 1+3+10 *)
  assert (st_query st 2 4 = 26);  (* 10+7+9 *)

  Printf.printf "โœ“ All tests passed\n"

๐Ÿ“Š Detailed Comparison

Segment Tree โ€” Comparison

Core Insight

A segment tree stores aggregate values (sums, min, max) for array ranges in a complete binary tree laid out in an array. Node `i` covers a range; its children at `2i` and `2i+1` cover the halves. Both OCaml and Rust implement identical recursive build/update/query โ€” the tree layout and algorithm are language-agnostic.

OCaml Approach

  • `Array.make (4 * n) 0` โ€” 4n slots ensures enough space for any n
  • Recursive `build`, `update`, `query` as top-level functions
  • Public wrappers `st_build`, `st_update`, `st_query` start at node 1
  • `st.tree.(node) <- st.tree.(2node) + st.tree.(2node+1)` โ€” push-up
  • Early returns `0` for out-of-range, `st.tree.(node)` for fully covered

Rust Approach

  • `vec![0i64; 4 * n]` โ€” same layout
  • Private `build_rec`, `update_rec`, `query_rec` methods on struct
  • Public `build`, `update`, `query` as clean API
  • `self.tree[node] = self.tree[2node] + self.tree[2node+1]` โ€” push-up
  • `usize` indices (no negative โ€” avoids signed/unsigned confusion)

Comparison Table

AspectOCamlRust
Storage`int array` (4n)`Vec<i64>` (4n)
RecursionTop-level functions with `st` argPrivate `_rec` methods
Index type`int``usize`
Push-up`tree.(n) <- tree.(2n) + tree.(2n+1)``tree[n] = tree[2n] + tree[2n+1]`
Range miss`0``0`
Range cover`tree.(node)``tree[node]`
Split point`mid = (lo + hi) / 2``mid = (lo + hi) / 2`
Build init`Array.make (4n) 0``vec![0i64; 4n]`