add method to list reasonable moves

This commit is contained in:
Maximilian Keßler 2023-08-11 15:41:03 +02:00
parent 907fb3ae47
commit 57ebc3d478
Signed by: max
GPG key ID: BCC5A619923C0BA5
2 changed files with 93 additions and 2 deletions

View file

@ -143,6 +143,10 @@ enum class ActionType {
vote_terminate = 10,
};
class Action {
ActionType type {};
Card card {};
};
/** Would like to have 2 versions:
* All:
@ -185,8 +189,6 @@ public:
HanabiState() = default;
explicit HanabiState(const std::vector<Card>& deck);
probability_t evaluate_state() final;
void clue() final;
void discard(hand_index_t index) final;
void play(hand_index_t index) final;
@ -203,6 +205,11 @@ public:
[[nodiscard]] const std::unordered_map<unsigned long, probability_t>& position_tablebase() const final;
void init_backtracking_information() final;
probability_t evaluate_state() final;
std::optional<probability_t> lookup() const;
std::vector<std::pair<Action, std::optional<probability_t>>> get_reasonable_actions();
auto operator<=>(const HanabiState &) const = default;

View file

@ -451,6 +451,90 @@ namespace Hanabi {
}
}
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
std::vector<std::pair<Action, std::optional<probability_t>>> HanabiState<num_suits, num_players, hand_size>::get_reasonable_actions() {
std::vector<std::pair<Action, std::optional<probability_t>>> reasonable_actions {};
if(_score == 5 * num_suits or _pace < 0 or _endgame_turns_left == 0) {
return reasonable_actions;
}
const std::array<Card, hand_size>& hand = _hands[_turn];
// First, check for playable cards
for(std::uint8_t index = 0; index < hand_size; index++) {
if(is_playable(hand[index])) {
const Action action = {ActionType::play, hand[index]};
bool known = true;
probability_t sum_of_probabilities = 0;
do_for_each_potential_draw<true>(index, [this, &sum_of_probabilities, &known](const unsigned long multiplicity){
const std::optional<probability_t> prob = lookup();
if (prob.has_value()) {
sum_of_probabilities += prob.value() * multiplicity;
} else {
known = false;
}
});
if (known) {
const unsigned long total_weight = std::max(static_cast<unsigned long>(_weighted_draw_pile_size), 1ul);
const probability_t probability_play = sum_of_probabilities / total_weight;
reasonable_actions.emplace_back(action, probability_play);
} else {
reasonable_actions.emplace_back(action, std::nullopt);
}
}
}
if(_pace > 0 and _num_clues < max_num_clues) {
for(std::uint8_t index = 0; index < hand_size; index++) {
if (is_trash(hand[index])) {
const Action action = {ActionType::discard, hand[index]};
bool known = true;
probability_t sum_of_probabilities = 0;
do_for_each_potential_draw<false>(index, [this, &sum_of_probabilities, &known](const unsigned long multiplicity){
const std::optional<probability_t> prob = lookup();
if (prob.has_value()) {
sum_of_probabilities += prob.value() * multiplicity;
} else {
known = false;
}
});
if (known) {
const unsigned long total_weight = std::max(static_cast<unsigned long>(_weighted_draw_pile_size), 1ul);
const probability_t probability_discard = sum_of_probabilities / total_weight;
reasonable_actions.emplace_back(action, probability_discard);
} else {
reasonable_actions.emplace_back(action, std::nullopt);
}
// All discards are equivalent, do not continue searching for different trash
break;
}
}
}
if(_num_clues > 0) {
clue();
const std::optional<probability_t> prob = lookup();
const Action action = {ActionType::clue, unknown_card};
reasonable_actions.emplace_back(action, prob);
const probability_t probability_stall = evaluate_state();
revert_clue();
}
}
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
std::optional<probability_t> HanabiState<num_suits, num_players, hand_size>::lookup() const {
const auto id = unique_id();
if(_position_tablebase.contains(id)) {
return _position_tablebase[id];
} else {
return std::nullopt;
}
}
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
probability_t HanabiState<num_suits, num_players, hand_size>::evaluate_state() {