From a4e95607531598b9fbb405a687562c017e79ab73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20Ke=C3=9Fler?= Date: Fri, 11 Aug 2023 18:28:12 +0200 Subject: [PATCH] add function to enumerate possible next states upon specific action --- game_state.h | 32 ++++++++++++++++++++------------ game_state.hpp | 46 ++++++++++++++++++++++++++++++++++++---------- main.cpp | 16 +++++++++++++++- 3 files changed, 71 insertions(+), 23 deletions(-) diff --git a/game_state.h b/game_state.h index 6a38ddb..c0cbe3c 100644 --- a/game_state.h +++ b/game_state.h @@ -69,16 +69,18 @@ namespace Hanabi { std::ostream &operator<<(std::ostream &os, const Card &card); -constexpr Card r0 = {0, 0}; -constexpr Card r1 = {0, 1}; -constexpr Card r2 = {0, 2}; -constexpr Card r3 = {0, 3}; -constexpr Card r4 = {0, 4}; -constexpr Card y0 = {1, 0}; -constexpr Card y1 = {1, 1}; -constexpr Card y2 = {1, 2}; -constexpr Card y3 = {1, 3}; -constexpr Card y4 = {1, 4}; +constexpr Card r0 = {0, 5}; +constexpr Card r1 = {0, 4}; +constexpr Card r2 = {0, 3}; +constexpr Card r3 = {0, 2}; +constexpr Card r4 = {0, 1}; +constexpr Card r5 = {0, 0}; +constexpr Card y0 = {1, 5}; +constexpr Card y1 = {1, 4}; +constexpr Card y2 = {1, 3}; +constexpr Card y3 = {1, 2}; +constexpr Card y4 = {1, 1}; +constexpr Card y5 = {1, 0}; constexpr Card unknown_card = {0, 6}; /** @@ -163,6 +165,8 @@ public: virtual void discard(hand_index_t index) = 0; virtual void play(hand_index_t index) = 0; + virtual void rotate_next_draw(const Card& card) = 0; + virtual void revert() = 0; [[nodiscard]] virtual hand_index_t find_card_in_hand(const Card& card) const = 0; @@ -178,6 +182,7 @@ public: virtual probability_t evaluate_state() = 0; virtual std::vector>> get_reasonable_actions() = 0; + virtual std::vector> possible_next_states(hand_index_t index, bool play) = 0; virtual ~HanabiStateIF() = default; @@ -197,6 +202,8 @@ public: void discard(hand_index_t index) final; void play(hand_index_t index) final; + void rotate_next_draw(const Card& card) final; + void revert() final; [[nodiscard]] hand_index_t find_card_in_hand(const Card& card) const final; @@ -214,6 +221,7 @@ public: std::optional lookup() const; std::vector>> get_reasonable_actions() final; + std::vector> possible_next_states(hand_index_t index, bool play) final; auto operator<=>(const HanabiState &) const = default; @@ -276,8 +284,8 @@ private: void update_tablebase(unsigned long id, probability_t probability); - template - void do_for_each_potential_draw(hand_index_t index, Function f); + template + void do_for_each_potential_draw(hand_index_t index, bool play, Function f); void incr_turn(); void decr_turn(); diff --git a/game_state.hpp b/game_state.hpp index 24e8f1d..0cfa684 100644 --- a/game_state.hpp +++ b/game_state.hpp @@ -469,6 +469,23 @@ namespace Hanabi { } } + template + std::vector> HanabiState::possible_next_states(hand_index_t index, bool play) { + std::vector> next_states; + do_for_each_potential_draw(index, play, [this, &next_states](unsigned long multiplicity){ + auto prob = lookup(); + ASSERT(lookup().has_value()); + + // bit hacky to get drawn card here + decr_turn(); + const CardMultiplicity drawn_card = {_hands[_turn], multiplicity}; + incr_turn(); + + next_states.emplace_back(drawn_card, prob.value()); + }); + return next_states; + } + template std::vector>> HanabiState::get_reasonable_actions() { std::vector>> reasonable_actions {}; @@ -485,7 +502,7 @@ namespace Hanabi { bool known = true; probability_t sum_of_probabilities = 0; - do_for_each_potential_draw(index, [this, &sum_of_probabilities, &known](const unsigned long multiplicity){ + do_for_each_potential_draw(index, true, [this, &sum_of_probabilities, &known](const unsigned long multiplicity){ const std::optional prob = lookup(); if (prob.has_value()) { sum_of_probabilities += prob.value() * multiplicity; @@ -511,7 +528,7 @@ namespace Hanabi { bool known = true; probability_t sum_of_probabilities = 0; - do_for_each_potential_draw(index, [this, &sum_of_probabilities, &known](const unsigned long multiplicity){ + do_for_each_potential_draw(index, false, [this, &sum_of_probabilities, &known](const unsigned long multiplicity){ const std::optional prob = lookup(); if (prob.has_value()) { sum_of_probabilities += prob.value() * multiplicity; @@ -554,6 +571,15 @@ namespace Hanabi { } } + template + void HanabiState::rotate_next_draw(const Card& card) { + auto card_it = std::find_if(_draw_pile.begin(), _draw_pile.end(), [&card](const CardMultiplicity& card_multiplicity){ + return card_multiplicity.card.rank == card.rank and card_multiplicity.card.suit == card.suit; + }); + ASSERT(card_it != _draw_pile.end()); + std::swap(*card_it, _draw_pile.front()); + } + template probability_t HanabiState::evaluate_state() { ASSERT(_relative_representation.initialized); @@ -580,7 +606,7 @@ namespace Hanabi { if(is_playable(hand[index])) { probability_t sum_of_probabilities = 0; - do_for_each_potential_draw(index, [this, &sum_of_probabilities](const unsigned long multiplicity){ + do_for_each_potential_draw(index, true, [this, &sum_of_probabilities](const unsigned long multiplicity){ sum_of_probabilities += evaluate_state() * multiplicity; }); @@ -601,7 +627,7 @@ namespace Hanabi { if (is_trash(hand[index])) { probability_t sum_of_probabilities = 0; - do_for_each_potential_draw(index, [this, &sum_of_probabilities](const unsigned long multiplicity){ + do_for_each_potential_draw(index, false, [this, &sum_of_probabilities](const unsigned long multiplicity){ sum_of_probabilities += evaluate_state() * multiplicity; }); @@ -638,18 +664,18 @@ namespace Hanabi { } template - template - void HanabiState::do_for_each_potential_draw(hand_index_t index, Function f) { - auto do_action = [this, index](){ - if constexpr (play) { + template + void HanabiState::do_for_each_potential_draw(hand_index_t index, bool play, Function f) { + auto do_action = [this, index, play](){ + if (play) { return play_and_potentially_update(index); } else { return discard_and_potentially_update(index); } }; - auto revert_action = [this](){ - if constexpr (play) { + auto revert_action = [this, play](){ + if (play) { revert_play(); } else { revert_discard(); diff --git a/main.cpp b/main.cpp index e04d574..7c7d620 100644 --- a/main.cpp +++ b/main.cpp @@ -23,6 +23,20 @@ namespace Hanabi { std::cout << "Probability with optimal play: " << res << std::endl; std::cout << "Enumerated " << game->enumerated_states() << " states" << std::endl; std::cout << "Visited " << game->position_tablebase().size() << " unique game states. " << std::endl; + + game->rotate_next_draw(r1); + game->discard(0); + + game->rotate_next_draw(r1); + game->play(game->find_card_in_hand(y3)); + + game->clue(); + + game->rotate_next_draw(r1); + game->play(game->find_card_in_hand(y4)); + + std::cout << *game << std::endl; + for (const auto &[action, probability] : game->get_reasonable_actions()) { std::cout << action; if(probability.has_value()) { @@ -117,7 +131,7 @@ void check_games(unsigned num_players, unsigned max_draw_pile_size, unsigned fir int main(int argc, char *argv[]) { #ifndef NDEBUG -// test(); + test(); #endif if(argc == 3) { std::string game(argv[1]);