add nthreads option, histogram
This commit is contained in:
parent
706f5b52b6
commit
e371e2f112
4 changed files with 112 additions and 35 deletions
6
Cargo.lock
generated
6
Cargo.lock
generated
|
@ -2,11 +2,17 @@
|
||||||
name = "rust_hanabi"
|
name = "rust_hanabi"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"crossbeam 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"getopts 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)",
|
"getopts 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"log 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
"log 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"rand 0.3.14 (registry+https://github.com/rust-lang/crates.io-index)",
|
"rand 0.3.14 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam"
|
||||||
|
version = "0.2.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "getopts"
|
name = "getopts"
|
||||||
version = "0.2.14"
|
version = "0.2.14"
|
||||||
|
|
|
@ -7,3 +7,4 @@ authors = ["Jeff Wu <wuthefwasthat@gmail.com>"]
|
||||||
rand = "*"
|
rand = "*"
|
||||||
log = "*"
|
log = "*"
|
||||||
getopts = "*"
|
getopts = "*"
|
||||||
|
crossbeam = "0.2.5"
|
||||||
|
|
17
src/main.rs
17
src/main.rs
|
@ -2,6 +2,7 @@ extern crate getopts;
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate log;
|
extern crate log;
|
||||||
extern crate rand;
|
extern crate rand;
|
||||||
|
extern crate crossbeam;
|
||||||
|
|
||||||
mod game;
|
mod game;
|
||||||
mod simulator;
|
mod simulator;
|
||||||
|
@ -40,6 +41,7 @@ fn main() {
|
||||||
let mut opts = Options::new();
|
let mut opts = Options::new();
|
||||||
opts.optopt("l", "loglevel", "Log level, one of 'trace', 'debug', 'info', 'warn', and 'error'", "LOGLEVEL");
|
opts.optopt("l", "loglevel", "Log level, one of 'trace', 'debug', 'info', 'warn', and 'error'", "LOGLEVEL");
|
||||||
opts.optopt("n", "ntrials", "Number of games to simulate", "NTRIALS");
|
opts.optopt("n", "ntrials", "Number of games to simulate", "NTRIALS");
|
||||||
|
opts.optopt("t", "nthreads", "Number of threads to use for simulation", "NTHREADS");
|
||||||
opts.optopt("s", "seed", "Seed for PRNG (can only be used with n=1)", "SEED");
|
opts.optopt("s", "seed", "Seed for PRNG (can only be used with n=1)", "SEED");
|
||||||
opts.optflag("h", "help", "Print this help menu");
|
opts.optflag("h", "help", "Print this help menu");
|
||||||
let matches = match opts.parse(&args[1..]) {
|
let matches = match opts.parse(&args[1..]) {
|
||||||
|
@ -75,7 +77,8 @@ fn main() {
|
||||||
|
|
||||||
let seed = matches.opt_str("s").map(|seed_str| { u32::from_str(&seed_str).unwrap() });
|
let seed = matches.opt_str("s").map(|seed_str| { u32::from_str(&seed_str).unwrap() });
|
||||||
|
|
||||||
// TODO: make these configurable
|
let n_threads = u32::from_str(&matches.opt_str("t").unwrap_or("1".to_string())).unwrap();
|
||||||
|
|
||||||
let opts = game::GameOptions {
|
let opts = game::GameOptions {
|
||||||
num_players: 5,
|
num_players: 5,
|
||||||
hand_size: 4,
|
hand_size: 4,
|
||||||
|
@ -84,10 +87,10 @@ fn main() {
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: make this configurable
|
// TODO: make this configurable
|
||||||
let strategy_config = strategies::examples::RandomStrategyConfig {
|
// let strategy_config = strategies::examples::RandomStrategyConfig {
|
||||||
hint_probability: 0.4,
|
// hint_probability: 0.4,
|
||||||
play_probability: 0.2,
|
// play_probability: 0.2,
|
||||||
};
|
// };
|
||||||
// let strategy_config = strategies::cheating::CheatingStrategyConfig::new();
|
let strategy_config = strategies::cheating::CheatingStrategyConfig::new();
|
||||||
simulator::simulate(&opts, &strategy_config, seed, n);
|
simulator::simulate(&opts, &strategy_config, seed, n, n_threads);
|
||||||
}
|
}
|
||||||
|
|
107
src/simulator.rs
107
src/simulator.rs
|
@ -1,6 +1,8 @@
|
||||||
use rand::{self, Rng};
|
use rand::{self, Rng};
|
||||||
use game::*;
|
use game::*;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::fmt;
|
||||||
|
use crossbeam;
|
||||||
|
|
||||||
// Traits to implement for any valid Hanabi strategy
|
// Traits to implement for any valid Hanabi strategy
|
||||||
|
|
||||||
|
@ -68,40 +70,105 @@ pub fn simulate_once(
|
||||||
score
|
score
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Histogram {
|
||||||
|
pub hist: HashMap<Score, usize>,
|
||||||
|
pub sum: Score,
|
||||||
|
pub total_count: usize,
|
||||||
|
}
|
||||||
|
impl Histogram {
|
||||||
|
pub fn new() -> Histogram {
|
||||||
|
Histogram {
|
||||||
|
hist: HashMap::new(),
|
||||||
|
sum: 0,
|
||||||
|
total_count: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn insert_many(&mut self, val: Score, count: usize) {
|
||||||
|
let new_count = self.get_count(&val) + count;
|
||||||
|
self.hist.insert(val, new_count);
|
||||||
|
self.sum += val * (count as u32);
|
||||||
|
self.total_count += count;
|
||||||
|
}
|
||||||
|
pub fn insert(&mut self, val: Score) {
|
||||||
|
self.insert_many(val, 1);
|
||||||
|
}
|
||||||
|
pub fn get_count(&self, val: &Score) -> usize {
|
||||||
|
*self.hist.get(&val).unwrap_or(&0)
|
||||||
|
}
|
||||||
|
pub fn average(&self) -> f32 {
|
||||||
|
(self.sum as f32) / (self.total_count as f32)
|
||||||
|
}
|
||||||
|
pub fn merge(&mut self, other: Histogram) {
|
||||||
|
for (val, count) in other.hist.iter() {
|
||||||
|
self.insert_many(*val, *count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl fmt::Display for Histogram {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
let mut keys = self.hist.keys().collect::<Vec<_>>();
|
||||||
|
keys.sort();
|
||||||
|
for val in keys {
|
||||||
|
try!(f.write_str(&format!(
|
||||||
|
"{}: {}\n", val, self.get_count(val),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: multithreaded
|
// TODO: multithreaded
|
||||||
pub fn simulate(
|
pub fn simulate<T>(
|
||||||
opts: &GameOptions,
|
opts: &GameOptions,
|
||||||
strat_config: &GameStrategyConfig,
|
strat_config: &T,
|
||||||
first_seed_opt: Option<u32>,
|
first_seed_opt: Option<u32>,
|
||||||
n_trials: u32,
|
n_trials: u32,
|
||||||
) -> f32 {
|
n_threads: u32,
|
||||||
|
) -> f32 where T: GameStrategyConfig + Sync {
|
||||||
let mut total_score = 0;
|
|
||||||
let mut non_perfect_seeds = Vec::new();
|
|
||||||
|
|
||||||
let first_seed = first_seed_opt.unwrap_or(rand::thread_rng().next_u32());
|
let first_seed = first_seed_opt.unwrap_or(rand::thread_rng().next_u32());
|
||||||
info!("Initial seed: {}\n", first_seed);
|
|
||||||
let mut histogram = HashMap::<Score, usize>::new();
|
|
||||||
|
|
||||||
for i in 0..n_trials {
|
crossbeam::scope(|scope| {
|
||||||
if (i > 0) && (i % 1000 == 0) {
|
let mut join_handles = Vec::new();
|
||||||
let average: f32 = (total_score as f32) / (i as f32);
|
for i in 0..n_threads {
|
||||||
info!("Trials: {}, Average so far: {}", i, average);
|
let start = first_seed + ((n_trials * i) / n_threads);
|
||||||
|
let end = first_seed + ((n_trials * (i+1)) / n_threads);
|
||||||
|
join_handles.push(scope.spawn(move || {
|
||||||
|
info!("Thread {} spawned: seeds {} to {}", i, start, end);
|
||||||
|
let mut non_perfect_seeds = Vec::new();
|
||||||
|
|
||||||
|
let mut histogram = Histogram::new();
|
||||||
|
|
||||||
|
for seed in start..end {
|
||||||
|
if (seed > start) && ((seed-start) % 1000 == 0) {
|
||||||
|
info!(
|
||||||
|
"Thread {}, Trials: {}, Average so far: {}",
|
||||||
|
i, seed-start, histogram.average()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
let seed = first_seed + i;
|
|
||||||
let score = simulate_once(&opts, strat_config.initialize(&opts), Some(seed));
|
let score = simulate_once(&opts, strat_config.initialize(&opts), Some(seed));
|
||||||
let count = histogram.get(&score).unwrap_or(&0) + 1;
|
histogram.insert(score);
|
||||||
histogram.insert(score, count);
|
if score != 25 { non_perfect_seeds.push((score, seed)); }
|
||||||
if score != 25 {
|
|
||||||
non_perfect_seeds.push((score, seed));
|
|
||||||
}
|
}
|
||||||
total_score += score;
|
info!("Thread {} done", i);
|
||||||
|
(non_perfect_seeds, histogram)
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut non_perfect_seeds : Vec<(Score,u32)> = Vec::new();
|
||||||
|
let mut histogram = Histogram::new();
|
||||||
|
for join_handle in join_handles {
|
||||||
|
let (thread_non_perfect_seeds, thread_histogram) = join_handle.join();
|
||||||
|
info!("Thread joined");
|
||||||
|
non_perfect_seeds.extend(thread_non_perfect_seeds.iter());
|
||||||
|
histogram.merge(thread_histogram);
|
||||||
}
|
}
|
||||||
|
|
||||||
non_perfect_seeds.sort();
|
non_perfect_seeds.sort();
|
||||||
info!("Score histogram: {:?}", histogram);
|
|
||||||
info!("Seeds with non-perfect score: {:?}", non_perfect_seeds);
|
info!("Seeds with non-perfect score: {:?}", non_perfect_seeds);
|
||||||
let average: f32 = (total_score as f32) / (n_trials as f32);
|
info!("Score histogram:\n{}", histogram);
|
||||||
|
let average = histogram.average();
|
||||||
info!("Average score: {:?}", average);
|
info!("Average score: {:?}", average);
|
||||||
average
|
average
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue