diff --git a/game_state.h b/game_state.h index 3613be9..1c7e94b 100644 --- a/game_state.h +++ b/game_state.h @@ -143,6 +143,10 @@ enum class ActionType { vote_terminate = 10, }; +class Action { + ActionType type {}; + Card card {}; +}; /** Would like to have 2 versions: * All: @@ -185,8 +189,6 @@ public: HanabiState() = default; explicit HanabiState(const std::vector& deck); - probability_t evaluate_state() final; - void clue() final; void discard(hand_index_t index) final; void play(hand_index_t index) final; @@ -203,6 +205,11 @@ public: [[nodiscard]] const std::unordered_map& position_tablebase() const final; void init_backtracking_information() final; + probability_t evaluate_state() final; + + std::optional lookup() const; + + std::vector>> get_reasonable_actions(); auto operator<=>(const HanabiState &) const = default; diff --git a/game_state.hpp b/game_state.hpp index b4f3d82..4fbcb63 100644 --- a/game_state.hpp +++ b/game_state.hpp @@ -451,6 +451,90 @@ namespace Hanabi { } } + template + std::vector>> HanabiState::get_reasonable_actions() { + std::vector>> reasonable_actions {}; + + if(_score == 5 * num_suits or _pace < 0 or _endgame_turns_left == 0) { + return reasonable_actions; + } + + const std::array& hand = _hands[_turn]; + // First, check for playable cards + for(std::uint8_t index = 0; index < hand_size; index++) { + if(is_playable(hand[index])) { + const Action action = {ActionType::play, hand[index]}; + bool known = true; + probability_t sum_of_probabilities = 0; + + do_for_each_potential_draw(index, [this, &sum_of_probabilities, &known](const unsigned long multiplicity){ + const std::optional prob = lookup(); + if (prob.has_value()) { + sum_of_probabilities += prob.value() * multiplicity; + } else { + known = false; + } + }); + + if (known) { + const unsigned long total_weight = std::max(static_cast(_weighted_draw_pile_size), 1ul); + const probability_t probability_play = sum_of_probabilities / total_weight; + reasonable_actions.emplace_back(action, probability_play); + } else { + reasonable_actions.emplace_back(action, std::nullopt); + } + } + } + + if(_pace > 0 and _num_clues < max_num_clues) { + for(std::uint8_t index = 0; index < hand_size; index++) { + if (is_trash(hand[index])) { + const Action action = {ActionType::discard, hand[index]}; + bool known = true; + probability_t sum_of_probabilities = 0; + + do_for_each_potential_draw(index, [this, &sum_of_probabilities, &known](const unsigned long multiplicity){ + const std::optional prob = lookup(); + if (prob.has_value()) { + sum_of_probabilities += prob.value() * multiplicity; + } else { + known = false; + } + }); + + if (known) { + const unsigned long total_weight = std::max(static_cast(_weighted_draw_pile_size), 1ul); + const probability_t probability_discard = sum_of_probabilities / total_weight; + reasonable_actions.emplace_back(action, probability_discard); + } else { + reasonable_actions.emplace_back(action, std::nullopt); + } + + // All discards are equivalent, do not continue searching for different trash + break; + } + } + } + + if(_num_clues > 0) { + clue(); + const std::optional prob = lookup(); + const Action action = {ActionType::clue, unknown_card}; + reasonable_actions.emplace_back(action, prob); + const probability_t probability_stall = evaluate_state(); + revert_clue(); + } + } + + template + std::optional HanabiState::lookup() const { + const auto id = unique_id(); + if(_position_tablebase.contains(id)) { + return _position_tablebase[id]; + } else { + return std::nullopt; + } + } template probability_t HanabiState::evaluate_state() {