๐Ÿฆ€ Functional Rust

149: Extension Traits

Difficulty: 3 Level: Intermediate Add new methods to types you don't own โ€” without modifying them.

The Problem This Solves

You're using a library type โ€” `String`, `Iterator`, `Option` โ€” and you wish it had a method it doesn't have. In most languages you'd subclass or monkey-patch. Rust's orphan rules prevent adding trait impls arbitrarily, but there's a clean pattern: the extension trait. Define a new trait in your crate with the methods you want. Provide a blanket `impl` for all types that meet the requirements. Now every `T: Ord` gains your `clamp_to` and `is_between` methods โ€” without touching the standard library. This is how iterator adapter libraries work. It's how test utilities add `.assert_eq_sorted()` to iterators. It's structured and composable, unlike monkey-patching.

The Intuition

An extension trait is just a regular trait with: 1. Default method implementations derived from existing capabilities 2. A blanket `impl` that gives those methods to every qualifying type automatically
pub trait OrdExt: Ord + Sized + Clone {
 fn clamp_to(&self, lo: &Self, hi: &Self) -> Self { ... }
}
// Every T: Ord + Clone gets clamp_to for free:
impl<T: Ord + Clone> OrdExt for T {}
The blanket impl says: "for any type `T` that already knows how to compare itself (`Ord`) and copy itself (`Clone`), give it these extra methods." You wrote the methods once, they work everywhere.

How It Works in Rust

// Extension trait for any Ord + Clone type
pub trait OrdExt: Ord + Sized + Clone {
 fn clamp_to(&self, lo: &Self, hi: &Self) -> Self {
     if self < lo { lo.clone() }
     else if self > hi { hi.clone() }
     else { self.clone() }
 }

 fn is_between(&self, lo: &Self, hi: &Self) -> bool {
     self >= lo && self <= hi
 }
}

// Blanket impl โ€” automatically applies to every T: Ord + Clone
impl<T: Ord + Clone> OrdExt for T {}

// Now all of these work:
15_i32.clamp_to(&0, &10);          // 10
"banana".clamp_to(&"apple", &"cherry");  // "banana"
Extension trait for iterators:
pub trait IterExt: Iterator + Sized {
 fn sorted(self) -> Vec<Self::Item> where Self::Item: Ord {
     let mut v: Vec<_> = self.collect();
     v.sort();
     v
 }

 fn join_display(self, sep: &str) -> String where Self::Item: std::fmt::Display {
     self.map(|x| x.to_string()).collect::<Vec<_>>().join(sep)
 }
}

impl<I: Iterator> IterExt for I {}  // blanket impl

// Every iterator gets these:
[3, 1, 4, 1, 5].iter().copied().sorted()   // [1, 1, 3, 4, 5]
[1, 2, 3].iter().join_display(", ")         // "1, 2, 3"
Extension trait on `str` (a foreign type):
pub trait StrExt {
 fn title_case(&self) -> String;
 fn is_palindrome(&self) -> bool;
}

impl StrExt for str {  // str is foreign, but our trait is local โ€” allowed!
 fn title_case(&self) -> String { /* ... */ }
 fn is_palindrome(&self) -> bool { /* ... */ }
}

"hello world".title_case()   // "Hello World"
"racecar".is_palindrome()    // true

What This Unlocks

Key Differences

ConceptOCamlRust
MechanismFunctor: `OrdExt(O: ORD)` produces a module with derived opsBlanket `impl<T: Ord> OrdExt for T {}`
ActivationExplicit: `module IntOrdExt = OrdExt(Int)`Automatic: import the trait, all qualifying types gain it
Foreign typesFunctors work regardless of where `T` is fromExtension trait on foreign type works as long as your trait is local
Multiple instancesMultiple explicit module applicationsOne blanket impl (but can be more specific)
DiscoverabilityExplicit module application requiredTrait must be in scope (`use crate::IterExt`)
// Extension traits: define extra methods on types you don't own, or
// provide default implementations derived from a minimal interface.
//
// OCaml achieves this with functors (OrdExt).
// Rust uses the extension-trait pattern: a new public trait with blanket impls.

use std::fmt::Display;

// โ”€โ”€ OrdExt: derived ordering operations from a minimal Ord impl โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

/// Extension trait: any `T: Ord` gains these convenience methods.
pub trait OrdExt: Ord + Sized + Clone {
    fn clamp_to(&self, lo: &Self, hi: &Self) -> Self {
        if self < lo { lo.clone() }
        else if self > hi { hi.clone() }
        else { self.clone() }
    }

    fn is_between(&self, lo: &Self, hi: &Self) -> bool {
        self >= lo && self <= hi
    }

    fn min_of(self, other: Self) -> Self {
        std::cmp::min(self, other)
    }

    fn max_of(self, other: Self) -> Self {
        std::cmp::max(self, other)
    }
}

/// Blanket impl: every type with `Ord + Clone` gets OrdExt for free.
impl<T: Ord + Clone> OrdExt for T {}

// โ”€โ”€ IterExt: additional iterator combinators โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

