๐Ÿฆ€ Functional Rust
๐ŸŽฌ Closures in Rust Fn/FnMut/FnOnce, capturing environment, move closures, higher-order functions.
๐Ÿ“ Text version (for readers / accessibility)

โ€ข Closures capture variables from their environment โ€” by reference, mutable reference, or by value (move)

โ€ข Three traits: Fn (shared borrow), FnMut (mutable borrow), FnOnce (takes ownership)

โ€ข Higher-order functions like .map(), .filter(), .fold() accept closures as arguments

โ€ข move closures take ownership of captured variables โ€” essential for threading

โ€ข Closures enable functional patterns: partial application, composition, and strategy

520: Higher-Order Functions

Difficulty: 2 Level: Beginner-Intermediate Functions that take or return other functions โ€” the backbone of Rust's iterator API and functional programming style.

The Problem This Solves

Without higher-order functions, every data transformation is a loop with slightly different body. Want the sum of squares of even numbers? One loop. Want the maximum after a transformation? Another loop. Want to group by a property? Yet another loop. The structure is identical; only the inner logic changes. Higher-order functions (HOFs) extract the structure and let you supply only what varies. `filter(is_even)` + `map(square)` + `sum()` replaces a hand-written loop, and each step is independently testable and named. The secondary problem: without lazy HOFs, each step allocates an intermediate collection. Rust's iterator HOFs are lazy โ€” they fuse into a single loop at compile time, eliminating intermediate allocations entirely.

The Intuition

Higher-order functions treat behavior as data. `map` is a machine that applies your transformation to each element. `filter` is a machine that applies your predicate. `fold` is a machine that applies your accumulator. You supply the behavior; they supply the infrastructure. Python has `map()`, `filter()`, and `functools.reduce()`. JavaScript has `Array.map()`, `.filter()`, `.reduce()`. Rust's iterator methods are the same, but with a crucial difference: they're lazy. Nothing computes until you call a consuming adapter like `.collect()` or `.sum()`. Laziness means `(1..1_000_000).filter(is_prime).take(5)` only computes until it finds 5 primes โ€” it doesn't check all million numbers first.

How It Works in Rust

let nums: Vec<i32> = (1..=10).collect();

// map: transform each element
let squares: Vec<i32> = nums.iter().map(|&x| x * x).collect();

// filter: keep elements matching predicate
let evens: Vec<i32> = nums.iter().filter(|&&x| x % 2 == 0).copied().collect();

// fold: accumulate (the general HOF โ€” map and filter are special cases)
let sum: i32 = nums.iter().fold(0, |acc, &x| acc + x);

// chained pipeline โ€” LAZY: single pass, no intermediate allocations
let result: i32 = nums.iter()
 .filter(|&&x| x % 2 == 0)   // keep evens
 .map(|&x| x * x)             // square them
 .sum();                       // accumulate

// flat_map: map then flatten (like Python's chain of map with list results)
let pairs: Vec<i32> = [1, 2, 3].iter()
 .flat_map(|&x| [x, x * 10])  // each element becomes two elements
 .collect();  // [1, 10, 2, 20, 3, 30]

// any/all: short-circuit HOFs
println!("{}", nums.iter().any(|&x| x > 5));  // true (stops at first match)
println!("{}", nums.iter().all(|&x| x > 0));  // true

// Custom HOF: zip two slices with a combining function
fn zip_with<A, B, C, F: Fn(&A, &B) -> C>(a: &[A], b: &[B], f: F) -> Vec<C> {
 a.iter().zip(b.iter()).map(|(x, y)| f(x, y)).collect()
}
let sums = zip_with(&[1, 2, 3], &[10, 20, 30], |x, y| x + y); // [11, 22, 33]

// Custom HOF: scan (running totals โ€” all intermediate fold values)
fn scan_left<T: Clone, U: Clone, F: Fn(U, &T) -> U>(
 items: &[T], init: U, f: F
) -> Vec<U> {
 let mut acc = init.clone();
 std::iter::once(init)
     .chain(items.iter().map(move |item| { acc = f(acc.clone(), item); acc.clone() }))
     .collect()
}

What This Unlocks

Key Differences

ConceptOCamlRust
Map`List.map f xs``iter.map(f).collect()`
Filter`List.filter pred xs``iter.filter(pred).collect()`
Fold`List.fold_left f init xs``iter.fold(init, f)`
Flat map`List.concat_map f xs``iter.flat_map(f).collect()`
LazinessEager by defaultLazy by default โ€” consumes only on `.collect()`
//! # 520. Higher-Order Functions
//! Rust's iterator HOFs: map, filter, fold, flat_map, zip, and custom ones.

/// Custom HOF: zip two slices with a combining function
fn zip_with<A, B, C, F>(a: &[A], b: &[B], f: F) -> Vec<C>
where
    F: Fn(&A, &B) -> C,
{
    a.iter().zip(b.iter()).map(|(x, y)| f(x, y)).collect()
}

/// Custom HOF: scan (like fold but returns all intermediate values)
fn scan_left<T: Clone, U: Clone, F>(items: &[T], init: U, f: F) -> Vec<U>
where
    F: Fn(U, &T) -> U,
{
    let mut acc = init;
    let mut result = vec![acc.clone()];
    for item in items {
        acc = f(acc, item);
        result.push(acc.clone());
    }
    result
}

