// 078: Where Clauses
// Complex trait bounds using where syntax
use std::fmt::Display;
use std::ops::{Add, Mul};
// Approach 1: Where clause for readability
fn print_if_equal<T>(a: &T, b: &T) -> String
where
T: Display + PartialEq,
{
if a == b {
format!("{} == {}", a, b)
} else {
format!("{} != {}", a, b)
}
}
// Approach 2: Multiple type params with where
fn zip_with<A, B, C, F>(a: &[A], b: &[B], f: F) -> Vec<C>
where
A: Clone,
B: Clone,
F: Fn(A, B) -> C,
{
a.iter().cloned().zip(b.iter().cloned()).map(|(x, y)| f(x, y)).collect()
}
// Approach 3: Associated type bounds
fn sum_items<I>(iter: I) -> I::Item
where
I: Iterator,
I::Item: Add<Output = I::Item> + Default,
{
iter.fold(I::Item::default(), |acc, x| acc + x)
}
fn dot_product<T>(a: &[T], b: &[T]) -> T
where
T: Add<Output = T> + Mul<Output = T> + Default + Copy,
{
a.iter().zip(b.iter()).fold(T::default(), |acc, (&x, &y)| acc + x * y)
}
// Complex: display collection of displayable items
fn display_collection<I>(iter: I) -> String
where
I: IntoIterator,
I::Item: Display,
{
let items: Vec<String> = iter.into_iter().map(|x| format!("{}", x)).collect();
format!("[{}]", items.join(", "))
}
fn main() {
println!("{}", print_if_equal(&5, &5));
println!("{}", print_if_equal(&3, &4));
println!("zip_with: {:?}", zip_with(&[1, 2, 3], &[4, 5, 6], |a, b| a + b));
println!("dot: {}", dot_product(&[1, 2, 3], &[4, 5, 6]));
println!("{}", display_collection(vec![1, 2, 3]));
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_print_if_equal() {
assert_eq!(print_if_equal(&5, &5), "5 == 5");
assert_eq!(print_if_equal(&3, &4), "3 != 4");
}
#[test]
fn test_zip_with() {
assert_eq!(zip_with(&[1, 2, 3], &[4, 5, 6], |a, b| a + b), vec![5, 7, 9]);
assert_eq!(zip_with(&[1, 2], &[3, 4], |a, b| a * b), vec![3, 8]);
}
#[test]
fn test_sum_items() {
assert_eq!(sum_items(vec![1, 2, 3, 4, 5].into_iter()), 15);
}
#[test]
fn test_dot_product() {
assert_eq!(dot_product(&[1, 2, 3], &[4, 5, 6]), 32);
}
#[test]
fn test_display_collection() {
assert_eq!(display_collection(vec![1, 2, 3]), "[1, 2, 3]");
}
}
(* 078: Where Clauses โ OCaml functor constraints *)
(* OCaml uses module signatures as "where clause" equivalent *)
module type SUMMABLE = sig
type t
val zero : t
val add : t -> t -> t
val to_string : t -> string
end
module type MULTIPLIABLE = sig
type t
val one : t
val mul : t -> t -> t
end
(* Functor with multiple constraints *)
module MathOps (S : SUMMABLE) = struct
let sum lst = List.fold_left S.add S.zero lst
let sum_to_string lst =
S.to_string (sum lst)
end
module IntSum = MathOps(struct
type t = int
let zero = 0
let add = ( + )
let to_string = string_of_int
end)
module FloatSum = MathOps(struct
type t = float
let zero = 0.0
let add = ( +. )
let to_string = string_of_float
end)
(* Complex constraint: both summable and multipliable *)
module type RING = sig
include SUMMABLE
include MULTIPLIABLE with type t := t
end
module RingOps (R : RING) = struct
let dot_product a b =
List.fold_left2 (fun acc x y -> R.add acc (R.mul x y)) R.zero a b
end
module IntRing = RingOps(struct
type t = int
let zero = 0
let one = 1
let add = ( + )
let mul = ( * )
let to_string = string_of_int
end)
(* Tests *)
let () =
assert (IntSum.sum [1; 2; 3; 4; 5] = 15);
assert (IntSum.sum_to_string [1; 2; 3] = "6");
assert (abs_float (FloatSum.sum [1.0; 2.0; 3.0] -. 6.0) < 0.001);
assert (IntRing.dot_product [1; 2; 3] [4; 5; 6] = 32);
Printf.printf "โ All tests passed\n"