pub trait IterExt: Iterator + Sized {
    /// Collect into a sorted Vec.
    fn sorted(self) -> Vec<Self::Item>
    where
        Self::Item: Ord,
    {
        let mut v: Vec<Self::Item> = self.collect();
        v.sort();
        v
    }

    /// Collect into a sorted Vec (by key).
    fn sorted_by_key_ext<K: Ord, F: FnMut(&Self::Item) -> K>(self, f: F) -> Vec<Self::Item>
    {
        let mut v: Vec<Self::Item> = self.collect();
        v.sort_by_key(f);
        v
    }

    /// Running sum (prefix scan).
    fn running_sum(self) -> Vec<i64>
    where
        Self::Item: Into<i64>,
    {
        let mut acc = 0_i64;
        self.map(|x| {
            acc += x.into();
            acc
        })
        .collect()
    }

    /// Collect into a String, joining with a separator.
    fn join_display(self, sep: &str) -> String
    where
        Self::Item: Display,
    {
        let parts: Vec<String> = self.map(|x| x.to_string()).collect();
        parts.join(sep)
    }

    /// Take elements in chunks of `n`.
    fn chunks_of(self, n: usize) -> Vec<Vec<Self::Item>>
    where
        Self::Item: Clone,
    {
        let v: Vec<Self::Item> = self.collect();
        v.chunks(n).map(|c| c.to_vec()).collect()
    }

    /// Interleave with another iterator.
    fn interleave<I: Iterator<Item = Self::Item>>(self, other: I) -> Vec<Self::Item> {
        let mut a = self.peekable();
        let mut b = other.peekable();
        let mut out = Vec::new();
        loop {
            match (a.next(), b.next()) {
                (Some(x), Some(y)) => { out.push(x); out.push(y); }
                (Some(x), None)    => { out.push(x); out.extend(a); break; }
                (None,    Some(y)) => { out.push(y); out.extend(b); break; }
                (None,    None)    => break,
            }
        }
        out
    }
}

/// Blanket impl: every `Iterator` gets `IterExt`.
impl<I: Iterator> IterExt for I {}

// โ”€โ”€ StrExt: extra string operations โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

pub trait StrExt {
    fn word_count(&self) -> usize;
    fn title_case(&self) -> String;
    fn is_palindrome(&self) -> bool;
    fn indent(&self, n: usize) -> String;
}

impl StrExt for str {
    fn word_count(&self) -> usize {
        self.split_whitespace().count()
    }

    fn title_case(&self) -> String {
        self.split_whitespace()
            .map(|word| {
                let mut chars = word.chars();
                match chars.next() {
                    None => String::new(),
                    Some(c) => {
                        let upper: String = c.to_uppercase().collect();
                        upper + &chars.as_str().to_lowercase()
                    }
                }
            })
            .collect::<Vec<_>>()
            .join(" ")
    }

    fn is_palindrome(&self) -> bool {
        let s: String = self.chars().filter(|c| c.is_alphanumeric()).collect();
        let s = s.to_lowercase();
        s == s.chars().rev().collect::<String>()
    }

    fn indent(&self, n: usize) -> String {
        let prefix = " ".repeat(n);
        self.lines()
            .map(|line| format!("{}{}", prefix, line))
            .collect::<Vec<_>>()
            .join("\n")
    }
}

// โ”€โ”€ OptionExt: extra Option combinators โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

pub trait OptionExt<T> {
    /// Filter with a predicate; None if predicate fails.
    fn filter_ext<F: FnOnce(&T) -> bool>(self, f: F) -> Option<T>;

    /// Convert None to Err.
    fn ok_or_msg(self, msg: &'static str) -> Result<T, &'static str>;

    /// Tap: peek at the value if Some, then pass it through.
    fn tap<F: FnOnce(&T)>(self, f: F) -> Option<T>;
}

impl<T> OptionExt<T> for Option<T> {
    fn filter_ext<F: FnOnce(&T) -> bool>(self, f: F) -> Option<T> {
        self.and_then(|v| if f(&v) { Some(v) } else { None })
    }

