add nthreads option, histogram
This commit is contained in:
parent
706f5b52b6
commit
e371e2f112
4 changed files with 112 additions and 35 deletions
6
Cargo.lock
generated
6
Cargo.lock
generated
|
@ -2,11 +2,17 @@
|
|||
name = "rust_hanabi"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"crossbeam 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"getopts 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"log 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand 0.3.14 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam"
|
||||
version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "getopts"
|
||||
version = "0.2.14"
|
||||
|
|
|
@ -7,3 +7,4 @@ authors = ["Jeff Wu <wuthefwasthat@gmail.com>"]
|
|||
rand = "*"
|
||||
log = "*"
|
||||
getopts = "*"
|
||||
crossbeam = "0.2.5"
|
||||
|
|
17
src/main.rs
17
src/main.rs
|
@ -2,6 +2,7 @@ extern crate getopts;
|
|||
#[macro_use]
|
||||
extern crate log;
|
||||
extern crate rand;
|
||||
extern crate crossbeam;
|
||||
|
||||
mod game;
|
||||
mod simulator;
|
||||
|
@ -40,6 +41,7 @@ fn main() {
|
|||
let mut opts = Options::new();
|
||||
opts.optopt("l", "loglevel", "Log level, one of 'trace', 'debug', 'info', 'warn', and 'error'", "LOGLEVEL");
|
||||
opts.optopt("n", "ntrials", "Number of games to simulate", "NTRIALS");
|
||||
opts.optopt("t", "nthreads", "Number of threads to use for simulation", "NTHREADS");
|
||||
opts.optopt("s", "seed", "Seed for PRNG (can only be used with n=1)", "SEED");
|
||||
opts.optflag("h", "help", "Print this help menu");
|
||||
let matches = match opts.parse(&args[1..]) {
|
||||
|
@ -75,7 +77,8 @@ fn main() {
|
|||
|
||||
let seed = matches.opt_str("s").map(|seed_str| { u32::from_str(&seed_str).unwrap() });
|
||||
|
||||
// TODO: make these configurable
|
||||
let n_threads = u32::from_str(&matches.opt_str("t").unwrap_or("1".to_string())).unwrap();
|
||||
|
||||
let opts = game::GameOptions {
|
||||
num_players: 5,
|
||||
hand_size: 4,
|
||||
|
@ -84,10 +87,10 @@ fn main() {
|
|||
};
|
||||
|
||||
// TODO: make this configurable
|
||||
let strategy_config = strategies::examples::RandomStrategyConfig {
|
||||
hint_probability: 0.4,
|
||||
play_probability: 0.2,
|
||||
};
|
||||
// let strategy_config = strategies::cheating::CheatingStrategyConfig::new();
|
||||
simulator::simulate(&opts, &strategy_config, seed, n);
|
||||
// let strategy_config = strategies::examples::RandomStrategyConfig {
|
||||
// hint_probability: 0.4,
|
||||
// play_probability: 0.2,
|
||||
// };
|
||||
let strategy_config = strategies::cheating::CheatingStrategyConfig::new();
|
||||
simulator::simulate(&opts, &strategy_config, seed, n, n_threads);
|
||||
}
|
||||
|
|
123
src/simulator.rs
123
src/simulator.rs
|
@ -1,6 +1,8 @@
|
|||
use rand::{self, Rng};
|
||||
use game::*;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use crossbeam;
|
||||
|
||||
// Traits to implement for any valid Hanabi strategy
|
||||
|
||||
|
@ -68,40 +70,105 @@ pub fn simulate_once(
|
|||
score
|
||||
}
|
||||
|
||||
struct Histogram {
|
||||
pub hist: HashMap<Score, usize>,
|
||||
pub sum: Score,
|
||||
pub total_count: usize,
|
||||
}
|
||||
impl Histogram {
|
||||
pub fn new() -> Histogram {
|
||||
Histogram {
|
||||
hist: HashMap::new(),
|
||||
sum: 0,
|
||||
total_count: 0,
|
||||
}
|
||||
}
|
||||
fn insert_many(&mut self, val: Score, count: usize) {
|
||||
let new_count = self.get_count(&val) + count;
|
||||
self.hist.insert(val, new_count);
|
||||
self.sum += val * (count as u32);
|
||||
self.total_count += count;
|
||||
}
|
||||
pub fn insert(&mut self, val: Score) {
|
||||
self.insert_many(val, 1);
|
||||
}
|
||||
pub fn get_count(&self, val: &Score) -> usize {
|
||||
*self.hist.get(&val).unwrap_or(&0)
|
||||
}
|
||||
pub fn average(&self) -> f32 {
|
||||
(self.sum as f32) / (self.total_count as f32)
|
||||
}
|
||||
pub fn merge(&mut self, other: Histogram) {
|
||||
for (val, count) in other.hist.iter() {
|
||||
self.insert_many(*val, *count);
|
||||
}
|
||||
}
|
||||
}
|
||||
impl fmt::Display for Histogram {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let mut keys = self.hist.keys().collect::<Vec<_>>();
|
||||
keys.sort();
|
||||
for val in keys {
|
||||
try!(f.write_str(&format!(
|
||||
"{}: {}\n", val, self.get_count(val),
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: multithreaded
|
||||
pub fn simulate(
|
||||
pub fn simulate<T>(
|
||||
opts: &GameOptions,
|
||||
strat_config: &GameStrategyConfig,
|
||||
strat_config: &T,
|
||||
first_seed_opt: Option<u32>,
|
||||
n_trials: u32,
|
||||
) -> f32 {
|
||||
|
||||
let mut total_score = 0;
|
||||
let mut non_perfect_seeds = Vec::new();
|
||||
n_threads: u32,
|
||||
) -> f32 where T: GameStrategyConfig + Sync {
|
||||
|
||||
let first_seed = first_seed_opt.unwrap_or(rand::thread_rng().next_u32());
|
||||
info!("Initial seed: {}\n", first_seed);
|
||||
let mut histogram = HashMap::<Score, usize>::new();
|
||||
|
||||
for i in 0..n_trials {
|
||||
if (i > 0) && (i % 1000 == 0) {
|
||||
let average: f32 = (total_score as f32) / (i as f32);
|
||||
info!("Trials: {}, Average so far: {}", i, average);
|
||||
}
|
||||
let seed = first_seed + i;
|
||||
let score = simulate_once(&opts, strat_config.initialize(&opts), Some(seed));
|
||||
let count = histogram.get(&score).unwrap_or(&0) + 1;
|
||||
histogram.insert(score, count);
|
||||
if score != 25 {
|
||||
non_perfect_seeds.push((score, seed));
|
||||
}
|
||||
total_score += score;
|
||||
}
|
||||
crossbeam::scope(|scope| {
|
||||
let mut join_handles = Vec::new();
|
||||
for i in 0..n_threads {
|
||||
let start = first_seed + ((n_trials * i) / n_threads);
|
||||
let end = first_seed + ((n_trials * (i+1)) / n_threads);
|
||||
join_handles.push(scope.spawn(move || {
|
||||
info!("Thread {} spawned: seeds {} to {}", i, start, end);
|
||||
let mut non_perfect_seeds = Vec::new();
|
||||
|
||||
non_perfect_seeds.sort();
|
||||
info!("Score histogram: {:?}", histogram);
|
||||
info!("Seeds with non-perfect score: {:?}", non_perfect_seeds);
|
||||
let average: f32 = (total_score as f32) / (n_trials as f32);
|
||||
info!("Average score: {:?}", average);
|
||||
average
|
||||
let mut histogram = Histogram::new();
|
||||
|
||||
for seed in start..end {
|
||||
if (seed > start) && ((seed-start) % 1000 == 0) {
|
||||
info!(
|
||||
"Thread {}, Trials: {}, Average so far: {}",
|
||||
i, seed-start, histogram.average()
|
||||
);
|
||||
}
|
||||
let score = simulate_once(&opts, strat_config.initialize(&opts), Some(seed));
|
||||
histogram.insert(score);
|
||||
if score != 25 { non_perfect_seeds.push((score, seed)); }
|
||||
}
|
||||
info!("Thread {} done", i);
|
||||
(non_perfect_seeds, histogram)
|
||||
}));
|
||||
}
|
||||
|
||||
let mut non_perfect_seeds : Vec<(Score,u32)> = Vec::new();
|
||||
let mut histogram = Histogram::new();
|
||||
for join_handle in join_handles {
|
||||
let (thread_non_perfect_seeds, thread_histogram) = join_handle.join();
|
||||
info!("Thread joined");
|
||||
non_perfect_seeds.extend(thread_non_perfect_seeds.iter());
|
||||
histogram.merge(thread_histogram);
|
||||
}
|
||||
|
||||
non_perfect_seeds.sort();
|
||||
info!("Seeds with non-perfect score: {:?}", non_perfect_seeds);
|
||||
info!("Score histogram:\n{}", histogram);
|
||||
let average = histogram.average();
|
||||
info!("Average score: {:?}", average);
|
||||
average
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue