add function to enumerate possible next states upon specific action
This commit is contained in:
parent
6ae9244fa6
commit
a4e9560753
3 changed files with 71 additions and 23 deletions
32
game_state.h
32
game_state.h
|
@ -69,16 +69,18 @@ namespace Hanabi {
|
||||||
|
|
||||||
std::ostream &operator<<(std::ostream &os, const Card &card);
|
std::ostream &operator<<(std::ostream &os, const Card &card);
|
||||||
|
|
||||||
constexpr Card r0 = {0, 0};
|
constexpr Card r0 = {0, 5};
|
||||||
constexpr Card r1 = {0, 1};
|
constexpr Card r1 = {0, 4};
|
||||||
constexpr Card r2 = {0, 2};
|
constexpr Card r2 = {0, 3};
|
||||||
constexpr Card r3 = {0, 3};
|
constexpr Card r3 = {0, 2};
|
||||||
constexpr Card r4 = {0, 4};
|
constexpr Card r4 = {0, 1};
|
||||||
constexpr Card y0 = {1, 0};
|
constexpr Card r5 = {0, 0};
|
||||||
constexpr Card y1 = {1, 1};
|
constexpr Card y0 = {1, 5};
|
||||||
constexpr Card y2 = {1, 2};
|
constexpr Card y1 = {1, 4};
|
||||||
constexpr Card y3 = {1, 3};
|
constexpr Card y2 = {1, 3};
|
||||||
constexpr Card y4 = {1, 4};
|
constexpr Card y3 = {1, 2};
|
||||||
|
constexpr Card y4 = {1, 1};
|
||||||
|
constexpr Card y5 = {1, 0};
|
||||||
constexpr Card unknown_card = {0, 6};
|
constexpr Card unknown_card = {0, 6};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -163,6 +165,8 @@ public:
|
||||||
virtual void discard(hand_index_t index) = 0;
|
virtual void discard(hand_index_t index) = 0;
|
||||||
virtual void play(hand_index_t index) = 0;
|
virtual void play(hand_index_t index) = 0;
|
||||||
|
|
||||||
|
virtual void rotate_next_draw(const Card& card) = 0;
|
||||||
|
|
||||||
virtual void revert() = 0;
|
virtual void revert() = 0;
|
||||||
|
|
||||||
[[nodiscard]] virtual hand_index_t find_card_in_hand(const Card& card) const = 0;
|
[[nodiscard]] virtual hand_index_t find_card_in_hand(const Card& card) const = 0;
|
||||||
|
@ -178,6 +182,7 @@ public:
|
||||||
virtual probability_t evaluate_state() = 0;
|
virtual probability_t evaluate_state() = 0;
|
||||||
|
|
||||||
virtual std::vector<std::pair<Action, std::optional<probability_t>>> get_reasonable_actions() = 0;
|
virtual std::vector<std::pair<Action, std::optional<probability_t>>> get_reasonable_actions() = 0;
|
||||||
|
virtual std::vector<std::pair<CardMultiplicity,probability_t>> possible_next_states(hand_index_t index, bool play) = 0;
|
||||||
|
|
||||||
virtual ~HanabiStateIF() = default;
|
virtual ~HanabiStateIF() = default;
|
||||||
|
|
||||||
|
@ -197,6 +202,8 @@ public:
|
||||||
void discard(hand_index_t index) final;
|
void discard(hand_index_t index) final;
|
||||||
void play(hand_index_t index) final;
|
void play(hand_index_t index) final;
|
||||||
|
|
||||||
|
void rotate_next_draw(const Card& card) final;
|
||||||
|
|
||||||
void revert() final;
|
void revert() final;
|
||||||
|
|
||||||
[[nodiscard]] hand_index_t find_card_in_hand(const Card& card) const final;
|
[[nodiscard]] hand_index_t find_card_in_hand(const Card& card) const final;
|
||||||
|
@ -214,6 +221,7 @@ public:
|
||||||
std::optional<probability_t> lookup() const;
|
std::optional<probability_t> lookup() const;
|
||||||
|
|
||||||
std::vector<std::pair<Action, std::optional<probability_t>>> get_reasonable_actions() final;
|
std::vector<std::pair<Action, std::optional<probability_t>>> get_reasonable_actions() final;
|
||||||
|
std::vector<std::pair<CardMultiplicity,probability_t>> possible_next_states(hand_index_t index, bool play) final;
|
||||||
|
|
||||||
auto operator<=>(const HanabiState &) const = default;
|
auto operator<=>(const HanabiState &) const = default;
|
||||||
|
|
||||||
|
@ -276,8 +284,8 @@ private:
|
||||||
|
|
||||||
void update_tablebase(unsigned long id, probability_t probability);
|
void update_tablebase(unsigned long id, probability_t probability);
|
||||||
|
|
||||||
template<bool play, class Function>
|
template<class Function>
|
||||||
void do_for_each_potential_draw(hand_index_t index, Function f);
|
void do_for_each_potential_draw(hand_index_t index, bool play, Function f);
|
||||||
|
|
||||||
void incr_turn();
|
void incr_turn();
|
||||||
void decr_turn();
|
void decr_turn();
|
||||||
|
|
|
@ -469,6 +469,23 @@ namespace Hanabi {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
|
||||||
|
std::vector<std::pair<CardMultiplicity, probability_t>> HanabiState<num_suits, num_players, hand_size>::possible_next_states(hand_index_t index, bool play) {
|
||||||
|
std::vector<std::pair<CardMultiplicity, probability_t>> next_states;
|
||||||
|
do_for_each_potential_draw(index, play, [this, &next_states](unsigned long multiplicity){
|
||||||
|
auto prob = lookup();
|
||||||
|
ASSERT(lookup().has_value());
|
||||||
|
|
||||||
|
// bit hacky to get drawn card here
|
||||||
|
decr_turn();
|
||||||
|
const CardMultiplicity drawn_card = {_hands[_turn], multiplicity};
|
||||||
|
incr_turn();
|
||||||
|
|
||||||
|
next_states.emplace_back(drawn_card, prob.value());
|
||||||
|
});
|
||||||
|
return next_states;
|
||||||
|
}
|
||||||
|
|
||||||
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
|
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
|
||||||
std::vector<std::pair<Action, std::optional<probability_t>>> HanabiState<num_suits, num_players, hand_size>::get_reasonable_actions() {
|
std::vector<std::pair<Action, std::optional<probability_t>>> HanabiState<num_suits, num_players, hand_size>::get_reasonable_actions() {
|
||||||
std::vector<std::pair<Action, std::optional<probability_t>>> reasonable_actions {};
|
std::vector<std::pair<Action, std::optional<probability_t>>> reasonable_actions {};
|
||||||
|
@ -485,7 +502,7 @@ namespace Hanabi {
|
||||||
bool known = true;
|
bool known = true;
|
||||||
probability_t sum_of_probabilities = 0;
|
probability_t sum_of_probabilities = 0;
|
||||||
|
|
||||||
do_for_each_potential_draw<true>(index, [this, &sum_of_probabilities, &known](const unsigned long multiplicity){
|
do_for_each_potential_draw(index, true, [this, &sum_of_probabilities, &known](const unsigned long multiplicity){
|
||||||
const std::optional<probability_t> prob = lookup();
|
const std::optional<probability_t> prob = lookup();
|
||||||
if (prob.has_value()) {
|
if (prob.has_value()) {
|
||||||
sum_of_probabilities += prob.value() * multiplicity;
|
sum_of_probabilities += prob.value() * multiplicity;
|
||||||
|
@ -511,7 +528,7 @@ namespace Hanabi {
|
||||||
bool known = true;
|
bool known = true;
|
||||||
probability_t sum_of_probabilities = 0;
|
probability_t sum_of_probabilities = 0;
|
||||||
|
|
||||||
do_for_each_potential_draw<false>(index, [this, &sum_of_probabilities, &known](const unsigned long multiplicity){
|
do_for_each_potential_draw(index, false, [this, &sum_of_probabilities, &known](const unsigned long multiplicity){
|
||||||
const std::optional<probability_t> prob = lookup();
|
const std::optional<probability_t> prob = lookup();
|
||||||
if (prob.has_value()) {
|
if (prob.has_value()) {
|
||||||
sum_of_probabilities += prob.value() * multiplicity;
|
sum_of_probabilities += prob.value() * multiplicity;
|
||||||
|
@ -554,6 +571,15 @@ namespace Hanabi {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
|
||||||
|
void HanabiState<num_suits, num_players, hand_size>::rotate_next_draw(const Card& card) {
|
||||||
|
auto card_it = std::find_if(_draw_pile.begin(), _draw_pile.end(), [&card](const CardMultiplicity& card_multiplicity){
|
||||||
|
return card_multiplicity.card.rank == card.rank and card_multiplicity.card.suit == card.suit;
|
||||||
|
});
|
||||||
|
ASSERT(card_it != _draw_pile.end());
|
||||||
|
std::swap(*card_it, _draw_pile.front());
|
||||||
|
}
|
||||||
|
|
||||||
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
|
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
|
||||||
probability_t HanabiState<num_suits, num_players, hand_size>::evaluate_state() {
|
probability_t HanabiState<num_suits, num_players, hand_size>::evaluate_state() {
|
||||||
ASSERT(_relative_representation.initialized);
|
ASSERT(_relative_representation.initialized);
|
||||||
|
@ -580,7 +606,7 @@ namespace Hanabi {
|
||||||
if(is_playable(hand[index])) {
|
if(is_playable(hand[index])) {
|
||||||
probability_t sum_of_probabilities = 0;
|
probability_t sum_of_probabilities = 0;
|
||||||
|
|
||||||
do_for_each_potential_draw<true>(index, [this, &sum_of_probabilities](const unsigned long multiplicity){
|
do_for_each_potential_draw(index, true, [this, &sum_of_probabilities](const unsigned long multiplicity){
|
||||||
sum_of_probabilities += evaluate_state() * multiplicity;
|
sum_of_probabilities += evaluate_state() * multiplicity;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -601,7 +627,7 @@ namespace Hanabi {
|
||||||
if (is_trash(hand[index])) {
|
if (is_trash(hand[index])) {
|
||||||
probability_t sum_of_probabilities = 0;
|
probability_t sum_of_probabilities = 0;
|
||||||
|
|
||||||
do_for_each_potential_draw<false>(index, [this, &sum_of_probabilities](const unsigned long multiplicity){
|
do_for_each_potential_draw(index, false, [this, &sum_of_probabilities](const unsigned long multiplicity){
|
||||||
sum_of_probabilities += evaluate_state() * multiplicity;
|
sum_of_probabilities += evaluate_state() * multiplicity;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -638,18 +664,18 @@ namespace Hanabi {
|
||||||
}
|
}
|
||||||
|
|
||||||
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
|
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
|
||||||
template<bool play, class Function>
|
template<class Function>
|
||||||
void HanabiState<num_suits, num_players, hand_size>::do_for_each_potential_draw(hand_index_t index, Function f) {
|
void HanabiState<num_suits, num_players, hand_size>::do_for_each_potential_draw(hand_index_t index, bool play, Function f) {
|
||||||
auto do_action = [this, index](){
|
auto do_action = [this, index, play](){
|
||||||
if constexpr (play) {
|
if (play) {
|
||||||
return play_and_potentially_update(index);
|
return play_and_potentially_update(index);
|
||||||
} else {
|
} else {
|
||||||
return discard_and_potentially_update(index);
|
return discard_and_potentially_update(index);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
auto revert_action = [this](){
|
auto revert_action = [this, play](){
|
||||||
if constexpr (play) {
|
if (play) {
|
||||||
revert_play();
|
revert_play();
|
||||||
} else {
|
} else {
|
||||||
revert_discard();
|
revert_discard();
|
||||||
|
|
16
main.cpp
16
main.cpp
|
@ -23,6 +23,20 @@ namespace Hanabi {
|
||||||
std::cout << "Probability with optimal play: " << res << std::endl;
|
std::cout << "Probability with optimal play: " << res << std::endl;
|
||||||
std::cout << "Enumerated " << game->enumerated_states() << " states" << std::endl;
|
std::cout << "Enumerated " << game->enumerated_states() << " states" << std::endl;
|
||||||
std::cout << "Visited " << game->position_tablebase().size() << " unique game states. " << std::endl;
|
std::cout << "Visited " << game->position_tablebase().size() << " unique game states. " << std::endl;
|
||||||
|
|
||||||
|
game->rotate_next_draw(r1);
|
||||||
|
game->discard(0);
|
||||||
|
|
||||||
|
game->rotate_next_draw(r1);
|
||||||
|
game->play(game->find_card_in_hand(y3));
|
||||||
|
|
||||||
|
game->clue();
|
||||||
|
|
||||||
|
game->rotate_next_draw(r1);
|
||||||
|
game->play(game->find_card_in_hand(y4));
|
||||||
|
|
||||||
|
std::cout << *game << std::endl;
|
||||||
|
|
||||||
for (const auto &[action, probability] : game->get_reasonable_actions()) {
|
for (const auto &[action, probability] : game->get_reasonable_actions()) {
|
||||||
std::cout << action;
|
std::cout << action;
|
||||||
if(probability.has_value()) {
|
if(probability.has_value()) {
|
||||||
|
@ -117,7 +131,7 @@ void check_games(unsigned num_players, unsigned max_draw_pile_size, unsigned fir
|
||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
// test();
|
test();
|
||||||
#endif
|
#endif
|
||||||
if(argc == 3) {
|
if(argc == 3) {
|
||||||
std::string game(argv[1]);
|
std::string game(argv[1]);
|
||||||
|
|
Loading…
Reference in a new issue