    fn ok_or_msg(self, msg: &'static str) -> Result<T, &'static str> {
        self.ok_or(msg)
    }

    fn tap<F: FnOnce(&T)>(self, f: F) -> Option<T> {
        if let Some(ref v) = self { f(v); }
        self
    }
}

fn main() {
    // OrdExt
    println!("min_of 3 5 = {}", 3_i32.min_of(5));
    println!("max_of 3 5 = {}", 3_i32.max_of(5));
    println!("clamp 0..10: 15 โ†’ {}", 15_i32.clamp_to(&0, &10));
    println!("clamp 0..10:  5 โ†’ {}", 5_i32.clamp_to(&0, &10));
    println!("between 0..10: 7 = {}", 7_i32.is_between(&0, &10));
    println!("sorted strings: {:?}",
        ["banana", "apple", "cherry"].iter().copied().sorted());

    // IterExt
    let sum: Vec<i64> = [1_i32, 2, 3, 4, 5].iter().copied().running_sum();
    println!("running sum: {:?}", sum);

    let joined = [1, 2, 3, 4].iter().join_display(", ");
    println!("joined: {}", joined);

    let chunks = (0..8_i32).chunks_of(3);
    println!("chunks(3): {:?}", chunks);

    let merged = [1, 3, 5].iter().copied()
        .interleave([2, 4, 6].iter().copied());
    println!("interleaved: {:?}", merged);

    // StrExt
    println!("word_count: {}", "hello world foo".word_count());
    println!("title_case: {}", "hello world foo".title_case());
    println!("is_palindrome 'racecar': {}", "racecar".is_palindrome());
    println!("is_palindrome 'hello': {}", "hello".is_palindrome());
    println!("is_palindrome 'A man a plan a canal Panama': {}",
        "A man a plan a canal Panama".is_palindrome());

    // OptionExt
    let v = Some(42_i32)
        .filter_ext(|&n| n > 0)
        .tap(|n| print!("tap: {} โ†’ ", n))
        .map(|n| n * 2);
    println!("result: {:?}", v);
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_ord_ext_clamp() {
        assert_eq!(15_i32.clamp_to(&0, &10), 10);
        assert_eq!((-5_i32).clamp_to(&0, &10), 0);
        assert_eq!(5_i32.clamp_to(&0, &10), 5);
    }

    #[test]
    fn test_ord_ext_between() {
        assert!(5_i32.is_between(&0, &10));
        assert!(!15_i32.is_between(&0, &10));
    }

    #[test]
    fn test_iter_sorted() {
        let v = vec![3, 1, 4, 1, 5].into_iter().sorted();
        assert_eq!(v, vec![1, 1, 3, 4, 5]);
    }

    #[test]
    fn test_iter_running_sum() {
        let v = [1_i32, 2, 3, 4].iter().copied().running_sum();
        assert_eq!(v, vec![1, 3, 6, 10]);
    }

    #[test]
    fn test_str_title_case() {
        assert_eq!("hello world".title_case(), "Hello World");
        assert_eq!("the quick BROWN fox".title_case(), "The Quick Brown Fox");
    }

    #[test]
    fn test_str_palindrome() {
        assert!("racecar".is_palindrome());
        assert!("A man a plan a canal Panama".is_palindrome());
        assert!(!"hello".is_palindrome());
    }

    #[test]
    fn test_option_ext_filter() {
        assert_eq!(Some(5_i32).filter_ext(|&n| n > 3), Some(5));
        assert_eq!(Some(2_i32).filter_ext(|&n| n > 3), None);
    }

    #[test]
    fn test_option_ext_tap() {
        let mut saw = 0_i32;
        let _ = Some(42_i32).tap(|&n| saw = n);
        assert_eq!(saw, 42);
    }
}
(* Extension pattern: given a minimal interface, derive many extra functions.
   Like Rust's extension traits, but via OCaml functors. *)

module type ORD = sig
  type t
  val compare : t -> t -> int
end

(* Extend any ORD with derived operations *)
module OrdExt (O : ORD) = struct
  include O
  let ( <  ) a b = compare a b < 0
  let ( >  ) a b = compare a b > 0
  let ( <= ) a b = compare a b <= 0
  let ( >= ) a b = compare a b >= 0
  let ( =  ) a b = compare a b = 0
  let min a b = if a < b then a else b
  let max a b = if a > b then a else b
  let clamp ~lo ~hi x = max lo (min hi x)
  let between ~lo ~hi x = lo <= x && x <= hi
  let sort lst = List.sort compare lst
end

module IntOrd = OrdExt (struct
  type t = int
  let compare = Int.compare
end)

module StringOrd = OrdExt (struct
  type t = string
  let compare = String.compare
end)

let () =
  let open IntOrd in
  Printf.printf "min 3 5 = %d\n" (min 3 5);
  Printf.printf "max 3 5 = %d\n" (max 3 5);
  Printf.printf "clamp 0 10 15 = %d\n" (clamp ~lo:0 ~hi:10 15);
  Printf.printf "between 0 10 7 = %b\n" (between ~lo:0 ~hi:10 7);
  Printf.printf "sort [3;1;4;1;5] = %s\n"
    (sort [3;1;4;1;5] |> List.map string_of_int |> String.concat ",");

  let open StringOrd in
  Printf.printf "sort strings: %s\n"
    (sort ["banana";"apple";"cherry"] |> String.concat ",")

๐Ÿ“Š Detailed Comparison

Comparison: Extension Traits

OCaml

๐Ÿช Show OCaml equivalent
(* Module-level functions โ€” no method syntax *)
module StringExt = struct
let is_blank s = String.trim s = ""
let truncate n s =
 if String.length s <= n then s
 else String.sub s 0 n ^ "..."
end

let _ = StringExt.is_blank "   "

Rust

// Extension trait โ€” method syntax on external types!
trait StringExt {
 fn is_blank(&self) -> bool;
 fn truncate_with_ellipsis(&self, max_len: usize) -> String;
}

impl StringExt for str {
 fn is_blank(&self) -> bool { self.trim().is_empty() }
 fn truncate_with_ellipsis(&self, n: usize) -> String { /* ... */ }
}

// Natural method call
"   ".is_blank()  // true