simplify backtrack interface by storing actions in stack

This commit is contained in:
Maximilian Keßler 2023-08-10 12:06:13 +02:00
parent 24aa016d36
commit a4ee7ace1d
Signed by: max
GPG key ID: BCC5A619923C0BA5
2 changed files with 101 additions and 61 deletions

View file

@ -2,6 +2,7 @@
#define DYNAMIC_PROGRAM_GAME_STATE_H #define DYNAMIC_PROGRAM_GAME_STATE_H
#include <array> #include <array>
#include <stack>
#include <cstdint> #include <cstdint>
#include <algorithm> #include <algorithm>
#include <cstddef> #include <cstddef>
@ -80,6 +81,7 @@ constexpr Card y1 = {1, 1};
constexpr Card y2 = {1, 2}; constexpr Card y2 = {1, 2};
constexpr Card y3 = {1, 3}; constexpr Card y3 = {1, 3};
constexpr Card y4 = {1, 4}; constexpr Card y4 = {1, 4};
constexpr Card unknown_card = {0, 6};
/** /**
* To store: * To store:
@ -144,21 +146,29 @@ enum class ActionType {
}; };
struct BacktrackAction { struct BacktrackAction {
// The card that was discarded or played BacktrackAction(ActionType action_type, Card discarded_or_played, hand_index_t index);
Card discarded{};
// Index of card in hand that was discarded or played ActionType action_type{};
hand_index_t index{}; // The card that was discarded or played
// Multiplicity of new draw (needed for probability calculations) Card discarded{};
hand_index_t multiplicity{}; // Index of card in hand that was discarded or played
hand_index_t index{};
}; };
/** Would like to have 2 versions:
* All:
* - support playing cards, querying basic information
* - support going back, but with a different interface: efficient (needs arguments, does not store) or using a stack
*
*/
class HanabiStateIF { class HanabiStateIF {
public: public:
virtual probability_t backtrack(size_t depth) = 0; virtual probability_t backtrack(size_t depth) = 0;
virtual void clue() = 0; virtual void clue() = 0;
virtual BacktrackAction discard(hand_index_t index) = 0; virtual void discard(hand_index_t index) = 0;
virtual BacktrackAction play(hand_index_t index) = 0; virtual void play(hand_index_t index) = 0;
[[nodiscard]] virtual hand_index_t find_card_in_hand(const Card& card) const = 0; [[nodiscard]] virtual hand_index_t find_card_in_hand(const Card& card) const = 0;
[[nodiscard]] virtual bool is_trash(const Card& card) const = 0; [[nodiscard]] virtual bool is_trash(const Card& card) const = 0;
@ -187,12 +197,12 @@ public:
probability_t backtrack(size_t depth) final; probability_t backtrack(size_t depth) final;
void clue() final; void clue() final;
BacktrackAction play(hand_index_t index) final; void play(hand_index_t index) final;
BacktrackAction discard(hand_index_t index) final; void discard(hand_index_t index) final;
void revert_clue(); void revert_clue();
void revert_play(const BacktrackAction &action, bool was_on_8_clues); void revert_play(bool was_on_8_clues);
void revert_discard(const BacktrackAction &action); void revert_discard();
[[nodiscard]] hand_index_t find_card_in_hand(const Card& card) const final; [[nodiscard]] hand_index_t find_card_in_hand(const Card& card) const final;
[[nodiscard]] bool is_trash(const Card& card) const final; [[nodiscard]] bool is_trash(const Card& card) const final;
@ -210,10 +220,10 @@ protected:
void print(std::ostream& os) const final; void print(std::ostream& os) const final;
private: private:
template<bool update_card_positions> BacktrackAction play_and_potentially_update(hand_index_t index); template<bool update_card_positions> unsigned long play_and_potentially_update(hand_index_t index);
template<bool update_card_positions> BacktrackAction discard_and_potentially_update(hand_index_t index); template<bool update_card_positions> unsigned long discard_and_potentially_update(hand_index_t index);
template<bool update_card_positions> hand_index_t draw(hand_index_t index); template<bool update_card_positions> unsigned long draw(hand_index_t index);
void revert_draw(hand_index_t index, Card discarded_card); void revert_draw(hand_index_t index, Card discarded_card);
void incr_turn(); void incr_turn();
@ -241,8 +251,8 @@ private:
// This will indicate whether cards that were in hands initially still are in hands // This will indicate whether cards that were in hands initially still are in hands
std::bitset<num_players * hand_size> _card_positions_hands; std::bitset<num_players * hand_size> _card_positions_hands;
size_t _num_useful_cards_in_starting_hands; size_t _num_useful_cards_in_starting_hands{};
size_t _initial_draw_pile_size; size_t _initial_draw_pile_size{};
// further statistics that we might want to keep track of // further statistics that we might want to keep track of
int8_t _pace{}; int8_t _pace{};
@ -251,6 +261,8 @@ private:
std::uint64_t _enumerated_states {}; std::uint64_t _enumerated_states {};
std::unordered_map<unsigned long, probability_t> _position_tablebase; std::unordered_map<unsigned long, probability_t> _position_tablebase;
std::stack<BacktrackAction> _actions_log;
}; };
template <std::size_t num_suits, player_t num_players, std::size_t hand_size> template <std::size_t num_suits, player_t num_players, std::size_t hand_size>

