diff --git a/Cargo.lock b/Cargo.lock index a94108f..5c6bcd6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,11 +2,17 @@ name = "rust_hanabi" version = "0.1.0" dependencies = [ + "crossbeam 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)", "getopts 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", "rand 0.3.14 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "crossbeam" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "getopts" version = "0.2.14" diff --git a/Cargo.toml b/Cargo.toml index 2768528..3331f4c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,3 +7,4 @@ authors = ["Jeff Wu "] rand = "*" log = "*" getopts = "*" +crossbeam = "0.2.5" diff --git a/src/main.rs b/src/main.rs index a99a20d..678f32c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ extern crate getopts; #[macro_use] extern crate log; extern crate rand; +extern crate crossbeam; mod game; mod simulator; @@ -40,6 +41,7 @@ fn main() { let mut opts = Options::new(); opts.optopt("l", "loglevel", "Log level, one of 'trace', 'debug', 'info', 'warn', and 'error'", "LOGLEVEL"); opts.optopt("n", "ntrials", "Number of games to simulate", "NTRIALS"); + opts.optopt("t", "nthreads", "Number of threads to use for simulation", "NTHREADS"); opts.optopt("s", "seed", "Seed for PRNG (can only be used with n=1)", "SEED"); opts.optflag("h", "help", "Print this help menu"); let matches = match opts.parse(&args[1..]) { @@ -75,7 +77,8 @@ fn main() { let seed = matches.opt_str("s").map(|seed_str| { u32::from_str(&seed_str).unwrap() }); - // TODO: make these configurable + let n_threads = u32::from_str(&matches.opt_str("t").unwrap_or("1".to_string())).unwrap(); + let opts = game::GameOptions { num_players: 5, hand_size: 4, @@ -84,10 +87,10 @@ fn main() { }; // TODO: make this configurable - let strategy_config = strategies::examples::RandomStrategyConfig { - hint_probability: 0.4, - play_probability: 0.2, - }; - // let strategy_config = strategies::cheating::CheatingStrategyConfig::new(); - simulator::simulate(&opts, &strategy_config, seed, n); + // let strategy_config = strategies::examples::RandomStrategyConfig { + // hint_probability: 0.4, + // play_probability: 0.2, + // }; + let strategy_config = strategies::cheating::CheatingStrategyConfig::new(); + simulator::simulate(&opts, &strategy_config, seed, n, n_threads); } diff --git a/src/simulator.rs b/src/simulator.rs index 89862dd..c45982f 100644 --- a/src/simulator.rs +++ b/src/simulator.rs @@ -1,6 +1,8 @@ use rand::{self, Rng}; use game::*; use std::collections::HashMap; +use std::fmt; +use crossbeam; // Traits to implement for any valid Hanabi strategy @@ -68,40 +70,105 @@ pub fn simulate_once( score } +struct Histogram { + pub hist: HashMap, + pub sum: Score, + pub total_count: usize, +} +impl Histogram { + pub fn new() -> Histogram { + Histogram { + hist: HashMap::new(), + sum: 0, + total_count: 0, + } + } + fn insert_many(&mut self, val: Score, count: usize) { + let new_count = self.get_count(&val) + count; + self.hist.insert(val, new_count); + self.sum += val * (count as u32); + self.total_count += count; + } + pub fn insert(&mut self, val: Score) { + self.insert_many(val, 1); + } + pub fn get_count(&self, val: &Score) -> usize { + *self.hist.get(&val).unwrap_or(&0) + } + pub fn average(&self) -> f32 { + (self.sum as f32) / (self.total_count as f32) + } + pub fn merge(&mut self, other: Histogram) { + for (val, count) in other.hist.iter() { + self.insert_many(*val, *count); + } + } +} +impl fmt::Display for Histogram { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut keys = self.hist.keys().collect::>(); + keys.sort(); + for val in keys { + try!(f.write_str(&format!( + "{}: {}\n", val, self.get_count(val), + ))); + } + Ok(()) + } +} + // TODO: multithreaded -pub fn simulate( +pub fn simulate( opts: &GameOptions, - strat_config: &GameStrategyConfig, + strat_config: &T, first_seed_opt: Option, n_trials: u32, - ) -> f32 { - - let mut total_score = 0; - let mut non_perfect_seeds = Vec::new(); + n_threads: u32, + ) -> f32 where T: GameStrategyConfig + Sync { let first_seed = first_seed_opt.unwrap_or(rand::thread_rng().next_u32()); - info!("Initial seed: {}\n", first_seed); - let mut histogram = HashMap::::new(); - for i in 0..n_trials { - if (i > 0) && (i % 1000 == 0) { - let average: f32 = (total_score as f32) / (i as f32); - info!("Trials: {}, Average so far: {}", i, average); - } - let seed = first_seed + i; - let score = simulate_once(&opts, strat_config.initialize(&opts), Some(seed)); - let count = histogram.get(&score).unwrap_or(&0) + 1; - histogram.insert(score, count); - if score != 25 { - non_perfect_seeds.push((score, seed)); - } - total_score += score; - } + crossbeam::scope(|scope| { + let mut join_handles = Vec::new(); + for i in 0..n_threads { + let start = first_seed + ((n_trials * i) / n_threads); + let end = first_seed + ((n_trials * (i+1)) / n_threads); + join_handles.push(scope.spawn(move || { + info!("Thread {} spawned: seeds {} to {}", i, start, end); + let mut non_perfect_seeds = Vec::new(); - non_perfect_seeds.sort(); - info!("Score histogram: {:?}", histogram); - info!("Seeds with non-perfect score: {:?}", non_perfect_seeds); - let average: f32 = (total_score as f32) / (n_trials as f32); - info!("Average score: {:?}", average); - average + let mut histogram = Histogram::new(); + + for seed in start..end { + if (seed > start) && ((seed-start) % 1000 == 0) { + info!( + "Thread {}, Trials: {}, Average so far: {}", + i, seed-start, histogram.average() + ); + } + let score = simulate_once(&opts, strat_config.initialize(&opts), Some(seed)); + histogram.insert(score); + if score != 25 { non_perfect_seeds.push((score, seed)); } + } + info!("Thread {} done", i); + (non_perfect_seeds, histogram) + })); + } + + let mut non_perfect_seeds : Vec<(Score,u32)> = Vec::new(); + let mut histogram = Histogram::new(); + for join_handle in join_handles { + let (thread_non_perfect_seeds, thread_histogram) = join_handle.join(); + info!("Thread joined"); + non_perfect_seeds.extend(thread_non_perfect_seeds.iter()); + histogram.merge(thread_histogram); + } + + non_perfect_seeds.sort(); + info!("Seeds with non-perfect score: {:?}", non_perfect_seeds); + info!("Score histogram:\n{}", histogram); + let average = histogram.average(); + info!("Average score: {:?}", average); + average + }) }