diff --git a/game_state.h b/game_state.h index 0fc759f..1186547 100644 --- a/game_state.h +++ b/game_state.h @@ -235,6 +235,11 @@ private: template unsigned long draw(hand_index_t index); void revert_draw(hand_index_t index, Card discarded_card); + void update_tablebase(unsigned long id, probability_t probability); + + template + void do_for_each_potential_draw(hand_index_t index, Function f); + void incr_turn(); void decr_turn(); diff --git a/game_state.hpp b/game_state.hpp index bbc92ba..eec2c92 100644 --- a/game_state.hpp +++ b/game_state.hpp @@ -280,7 +280,7 @@ namespace Hanabi { } return draw.multiplicity; } - return 0; + return 1; } template @@ -435,19 +435,6 @@ namespace Hanabi { _num_clues++; } - #define RETURN_PROBABILITY \ - if (_position_tablebase.contains(id_of_state)) { \ - ASSERT(_position_tablebase[id_of_state] == best_probability); \ - } \ - _position_tablebase[id_of_state] = best_probability; \ - \ - return best_probability; - - #define UPDATE_PROBABILITY(new_probability) \ - best_probability = std::max(best_probability, new_probability); \ - if (best_probability == 1) { \ - RETURN_PROBABILITY; \ - } template probability_t HanabiState::backtrack(size_t depth) { @@ -466,32 +453,26 @@ namespace Hanabi { // TODO: Have some endgame analysis here? - // First, check if we have any playable cards probability_t best_probability = 0; - const std::array hand = _hands[_turn]; + const std::array& hand = _hands[_turn]; // First, check for playables for(std::uint8_t index = 0; index < hand_size; index++) { if(is_playable(hand[index])) { - if (_draw_pile.empty()) { - play_and_potentially_update(index); - const probability_t probability_for_this_play = backtrack(depth + 1); - revert_play(); - UPDATE_PROBABILITY(probability_for_this_play); - } else { - probability_t sum_of_probabilities = 0; - uint8_t sum_of_mults = 0; - for (size_t i = 0; i < _draw_pile.size(); i++) { - const unsigned long multiplicity = play_and_potentially_update(index); - sum_of_probabilities += backtrack(depth + 1) * multiplicity; - sum_of_mults += multiplicity; - revert_play(); - ASSERT(sum_of_mults <= _weighted_draw_pile_size); - } - ASSERT(sum_of_mults == _weighted_draw_pile_size); - const probability_t probability_for_this_play = sum_of_probabilities / _weighted_draw_pile_size; - UPDATE_PROBABILITY(probability_for_this_play); - } + probability_t sum_of_probabilities = 0; + + do_for_each_potential_draw(index, [this, &sum_of_probabilities](const unsigned long multiplicity){ + sum_of_probabilities += backtrack(0) * multiplicity; + }); + + const unsigned long total_weight = std::max(static_cast(_weighted_draw_pile_size), 1ul); + const probability_t probability_play = sum_of_probabilities / total_weight; + + best_probability = std::max(best_probability, probability_play); + if (best_probability == 1) { + update_tablebase(id_of_state, best_probability); + return best_probability; + }; } } @@ -500,23 +481,20 @@ namespace Hanabi { for(std::uint8_t index = 0; index < hand_size; index++) { if (is_trash(hand[index])) { probability_t sum_of_probabilities = 0; - if (_draw_pile.empty()) { - discard_and_potentially_update(index); - const probability_t probability_for_this_discard = backtrack(depth + 1); - revert_discard(); - UPDATE_PROBABILITY(probability_for_this_discard); - } else { - uint8_t sum_of_mults = 0; - for (size_t i = 0; i < _draw_pile.size(); i++) { - const unsigned long multiplicity = discard_and_potentially_update(index); - sum_of_probabilities += backtrack(depth + 1) * multiplicity; - sum_of_mults += multiplicity; - revert_discard(); - } - ASSERT(sum_of_mults == _weighted_draw_pile_size); - const probability_t probability_discard = sum_of_probabilities / _weighted_draw_pile_size; - UPDATE_PROBABILITY(probability_discard); - } + + do_for_each_potential_draw(index, [this, &sum_of_probabilities](const unsigned long multiplicity){ + sum_of_probabilities += backtrack(0) * multiplicity; + }); + + const unsigned long total_weight = std::max(static_cast(_weighted_draw_pile_size), 1ul); + const probability_t probability_discard = sum_of_probabilities / total_weight; + best_probability = std::max(best_probability, probability_discard); + + best_probability = std::max(best_probability, probability_discard); + if (best_probability == 1) { + update_tablebase(id_of_state, best_probability); + return best_probability; + }; // All discards are equivalent, do not continue searching for different trash break; @@ -529,10 +507,50 @@ namespace Hanabi { clue(); const probability_t probability_stall = backtrack(depth + 1); revert_clue(); - UPDATE_PROBABILITY(probability_stall); + best_probability = std::max(best_probability, probability_stall); + if (best_probability == 1) { + update_tablebase(id_of_state, best_probability); + return best_probability; + }; } - RETURN_PROBABILITY; + update_tablebase(id_of_state, best_probability); + return best_probability; + } + + template + template + void HanabiState::do_for_each_potential_draw(hand_index_t index, Function f) { + auto do_action = [this, index](){ + if constexpr (play) { + return play_and_potentially_update(index); + } else { + return discard_and_potentially_update(index); + } + }; + + auto revert_action = [this](){ + if constexpr (play) { + revert_play(); + } else { + revert_discard(); + } + }; + + if(_draw_pile.empty()) { + do_action(); + f(1); + revert_action(); + } else { + unsigned sum_of_multiplicities; + for(size_t i = 0; i < _draw_pile.size(); i++) { + const unsigned long multiplicity = do_action(); + sum_of_multiplicities += multiplicity; + f(multiplicity); + revert_action(); + } + ASSERT(sum_of_multiplicities == _weighted_draw_pile_size); + } } template @@ -597,4 +615,14 @@ namespace Hanabi { return _weighted_draw_pile_size; } + template + void HanabiState::update_tablebase( + unsigned long id, + Hanabi::probability_t probability) { + if (_position_tablebase.contains(id)) { + ASSERT(_position_tablebase[id] == probability); + } + _position_tablebase[id] = probability; + } + } // namespace Hanabi \ No newline at end of file