add virtual interface for all hanabi states

This commit is contained in:
Maximilian Keßler 2023-08-06 22:06:58 +02:00
parent 87ff267f80
commit 385ba6650b
Signed by: max
GPG key ID: BCC5A619923C0BA5
4 changed files with 160 additions and 82 deletions

View file

@ -69,7 +69,7 @@ namespace Download {
return action; return action;
} }
std::vector<Hanabi::Card> parse_deck(const boost::json::value &deck_json) { std::pair<std::vector<Hanabi::Card>, Hanabi::rank_t> parse_deck(const boost::json::value &deck_json) {
auto deck = boost::json::value_to<std::vector<Hanabi::Card>>(deck_json); auto deck = boost::json::value_to<std::vector<Hanabi::Card>>(deck_json);
for (auto &card: deck) { for (auto &card: deck) {
ASSERT(card.rank < 5); ASSERT(card.rank < 5);
@ -77,7 +77,11 @@ namespace Download {
ASSERT(card.suit < 6); ASSERT(card.suit < 6);
ASSERT(card.suit >= 0); ASSERT(card.suit >= 0);
} }
return deck; Hanabi::rank_t num_suits = 0;
for(const auto& card: deck) {
num_suits = std::max(num_suits, card.suit);
}
return {deck, num_suits + 1};
} }
std::vector<Action> parse_actions(const boost::json::value &action_json) { std::vector<Action> parse_actions(const boost::json::value &action_json) {
@ -104,39 +108,39 @@ namespace Download {
} }
template<std::size_t num_suits, Hanabi::player_t num_players, std::size_t hand_size> template<std::size_t num_suits, Hanabi::player_t num_players, std::size_t hand_size>
Hanabi::HanabiState<num_suits, num_players, hand_size> produce_state( std::unique_ptr<Hanabi::HanabiStateIF> produce_state(
const std::vector<Hanabi::Card>& deck, const std::vector<Hanabi::Card>& deck,
const std::vector<Action>& actions, const std::vector<Action>& actions,
size_t num_turns_to_replicate size_t num_turns_to_replicate
) { ) {
Hanabi::HanabiState<num_suits, num_players, hand_size> game(deck); auto game = std::unique_ptr<Hanabi::HanabiStateIF>(new Hanabi::HanabiState<num_suits, num_players, hand_size>(deck));
std::uint8_t index; std::uint8_t index;
for (size_t i = 0; i < num_turns_to_replicate; i++) { for (size_t i = 0; i < num_turns_to_replicate; i++) {
switch(actions[i].type) { switch(actions[i].type) {
case Hanabi::ActionType::color_clue: case Hanabi::ActionType::color_clue:
case Hanabi::ActionType::rank_clue: case Hanabi::ActionType::rank_clue:
game.clue(); game->clue();
break; break;
case Hanabi::ActionType::discard: case Hanabi::ActionType::discard:
index = game.find_card_in_hand(deck[actions[i].target]); index = game->find_card_in_hand(deck[actions[i].target]);
ASSERT(index != std::uint8_t(-1)); ASSERT(index != std::uint8_t(-1));
game.discard(index); game->discard(index);
break; break;
case Hanabi::ActionType::play: case Hanabi::ActionType::play:
index = game.find_card_in_hand(deck[actions[i].target]); index = game->find_card_in_hand(deck[actions[i].target]);
ASSERT(index != std::uint8_t(-1)); ASSERT(index != std::uint8_t(-1));
game.play(index); game->play(index);
break; break;
case Hanabi::ActionType::vote_terminate: case Hanabi::ActionType::vote_terminate:
case Hanabi::ActionType::end_game: case Hanabi::ActionType::end_game:
return game; return game;
} }
} }
game->normalize_draw_and_positions();
return game; return game;
} }
template <std::size_t num_suits, Hanabi::player_t num_players, std::size_t hand_size> std::unique_ptr<Hanabi::HanabiStateIF> get_game(std::variant<int, const char*> game_spec, unsigned turn) {
Hanabi::HanabiState<num_suits, num_players, hand_size> get_game(std::variant<int, const char *> game_spec, unsigned turn) {
const boost::json::object game_json = [&game_spec]() { const boost::json::object game_json = [&game_spec]() {
if (game_spec.index() == 0) { if (game_spec.index() == 0) {
return download_game_json(std::get<int>(game_spec)); return download_game_json(std::get<int>(game_spec));
@ -144,15 +148,82 @@ namespace Download {
return open_game_json(std::get<const char *>(game_spec)); return open_game_json(std::get<const char *>(game_spec));
} }
}(); }();
const std::vector<Hanabi::Card> deck = parse_deck(game_json.at("deck"));
const std::vector<Action> actions = parse_actions(game_json.at("actions"));
const size_t num_players_js = game_json.at("players").as_array().size();
ASSERT (num_players_js == num_players);
auto game = produce_state<num_suits, num_players, hand_size>(deck, actions, turn); const auto [deck, num_suits] = parse_deck(game_json.at("deck"));
game.normalize_draw_and_positions(); const std::vector<Action> actions = parse_actions(game_json.at("actions"));
return game; const size_t num_players = game_json.at("players").as_array().size();
switch(num_players) {
case 2:
switch(num_suits) {
case 3:
return produce_state<3,2,5>(deck, actions, turn);
case 4:
return produce_state<4,2,5>(deck, actions, turn);
case 5:
return produce_state<5,2,5>(deck, actions, turn);
case 6:
return produce_state<6,2,5>(deck, actions, turn);
default:
throw std::runtime_error("Invalid number of suits: " + std::to_string(num_suits));
} }
case 3:
switch(num_suits) {
case 3:
return produce_state<3,3,5>(deck, actions, turn);
case 4:
return produce_state<4,3,5>(deck, actions, turn);
case 5:
return produce_state<5,3,5>(deck, actions, turn);
case 6:
return produce_state<6,3,5>(deck, actions, turn);
default:
throw std::runtime_error("Invalid number of suits: " + std::to_string(num_suits));
}
case 4:
switch(num_suits) {
case 3:
return produce_state<3,4,4>(deck, actions, turn);
case 4:
return produce_state<4,4,4>(deck, actions, turn);
case 5:
return produce_state<5,4,4>(deck, actions, turn);
case 6:
return produce_state<6,4,4>(deck, actions, turn);
default:
throw std::runtime_error("Invalid number of suits: " + std::to_string(num_suits));
}
case 5:
switch(num_suits) {
case 3:
return produce_state<3,5,4>(deck, actions, turn);
case 4:
return produce_state<4,5,4>(deck, actions, turn);
case 5:
return produce_state<5,5,4>(deck, actions, turn);
case 6:
return produce_state<6,5,4>(deck, actions, turn);
default:
throw std::runtime_error("Invalid number of suits: " + std::to_string(num_suits));
}
case 6:
switch(num_suits) {
case 3:
return produce_state<3,6,3>(deck, actions, turn);
case 4:
return produce_state<4,6,3>(deck, actions, turn);
case 5:
return produce_state<5,6,3>(deck, actions, turn);
case 6:
return produce_state<6,6,3>(deck, actions, turn);
default:
throw std::runtime_error("Invalid number of suits: " + std::to_string(num_suits));
}
default:
throw std::runtime_error("Invalid number of players: " + std::to_string(num_players));
}
}
} // namespacen Download } // namespacen Download

View file

@ -139,61 +139,85 @@ struct BacktrackAction {
std::uint8_t multiplicity{}; std::uint8_t multiplicity{};
}; };
class HanabiStateIF {
public:
virtual double backtrack(size_t depth) = 0;
virtual void clue() = 0;
virtual BacktrackAction discard(std::uint8_t index) = 0;
virtual BacktrackAction play(std::uint8_t index) = 0;
virtual void revert_clue() = 0;
virtual void revert_play(const BacktrackAction &action, bool was_on_8_clues) = 0;
virtual void revert_discard(const BacktrackAction &action) = 0;
[[nodiscard]] virtual std::uint8_t find_card_in_hand(const Card& card) const = 0;
[[nodiscard]] virtual bool is_trash(const Card& card) const = 0;
[[nodiscard]] virtual bool is_playable(const Card& card) const = 0;
[[nodiscard]] virtual std::uint64_t enumerated_states() const = 0;
virtual void normalize_draw_and_positions() = 0;
virtual ~HanabiStateIF() = default;
protected:
virtual void print(std::ostream& os) const = 0;
friend std::ostream& operator<<(std::ostream&, HanabiStateIF const&);
};
template <std::size_t num_suits, player_t num_players, std::size_t hand_size> template <std::size_t num_suits, player_t num_players, std::size_t hand_size>
class HanabiState { class HanabiState : public HanabiStateIF {
public: public:
HanabiState() = default; HanabiState() = default;
explicit HanabiState(const std::vector<Card>& deck); explicit HanabiState(const std::vector<Card>& deck);
double backtrack(size_t depth); double backtrack(size_t depth) final;
void clue(); void clue() final;
BacktrackAction play(std::uint8_t index) final;
BacktrackAction discard(std::uint8_t index) final;
/** void revert_clue() final;
* Plays a card from current hand, drawing top card of draw pile and rotating draw pile void revert_play(const BacktrackAction &action, bool was_on_8_clues) final;
* @param index of card in hand to be played void revert_discard(const BacktrackAction &action) final;
*/
BacktrackAction play(std::uint8_t index);
BacktrackAction discard(std::uint8_t index); [[nodiscard]] std::uint8_t find_card_in_hand(const Card& card) const final;
[[nodiscard]] bool is_trash(const Card& card) const final;
[[nodiscard]] bool is_playable(const Card& card) const final;
std::uint8_t find_card_in_hand(const Card& card) const; [[nodiscard]] std::uint64_t enumerated_states() const final;
void normalize_draw_and_positions(); void normalize_draw_and_positions() final;
void revert_clue(); auto operator<=>(const HanabiState &) const = default;
void revert_play(const BacktrackAction &action, bool was_on_8_clues);
void revert_discard(const BacktrackAction &action);
protected:
void print(std::ostream& os) const final;
private:
uint8_t draw(uint8_t index); uint8_t draw(uint8_t index);
void revert_draw(std::uint8_t index, Card discarded_card); void revert_draw(std::uint8_t index, Card discarded_card);
void incr_turn(); void incr_turn();
void decr_turn(); void decr_turn();
bool is_trash(const Card& card) const;
bool is_playable(const Card& card) const;
player_t _turn{}; player_t _turn{};
clue_t _num_clues{}; clue_t _num_clues{};
std::uint8_t _weighted_draw_pile_size{}; std::uint8_t _weighted_draw_pile_size{};
Stacks<num_suits> _stacks{}; Stacks<num_suits> _stacks{};
std::array<std::array<Card, hand_size>, num_players> _hands{}; std::array<std::array<Card, hand_size>, num_players> _hands{};
// CardArray<num_suits, player_t> _card_positions{};
std::list<CardMultiplicity> _draw_pile{}; std::list<CardMultiplicity> _draw_pile{};
std::uint8_t _endgame_turns_left; std::uint8_t _endgame_turns_left{};
static constexpr uint8_t no_endgame = std::numeric_limits<uint8_t>::max() - 1; static constexpr uint8_t no_endgame = std::numeric_limits<uint8_t>::max();
// further statistics that we might want to keep track of // further statistics that we might want to keep track of
uint8_t _pace{}; uint8_t _pace{};
uint8_t _score{}; uint8_t _score{};
std::uint64_t _enumerated_states {}; std::uint64_t _enumerated_states {};
auto operator<=>(const HanabiState &) const = default;
}; };
template <std::size_t num_suits, player_t num_players, std::size_t hand_size> template <std::size_t num_suits, player_t num_players, std::size_t hand_size>
@ -207,10 +231,6 @@ bool same_up_to_discard_permutation(HanabiState<num_suits, num_players, hand_siz
return state1 == state2; return state1 == state2;
} }
template <std::size_t num_suits, player_t num_players, std::size_t hand_size>
std::ostream & operator<<(std::ostream &os, HanabiState<num_suits, num_players, hand_size> hanabi_state);
template class HanabiState<5, 3, 4>;
} }

