refactor backtracking function
This commit is contained in:
parent
17405b0f00
commit
44db744ae3
3 changed files with 9 additions and 9 deletions
|
@ -173,7 +173,7 @@ struct BacktrackAction {
|
||||||
|
|
||||||
class HanabiStateIF {
|
class HanabiStateIF {
|
||||||
public:
|
public:
|
||||||
virtual probability_t backtrack(size_t depth) = 0;
|
virtual probability_t evaluate_state() = 0;
|
||||||
|
|
||||||
virtual void clue() = 0;
|
virtual void clue() = 0;
|
||||||
virtual void discard(hand_index_t index) = 0;
|
virtual void discard(hand_index_t index) = 0;
|
||||||
|
@ -203,7 +203,7 @@ public:
|
||||||
HanabiState() = default;
|
HanabiState() = default;
|
||||||
explicit HanabiState(const std::vector<Card>& deck);
|
explicit HanabiState(const std::vector<Card>& deck);
|
||||||
|
|
||||||
probability_t backtrack(size_t depth) final;
|
probability_t evaluate_state() final;
|
||||||
|
|
||||||
void clue() final;
|
void clue() final;
|
||||||
void play(hand_index_t index) final;
|
void play(hand_index_t index) final;
|
||||||
|
|
|
@ -437,7 +437,7 @@ namespace Hanabi {
|
||||||
|
|
||||||
|
|
||||||
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
|
template<suit_t num_suits, player_t num_players, hand_index_t hand_size>
|
||||||
probability_t HanabiState<num_suits, num_players, hand_size>::backtrack(size_t depth) {
|
probability_t HanabiState<num_suits, num_players, hand_size>::evaluate_state() {
|
||||||
_enumerated_states++;
|
_enumerated_states++;
|
||||||
const unsigned long id_of_state = unique_id();
|
const unsigned long id_of_state = unique_id();
|
||||||
|
|
||||||
|
@ -462,7 +462,7 @@ namespace Hanabi {
|
||||||
probability_t sum_of_probabilities = 0;
|
probability_t sum_of_probabilities = 0;
|
||||||
|
|
||||||
do_for_each_potential_draw<true>(index, [this, &sum_of_probabilities](const unsigned long multiplicity){
|
do_for_each_potential_draw<true>(index, [this, &sum_of_probabilities](const unsigned long multiplicity){
|
||||||
sum_of_probabilities += backtrack(0) * multiplicity;
|
sum_of_probabilities += evaluate_state() * multiplicity;
|
||||||
});
|
});
|
||||||
|
|
||||||
const unsigned long total_weight = std::max(static_cast<unsigned long>(_weighted_draw_pile_size), 1ul);
|
const unsigned long total_weight = std::max(static_cast<unsigned long>(_weighted_draw_pile_size), 1ul);
|
||||||
|
@ -483,7 +483,7 @@ namespace Hanabi {
|
||||||
probability_t sum_of_probabilities = 0;
|
probability_t sum_of_probabilities = 0;
|
||||||
|
|
||||||
do_for_each_potential_draw<false>(index, [this, &sum_of_probabilities](const unsigned long multiplicity){
|
do_for_each_potential_draw<false>(index, [this, &sum_of_probabilities](const unsigned long multiplicity){
|
||||||
sum_of_probabilities += backtrack(0) * multiplicity;
|
sum_of_probabilities += evaluate_state() * multiplicity;
|
||||||
});
|
});
|
||||||
|
|
||||||
const unsigned long total_weight = std::max(static_cast<unsigned long>(_weighted_draw_pile_size), 1ul);
|
const unsigned long total_weight = std::max(static_cast<unsigned long>(_weighted_draw_pile_size), 1ul);
|
||||||
|
@ -505,7 +505,7 @@ namespace Hanabi {
|
||||||
// Last option is to stall
|
// Last option is to stall
|
||||||
if(_num_clues > 0) {
|
if(_num_clues > 0) {
|
||||||
clue();
|
clue();
|
||||||
const probability_t probability_stall = backtrack(depth + 1);
|
const probability_t probability_stall = evaluate_state();
|
||||||
revert_clue();
|
revert_clue();
|
||||||
best_probability = std::max(best_probability, probability_stall);
|
best_probability = std::max(best_probability, probability_stall);
|
||||||
if (best_probability == 1) {
|
if (best_probability == 1) {
|
||||||
|
|
6
main.cpp
6
main.cpp
|
@ -17,7 +17,7 @@ namespace Hanabi {
|
||||||
void download(std::variant<int, const char*> game_id, int turn) {
|
void download(std::variant<int, const char*> game_id, int turn) {
|
||||||
auto game = Download::get_game(game_id, turn);
|
auto game = Download::get_game(game_id, turn);
|
||||||
std::cout << "Analysing state: " << std::endl << *game << std::endl;
|
std::cout << "Analysing state: " << std::endl << *game << std::endl;
|
||||||
auto res = game->backtrack(1);
|
auto res = game->evaluate_state();
|
||||||
std::cout.precision(10);
|
std::cout.precision(10);
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
std::cout << "Probability with optimal play: " << res << std::endl;
|
std::cout << "Probability with optimal play: " << res << std::endl;
|
||||||
|
@ -52,7 +52,7 @@ namespace Hanabi {
|
||||||
void test() {
|
void test() {
|
||||||
{
|
{
|
||||||
auto game = Download::get_game("in/1005195", 43);
|
auto game = Download::get_game("in/1005195", 43);
|
||||||
auto res = game->backtrack(1);
|
auto res = game->evaluate_state();
|
||||||
CHECK("1005195", res == Hanabi::probability_t (7,8));
|
CHECK("1005195", res == Hanabi::probability_t (7,8));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -68,7 +68,7 @@ void check_games(unsigned num_players, unsigned max_draw_pile_size, unsigned fir
|
||||||
for(size_t game_id = first_game; game_id <= last_game; game_id++) {
|
for(size_t game_id = first_game; game_id <= last_game; game_id++) {
|
||||||
const std::string input_fname = "json/" + std::to_string(num_players) + "p/" + std::to_string(game_id) + ".json";
|
const std::string input_fname = "json/" + std::to_string(num_players) + "p/" + std::to_string(game_id) + ".json";
|
||||||
auto game = Download::get_game(input_fname.c_str(), 50, draw_pile_size);
|
auto game = Download::get_game(input_fname.c_str(), 50, draw_pile_size);
|
||||||
const Hanabi::probability_t chance = game->backtrack(0);
|
const Hanabi::probability_t chance = game->evaluate_state();
|
||||||
winning_percentages[game_id].push_back(chance);
|
winning_percentages[game_id].push_back(chance);
|
||||||
if(chance != 1) {
|
if(chance != 1) {
|
||||||
file << "Game " << game_id << ": " << chance << std::endl;
|
file << "Game " << game_id << ": " << chance << std::endl;
|
||||||
|
|
Loading…
Reference in a new issue