adjust sat solver to handle mid-game states

This commit is contained in:
Maximilian Keßler 2023-03-15 11:13:17 +01:00
parent d71dba523c
commit f0c1f112a0
Signed by: max
GPG key ID: BCC5A619923C0BA5

165
sat.py
View file

@ -4,8 +4,8 @@ import json
from typing import List from typing import List
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from compress import DeckCard, Action, ActionType, link from compress import DeckCard, Action, ActionType, link, decompress_deck
from greedy_solver import GameState from greedy_solver import GameState, GreedyStrategy
COLORS = 'rygbp' COLORS = 'rygbp'
STANDARD_HAND_SIZE = {2: 5, 3: 5, 4: 4, 5: 4, 6: 3} STANDARD_HAND_SIZE = {2: 5, 3: 5, 4: 4, 5: 4, 6: 3}
@ -51,7 +51,7 @@ class Literals():
, **{ , **{
m: { m: {
0: Bool(True), # always at least 0 clues 0: Bool(True), # always at least 0 clues
**{ i: Symbol('m{}c{}'.format(m, i)) for i in range(1, 9) }, **{ i: Symbol('m{}clues{}'.format(m, i)) for i in range(1, 9) },
9: Bool(False) # never 9 or more clues. This will implicitly forbid discarding at 8 clues later 9: Bool(False) # never 9 or more clues. This will implicitly forbid discarding at 8 clues later
} }
for m in range(self.max_moves) for m in range(self.max_moves)
@ -64,7 +64,7 @@ class Literals():
, **{ , **{
m: { m: {
0: Bool(True), 0: Bool(True),
**{ s: Symbol('m{}s{}'.format(m,s)) for s in range(1, self.num_strikes) }, **{ s: Symbol('m{}strikes{}'.format(m,s)) for s in range(1, self.num_strikes) },
self.num_strikes: Bool(False) # never so many clues that we lose. Implicitly forbids striking out self.num_strikes: Bool(False) # never so many clues that we lose. Implicitly forbids striking out
} }
for m in range(self.max_moves) for m in range(self.max_moves)
@ -75,7 +75,7 @@ class Literals():
self.extraround = { self.extraround = {
-1: Bool(False) -1: Bool(False)
, **{ , **{
m: Bool(False) if m < self.draw_pile_size else Symbol('m{}e'.format(m)) # it takes at least as many turns as cards in the draw pile to start the extra round m: Bool(False) if m < self.draw_pile_size else Symbol('m{}extra'.format(m)) # it takes at least as many turns as cards in the draw pile to start the extra round
for m in range(0, self.max_moves) for m in range(0, self.max_moves)
} }
} }
@ -84,14 +84,14 @@ class Literals():
self.dummyturn = { self.dummyturn = {
-1: Bool(False) -1: Bool(False)
, **{ , **{
m: Bool(False) if m < self.draw_pile_size + self.num_players else Symbol('m{}dt'.format(m)) m: Bool(False) if m < self.draw_pile_size + self.num_players else Symbol('m{}dummy'.format(m))
for m in range(0, self.max_moves) for m in range(0, self.max_moves)
} }
} }
# draw[m][i] == "at move m we play/discard deck[i]" # draw[m][i] == "at move m we play/discard deck[i]"
self.discard = { self.discard = {
m: {i: Symbol('m{}-{}'.format(m, i)) for i in range(self.deck_size)} m: {i: Symbol('m{}discard{}'.format(m, i)) for i in range(self.deck_size)}
for m in range(self.max_moves) for m in range(self.max_moves)
} }
@ -101,7 +101,7 @@ class Literals():
, **{ , **{
m: { m: {
self.distributed_cards - 1: Bool(False), self.distributed_cards - 1: Bool(False),
**{i: Symbol('m{}+{}'.format(m, i)) for i in range(self.distributed_cards, self.deck_size)} **{i: Symbol('m{}draw{}'.format(m, i)) for i in range(self.distributed_cards, self.deck_size)}
} }
for m in range(self.max_moves) for m in range(self.max_moves)
} }
@ -111,7 +111,7 @@ class Literals():
self.strike = { self.strike = {
-1: Bool(False) -1: Bool(False)
, **{ , **{
m: Symbol('m{}s+'.format(m)) m: Symbol('m{}newstrike'.format(m))
for m in range(self.max_moves) for m in range(self.max_moves)
} }
} }
@ -122,7 +122,7 @@ class Literals():
, **{ , **{
m: { m: {
**{(s, 0): Bool(True) for s in range(0, self.num_suits)}, **{(s, 0): Bool(True) for s in range(0, self.num_suits)},
**{(s, r): Symbol('m{}:{}{}'.format(m, s, r)) for s in range(0, self.num_suits) for r in range(1, 6)} **{(s, r): Symbol('m{}progress{}{}'.format(m, s, r)) for s in range(0, self.num_suits) for r in range(1, 6)}
} }
for m in range(self.max_moves) for m in range(self.max_moves)
} }
@ -131,29 +131,66 @@ class Literals():
## Utility variables ## Utility variables
# discard_any[m] == "at move m we play/discard a card" # discard_any[m] == "at move m we play/discard a card"
self.discard_any = { m: Symbol('m{}d'.format(m)) for m in range(self.max_moves) } self.discard_any = { m: Symbol('m{}discard_any'.format(m)) for m in range(self.max_moves) }
# draw_any[m] == "at move m we draw a card" # draw_any[m] == "at move m we draw a card"
self.draw_any = {m: Symbol('m{}D'.format(m)) for m in range(self.max_moves)} self.draw_any = {m: Symbol('m{}draw_any'.format(m)) for m in range(self.max_moves)}
# play[m] == "at move m we play a card" # play[m] == "at move m we play a card"
self.play = {m: Symbol('m{}p'.format(m)) for m in range(self.max_moves)} self.play = {m: Symbol('m{}play'.format(m)) for m in range(self.max_moves)}
# play5[m] == "at move m we play a 5" # play5[m] == "at move m we play a 5"
self.play5 = {m: Symbol('m{}p5'.format(m)) for m in range(self.max_moves)} self.play5 = {m: Symbol('m{}play5'.format(m)) for m in range(self.max_moves)}
# incr_clues[m] == "at move m we obtain a clue" # incr_clues[m] == "at move m we obtain a clue"
self.incr_clues = {m: Symbol('m{}c+'.format(m)) for m in range(self.max_moves)} self.incr_clues = {m: Symbol('m{}c+'.format(m)) for m in range(self.max_moves)}
def solve(game_state: GameState):
ls = Literals(game_state.num_players, game_state.num_suits, game_state.num_dark_suits)
def solve(deck: List[DeckCard], num_players=5): ##### setup of initial game state
num_suits = max(map(lambda card: card.suitIndex, deck)) + 1 # properties used later to model valid moves
num_dark_suits = (len(deck) - 10 * num_suits) // (-5) num_dark_suits = game_state.num_dark_suits
num_suits = game_state.num_suits
deck = game_state.deck
next_draw = game_state.progress
ls = Literals(num_players, num_suits, num_dark_suits) starting_hands = [[card.deck_index for card in hand] for hand in game_state.hands]
first_turn = len(game_state.actions)
# set initial clues
for i in range(0,10):
ls.clues[first_turn - 1][i] = Bool(i <= game_state.clues)
# set initial strikes
for i in range(0, game_state.num_strikes + 1):
ls.strikes[first_turn - 1][i] = Bool(i <= game_state.strikes)
# check if extraround has started (usually not)
ls.extraround[first_turn - 1] = Bool(game_state.remaining_extra_turns < game_state.num_players)
ls.dummyturn[first_turn -1] = Bool(False)
# set recent draws: important to model progress
# we just pretend that the last card drawn was in fact drawn last turn,
# regardless of when it was actually drawn
for neg_turn in range(1, min(9, first_turn + 2)):
for i in range(game_state.num_players * game_state.hand_size, game_state.deck_size):
ls.draw[first_turn - neg_turn][i] = Bool(neg_turn == 1 and i == game_state.progress - 1)
# forbid re-drawing of the last card drawn
for m in range(first_turn, ls.max_moves):
ls.draw[m][game_state.progress - 1] = Bool(False)
# model initial progress
for s in range(0, game_state.num_suits):
for r in range(0, 6):
ls.progress[first_turn - 1][s, r] = Bool(r <= game_state.stacks[s])
### Now, model all valid moves
valid_move = lambda m: And( valid_move = lambda m: And(
# in dummy turns, nothing can be discarded # in dummy turns, nothing can be discarded
Implies(ls.dummyturn[m], Not(ls.discard_any[m])), Implies(ls.dummyturn[m], Not(ls.discard_any[m])),
@ -162,7 +199,7 @@ def solve(deck: List[DeckCard], num_players=5):
Iff(ls.discard_any[m], Or(ls.discard[m][i] for i in range(ls.deck_size))), Iff(ls.discard_any[m], Or(ls.discard[m][i] for i in range(ls.deck_size))),
# definition of draw_any # definition of draw_any
Iff(ls.draw_any[m], Or(ls.draw[m][i] for i in range(ls.distributed_cards, ls.deck_size))), Iff(ls.draw_any[m], Or(ls.draw[m][i] for i in range(next_draw, ls.deck_size))),
# ls.draw implies ls.discard (and converse true before the ls.extraround) # ls.draw implies ls.discard (and converse true before the ls.extraround)
Implies(ls.draw_any[m], ls.discard_any[m]), Implies(ls.draw_any[m], ls.discard_any[m]),
@ -194,19 +231,19 @@ def solve(deck: List[DeckCard], num_players=5):
Implies(Not(ls.discard_any[m]), Or(ls.clues[m-1][1], ls.dummyturn[m])), Implies(Not(ls.discard_any[m]), Or(ls.clues[m-1][1], ls.dummyturn[m])),
# we can only draw card i if the last ls.drawn card was i-1 # we can only draw card i if the last ls.drawn card was i-1
*[Implies(ls.draw[m][i], Or(And(ls.draw[m0][i-1], *[Not(ls.draw_any[m1]) for m1 in range(m0+1, m)]) for m0 in range(max(-1, m-9), m))) for i in range(ls.distributed_cards, ls.deck_size)], *[Implies(ls.draw[m][i], Or(And(ls.draw[m0][i-1], *[Not(ls.draw_any[m1]) for m1 in range(m0+1, m)]) for m0 in range(max(first_turn - 1, m-9), m))) for i in range(next_draw, ls.deck_size)],
# we can only draw at most one card (NOTE: redundant, FIXME: avoid quadratic formula) # we can only draw at most one card (NOTE: redundant, FIXME: avoid quadratic formula)
AtMostOne(ls.draw[m][i] for i in range(ls.distributed_cards, ls.deck_size)), AtMostOne(ls.draw[m][i] for i in range(next_draw, ls.deck_size)),
# we can only discard a card if we drew it earlier... # we can only discard a card if we drew it earlier...
*[Implies(ls.discard[m][i], Or(ls.draw[m0][i] for m0 in range(m-ls.num_players, -1, -ls.num_players))) for i in range(ls.distributed_cards, ls.deck_size)], *[Implies(ls.discard[m][i], Or(ls.draw[m0][i] for m0 in range(m-ls.num_players, first_turn - 1, -ls.num_players))) for i in range(next_draw, ls.deck_size)],
# ...or if it was part of the initial hand # ...or if it was part of the initial hand
*[Not(ls.discard[m][i]) for i in range(0, ls.distributed_cards) if i // ls.hand_size != m % ls.num_players], *[Not(ls.discard[m][i]) for i in range(0, next_draw) if i not in starting_hands[m % ls.num_players] ],
# we can only discard a card if we did not discard it yet # we can only discard a card if we did not discard it yet
*[Implies(ls.discard[m][i], And(Not(ls.discard[m0][i]) for m0 in range(m-ls.num_players, -1, -ls.num_players))) for i in range(ls.deck_size)], *[Implies(ls.discard[m][i], And(Not(ls.discard[m0][i]) for m0 in range(m-ls.num_players, first_turn - 1, -ls.num_players))) for i in range(ls.deck_size)],
# we can only discard at most one card (FIXME: avoid quadratic formula) # we can only discard at most one card (FIXME: avoid quadratic formula)
AtMostOne(ls.discard[m][i] for i in range(ls.deck_size)), AtMostOne(ls.discard[m][i] for i in range(ls.deck_size)),
@ -252,73 +289,93 @@ def solve(deck: List[DeckCard], num_players=5):
*[ *[
Or( Or(
And(ls.discard[m][i], ls.play[m]) And(ls.discard[m][i], ls.play[m])
for m in range(ls.max_moves) for m in range(first_turn, ls.max_moves)
for i in range(ls.deck_size) for i in range(ls.deck_size)
if deck[i] == DeckCard(s, r) if game_state.deck[i] == DeckCard(s, r)
) )
for s in range(0, ls.num_suits) for s in range(0, ls.num_suits)
for r in range(1, 6) for r in range(1, 6)
if r > game_state.stacks[s]
] ]
) )
constraints = And(*[valid_move(m) for m in range(ls.max_moves)], win) constraints = And(*[valid_move(m) for m in range(first_turn, ls.max_moves)], win)
# print('Solving instance with {} variables, {} nodes'.format(len(get_atoms(constraints)), get_formula_size(constraints))) # print('Solving instance with {} variables, {} nodes'.format(len(get_atoms(constraints)), get_formula_size(constraints)))
model = get_model(constraints) model = get_model(constraints)
if model: if model:
# print_model(model, deck) # print_model(model, game_state, ls)
solution = toJSON(model, deck, ls) solution = toJSON(model, game_state, ls)
return True, solution return True, solution
else: else:
return False, None
#conj = list(conjunctive_partition(constraints)) #conj = list(conjunctive_partition(constraints))
#print('statements: {}'.format(len(conj))) #print('statements: {}'.format(len(conj)))
#ucore = get_unsat_core(conj) #ucore = get_unsat_core(conj)
#print('unsat core size: {}'.format(len(ucore))) #print('unsat core size: {}'.format(len(ucore)))
#for f in ucore: #for f in ucore:
# print(f.serialize()) # print(f.serialize())
return False, None
def print_model(model, deck, num_players): def print_model(model, cur_game_state, ls: Literals):
draw = globals()['draw'][num_players] deck = cur_game_state.deck
for m in range(max_moves[num_players]):
print('=== move {} ==='.format(m))
print('clues: ' + ''.join(str(i) for i in range(1, 9) if model.get_py_value(clues[m][i])))
print('strikes: ' + ''.join(str(i) for i in range(1, NUM_STRIKES) if model.get_py_value(strikes[m][i])))
print('draw: ' + ', '.join('{} [{}{}]'.format(i, deck[i][0], deck[i][1]) for i in range(20, 50) if model.get_py_value(draw[m][i])))
print('discard: ' + ', '.join('{} [{}{}]'.format(i, deck[i][0], deck[i][1]) for i in range(50) if model.get_py_value(discard[m][i])))
for c in COLORS:
print('progress {}: '.format(c) + ''.join(str(k) for k in range(1, 6) if model.get_py_value(progress[m][c, k])))
flags = ['discard_any', 'draw_any', 'play', 'play5', 'incr_clues', 'strike', 'extraround', 'dummyturn']
print(', '.join(f for f in flags if model.get_py_value(globals()[f][m])))
def toJSON(model, deck: List[DeckCard], ls: Literals) -> dict:
gs = GameState(ls.num_players, deck)
for m in range(ls.max_moves): for m in range(ls.max_moves):
print('=== move {} ==='.format(m))
print('clues: ' + ''.join(str(i) for i in range(1, 9) if model.get_py_value(ls.clues[m][i])))
print('strikes: ' + ''.join(str(i) for i in range(1, 3) if model.get_py_value(ls.strikes[m][i])))
print('draw: ' + ', '.join('{}: {}'.format(i, deck[i]) for i in range(cur_game_state.progress, 50) if model.get_py_value(ls.draw[m][i])))
print('discard: ' + ', '.join('{}: {}'.format(i, deck[i]) for i in range(50) if model.get_py_value(ls.discard[m][i])))
for s in range(0, ls.num_suits):
print('progress {}: '.format(COLORS[s]) + ''.join(str(r) for r in range(1, 6) if model.get_py_value(ls.progress[m][s, r])))
flags = ['discard_any', 'draw_any', 'play', 'play5', 'incr_clues', 'strike', 'extraround', 'dummyturn']
print(', '.join(f for f in flags if model.get_py_value(getattr(ls, f)[m])))
def toJSON(model, cur_game_state: GameState, ls: Literals) -> dict:
for m in range(len(cur_game_state.actions), ls.max_moves):
if model.get_py_value(ls.dummyturn[m]): if model.get_py_value(ls.dummyturn[m]):
break break
if model.get_py_value(ls.discard_any[m]): if model.get_py_value(ls.discard_any[m]):
card_idx = next(i for i in range(0, ls.deck_size) if model.get_py_value(ls.discard[m][i])) card_idx = next(i for i in range(0, ls.deck_size) if model.get_py_value(ls.discard[m][i]))
if model.get_py_value(ls.play[m]) or model.get_py_value(ls.strike[m]): if model.get_py_value(ls.play[m]) or model.get_py_value(ls.strike[m]):
gs.play(card_idx) cur_game_state.play(card_idx)
else: else:
gs.discard(card_idx) cur_game_state.discard(card_idx)
else: else:
gs.clue() cur_game_state.clue()
return gs.to_json() return cur_game_state.to_json()
def run_deck(): def run_deck():
deck_str = 'p5 p3 b4 r5 y4 y4 y5 r4 b2 y2 y3 g5 g2 g3 g4 p4 r3 b2 b3 b3 p4 b1 p2 b1 b1 p2 p1 p1 g1 r4 g1 r1 r3 r1 g1 r1 p1 b4 p3 g2 g3 g4 b5 y1 y1 y1 r2 r2 y2 y3' puzzle = True
if puzzle:
deck_str = 'p5 p3 b4 r5 y4 y4 y5 r4 b2 y2 y3 g5 g2 g3 g4 p4 r3 b2 b3 b3 p4 b1 p2 b1 b1 p2 p1 p1 g1 r4 g1 r1 r3 r1 g1 r1 p1 b4 p3 g2 g3 g4 b5 y1 y1 y1 r2 r2 y2 y3'
deck = [DeckCard(COLORS.index(c[0]), int(c[1])) for c in deck_str.split(" ")]
num_p = 5
else:
deck_str = "15gfvqluvuwaqnmrkpkaignlaxpjbmsprksfcddeybfixchuhtwo"
deck = decompress_deck(deck_str)
num_p = 4
deck = [DeckCard(COLORS.index(c[0]), int(c[1])) for c in deck_str.split(" ")]
print(deck) print(deck)
solvable, sol = solve(deck, num_players=5) gs = GameState(num_p, deck)
if puzzle:
gs.play(2)
pass
else:
strat = GreedyStrategy(gs)
for _ in range(18):
strat.make_move()
print(link(gs.to_json()))
solvable, sol = solve(gs)
if solvable: if solvable:
print(sol) print(sol)
print(link(sol)) print(link(sol))
else:
print('unsolvable')
if __name__ == "__main__": if __name__ == "__main__":
run_deck() run_deck()