View file

@ -1,9 +1,16 @@
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
#include "myassert.h" #include "myassert.h"
#include "game_state.h"
#include <vector>
namespace Hanabi { namespace Hanabi {
std::ostream& operator<<(std::ostream& os, HanabiStateIF const& hanabi_state) {
hanabi_state.print(os);
return os;
}
Card &Card::operator++() { Card &Card::operator++() {
rank++; rank++;
return *this; return *this;
@ -111,6 +118,11 @@ namespace Hanabi {
return card.rank == _stacks[card.suit] - 1; return card.rank == _stacks[card.suit] - 1;
} }
template<size_t num_suits, player_t num_players, size_t hand_size>
std::uint64_t HanabiState<num_suits, num_players, hand_size>::enumerated_states() const {
return _enumerated_states;
}
template<size_t num_suits, player_t num_players, size_t hand_size> template<size_t num_suits, player_t num_players, size_t hand_size>
bool HanabiState<num_suits, num_players, hand_size>::is_trash(const Hanabi::Card &card) const { bool HanabiState<num_suits, num_players, hand_size>::is_trash(const Hanabi::Card &card) const {
return card.rank >= _stacks[card.suit]; return card.rank >= _stacks[card.suit];
@ -167,26 +179,25 @@ namespace Hanabi {
} }
template<std::size_t num_suits, player_t num_players, std::size_t hand_size> template<std::size_t num_suits, player_t num_players, std::size_t hand_size>
std::ostream &operator<<(std::ostream &os, const HanabiState<num_suits, num_players, hand_size> hanabi_state) { void HanabiState<num_suits, num_players, hand_size>::print(std::ostream &os) const {
os << "Stacks: " << hanabi_state._stacks << " (score " << +hanabi_state._score << ")"; os << "Stacks: " << _stacks << " (score " << +_score << ")";
os << ", clues: " << +hanabi_state._num_clues << ", turn: " << +hanabi_state._turn << std::endl; os << ", clues: " << +_num_clues << ", turn: " << +_turn << std::endl;
os << "Draw pile: "; os << "Draw pile: ";
for (const auto &[card, mul]: hanabi_state._draw_pile) { for (const auto &[card, mul]: _draw_pile) {
os << card; os << card;
if (mul > 1) { if (mul > 1) {
os << " (" << +mul << ")"; os << " (" << +mul << ")";
} }
os << ", "; os << ", ";
} }
os << "(size " << +hanabi_state._weighted_draw_pile_size << ")" << std::endl; os << "(size " << +_weighted_draw_pile_size << ")" << std::endl;
os << "Hands: "; os << "Hands: ";
for (const auto &hand: hanabi_state._hands) { for (const auto &hand: _hands) {
for (const auto &card: hand) { for (const auto &card: hand) {
os << card << ", "; os << card << ", ";
} }
os << " | "; os << " | ";
} }
return os;
} }
template<std::size_t num_suits, player_t num_players, std::size_t hand_size> template<std::size_t num_suits, player_t num_players, std::size_t hand_size>

View file

@ -15,37 +15,13 @@
namespace Hanabi { namespace Hanabi {
void test_game() {
HanabiState<2, 2, 5> state;
state._stacks[0] = 2;
state._stacks[1] = 3;
Card r41 = {0, 4, 1};
state._draw_pile.push_back({r41, 1});
state._hands[0] = {y0, y1, y2, r0, r1};
state._hands[1] = {r1, r1, y1, r3, r2};
// state._card_positions[r1] = 0;
state._weighted_draw_pile_size = 1;
auto state2 = state;
auto a = state.play(4);
std::cout << state;
state.revert_play(a, false);
std::cout << state << std::endl;
std::cout << state2 << std::endl;
ASSERT(state._hands == state2._hands);
ASSERT(state._draw_pile == state2._draw_pile);
// ASSERT(state._card_positions == state2._card_positions);
ASSERT(state == state2);
}
void download(int turn) { void download(int turn) {
auto game = Download::get_game<6,5,4>("996518", turn); auto game = Download::get_game("1004116", turn);
std::cout << "Analysing state: " << game << std::endl; std::cout << "Analysing state: " << *game << std::endl;
auto res = game.backtrack(1); auto res = game->backtrack(1);
std::cout.precision(10);
std::cout << "Probability with optimal play: " << res << std::endl; std::cout << "Probability with optimal play: " << res << std::endl;
std::cout << "Enumerated " << game._enumerated_states << " states" << std::endl; std::cout << "Enumerated " << game->enumerated_states() << " states" << std::endl;
} }
void print_sizes() { void print_sizes() {