add function to enumerate possible next states upon specific action

This commit is contained in:
Maximilian Keßler 2023-08-11 18:28:12 +02:00
parent 6ae9244fa6
commit a4e9560753
Signed by: max
GPG key ID: BCC5A619923C0BA5
3 changed files with 71 additions and 23 deletions

View file

@ -69,16 +69,18 @@ namespace Hanabi {
std::ostream &operator<<(std::ostream &os, const Card &card);
constexpr Card r0 = {0, 0};
constexpr Card r1 = {0, 1};
constexpr Card r2 = {0, 2};
constexpr Card r3 = {0, 3};
constexpr Card r4 = {0, 4};
constexpr Card y0 = {1, 0};
constexpr Card y1 = {1, 1};
constexpr Card y2 = {1, 2};
constexpr Card y3 = {1, 3};
constexpr Card y4 = {1, 4};
constexpr Card r0 = {0, 5};
constexpr Card r1 = {0, 4};
constexpr Card r2 = {0, 3};
constexpr Card r3 = {0, 2};
constexpr Card r4 = {0, 1};
constexpr Card r5 = {0, 0};
constexpr Card y0 = {1, 5};
constexpr Card y1 = {1, 4};
constexpr Card y2 = {1, 3};
constexpr Card y3 = {1, 2};
constexpr Card y4 = {1, 1};
constexpr Card y5 = {1, 0};
constexpr Card unknown_card = {0, 6};
/**
@ -163,6 +165,8 @@ public:
virtual void discard(hand_index_t index) = 0;
virtual void play(hand_index_t index) = 0;
virtual void rotate_next_draw(const Card& card) = 0;
virtual void revert() = 0;
[[nodiscard]] virtual hand_index_t find_card_in_hand(const Card& card) const = 0;
@ -178,6 +182,7 @@ public:
virtual probability_t evaluate_state() = 0;
virtual std::vector<std::pair<Action, std::optional<probability_t>>> get_reasonable_actions() = 0;
virtual std::vector<std::pair<CardMultiplicity,probability_t>> possible_next_states(hand_index_t index, bool play) = 0;
virtual ~HanabiStateIF() = default;
@ -197,6 +202,8 @@ public:
void discard(hand_index_t index) final;
void play(hand_index_t index) final;
void rotate_next_draw(const Card& card) final;
void revert() final;
[[nodiscard]] hand_index_t find_card_in_hand(const Card& card) const final;
@ -214,6 +221,7 @@ public:
std::optional<probability_t> lookup() const;
std::vector<std::pair<Action, std::optional<probability_t>>> get_reasonable_actions() final;
std::vector<std::pair<CardMultiplicity,probability_t>> possible_next_states(hand_index_t index, bool play) final;
auto operator<=>(const HanabiState &) const = default;
@ -276,8 +284,8 @@ private:
void update_tablebase(unsigned long id, probability_t probability);
template<bool play, class Function>
void do_for_each_potential_draw(hand_index_t index, Function f);
template<class Function>
void do_for_each_potential_draw(hand_index_t index, bool play, Function f);
void incr_turn();
void decr_turn();

View file

@ -469,6 +469,23 @@ namespace Hanabi {
}
}
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
std::vector<std::pair<CardMultiplicity, probability_t>> HanabiState<num_suits, num_players, hand_size>::possible_next_states(hand_index_t index, bool play) {
std::vector<std::pair<CardMultiplicity, probability_t>> next_states;
do_for_each_potential_draw(index, play, [this, &next_states](unsigned long multiplicity){
auto prob = lookup();
ASSERT(lookup().has_value());
// bit hacky to get drawn card here
decr_turn();
const CardMultiplicity drawn_card = {_hands[_turn], multiplicity};
incr_turn();
next_states.emplace_back(drawn_card, prob.value());
});
return next_states;
}
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
std::vector<std::pair<Action, std::optional<probability_t>>> HanabiState<num_suits, num_players, hand_size>::get_reasonable_actions() {
std::vector<std::pair<Action, std::optional<probability_t>>> reasonable_actions {};
@ -485,7 +502,7 @@ namespace Hanabi {
bool known = true;
probability_t sum_of_probabilities = 0;
do_for_each_potential_draw<true>(index, [this, &sum_of_probabilities, &known](const unsigned long multiplicity){
do_for_each_potential_draw(index, true, [this, &sum_of_probabilities, &known](const unsigned long multiplicity){
const std::optional<probability_t> prob = lookup();
if (prob.has_value()) {
sum_of_probabilities += prob.value() * multiplicity;
@ -511,7 +528,7 @@ namespace Hanabi {
bool known = true;
probability_t sum_of_probabilities = 0;
do_for_each_potential_draw<false>(index, [this, &sum_of_probabilities, &known](const unsigned long multiplicity){
do_for_each_potential_draw(index, false, [this, &sum_of_probabilities, &known](const unsigned long multiplicity){
const std::optional<probability_t> prob = lookup();
if (prob.has_value()) {
sum_of_probabilities += prob.value() * multiplicity;
@ -554,6 +571,15 @@ namespace Hanabi {
}
}
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
void HanabiState<num_suits, num_players, hand_size>::rotate_next_draw(const Card& card) {
auto card_it = std::find_if(_draw_pile.begin(), _draw_pile.end(), [&card](const CardMultiplicity& card_multiplicity){
return card_multiplicity.card.rank == card.rank and card_multiplicity.card.suit == card.suit;
});
ASSERT(card_it != _draw_pile.end());
std::swap(*card_it, _draw_pile.front());
}
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
probability_t HanabiState<num_suits, num_players, hand_size>::evaluate_state() {
ASSERT(_relative_representation.initialized);
@ -580,7 +606,7 @@ namespace Hanabi {
if(is_playable(hand[index])) {
probability_t sum_of_probabilities = 0;
do_for_each_potential_draw<true>(index, [this, &sum_of_probabilities](const unsigned long multiplicity){
do_for_each_potential_draw(index, true, [this, &sum_of_probabilities](const unsigned long multiplicity){
sum_of_probabilities += evaluate_state() * multiplicity;
});
@ -601,7 +627,7 @@ namespace Hanabi {
if (is_trash(hand[index])) {
probability_t sum_of_probabilities = 0;
do_for_each_potential_draw<false>(index, [this, &sum_of_probabilities](const unsigned long multiplicity){
do_for_each_potential_draw(index, false, [this, &sum_of_probabilities](const unsigned long multiplicity){
sum_of_probabilities += evaluate_state() * multiplicity;
});
@ -638,18 +664,18 @@ namespace Hanabi {
}
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
template<bool play, class Function>
void HanabiState<num_suits, num_players, hand_size>::do_for_each_potential_draw(hand_index_t index, Function f) {
auto do_action = [this, index](){
if constexpr (play) {
template<class Function>
void HanabiState<num_suits, num_players, hand_size>::do_for_each_potential_draw(hand_index_t index, bool play, Function f) {
auto do_action = [this, index, play](){
if (play) {
return play_and_potentially_update(index);
} else {
return discard_and_potentially_update(index);
}
};
auto revert_action = [this](){
if constexpr (play) {
auto revert_action = [this, play](){
if (play) {
revert_play();
} else {
revert_discard();

View file

@ -23,6 +23,20 @@ namespace Hanabi {
std::cout << "Probability with optimal play: " << res << std::endl;
std::cout << "Enumerated " << game->enumerated_states() << " states" << std::endl;
std::cout << "Visited " << game->position_tablebase().size() << " unique game states. " << std::endl;
game->rotate_next_draw(r1);
game->discard(0);
game->rotate_next_draw(r1);
game->play(game->find_card_in_hand(y3));
game->clue();
game->rotate_next_draw(r1);
game->play(game->find_card_in_hand(y4));
std::cout << *game << std::endl;
for (const auto &[action, probability] : game->get_reasonable_actions()) {
std::cout << action;
if(probability.has_value()) {
@ -117,7 +131,7 @@ void check_games(unsigned num_players, unsigned max_draw_pile_size, unsigned fir
int main(int argc, char *argv[]) {
#ifndef NDEBUG
// test();
test();
#endif
if(argc == 3) {
std::string game(argv[1]);