Rework probability printing

Outputting of probabilities is now done in unified way.
This ensures that we can handle both rational and floating-point
probabilities in all program parts properly.

The compile-time macro
NUSE_RATIONAL_PROBABILITIES
can now be defined to use floating-point probabilities instead
of real rational ones.

Using floating-point results in roughly 10% speedup.
This commit is contained in:
Maximilian Keßler 2023-10-02 12:23:55 +02:00
parent 0cc069485d
commit b69cb6d974
Signed by: max
GPG key ID: BCC5A619923C0BA5
4 changed files with 57 additions and 15 deletions

View file

@ -23,7 +23,28 @@ namespace Hanabi {
using clue_t = std::uint8_t; using clue_t = std::uint8_t;
using player_t = std::uint8_t; using player_t = std::uint8_t;
using hand_index_t = std::uint8_t; using hand_index_t = std::uint8_t;
using probability_t = boost::rational<unsigned long>;
using probability_base_type = unsigned long;
using rational_probability = boost::rational<probability_base_type>;
/**
* Define macro
* NUSE_RATIONAL_PROBABILITIES
* to use floating-point arithematic for the stored probabilities
* instead of rational representations
*/
#ifndef NUSE_RATIONAL_PROBABILITIES
using probability_t = boost::rational<probability_base_type>;
#else
using probability_t = double;
#endif
inline std::ostream& print_probability(std::ostream& os, double prob);
inline std::ostream& print_probability(std::ostream& os, const rational_probability& prob);
template<typename T>
std::ostream& print_probability(std::ostream& os, const std::optional<T>& prob);
/** /**
* We will generally assume that stacks are played from n to 0 * We will generally assume that stacks are played from n to 0

View file

@ -5,6 +5,26 @@
namespace Hanabi { namespace Hanabi {
template<typename T>
std::ostream& print_probability(std::ostream& os, const std::optional<T>& prob) {
if (prob.has_value()) {
return print_probability(os, prob.value());
} else {
os << "unknown";
}
return os;
}
std::ostream& print_probability(std::ostream& os, const rational_probability & prob) {
os << prob << " ~ " << std::setprecision(5) << boost::rational_cast<double>(prob) * 100 << "%";
return os;
}
std::ostream& print_probability(std::ostream& os, double prob) {
os << std::setprecision(5) << prob;
return os;
}
std::ostream &operator<<(std::ostream &os, HanabiStateIF const &hanabi_state) { std::ostream &operator<<(std::ostream &os, HanabiStateIF const &hanabi_state) {
hanabi_state.print(os); hanabi_state.print(os);
return os; return os;

View file

@ -11,17 +11,10 @@
#include <memory> #include <memory>
#include <cmath> #include <cmath>
#include "game_state.h" #include "game_state.h"
#include <myassert.h>
namespace Hanabi { namespace Hanabi {
std::ostream& operator<<(std::ostream& os, const std::optional<probability_t>& prob) {
if (prob.has_value()) {
os << prob.value() << " ~ " << std::setprecision(5) << boost::rational_cast<double>(prob.value()) * 100 << "%";
} else {
os << "unknown";
}
return os;
}
std::string read_line_memory_safe(const char *prompt) { std::string read_line_memory_safe(const char *prompt) {
char *line = readline(prompt); char *line = readline(prompt);
@ -89,9 +82,13 @@ namespace Hanabi {
} }
} }
int representation_length(probability_t probability) { int representation_length(const rational_probability& probability) {
return 1 + static_cast<int>(std::ceil(std::log10(probability.denominator()))) + \ return 1 + static_cast<int>(std::ceil(std::log10(probability.denominator()))) + \
static_cast<int>(std::ceil(std::log10(probability.numerator()))); static_cast<int>(std::ceil(std::log10(probability.numerator())));
}
int representation_length(const double probability) {
return static_cast<int>(std::ceil(std::log10(probability)));
} }
bool ask_for_card_and_rotate_draw(const std::shared_ptr<HanabiStateIF>& game, hand_index_t index, bool play) { bool ask_for_card_and_rotate_draw(const std::shared_ptr<HanabiStateIF>& game, hand_index_t index, bool play) {
@ -132,7 +129,8 @@ namespace Hanabi {
for (const auto &[card_multiplicity, probability]: states_to_show) { for (const auto &[card_multiplicity, probability]: states_to_show) {
std::cout << card_multiplicity.card << " (" << card_multiplicity.multiplicity; std::cout << card_multiplicity.card << " (" << card_multiplicity.multiplicity;
std::cout << " copie(s) in draw) " << std::setw(max_rational_digit_len) << probability << std::endl; std::cout << " copie(s) in draw) " << std::setw(max_rational_digit_len);
print_probability(std::cout, probability) << std::endl;
} }
std::stringstream prompt; std::stringstream prompt;
@ -235,7 +233,8 @@ namespace Hanabi {
if (prompt.starts_with("state")) { if (prompt.starts_with("state")) {
std::cout << *game << std::endl; std::cout << *game << std::endl;
const std::optional<probability_t> prob = game->lookup(); const std::optional<probability_t> prob = game->lookup();
std::cout << "Winning chance: " << prob << std::endl; std::cout << "Winning chance: ";
print_probability(std::cout, prob) << std::endl;
continue; continue;
} }
@ -349,7 +348,8 @@ namespace Hanabi {
std::cout.setf(std::ios_base::left, std::ios_base::adjustfield); std::cout.setf(std::ios_base::left, std::ios_base::adjustfield);
std::cout << std::setw(7) << action << ": "; std::cout << std::setw(7) << action << ": ";
std::cout.setf(std::ios_base::right, std::ios_base::adjustfield); std::cout.setf(std::ios_base::right, std::ios_base::adjustfield);
std::cout << std::setw(max_rational_digit_len) << probability << std::endl; std::cout << std::setw(max_rational_digit_len);
print_probability(std::cout, probability) << std::endl;
} }
if(reasonable_actions.empty()) { if(reasonable_actions.empty()) {
std::cout << "Game is over, no actions to take." << std::endl; std::cout << "Game is over, no actions to take." << std::endl;

View file

@ -27,7 +27,8 @@ namespace Hanabi {
std::cout.precision(10); std::cout.precision(10);
std::cout << std::endl; std::cout << std::endl;
std::cout << "Probability with optimal play: " << res << " ~ " << std::setprecision(5) << boost::rational_cast<double>(res) * 100 << "%" << std::endl; std::cout << "Probability with optimal play: ";
print_probability(std::cout, res) << std::endl;
std::cout << "Took " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start) << "." << std::endl; std::cout << "Took " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start) << "." << std::endl;
std::cout << "Visited " << game->enumerated_states() << " states." << std::endl; std::cout << "Visited " << game->enumerated_states() << " states." << std::endl;
std::cout << "Enumerated " << game->position_tablebase().size() << " unique game states. " << std::endl; std::cout << "Enumerated " << game->position_tablebase().size() << " unique game states. " << std::endl;