refactor backtracking function

This commit is contained in:
Maximilian Keßler 2023-08-10 18:27:25 +02:00
parent 17405b0f00
commit 44db744ae3
Signed by: max
GPG key ID: BCC5A619923C0BA5
3 changed files with 9 additions and 9 deletions

View file

@ -173,7 +173,7 @@ struct BacktrackAction {
class HanabiStateIF { class HanabiStateIF {
public: public:
virtual probability_t backtrack(size_t depth) = 0; virtual probability_t evaluate_state() = 0;
virtual void clue() = 0; virtual void clue() = 0;
virtual void discard(hand_index_t index) = 0; virtual void discard(hand_index_t index) = 0;
@ -203,7 +203,7 @@ public:
HanabiState() = default; HanabiState() = default;
explicit HanabiState(const std::vector<Card>& deck); explicit HanabiState(const std::vector<Card>& deck);
probability_t backtrack(size_t depth) final; probability_t evaluate_state() final;
void clue() final; void clue() final;
void play(hand_index_t index) final; void play(hand_index_t index) final;

View file

@ -437,7 +437,7 @@ namespace Hanabi {
template<suit_t num_suits, player_t num_players, hand_index_t hand_size> template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
probability_t HanabiState<num_suits, num_players, hand_size>::backtrack(size_t depth) { probability_t HanabiState<num_suits, num_players, hand_size>::evaluate_state() {
_enumerated_states++; _enumerated_states++;
const unsigned long id_of_state = unique_id(); const unsigned long id_of_state = unique_id();
@ -462,7 +462,7 @@ namespace Hanabi {
probability_t sum_of_probabilities = 0; probability_t sum_of_probabilities = 0;
do_for_each_potential_draw<true>(index, [this, &sum_of_probabilities](const unsigned long multiplicity){ do_for_each_potential_draw<true>(index, [this, &sum_of_probabilities](const unsigned long multiplicity){
sum_of_probabilities += backtrack(0) * multiplicity; sum_of_probabilities += evaluate_state() * multiplicity;
}); });
const unsigned long total_weight = std::max(static_cast<unsigned long>(_weighted_draw_pile_size), 1ul); const unsigned long total_weight = std::max(static_cast<unsigned long>(_weighted_draw_pile_size), 1ul);
@ -483,7 +483,7 @@ namespace Hanabi {
probability_t sum_of_probabilities = 0; probability_t sum_of_probabilities = 0;
do_for_each_potential_draw<false>(index, [this, &sum_of_probabilities](const unsigned long multiplicity){ do_for_each_potential_draw<false>(index, [this, &sum_of_probabilities](const unsigned long multiplicity){
sum_of_probabilities += backtrack(0) * multiplicity; sum_of_probabilities += evaluate_state() * multiplicity;
}); });
const unsigned long total_weight = std::max(static_cast<unsigned long>(_weighted_draw_pile_size), 1ul); const unsigned long total_weight = std::max(static_cast<unsigned long>(_weighted_draw_pile_size), 1ul);
@ -505,7 +505,7 @@ namespace Hanabi {
// Last option is to stall // Last option is to stall
if(_num_clues > 0) { if(_num_clues > 0) {
clue(); clue();
const probability_t probability_stall = backtrack(depth + 1); const probability_t probability_stall = evaluate_state();
revert_clue(); revert_clue();
best_probability = std::max(best_probability, probability_stall); best_probability = std::max(best_probability, probability_stall);
if (best_probability == 1) { if (best_probability == 1) {

View file

@ -17,7 +17,7 @@ namespace Hanabi {
void download(std::variant<int, const char*> game_id, int turn) { void download(std::variant<int, const char*> game_id, int turn) {
auto game = Download::get_game(game_id, turn); auto game = Download::get_game(game_id, turn);
std::cout << "Analysing state: " << std::endl << *game << std::endl; std::cout << "Analysing state: " << std::endl << *game << std::endl;
auto res = game->backtrack(1); auto res = game->evaluate_state();
std::cout.precision(10); std::cout.precision(10);
std::cout << std::endl; std::cout << std::endl;
std::cout << "Probability with optimal play: " << res << std::endl; std::cout << "Probability with optimal play: " << res << std::endl;
@ -52,7 +52,7 @@ namespace Hanabi {
void test() { void test() {
{ {
auto game = Download::get_game("in/1005195", 43); auto game = Download::get_game("in/1005195", 43);
auto res = game->backtrack(1); auto res = game->evaluate_state();
CHECK("1005195", res == Hanabi::probability_t (7,8)); CHECK("1005195", res == Hanabi::probability_t (7,8));
} }
} }
@ -68,7 +68,7 @@ void check_games(unsigned num_players, unsigned max_draw_pile_size, unsigned fir
for(size_t game_id = first_game; game_id <= last_game; game_id++) { for(size_t game_id = first_game; game_id <= last_game; game_id++) {
const std::string input_fname = "json/" + std::to_string(num_players) + "p/" + std::to_string(game_id) + ".json"; const std::string input_fname = "json/" + std::to_string(num_players) + "p/" + std::to_string(game_id) + ".json";
auto game = Download::get_game(input_fname.c_str(), 50, draw_pile_size); auto game = Download::get_game(input_fname.c_str(), 50, draw_pile_size);
const Hanabi::probability_t chance = game->backtrack(0); const Hanabi::probability_t chance = game->evaluate_state();
winning_percentages[game_id].push_back(chance); winning_percentages[game_id].push_back(chance);
if(chance != 1) { if(chance != 1) {
file << "Game " << game_id << ": " << chance << std::endl; file << "Game " << game_id << ": " << chance << std::endl;