/// Custom HOF: group consecutive elements by a key
fn group_by<T, K, F>(items: Vec<T>, key: F) -> Vec<(K, Vec<T>)>
where
    K: PartialEq,
    F: Fn(&T) -> K,
{
    let mut groups: Vec<(K, Vec<T>)> = Vec::new();
    for item in items {
        let k = key(&item);
        if let Some(last) = groups.last_mut() {
            if last.0 == k {
                last.1.push(item);
                continue;
            }
        }
        groups.push((k, vec![item]));
    }
    groups
}

fn main() {
    let nums: Vec<i32> = (1..=10).collect();

    // map
    let squares: Vec<i32> = nums.iter().map(|&x| x * x).collect();
    println!("squares: {:?}", squares);

    // filter
    let evens: Vec<i32> = nums.iter().filter(|&&x| x % 2 == 0).copied().collect();
    println!("evens: {:?}", evens);

    // fold
    let sum: i32 = nums.iter().fold(0, |acc, &x| acc + x);
    println!("sum: {}", sum);

    // chained pipeline (lazy โ€” no intermediate allocations)
    let sum_even_squares: i32 = nums.iter()
        .filter(|&&x| x % 2 == 0)
        .map(|&x| x * x)
        .sum();
    println!("sum of even squares: {}", sum_even_squares);

    // flat_map
    let expanded: Vec<i32> = [1, 2, 3].iter()
        .flat_map(|&x| vec![x, x * 10])
        .collect();
    println!("flat_map: {:?}", expanded);

    // zip
    let a = [1, 2, 3];
    let b = [10, 20, 30];
    let sums = zip_with(&a, &b, |x, y| x + y);
    println!("zip_with(+): {:?}", sums);

    // any / all
    println!("any > 5: {}", nums.iter().any(|&x| x > 5));
    println!("all > 0: {}", nums.iter().all(|&x| x > 0));

    // take_while / skip_while
    let ascending: Vec<i32> = nums.iter().copied().take_while(|&x| x <= 5).collect();
    println!("take_while <=5: {:?}", ascending);

    // scan (running totals)
    let running = scan_left(&nums[..5], 0, |acc, &x| acc + x);
    println!("running totals: {:?}", running);

    // group_by
    let words = vec!["ant", "ape", "bear", "bee", "cat"];
    let grouped = group_by(words, |w| w.chars().next().unwrap());
    for (letter, group) in &grouped {
        println!("  '{}': {:?}", letter, group);
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_map_filter_fold() {
        let v = vec![1, 2, 3, 4, 5];
        let result: i32 = v.iter()
            .filter(|&&x| x % 2 != 0)
            .map(|&x| x * x)
            .sum();
        assert_eq!(result, 1 + 9 + 25); // 35
    }

    #[test]
    fn test_zip_with() {
        let a = [1, 2, 3];
        let b = [10, 20, 30];
        assert_eq!(zip_with(&a, &b, |x, y| x * y), vec![10, 40, 90]);
    }

    #[test]
    fn test_scan_left() {
        let v = [1, 2, 3, 4];
        let running = scan_left(&v, 0, |acc, &x| acc + x);
        assert_eq!(running, vec![0, 1, 3, 6, 10]);
    }

    #[test]
    fn test_flat_map() {
        let v: Vec<i32> = [1, 2, 3].iter().flat_map(|&x| [x, -x]).collect();
        assert_eq!(v, [1, -1, 2, -2, 3, -3]);
    }
}
(* Higher-order functions in OCaml *)

(* Custom HOFs *)
let my_map f = List.map f
let my_filter pred = List.filter pred
let my_fold_left f init = List.fold_left f init
let my_for_all pred = List.for_all pred

(* Pipeline operator *)
let ( |> ) x f = f x

let () =
  let nums = [1; 2; 3; 4; 5; 6; 7; 8; 9; 10] in

  (* Map *)
  let squares = List.map (fun x -> x * x) nums in
  Printf.printf "squares: [%s]\n" (String.concat ";" (List.map string_of_int squares));

  (* Filter *)
  let evens = List.filter (fun x -> x mod 2 = 0) nums in
  Printf.printf "evens: [%s]\n" (String.concat ";" (List.map string_of_int evens));

  (* Fold *)
  let sum = List.fold_left (+) 0 nums in
  Printf.printf "sum: %d\n" sum;

  (* Pipeline *)
  let result =
    nums
    |> List.filter (fun x -> x mod 2 = 0)
    |> List.map (fun x -> x * x)
    |> List.fold_left (+) 0
  in
  Printf.printf "sum of even squares: %d\n" result;

  (* any / all *)
  Printf.printf "any >5: %b\n" (List.exists (fun x -> x > 5) nums);
  Printf.printf "all >0: %b\n" (List.for_all (fun x -> x > 0) nums);

  (* flat_map *)
  let expanded = List.concat_map (fun x -> [x; x*10]) [1;2;3] in
  Printf.printf "flat_map: [%s]\n" (String.concat ";" (List.map string_of_int expanded))