Merge pull request #7 from felixbauckholt/information-efficiency

Refactoring to be able to improve information efficiency
This commit is contained in:
Felix Bauckholt 2019-03-20 02:35:57 +01:00 committed by GitHub
commit 7f384cc15d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 829 additions and 752 deletions

7
Cargo.lock generated
View file

@ -3,6 +3,11 @@ name = "crossbeam"
version = "0.2.8" version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "float-ord"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]] [[package]]
name = "fnv" name = "fnv"
version = "1.0.6" version = "1.0.6"
@ -39,6 +44,7 @@ name = "rust_hanabi"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"crossbeam 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)", "crossbeam 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)",
"float-ord 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
"fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)", "fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)",
"getopts 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)", "getopts 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)",
@ -47,6 +53,7 @@ dependencies = [
[metadata] [metadata]
"checksum crossbeam 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)" = "348228ce9f93d20ffc30c18e575f82fa41b9c8bf064806c65d41eba4771595a0" "checksum crossbeam 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)" = "348228ce9f93d20ffc30c18e575f82fa41b9c8bf064806c65d41eba4771595a0"
"checksum float-ord 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7bad48618fdb549078c333a7a8528acb57af271d0433bdecd523eb620628364e"
"checksum fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3" "checksum fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3"
"checksum getopts 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)" = "d9047cfbd08a437050b363d35ef160452c5fe8ea5187ae0a624708c91581d685" "checksum getopts 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)" = "d9047cfbd08a437050b363d35ef160452c5fe8ea5187ae0a624708c91581d685"
"checksum libc 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)" = "4870ef6725dde13394134e587e4ab4eca13cb92e916209a31c851b49131d3c75" "checksum libc 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)" = "4870ef6725dde13394134e587e4ab4eca13cb92e916209a31c851b49131d3c75"

View file

@ -8,4 +8,5 @@ rand = "*"
log = "*" log = "*"
getopts = "*" getopts = "*"
fnv = "*" fnv = "*"
float-ord = "*"
crossbeam = "0.2.5" crossbeam = "0.2.5"

View file

@ -73,5 +73,5 @@ On the first 20000 seeds, we have these scores and win rates (average ± standar
|---------|------------------|------------------|------------------|------------------| |---------|------------------|------------------|------------------|------------------|
| cheat | 24.8594 ± 0.0036 | 24.9785 ± 0.0012 | 24.9720 ± 0.0014 | 24.9557 ± 0.0018 | | cheat | 24.8594 ± 0.0036 | 24.9785 ± 0.0012 | 24.9720 ± 0.0014 | 24.9557 ± 0.0018 |
| | 90.59 ± 0.21 % | 98.17 ± 0.09 % | 97.76 ± 0.10 % | 96.42 ± 0.13 % | | | 90.59 ± 0.21 % | 98.17 ± 0.09 % | 97.76 ± 0.10 % | 96.42 ± 0.13 % |
| info | 22.3249 ± 0.0128 | 24.7278 ± 0.0046 | 24.8919 ± 0.0029 | 24.8961 ± 0.0027 | | info | 22.5194 ± 0.0125 | 24.7942 ± 0.0039 | 24.9354 ± 0.0022 | 24.9220 ± 0.0024 |
| | 09.81 ± 0.21 % | 80.54 ± 0.28 % | 91.67 ± 0.20 % | 91.90 ± 0.19 % | | | 12.58 ± 0.23 % | 84.46 ± 0.26 % | 95.03 ± 0.15 % | 94.01 ± 0.17 % |

View file

@ -44,7 +44,7 @@ impl fmt::Debug for Card {
} }
} }
#[derive(Debug,Clone)] #[derive(Debug,Clone,Eq,PartialEq)]
pub struct CardCounts { pub struct CardCounts {
counts: FnvHashMap<Card, u32>, counts: FnvHashMap<Card, u32>,
} }
@ -99,7 +99,7 @@ impl fmt::Display for CardCounts {
pub type Cards = Vec<Card>; pub type Cards = Vec<Card>;
#[derive(Debug,Clone)] #[derive(Debug,Clone,Eq,PartialEq)]
pub struct Discard { pub struct Discard {
pub cards: Cards, pub cards: Cards,
counts: CardCounts, counts: CardCounts,
@ -137,7 +137,7 @@ impl fmt::Display for Discard {
pub type Score = u32; pub type Score = u32;
pub const PERFECT_SCORE: Score = (NUM_COLORS * NUM_VALUES) as u32; pub const PERFECT_SCORE: Score = (NUM_COLORS * NUM_VALUES) as u32;
#[derive(Debug,Clone)] #[derive(Debug,Clone,Eq,PartialEq)]
pub struct Firework { pub struct Firework {
pub color: Color, pub color: Color,
pub top: Value, pub top: Value,
@ -198,14 +198,14 @@ impl fmt::Display for Hinted {
} }
} }
#[derive(Debug,Clone)] #[derive(Debug,Clone,Eq,PartialEq)]
pub struct Hint { pub struct Hint {
pub player: Player, pub player: Player,
pub hinted: Hinted, pub hinted: Hinted,
} }
// represents the choice a player made in a given turn // represents the choice a player made in a given turn
#[derive(Debug,Clone)] #[derive(Debug,Clone,Eq,PartialEq)]
pub enum TurnChoice { pub enum TurnChoice {
Hint(Hint), Hint(Hint),
Discard(usize), // index of card to discard Discard(usize), // index of card to discard
@ -213,7 +213,7 @@ pub enum TurnChoice {
} }
// represents what happened in a turn // represents what happened in a turn
#[derive(Debug,Clone)] #[derive(Debug,Clone,Eq,PartialEq)]
pub enum TurnResult { pub enum TurnResult {
Hint(Vec<bool>), // vector of whether each was in the hint Hint(Vec<bool>), // vector of whether each was in the hint
Discard(Card), // card discarded Discard(Card), // card discarded
@ -221,7 +221,7 @@ pub enum TurnResult {
} }
// represents a turn taken in the game // represents a turn taken in the game
#[derive(Debug,Clone)] #[derive(Debug,Clone,Eq,PartialEq)]
pub struct TurnRecord { pub struct TurnRecord {
pub player: Player, pub player: Player,
pub choice: TurnChoice, pub choice: TurnChoice,
@ -243,7 +243,7 @@ pub struct GameOptions {
// State of everything except the player's hands // State of everything except the player's hands
// Is all completely common knowledge // Is all completely common knowledge
#[derive(Debug,Clone)] #[derive(Debug,Clone,Eq,PartialEq)]
pub struct BoardState { pub struct BoardState {
pub deck_size: u32, pub deck_size: u32,
pub total_cards: u32, pub total_cards: u32,

View file

@ -246,7 +246,7 @@ impl fmt::Display for SimpleCardInfo {
// Can represent information of the form: // Can represent information of the form:
// this card is/isn't possible // this card is/isn't possible
// also, maintains integer weights for the cards // also, maintains integer weights for the cards
#[derive(Clone,Debug)] #[derive(Clone,Debug,Eq,PartialEq)]
pub struct CardPossibilityTable { pub struct CardPossibilityTable {
possible: HashMap<Card, u32>, possible: HashMap<Card, u32>,
} }
@ -369,7 +369,7 @@ impl fmt::Display for CardPossibilityTable {
} }
} }
#[derive(Clone)] #[derive(Clone,Eq,PartialEq)]
pub struct HandInfo<T> where T: CardInfo { pub struct HandInfo<T> where T: CardInfo {
pub hand_info: Vec<T> pub hand_info: Vec<T>
} }

View file

@ -4,6 +4,7 @@ extern crate log;
extern crate rand; extern crate rand;
extern crate crossbeam; extern crate crossbeam;
extern crate fnv; extern crate fnv;
extern crate float_ord;
mod helpers; mod helpers;
mod game; mod game;
@ -12,6 +13,7 @@ mod strategy;
mod strategies { mod strategies {
pub mod examples; pub mod examples;
pub mod cheating; pub mod cheating;
mod hat_helpers;
pub mod information; pub mod information;
} }

View file

@ -0,0 +1,243 @@
use game::*;
use helpers::*;
#[derive(Debug,Clone)]
pub struct ModulusInformation {
pub modulus: u32,
pub value: u32,
}
impl ModulusInformation {
pub fn new(modulus: u32, value: u32) -> Self {
assert!(value < modulus);
ModulusInformation {
modulus: modulus,
value: value,
}
}
pub fn none() -> Self {
Self::new(1, 0)
}
pub fn combine(&mut self, other: Self, max_modulus: u32) {
assert!(other.modulus <= self.info_remaining(max_modulus));
self.value = self.value + self.modulus * other.value;
self.modulus = std::cmp::min(max_modulus, self.modulus * other.modulus);
assert!(self.value < self.modulus);
}
pub fn info_remaining(&self, max_modulus: u32) -> u32 {
// We want to find the largest number `result` such that
// `self.combine(other, max_modulus)` works whenever `other.modulus == result`.
// `other.value` can be up to `result - 1`, so calling combine could increase our value to
// up to `self.value + self.modulus * (result - 1)`, which must always be less than
// `max_modulus`.
// Therefore, we compute the largest number `result` such that
// `self.value + self.modulus * (result - 1) < max_modulus`.
let result = (max_modulus - self.value - 1) / self.modulus + 1;
assert!(self.value + self.modulus * (result - 1) < max_modulus);
assert!(self.value + self.modulus * ((result + 1) - 1) >= max_modulus);
result
}
pub fn split(&mut self, modulus: u32) -> Self {
assert!(self.modulus >= modulus);
let original_modulus = self.modulus;
let original_value = self.value;
let value = self.value % modulus;
self.value = self.value / modulus;
// `self.modulus` is the largest number such that
// `value + (self.modulus - 1) * modulus < original_modulus`.
// TODO: find an explanation of why this makes everything work out
self.modulus = (original_modulus - value - 1) / modulus + 1;
assert!(original_value == value + modulus * self.value);
Self::new(modulus, value)
}
pub fn cast_up(&mut self, modulus: u32) {
assert!(self.modulus <= modulus);
self.modulus = modulus;
}
// pub fn cast_down(&mut self, modulus: u32) {
// assert!(self.modulus >= modulus);
// assert!(self.value < modulus);
// self.modulus = modulus;
// }
pub fn add(&mut self, other: &Self) {
assert!(self.modulus == other.modulus);
self.value = (self.value + other.value) % self.modulus;
}
pub fn subtract(&mut self, other: &Self) {
assert!(self.modulus == other.modulus);
self.value = (self.modulus + self.value - other.value) % self.modulus;
}
}
pub trait Question {
// how much info does this question ask for?
fn info_amount(&self) -> u32;
// get the answer to this question, given cards
fn answer(&self, &Cards, &BoardState) -> u32;
// process the answer to this question, updating card info
fn acknowledge_answer(
&self, value: u32, &mut HandInfo<CardPossibilityTable>, &BoardState
);
fn answer_info(&self, hand: &Cards, board: &BoardState) -> ModulusInformation {
ModulusInformation::new(
self.info_amount(),
self.answer(hand, board)
)
}
fn acknowledge_answer_info(
&self,
answer: ModulusInformation,
hand_info: &mut HandInfo<CardPossibilityTable>,
board: &BoardState
) {
assert!(self.info_amount() == answer.modulus);
self.acknowledge_answer(answer.value, hand_info, board);
}
}
pub trait PublicInformation: Clone {
fn get_player_info(&self, &Player) -> HandInfo<CardPossibilityTable>;
fn set_player_info(&mut self, &Player, HandInfo<CardPossibilityTable>);
fn new(&BoardState) -> Self;
fn set_board(&mut self, &BoardState);
/// If we store more state than just `HandInfo<CardPossibilityTable>`s, update it after `set_player_info` has been called.
fn update_other_info(&mut self) {
}
fn agrees_with(&self, other: Self) -> bool;
/// By defining `ask_question`, we decides which `Question`s a player learns the answers to.
///
/// Whenever we need to compute a "hat value", this method will be called repeatedly, either
/// until the information runs out, or until it returns `None`. These questions can depend on
/// the answers to earlier questions: We are given a `&HandInfo<CardPossibilityTable>` that
/// reflect the answers of previous questions for the same "hat value computation".
///
/// Note that `self` does not reflect the answers to previous questions; it reflects the state
/// before the entire "hat value" calculation.
fn ask_question(&self, &Player, &HandInfo<CardPossibilityTable>, total_info: u32) -> Option<Box<Question>>;
fn ask_question_wrapper(&self, player: &Player, hand_info: &HandInfo<CardPossibilityTable>, total_info: u32)
-> Option<Box<Question>>
{
assert!(total_info > 0);
if total_info == 1 {
None
} else {
let result = self.ask_question(player, hand_info, total_info);
if let Some(ref question) = result {
if question.info_amount() > total_info {
panic!("ask_question returned question with info_amount = {} > total_info = {}!",
question.info_amount(), total_info);
}
if question.info_amount() == 1 {
panic!("ask_question returned a trivial question!");
}
}
result
}
}
fn set_player_infos(&mut self, infos: Vec<(Player, HandInfo<CardPossibilityTable>)>) {
for (player, new_hand_info) in infos {
self.set_player_info(&player, new_hand_info);
}
self.update_other_info();
}
fn get_hat_info_for_player(
&self, player: &Player, hand_info: &mut HandInfo<CardPossibilityTable>, total_info: u32, view: &OwnedGameView
) -> ModulusInformation {
assert!(player != &view.player);
let mut answer_info = ModulusInformation::none();
while let Some(question) = self.ask_question_wrapper(player, hand_info, answer_info.info_remaining(total_info)) {
let new_answer_info = question.answer_info(view.get_hand(player), view.get_board());
question.acknowledge_answer_info(new_answer_info.clone(), hand_info, view.get_board());
answer_info.combine(new_answer_info, total_info);
}
answer_info.cast_up(total_info);
answer_info
}
fn update_from_hat_info_for_player(
&self,
player: &Player,
hand_info: &mut HandInfo<CardPossibilityTable>,
board: &BoardState,
mut info: ModulusInformation,
) {
while let Some(question) = self.ask_question_wrapper(player, hand_info, info.modulus) {
let answer_info = info.split(question.info_amount());
question.acknowledge_answer_info(answer_info, hand_info, board);
}
assert!(info.value == 0);
}
/// When deciding on a move, if we can choose between `total_info` choices,
/// `self.get_hat_sum(total_info, view)` tells us which choice to take, and at the same time
/// mutates `self` to simulate the choice becoming common knowledge.
fn get_hat_sum(&mut self, total_info: u32, view: &OwnedGameView) -> ModulusInformation {
let (infos, new_player_hands): (Vec<_>, Vec<_>) = view.get_other_players().iter().map(|player| {
let mut hand_info = self.get_player_info(player);
let info = self.get_hat_info_for_player(player, &mut hand_info, total_info, view);
(info, (player.clone(), hand_info))
}).unzip();
self.set_player_infos(new_player_hands);
infos.into_iter().fold(
ModulusInformation::new(total_info, 0),
|mut sum_info, info| {
sum_info.add(&info);
sum_info
}
)
}
/// When updating on a move, if we infer that the player making the move called `get_hat_sum()`
/// and got the result `info`, we can call `self.update_from_hat_sum(info, view)` to update
/// from that fact.
fn update_from_hat_sum(&mut self, mut info: ModulusInformation, view: &OwnedGameView) {
let info_source = view.board.player;
let (other_infos, mut new_player_hands): (Vec<_>, Vec<_>) = view.get_other_players().into_iter().filter(|player| {
*player != info_source
}).map(|player| {
let mut hand_info = self.get_player_info(&player);
let player_info = self.get_hat_info_for_player(&player, &mut hand_info, info.modulus, view);
(player_info, (player.clone(), hand_info))
}).unzip();
for other_info in other_infos {
info.subtract(&other_info);
}
let me = view.player;
if me == info_source {
assert!(info.value == 0);
} else {
let mut my_hand = self.get_player_info(&me);
self.update_from_hat_info_for_player(&me, &mut my_hand, &view.board, info);
new_player_hands.push((me, my_hand));
}
self.set_player_infos(new_player_hands);
}
fn get_private_info(&self, view: &OwnedGameView) -> HandInfo<CardPossibilityTable> {
let mut info = self.get_player_info(&view.player);
for card_table in info.iter_mut() {
for (_, hand) in &view.other_hands {
for card in hand {
card_table.decrement_weight_if_possible(card);
}
}
}
info
}
}

File diff suppressed because it is too large Load diff