View file

@ -61,6 +61,14 @@ namespace Hanabi {
return _array[card.suit][card.rank]; return _array[card.suit][card.rank];
}; };
BacktrackAction::BacktrackAction(
Hanabi::ActionType action_type, Hanabi::Card discarded_or_played, Hanabi::hand_index_t index
):
action_type(action_type),
discarded(discarded_or_played),
index(index) {
}
template<suit_t num_suits, player_t num_players, hand_index_t hand_size> template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
HanabiState<num_suits, num_players, hand_size>::HanabiState(const std::vector<Card> &deck): HanabiState<num_suits, num_players, hand_size>::HanabiState(const std::vector<Card> &deck):
_turn(0), _turn(0),
@ -95,6 +103,7 @@ namespace Hanabi {
ASSERT(_num_clues > 0); ASSERT(_num_clues > 0);
--_num_clues; --_num_clues;
_actions_log.emplace(ActionType::clue, unknown_card, 0);
incr_turn(); incr_turn();
} }
@ -130,56 +139,58 @@ namespace Hanabi {
} }
template<suit_t num_suits, player_t num_players, hand_index_t hand_size> template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
BacktrackAction HanabiState<num_suits, num_players, hand_size>::play(Hanabi::hand_index_t index) { void HanabiState<num_suits, num_players, hand_size>::play(Hanabi::hand_index_t index) {
const Card card = _hands[_turn][index]; const Card card = _hands[_turn][index];
if (!is_playable(card)) { if (!is_playable(card)) {
BacktrackAction ret{card, index, draw<false>(index)}; draw<false>(index);
incr_turn(); incr_turn();
return ret; return;
} }
return play_and_potentially_update<false>(index); play_and_potentially_update<false>(index);
} }
template<suit_t num_suits, player_t num_players, hand_index_t hand_size> template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
template<bool update_card_positions> template<bool update_card_positions>
BacktrackAction HanabiState<num_suits, num_players, hand_size>::play_and_potentially_update(hand_index_t index) { unsigned long HanabiState<num_suits, num_players, hand_size>::play_and_potentially_update(hand_index_t index) {
ASSERT(index < _hands[_turn].size()); ASSERT(index < _hands[_turn].size());
const Card card = _hands[_turn][index]; const Card played_card = _hands[_turn][index];
ASSERT(is_playable(card)); ASSERT(is_playable(played_card));
--_stacks[card.suit]; --_stacks[played_card.suit];
_score++; _score++;
if (card.rank == 0 and _num_clues < max_num_clues) { if (played_card.rank == 0 and _num_clues < max_num_clues) {
// update clues if we played the last card of a stack // update clues if we played the last played_card of a stack
_num_clues++; _num_clues++;
} }
BacktrackAction ret{card, index, draw<update_card_positions>(index)}; unsigned long multiplicity = draw<update_card_positions>(index);
_actions_log.emplace(ActionType::play, played_card, index);
incr_turn(); incr_turn();
return ret; return multiplicity;
} }
template<suit_t num_suits, player_t num_players, hand_index_t hand_size> template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
BacktrackAction HanabiState<num_suits, num_players, hand_size>::discard(std::uint8_t index) { void HanabiState<num_suits, num_players, hand_size>::discard(std::uint8_t index) {
return discard_and_potentially_update<false>(index); discard_and_potentially_update<false>(index);
} }
template<suit_t num_suits, player_t num_players, hand_index_t hand_size> template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
template<bool update_card_positions> template<bool update_card_positions>
BacktrackAction HanabiState<num_suits, num_players, hand_size>::discard_and_potentially_update(hand_index_t index) { unsigned long HanabiState<num_suits, num_players, hand_size>::discard_and_potentially_update(hand_index_t index) {
ASSERT(index < _hands[_turn].size()); ASSERT(index < _hands[_turn].size());
ASSERT(_num_clues != max_num_clues); ASSERT(_num_clues != max_num_clues);
const Card discarded = _hands[_turn][index]; const Card discarded_card = _hands[_turn][index];
_num_clues++; _num_clues++;
_pace--; _pace--;
BacktrackAction ret{discarded, index, draw<update_card_positions>(index)}; unsigned long multiplicity = draw<update_card_positions>(index);
_actions_log.emplace(ActionType::discard, discarded_card, index);
incr_turn(); incr_turn();
return ret; return multiplicity;
} }
template<suit_t num_suits, player_t num_players, hand_index_t hand_size> template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
@ -217,7 +228,7 @@ namespace Hanabi {
template<suit_t num_suits, player_t num_players, hand_index_t hand_size> template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
template<bool update_card_positions> template<bool update_card_positions>
std::uint8_t HanabiState<num_suits, num_players, hand_size>::draw(uint8_t index) { unsigned long HanabiState<num_suits, num_players, hand_size>::draw(uint8_t index) {
ASSERT(index < _hands[_turn].size()); ASSERT(index < _hands[_turn].size());
// update card position of the card we are about to discard // update card position of the card we are about to discard
@ -377,45 +388,62 @@ namespace Hanabi {
template<suit_t num_suits, player_t num_players, hand_index_t hand_size> template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
void void
HanabiState<num_suits, num_players, hand_size>::revert_play(const BacktrackAction &action, bool was_on_8_clues) { HanabiState<num_suits, num_players, hand_size>::revert_play(bool was_on_8_clues) {
const BacktrackAction last_action = _actions_log.top();
_actions_log.pop();
ASSERT(last_action.action_type == ActionType::play);
ASSERT(!was_on_8_clues or _num_clues == 8); ASSERT(!was_on_8_clues or _num_clues == 8);
decr_turn(); decr_turn();
if (action.discarded.rank == 0 and not was_on_8_clues) { if (last_action.discarded.rank == 0 and not was_on_8_clues) {
_num_clues--; _num_clues--;
} }
revert_draw(action.index, action.discarded); revert_draw(last_action.index, last_action.discarded);
_stacks[action.discarded.suit]++; _stacks[last_action.discarded.suit]++;
_score--; _score--;
} }
template<suit_t num_suits, player_t num_players, hand_index_t hand_size> template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
void HanabiState<num_suits, num_players, hand_size>::revert_discard(const BacktrackAction &action) { void HanabiState<num_suits, num_players, hand_size>::revert_discard() {
const BacktrackAction last_action = _actions_log.top();
_actions_log.pop();
ASSERT(last_action.action_type == ActionType::discard);
decr_turn(); decr_turn();
ASSERT(_num_clues > 0); ASSERT(_num_clues > 0);
_num_clues--; _num_clues--;
_pace++; _pace++;
revert_draw(action.index, action.discarded);
revert_draw(last_action.index, last_action.discarded);
} }
template<suit_t num_suits, player_t num_players, hand_index_t hand_size> template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
void HanabiState<num_suits, num_players, hand_size>::revert_clue() { void HanabiState<num_suits, num_players, hand_size>::revert_clue() {
const BacktrackAction last_action = _actions_log.top();
_actions_log.pop();
ASSERT(last_action.action_type == ActionType::clue);
decr_turn(); decr_turn();
ASSERT(_num_clues < max_num_clues); ASSERT(_num_clues < max_num_clues);
_num_clues++; _num_clues++;
} }
#define RETURN_PROBABILITY \ #define RETURN_PROBABILITY \
if (_position_tablebase.contains(id_of_state)) { \ if (_position_tablebase.contains(id_of_state)) { \
ASSERT(_position_tablebase[id_of_state] == best_probability); \ ASSERT(_position_tablebase[id_of_state] == best_probability); \
} \ } \
_position_tablebase[id_of_state] = best_probability; \ _position_tablebase[id_of_state] = best_probability; \
\ \
return best_probability; return best_probability;
#define UPDATE_PROBABILITY(new_probability) \ #define UPDATE_PROBABILITY(new_probability) \
best_probability = std::max(best_probability, new_probability); \ best_probability = std::max(best_probability, new_probability); \
if (best_probability == 1) { \ if (best_probability == 1) { \
RETURN_PROBABILITY; \ RETURN_PROBABILITY; \
} }
template<suit_t num_suits, player_t num_players, hand_index_t hand_size> template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
@ -444,19 +472,19 @@ namespace Hanabi {
if(is_playable(hand[index])) { if(is_playable(hand[index])) {
if (_draw_pile.empty()) { if (_draw_pile.empty()) {
bool on_8_clues = _num_clues == 8; bool on_8_clues = _num_clues == 8;
BacktrackAction action = play_and_potentially_update<true>(index); play_and_potentially_update<true>(index);
const probability_t probability_for_this_play = backtrack(depth + 1); const probability_t probability_for_this_play = backtrack(depth + 1);
revert_play(action, on_8_clues); revert_play(on_8_clues);
UPDATE_PROBABILITY(probability_for_this_play); UPDATE_PROBABILITY(probability_for_this_play);
} else { } else {
probability_t sum_of_probabilities = 0; probability_t sum_of_probabilities = 0;
uint8_t sum_of_mults = 0; uint8_t sum_of_mults = 0;
for (size_t i = 0; i < _draw_pile.size(); i++) { for (size_t i = 0; i < _draw_pile.size(); i++) {
bool on_8_clues = _num_clues == 8; bool on_8_clues = _num_clues == 8;
BacktrackAction action = play_and_potentially_update<true>(index); const unsigned long multiplicity = play_and_potentially_update<true>(index);
sum_of_probabilities += backtrack(depth + 1) * action.multiplicity; sum_of_probabilities += backtrack(depth + 1) * multiplicity;
sum_of_mults += action.multiplicity; sum_of_mults += multiplicity;
revert_play(action, on_8_clues); revert_play(on_8_clues);
ASSERT(sum_of_mults <= _weighted_draw_pile_size); ASSERT(sum_of_mults <= _weighted_draw_pile_size);
} }
ASSERT(sum_of_mults == _weighted_draw_pile_size); ASSERT(sum_of_mults == _weighted_draw_pile_size);
@ -472,17 +500,17 @@ namespace Hanabi {
if (is_trash(hand[index])) { if (is_trash(hand[index])) {
probability_t sum_of_probabilities = 0; probability_t sum_of_probabilities = 0;
if (_draw_pile.empty()) { if (_draw_pile.empty()) {
BacktrackAction action = discard_and_potentially_update<true>(index); discard_and_potentially_update<true>(index);
const probability_t probability_for_this_discard = backtrack(depth + 1); const probability_t probability_for_this_discard = backtrack(depth + 1);
revert_discard(action); revert_discard();
UPDATE_PROBABILITY(probability_for_this_discard); UPDATE_PROBABILITY(probability_for_this_discard);
} else { } else {
uint8_t sum_of_mults = 0; uint8_t sum_of_mults = 0;
for (size_t i = 0; i < _draw_pile.size(); i++) { for (size_t i = 0; i < _draw_pile.size(); i++) {
BacktrackAction action = discard_and_potentially_update<true>(index); const unsigned long multiplicity = discard_and_potentially_update<true>(index);
sum_of_probabilities += backtrack(depth + 1) * action.multiplicity; sum_of_probabilities += backtrack(depth + 1) * multiplicity;
sum_of_mults += action.multiplicity; sum_of_mults += multiplicity;
revert_discard(action); revert_discard();
} }
ASSERT(sum_of_mults == _weighted_draw_pile_size); ASSERT(sum_of_mults == _weighted_draw_pile_size);
const probability_t probability_discard = sum_of_probabilities / _weighted_draw_pile_size; const probability_t probability_discard = sum_of_probabilities / _weighted_draw_pile_size;