From 385ba6650bd5f0f5bf5cab8bd1bdbebb3038eca0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20Ke=C3=9Fler?= Date: Sun, 6 Aug 2023 22:06:58 +0200 Subject: [PATCH] add virtual interface for all hanabi states --- download.h | 107 ++++++++++++++++++++++++++++++++++++++++--------- game_state.h | 76 ++++++++++++++++++++++------------- game_state.hpp | 25 ++++++++---- main.cpp | 34 +++------------- 4 files changed, 160 insertions(+), 82 deletions(-) diff --git a/download.h b/download.h index 83412b2..668c314 100644 --- a/download.h +++ b/download.h @@ -69,7 +69,7 @@ namespace Download { return action; } - std::vector parse_deck(const boost::json::value &deck_json) { + std::pair, Hanabi::rank_t> parse_deck(const boost::json::value &deck_json) { auto deck = boost::json::value_to>(deck_json); for (auto &card: deck) { ASSERT(card.rank < 5); @@ -77,7 +77,11 @@ namespace Download { ASSERT(card.suit < 6); ASSERT(card.suit >= 0); } - return deck; + Hanabi::rank_t num_suits = 0; + for(const auto& card: deck) { + num_suits = std::max(num_suits, card.suit); + } + return {deck, num_suits + 1}; } std::vector parse_actions(const boost::json::value &action_json) { @@ -104,39 +108,39 @@ namespace Download { } template - Hanabi::HanabiState produce_state( + std::unique_ptr produce_state( const std::vector& deck, const std::vector& actions, size_t num_turns_to_replicate ) { - Hanabi::HanabiState game(deck); + auto game = std::unique_ptr(new Hanabi::HanabiState(deck)); std::uint8_t index; for (size_t i = 0; i < num_turns_to_replicate; i++) { switch(actions[i].type) { case Hanabi::ActionType::color_clue: case Hanabi::ActionType::rank_clue: - game.clue(); + game->clue(); break; case Hanabi::ActionType::discard: - index = game.find_card_in_hand(deck[actions[i].target]); + index = game->find_card_in_hand(deck[actions[i].target]); ASSERT(index != std::uint8_t(-1)); - game.discard(index); + game->discard(index); break; case Hanabi::ActionType::play: - index = game.find_card_in_hand(deck[actions[i].target]); + index = game->find_card_in_hand(deck[actions[i].target]); ASSERT(index != std::uint8_t(-1)); - game.play(index); + game->play(index); break; case Hanabi::ActionType::vote_terminate: case Hanabi::ActionType::end_game: return game; } } + game->normalize_draw_and_positions(); return game; } - template - Hanabi::HanabiState get_game(std::variant game_spec, unsigned turn) { + std::unique_ptr get_game(std::variant game_spec, unsigned turn) { const boost::json::object game_json = [&game_spec]() { if (game_spec.index() == 0) { return download_game_json(std::get(game_spec)); @@ -144,17 +148,84 @@ namespace Download { return open_game_json(std::get(game_spec)); } }(); - const std::vector deck = parse_deck(game_json.at("deck")); - const std::vector actions = parse_actions(game_json.at("actions")); - const size_t num_players_js = game_json.at("players").as_array().size(); - ASSERT (num_players_js == num_players); - auto game = produce_state(deck, actions, turn); - game.normalize_draw_and_positions(); - return game; + const auto [deck, num_suits] = parse_deck(game_json.at("deck")); + const std::vector actions = parse_actions(game_json.at("actions")); + const size_t num_players = game_json.at("players").as_array().size(); + + switch(num_players) { + case 2: + switch(num_suits) { + case 3: + return produce_state<3,2,5>(deck, actions, turn); + case 4: + return produce_state<4,2,5>(deck, actions, turn); + case 5: + return produce_state<5,2,5>(deck, actions, turn); + case 6: + return produce_state<6,2,5>(deck, actions, turn); + default: + throw std::runtime_error("Invalid number of suits: " + std::to_string(num_suits)); + } + case 3: + switch(num_suits) { + case 3: + return produce_state<3,3,5>(deck, actions, turn); + case 4: + return produce_state<4,3,5>(deck, actions, turn); + case 5: + return produce_state<5,3,5>(deck, actions, turn); + case 6: + return produce_state<6,3,5>(deck, actions, turn); + default: + throw std::runtime_error("Invalid number of suits: " + std::to_string(num_suits)); + } + case 4: + switch(num_suits) { + case 3: + return produce_state<3,4,4>(deck, actions, turn); + case 4: + return produce_state<4,4,4>(deck, actions, turn); + case 5: + return produce_state<5,4,4>(deck, actions, turn); + case 6: + return produce_state<6,4,4>(deck, actions, turn); + default: + throw std::runtime_error("Invalid number of suits: " + std::to_string(num_suits)); + } + case 5: + switch(num_suits) { + case 3: + return produce_state<3,5,4>(deck, actions, turn); + case 4: + return produce_state<4,5,4>(deck, actions, turn); + case 5: + return produce_state<5,5,4>(deck, actions, turn); + case 6: + return produce_state<6,5,4>(deck, actions, turn); + default: + throw std::runtime_error("Invalid number of suits: " + std::to_string(num_suits)); + } + case 6: + switch(num_suits) { + case 3: + return produce_state<3,6,3>(deck, actions, turn); + case 4: + return produce_state<4,6,3>(deck, actions, turn); + case 5: + return produce_state<5,6,3>(deck, actions, turn); + case 6: + return produce_state<6,6,3>(deck, actions, turn); + default: + throw std::runtime_error("Invalid number of suits: " + std::to_string(num_suits)); + } + default: + throw std::runtime_error("Invalid number of players: " + std::to_string(num_players)); + } } + } // namespacen Download #endif // DYNAMIC_PROGRAM_DOWNLOAD_H diff --git a/game_state.h b/game_state.h index 15c31e9..d390c25 100644 --- a/game_state.h +++ b/game_state.h @@ -139,61 +139,85 @@ struct BacktrackAction { std::uint8_t multiplicity{}; }; +class HanabiStateIF { +public: + virtual double backtrack(size_t depth) = 0; + + virtual void clue() = 0; + virtual BacktrackAction discard(std::uint8_t index) = 0; + virtual BacktrackAction play(std::uint8_t index) = 0; + + virtual void revert_clue() = 0; + virtual void revert_play(const BacktrackAction &action, bool was_on_8_clues) = 0; + virtual void revert_discard(const BacktrackAction &action) = 0; + + [[nodiscard]] virtual std::uint8_t find_card_in_hand(const Card& card) const = 0; + [[nodiscard]] virtual bool is_trash(const Card& card) const = 0; + [[nodiscard]] virtual bool is_playable(const Card& card) const = 0; + + [[nodiscard]] virtual std::uint64_t enumerated_states() const = 0; + + virtual void normalize_draw_and_positions() = 0; + + virtual ~HanabiStateIF() = default; + +protected: + virtual void print(std::ostream& os) const = 0; + + friend std::ostream& operator<<(std::ostream&, HanabiStateIF const&); +}; + template -class HanabiState { +class HanabiState : public HanabiStateIF { public: HanabiState() = default; explicit HanabiState(const std::vector& deck); - double backtrack(size_t depth); + double backtrack(size_t depth) final; - void clue(); + void clue() final; + BacktrackAction play(std::uint8_t index) final; + BacktrackAction discard(std::uint8_t index) final; - /** - * Plays a card from current hand, drawing top card of draw pile and rotating draw pile - * @param index of card in hand to be played - */ - BacktrackAction play(std::uint8_t index); + void revert_clue() final; + void revert_play(const BacktrackAction &action, bool was_on_8_clues) final; + void revert_discard(const BacktrackAction &action) final; - BacktrackAction discard(std::uint8_t index); + [[nodiscard]] std::uint8_t find_card_in_hand(const Card& card) const final; + [[nodiscard]] bool is_trash(const Card& card) const final; + [[nodiscard]] bool is_playable(const Card& card) const final; - std::uint8_t find_card_in_hand(const Card& card) const; + [[nodiscard]] std::uint64_t enumerated_states() const final; - void normalize_draw_and_positions(); + void normalize_draw_and_positions() final; - void revert_clue(); - void revert_play(const BacktrackAction &action, bool was_on_8_clues); - void revert_discard(const BacktrackAction &action); + auto operator<=>(const HanabiState &) const = default; +protected: + void print(std::ostream& os) const final; + +private: uint8_t draw(uint8_t index); - void revert_draw(std::uint8_t index, Card discarded_card); void incr_turn(); - void decr_turn(); - bool is_trash(const Card& card) const; - bool is_playable(const Card& card) const; - player_t _turn{}; clue_t _num_clues{}; std::uint8_t _weighted_draw_pile_size{}; Stacks _stacks{}; std::array, num_players> _hands{}; -// CardArray _card_positions{}; std::list _draw_pile{}; - std::uint8_t _endgame_turns_left; + std::uint8_t _endgame_turns_left{}; - static constexpr uint8_t no_endgame = std::numeric_limits::max() - 1; + static constexpr uint8_t no_endgame = std::numeric_limits::max(); // further statistics that we might want to keep track of uint8_t _pace{}; uint8_t _score{}; std::uint64_t _enumerated_states {}; - - auto operator<=>(const HanabiState &) const = default; }; template @@ -207,10 +231,6 @@ bool same_up_to_discard_permutation(HanabiState -std::ostream & operator<<(std::ostream &os, HanabiState hanabi_state); - -template class HanabiState<5, 3, 4>; } diff --git a/game_state.hpp b/game_state.hpp index 1f52c37..326a03d 100644 --- a/game_state.hpp +++ b/game_state.hpp @@ -1,9 +1,16 @@ #include #include #include "myassert.h" +#include "game_state.h" +#include namespace Hanabi { + std::ostream& operator<<(std::ostream& os, HanabiStateIF const& hanabi_state) { + hanabi_state.print(os); + return os; + } + Card &Card::operator++() { rank++; return *this; @@ -111,6 +118,11 @@ namespace Hanabi { return card.rank == _stacks[card.suit] - 1; } + template + std::uint64_t HanabiState::enumerated_states() const { + return _enumerated_states; + } + template bool HanabiState::is_trash(const Hanabi::Card &card) const { return card.rank >= _stacks[card.suit]; @@ -167,26 +179,25 @@ namespace Hanabi { } template - std::ostream &operator<<(std::ostream &os, const HanabiState hanabi_state) { - os << "Stacks: " << hanabi_state._stacks << " (score " << +hanabi_state._score << ")"; - os << ", clues: " << +hanabi_state._num_clues << ", turn: " << +hanabi_state._turn << std::endl; + void HanabiState::print(std::ostream &os) const { + os << "Stacks: " << _stacks << " (score " << +_score << ")"; + os << ", clues: " << +_num_clues << ", turn: " << +_turn << std::endl; os << "Draw pile: "; - for (const auto &[card, mul]: hanabi_state._draw_pile) { + for (const auto &[card, mul]: _draw_pile) { os << card; if (mul > 1) { os << " (" << +mul << ")"; } os << ", "; } - os << "(size " << +hanabi_state._weighted_draw_pile_size << ")" << std::endl; + os << "(size " << +_weighted_draw_pile_size << ")" << std::endl; os << "Hands: "; - for (const auto &hand: hanabi_state._hands) { + for (const auto &hand: _hands) { for (const auto &card: hand) { os << card << ", "; } os << " | "; } - return os; } template diff --git a/main.cpp b/main.cpp index fd143ca..5fe8214 100644 --- a/main.cpp +++ b/main.cpp @@ -15,37 +15,13 @@ namespace Hanabi { -void test_game() { - HanabiState<2, 2, 5> state; - state._stacks[0] = 2; - state._stacks[1] = 3; - Card r41 = {0, 4, 1}; - state._draw_pile.push_back({r41, 1}); - state._hands[0] = {y0, y1, y2, r0, r1}; - state._hands[1] = {r1, r1, y1, r3, r2}; -// state._card_positions[r1] = 0; - state._weighted_draw_pile_size = 1; - - auto state2 = state; - - auto a = state.play(4); - std::cout << state; - state.revert_play(a, false); - - std::cout << state << std::endl; - std::cout << state2 << std::endl; - ASSERT(state._hands == state2._hands); - ASSERT(state._draw_pile == state2._draw_pile); -// ASSERT(state._card_positions == state2._card_positions); - ASSERT(state == state2); -} - void download(int turn) { - auto game = Download::get_game<6,5,4>("996518", turn); - std::cout << "Analysing state: " << game << std::endl; - auto res = game.backtrack(1); + auto game = Download::get_game("1004116", turn); + std::cout << "Analysing state: " << *game << std::endl; + auto res = game->backtrack(1); + std::cout.precision(10); std::cout << "Probability with optimal play: " << res << std::endl; - std::cout << "Enumerated " << game._enumerated_states << " states" << std::endl; + std::cout << "Enumerated " << game->enumerated_states() << " states" << std::endl; } void print_sizes() {