code cleanup

- unified function to iterate over all possible draws when discarding or
  playing
- remove macro usages
This commit is contained in:
Maximilian Keßler 2023-08-10 18:23:33 +02:00
parent 78b3335414
commit 17405b0f00
Signed by: max
GPG Key ID: BCC5A619923C0BA5
2 changed files with 87 additions and 54 deletions

View File

@ -235,6 +235,11 @@ private:
template<bool update_card_positions> unsigned long draw(hand_index_t index);
void revert_draw(hand_index_t index, Card discarded_card);
void update_tablebase(unsigned long id, probability_t probability);
template<bool play, class Function>
void do_for_each_potential_draw(hand_index_t index, Function f);
void incr_turn();
void decr_turn();

View File

@ -280,7 +280,7 @@ namespace Hanabi {
}
return draw.multiplicity;
}
return 0;
return 1;
}
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
@ -435,19 +435,6 @@ namespace Hanabi {
_num_clues++;
}
#define RETURN_PROBABILITY \
if (_position_tablebase.contains(id_of_state)) { \
ASSERT(_position_tablebase[id_of_state] == best_probability); \
} \
_position_tablebase[id_of_state] = best_probability; \
\
return best_probability;
#define UPDATE_PROBABILITY(new_probability) \
best_probability = std::max(best_probability, new_probability); \
if (best_probability == 1) { \
RETURN_PROBABILITY; \
}
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
probability_t HanabiState<num_suits, num_players, hand_size>::backtrack(size_t depth) {
@ -466,32 +453,26 @@ namespace Hanabi {
// TODO: Have some endgame analysis here?
// First, check if we have any playable cards
probability_t best_probability = 0;
const std::array<Card, hand_size> hand = _hands[_turn];
const std::array<Card, hand_size>& hand = _hands[_turn];
// First, check for playables
for(std::uint8_t index = 0; index < hand_size; index++) {
if(is_playable(hand[index])) {
if (_draw_pile.empty()) {
play_and_potentially_update<true>(index);
const probability_t probability_for_this_play = backtrack(depth + 1);
revert_play();
UPDATE_PROBABILITY(probability_for_this_play);
} else {
probability_t sum_of_probabilities = 0;
uint8_t sum_of_mults = 0;
for (size_t i = 0; i < _draw_pile.size(); i++) {
const unsigned long multiplicity = play_and_potentially_update<true>(index);
sum_of_probabilities += backtrack(depth + 1) * multiplicity;
sum_of_mults += multiplicity;
revert_play();
ASSERT(sum_of_mults <= _weighted_draw_pile_size);
}
ASSERT(sum_of_mults == _weighted_draw_pile_size);
const probability_t probability_for_this_play = sum_of_probabilities / _weighted_draw_pile_size;
UPDATE_PROBABILITY(probability_for_this_play);
}
probability_t sum_of_probabilities = 0;
do_for_each_potential_draw<true>(index, [this, &sum_of_probabilities](const unsigned long multiplicity){
sum_of_probabilities += backtrack(0) * multiplicity;
});
const unsigned long total_weight = std::max(static_cast<unsigned long>(_weighted_draw_pile_size), 1ul);
const probability_t probability_play = sum_of_probabilities / total_weight;
best_probability = std::max(best_probability, probability_play);
if (best_probability == 1) {
update_tablebase(id_of_state, best_probability);
return best_probability;
};
}
}
@ -500,23 +481,20 @@ namespace Hanabi {
for(std::uint8_t index = 0; index < hand_size; index++) {
if (is_trash(hand[index])) {
probability_t sum_of_probabilities = 0;
if (_draw_pile.empty()) {
discard_and_potentially_update<true>(index);
const probability_t probability_for_this_discard = backtrack(depth + 1);
revert_discard();
UPDATE_PROBABILITY(probability_for_this_discard);
} else {
uint8_t sum_of_mults = 0;
for (size_t i = 0; i < _draw_pile.size(); i++) {
const unsigned long multiplicity = discard_and_potentially_update<true>(index);
sum_of_probabilities += backtrack(depth + 1) * multiplicity;
sum_of_mults += multiplicity;
revert_discard();
}
ASSERT(sum_of_mults == _weighted_draw_pile_size);
const probability_t probability_discard = sum_of_probabilities / _weighted_draw_pile_size;
UPDATE_PROBABILITY(probability_discard);
}
do_for_each_potential_draw<false>(index, [this, &sum_of_probabilities](const unsigned long multiplicity){
sum_of_probabilities += backtrack(0) * multiplicity;
});
const unsigned long total_weight = std::max(static_cast<unsigned long>(_weighted_draw_pile_size), 1ul);
const probability_t probability_discard = sum_of_probabilities / total_weight;
best_probability = std::max(best_probability, probability_discard);
best_probability = std::max(best_probability, probability_discard);
if (best_probability == 1) {
update_tablebase(id_of_state, best_probability);
return best_probability;
};
// All discards are equivalent, do not continue searching for different trash
break;
@ -529,10 +507,50 @@ namespace Hanabi {
clue();
const probability_t probability_stall = backtrack(depth + 1);
revert_clue();
UPDATE_PROBABILITY(probability_stall);
best_probability = std::max(best_probability, probability_stall);
if (best_probability == 1) {
update_tablebase(id_of_state, best_probability);
return best_probability;
};
}
RETURN_PROBABILITY;
update_tablebase(id_of_state, best_probability);
return best_probability;
}
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
template<bool play, class Function>
void HanabiState<num_suits, num_players, hand_size>::do_for_each_potential_draw(hand_index_t index, Function f) {
auto do_action = [this, index](){
if constexpr (play) {
return play_and_potentially_update<true>(index);
} else {
return discard_and_potentially_update<true>(index);
}
};
auto revert_action = [this](){
if constexpr (play) {
revert_play();
} else {
revert_discard();
}
};
if(_draw_pile.empty()) {
do_action();
f(1);
revert_action();
} else {
unsigned sum_of_multiplicities;
for(size_t i = 0; i < _draw_pile.size(); i++) {
const unsigned long multiplicity = do_action();
sum_of_multiplicities += multiplicity;
f(multiplicity);
revert_action();
}
ASSERT(sum_of_multiplicities == _weighted_draw_pile_size);
}
}
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
@ -597,4 +615,14 @@ namespace Hanabi {
return _weighted_draw_pile_size;
}
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
void HanabiState<num_suits, num_players, hand_size>::update_tablebase(
unsigned long id,
Hanabi::probability_t probability) {
if (_position_tablebase.contains(id)) {
ASSERT(_position_tablebase[id] == probability);
}
_position_tablebase[id] = probability;
}
} // namespace Hanabi