diff --git a/src/hanabi/solvers/sat.py b/src/hanabi/solvers/sat.py index fc6191b..d5485ac 100644 --- a/src/hanabi/solvers/sat.py +++ b/src/hanabi/solvers/sat.py @@ -32,6 +32,15 @@ class Literals(): } } + # progress[m] = i "after move m the next card drawn from the deck has index i" + self.next_draw = { + -1: Int(0) + , **{ + m: Symbol('m{}progress'.format(m), INT) + for m in range(instance.max_winning_moves) + } + } + # strikes[m][i] == "after move m we have at least i strikes" self.strikes = { -1: {i: Bool(i == 0) for i in range(0, instance.num_strikes + 1)} # no strikes when we start @@ -213,6 +222,10 @@ def solve_sat(starting_state: hanab_game.GameState | hanab_game.HanabiInstance, Implies(And(Or(ls.discard_any[m], ls.dummyturn[m]), Not(ls.incr_clues[m])), Equals(ls.clues[m], ls.clues[m - 1])), + # change of progress + Implies(ls.draw_any[m], Equals(ls.next_draw[m], ls.next_draw[m-1] + 1)), + Implies(Not(ls.draw_any[m]), Equals(ls.next_draw[m], ls.next_draw[m-1])), + # change of pace Implies(And(ls.discard_any[m], Or(ls.strike[m], Not(ls.play[m]))), Equals(ls.pace[m], ls.pace[m - 1] - 1)), Implies(Or(Not(ls.discard_any[m]), And(Not(ls.strike[m]), ls.play[m])), Equals(ls.pace[m], ls.pace[m - 1])),