// 466. Concurrent HashMap โ sharded locking
use std::collections::HashMap;
use std::hash::{Hash,Hasher,DefaultHasher};
use std::sync::{Arc,RwLock};
use std::thread;
pub struct ShardedMap<K,V> { shards: Vec<RwLock<HashMap<K,V>>>, n: usize }
impl<K:Hash+Eq+Clone,V:Clone> ShardedMap<K,V> {
pub fn new(n: usize) -> Self { ShardedMap { shards:(0..n).map(|_|RwLock::new(HashMap::new())).collect(), n } }
fn idx(&self, k: &K) -> usize { let mut h=DefaultHasher::new(); k.hash(&mut h); h.finish() as usize % self.n }
pub fn insert(&self, k: K, v: V) { self.shards[self.idx(&k)].write().unwrap().insert(k,v); }
pub fn get(&self, k: &K) -> Option<V> { self.shards[self.idx(k)].read().unwrap().get(k).cloned() }
pub fn len(&self) -> usize { self.shards.iter().map(|s| s.read().unwrap().len()).sum() }
}
fn main() {
let m: Arc<ShardedMap<String,u64>> = Arc::new(ShardedMap::new(16));
let hs: Vec<_> = (0..4).map(|id|{let m=Arc::clone(&m); thread::spawn(move || {
for i in 0..25u64 { m.insert(format!("k{}-{}",id,i), id*25+i); }
})}).collect();
for h in hs { h.join().unwrap(); }
println!("entries: {} (expected 100)", m.len());
println!("k2-10 = {:?}", m.get(&"k2-10".to_string()));
}
#[cfg(test)]
mod tests {
use super::*;
#[test] fn test_concurrent() {
let m=Arc::new(ShardedMap::<u32,u32>::new(8));
thread::scope(|s|{ for i in 0..100u32 { let m=Arc::clone(&m); s.spawn(move || m.insert(i,i*2)); } });
assert_eq!(m.len(),100);
}
#[test] fn test_get() { let m=ShardedMap::<String,i32>::new(4); m.insert("hi".to_string(),42); assert_eq!(m.get(&"hi".to_string()),Some(42)); }
}
(* 466. Sharded concurrent HashMap โ OCaml *)
let n = 16
let shards = Array.init n (fun _ -> (Hashtbl.create 8, Mutex.create ()))
let shard k = (Hashtbl.hash k) mod n
let insert k v = let (t,m)=shards.(shard k) in
Mutex.lock m; Hashtbl.replace t k v; Mutex.unlock m
let find k = let (t,m)=shards.(shard k) in
Mutex.lock m; let v=Hashtbl.find_opt t k in Mutex.unlock m; v
let len () = Array.fold_left (fun a (t,m) ->
Mutex.lock m; let l=Hashtbl.length t in Mutex.unlock m; a+l) 0 shards
let () =
let ts=Array.init 4 (fun id ->
Thread.create (fun () ->
for i=0 to 24 do insert (Printf.sprintf "k%d-%d" id i) (id*25+i) done) ()
) in
Array.iter Thread.join ts;
Printf.printf "total=%d\n" (len ());
Printf.printf "k2-10=%s\n" (match find "k2-10" with None->"?" | Some v->string_of_int v)