code cleanup

- unified function to iterate over all possible draws when discarding or
  playing
- remove macro usages
This commit is contained in:
Maximilian Keßler 2023-08-10 18:23:33 +02:00
parent 78b3335414
commit 17405b0f00
Signed by: max
GPG key ID: BCC5A619923C0BA5
2 changed files with 87 additions and 54 deletions

View file

@ -235,6 +235,11 @@ private:
template<bool update_card_positions> unsigned long draw(hand_index_t index); template<bool update_card_positions> unsigned long draw(hand_index_t index);
void revert_draw(hand_index_t index, Card discarded_card); void revert_draw(hand_index_t index, Card discarded_card);
void update_tablebase(unsigned long id, probability_t probability);
template<bool play, class Function>
void do_for_each_potential_draw(hand_index_t index, Function f);
void incr_turn(); void incr_turn();
void decr_turn(); void decr_turn();

View file

@ -280,7 +280,7 @@ namespace Hanabi {
} }
return draw.multiplicity; return draw.multiplicity;
} }
return 0; return 1;
} }
template<suit_t num_suits, player_t num_players, hand_index_t hand_size> template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
@ -435,19 +435,6 @@ namespace Hanabi {
_num_clues++; _num_clues++;
} }
#define RETURN_PROBABILITY \
if (_position_tablebase.contains(id_of_state)) { \
ASSERT(_position_tablebase[id_of_state] == best_probability); \
} \
_position_tablebase[id_of_state] = best_probability; \
\
return best_probability;
#define UPDATE_PROBABILITY(new_probability) \
best_probability = std::max(best_probability, new_probability); \
if (best_probability == 1) { \
RETURN_PROBABILITY; \
}
template<suit_t num_suits, player_t num_players, hand_index_t hand_size> template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
probability_t HanabiState<num_suits, num_players, hand_size>::backtrack(size_t depth) { probability_t HanabiState<num_suits, num_players, hand_size>::backtrack(size_t depth) {
@ -466,32 +453,26 @@ namespace Hanabi {
// TODO: Have some endgame analysis here? // TODO: Have some endgame analysis here?
// First, check if we have any playable cards
probability_t best_probability = 0; probability_t best_probability = 0;
const std::array<Card, hand_size> hand = _hands[_turn]; const std::array<Card, hand_size>& hand = _hands[_turn];
// First, check for playables // First, check for playables
for(std::uint8_t index = 0; index < hand_size; index++) { for(std::uint8_t index = 0; index < hand_size; index++) {
if(is_playable(hand[index])) { if(is_playable(hand[index])) {
if (_draw_pile.empty()) {
play_and_potentially_update<true>(index);
const probability_t probability_for_this_play = backtrack(depth + 1);
revert_play();
UPDATE_PROBABILITY(probability_for_this_play);
} else {
probability_t sum_of_probabilities = 0; probability_t sum_of_probabilities = 0;
uint8_t sum_of_mults = 0;
for (size_t i = 0; i < _draw_pile.size(); i++) { do_for_each_potential_draw<true>(index, [this, &sum_of_probabilities](const unsigned long multiplicity){
const unsigned long multiplicity = play_and_potentially_update<true>(index); sum_of_probabilities += backtrack(0) * multiplicity;
sum_of_probabilities += backtrack(depth + 1) * multiplicity; });
sum_of_mults += multiplicity;
revert_play(); const unsigned long total_weight = std::max(static_cast<unsigned long>(_weighted_draw_pile_size), 1ul);
ASSERT(sum_of_mults <= _weighted_draw_pile_size); const probability_t probability_play = sum_of_probabilities / total_weight;
}
ASSERT(sum_of_mults == _weighted_draw_pile_size); best_probability = std::max(best_probability, probability_play);
const probability_t probability_for_this_play = sum_of_probabilities / _weighted_draw_pile_size; if (best_probability == 1) {
UPDATE_PROBABILITY(probability_for_this_play); update_tablebase(id_of_state, best_probability);
} return best_probability;
};
} }
} }
@ -500,23 +481,20 @@ namespace Hanabi {
for(std::uint8_t index = 0; index < hand_size; index++) { for(std::uint8_t index = 0; index < hand_size; index++) {
if (is_trash(hand[index])) { if (is_trash(hand[index])) {
probability_t sum_of_probabilities = 0; probability_t sum_of_probabilities = 0;
if (_draw_pile.empty()) {
discard_and_potentially_update<true>(index); do_for_each_potential_draw<false>(index, [this, &sum_of_probabilities](const unsigned long multiplicity){
const probability_t probability_for_this_discard = backtrack(depth + 1); sum_of_probabilities += backtrack(0) * multiplicity;
revert_discard(); });
UPDATE_PROBABILITY(probability_for_this_discard);
} else { const unsigned long total_weight = std::max(static_cast<unsigned long>(_weighted_draw_pile_size), 1ul);
uint8_t sum_of_mults = 0; const probability_t probability_discard = sum_of_probabilities / total_weight;
for (size_t i = 0; i < _draw_pile.size(); i++) { best_probability = std::max(best_probability, probability_discard);
const unsigned long multiplicity = discard_and_potentially_update<true>(index);
sum_of_probabilities += backtrack(depth + 1) * multiplicity; best_probability = std::max(best_probability, probability_discard);
sum_of_mults += multiplicity; if (best_probability == 1) {
revert_discard(); update_tablebase(id_of_state, best_probability);
} return best_probability;
ASSERT(sum_of_mults == _weighted_draw_pile_size); };
const probability_t probability_discard = sum_of_probabilities / _weighted_draw_pile_size;
UPDATE_PROBABILITY(probability_discard);
}
// All discards are equivalent, do not continue searching for different trash // All discards are equivalent, do not continue searching for different trash
break; break;
@ -529,10 +507,50 @@ namespace Hanabi {
clue(); clue();
const probability_t probability_stall = backtrack(depth + 1); const probability_t probability_stall = backtrack(depth + 1);
revert_clue(); revert_clue();
UPDATE_PROBABILITY(probability_stall); best_probability = std::max(best_probability, probability_stall);
if (best_probability == 1) {
update_tablebase(id_of_state, best_probability);
return best_probability;
};
} }
RETURN_PROBABILITY; update_tablebase(id_of_state, best_probability);
return best_probability;
}
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
template<bool play, class Function>
void HanabiState<num_suits, num_players, hand_size>::do_for_each_potential_draw(hand_index_t index, Function f) {
auto do_action = [this, index](){
if constexpr (play) {
return play_and_potentially_update<true>(index);
} else {
return discard_and_potentially_update<true>(index);
}
};
auto revert_action = [this](){
if constexpr (play) {
revert_play();
} else {
revert_discard();
}
};
if(_draw_pile.empty()) {
do_action();
f(1);
revert_action();
} else {
unsigned sum_of_multiplicities;
for(size_t i = 0; i < _draw_pile.size(); i++) {
const unsigned long multiplicity = do_action();
sum_of_multiplicities += multiplicity;
f(multiplicity);
revert_action();
}
ASSERT(sum_of_multiplicities == _weighted_draw_pile_size);
}
} }
template<suit_t num_suits, player_t num_players, hand_index_t hand_size> template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
@ -597,4 +615,14 @@ namespace Hanabi {
return _weighted_draw_pile_size; return _weighted_draw_pile_size;
} }
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
void HanabiState<num_suits, num_players, hand_size>::update_tablebase(
unsigned long id,
Hanabi::probability_t probability) {
if (_position_tablebase.contains(id)) {
ASSERT(_position_tablebase[id] == probability);
}
_position_tablebase[id] = probability;
}
} // namespace Hanabi } // namespace Hanabi