add option to set minimum pace in sat solver
This commit is contained in:
parent
ac261c629e
commit
f45bde1883
1 changed files with 20 additions and 4 deletions
24
sat.py
24
sat.py
|
@ -29,6 +29,14 @@ class Literals():
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.pace = {
|
||||||
|
-1: Int(instance.initial_pace)
|
||||||
|
, **{
|
||||||
|
m: Symbol('m{}pace'.format(m), INT)
|
||||||
|
for m in range(instance.max_winning_moves)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# strikes[m][i] == "after move m we have at least i strikes"
|
# strikes[m][i] == "after move m we have at least i strikes"
|
||||||
self.strikes = {
|
self.strikes = {
|
||||||
-1: {i: Bool(i == 0) for i in range(0, instance.num_strikes + 1)} # no strikes when we start
|
-1: {i: Bool(i == 0) for i in range(0, instance.num_strikes + 1)} # no strikes when we start
|
||||||
|
@ -117,7 +125,7 @@ class Literals():
|
||||||
self.incr_clues = {m: Symbol('m{}c+'.format(m)) for m in range(instance.max_winning_moves)}
|
self.incr_clues = {m: Symbol('m{}c+'.format(m)) for m in range(instance.max_winning_moves)}
|
||||||
|
|
||||||
|
|
||||||
def solve_sat(starting_state: GameState | HanabiInstance) -> Tuple[bool, Optional[GameState]]:
|
def solve_sat(starting_state: GameState | HanabiInstance, min_pace: Optional[int] = 0) -> Tuple[bool, Optional[GameState]]:
|
||||||
if isinstance(starting_state, HanabiInstance):
|
if isinstance(starting_state, HanabiInstance):
|
||||||
instance = starting_state
|
instance = starting_state
|
||||||
game_state = GameState(instance)
|
game_state = GameState(instance)
|
||||||
|
@ -143,6 +151,9 @@ def solve_sat(starting_state: GameState | HanabiInstance) -> Tuple[bool, Optiona
|
||||||
for i in range(0, 10):
|
for i in range(0, 10):
|
||||||
ls.clues[first_turn - 1] = Int(game_state.clues)
|
ls.clues[first_turn - 1] = Int(game_state.clues)
|
||||||
|
|
||||||
|
# set initial pace
|
||||||
|
ls.pace[first_turn - 1] = Int(game_state.pace)
|
||||||
|
|
||||||
# set initial strikes
|
# set initial strikes
|
||||||
for i in range(0, instance.num_strikes + 1):
|
for i in range(0, instance.num_strikes + 1):
|
||||||
ls.strikes[first_turn - 1][i] = Bool(i <= game_state.strikes)
|
ls.strikes[first_turn - 1][i] = Bool(i <= game_state.strikes)
|
||||||
|
@ -198,8 +209,12 @@ def solve_sat(starting_state: GameState | HanabiInstance) -> Tuple[bool, Optiona
|
||||||
Implies(ls.incr_clues[m], Equals(ls.clues[m], ls.clues[m-1] + 1)),
|
Implies(ls.incr_clues[m], Equals(ls.clues[m], ls.clues[m-1] + 1)),
|
||||||
Implies(And(Or(ls.discard_any[m], ls.dummyturn[m]), Not(ls.incr_clues[m])), Equals(ls.clues[m], ls.clues[m-1])),
|
Implies(And(Or(ls.discard_any[m], ls.dummyturn[m]), Not(ls.incr_clues[m])), Equals(ls.clues[m], ls.clues[m-1])),
|
||||||
|
|
||||||
Iff(ls.half_clue[m], Or(And(ls.half_clue[m-1], Not(ls.incr_clues[m])), And(Not(ls.half_clue[m-1])), ls.incr_clues[m]))
|
# change of pace
|
||||||
if instance.clue_starved else Bool(True),
|
Implies(And(ls.discard_any[m], Or(ls.strike[m], Not(ls.play[m]))), Equals(ls.pace[m], ls.pace[m-1] - 1)),
|
||||||
|
Implies(Or(Not(ls.discard_any[m]), And(Not(ls.strike[m]), ls.play[m])), Equals(ls.pace[m], ls.pace[m-1])),
|
||||||
|
|
||||||
|
# pace is nonnegative
|
||||||
|
GE(ls.pace[m], Int(min_pace)),
|
||||||
|
|
||||||
## more than 8 clues not allowed, ls.discarding produces a strike
|
## more than 8 clues not allowed, ls.discarding produces a strike
|
||||||
# Note that this means that we will never strike while not at 8 clues.
|
# Note that this means that we will never strike while not at 8 clues.
|
||||||
|
@ -312,6 +327,7 @@ def log_model(model, cur_game_state, ls: Literals):
|
||||||
logger.debug('strikes: ' + ''.join(str(i) for i in range(1, 3) if model.get_py_value(ls.strikes[m][i])))
|
logger.debug('strikes: ' + ''.join(str(i) for i in range(1, 3) if model.get_py_value(ls.strikes[m][i])))
|
||||||
logger.debug('draw: ' + ', '.join('{}: {}'.format(i, deck[i]) for i in range(cur_game_state.progress, cur_game_state.instance.deck_size) if model.get_py_value(ls.draw[m][i])))
|
logger.debug('draw: ' + ', '.join('{}: {}'.format(i, deck[i]) for i in range(cur_game_state.progress, cur_game_state.instance.deck_size) if model.get_py_value(ls.draw[m][i])))
|
||||||
logger.debug('discard: ' + ', '.join('{}: {}'.format(i, deck[i]) for i in range(cur_game_state.instance.deck_size) if model.get_py_value(ls.discard[m][i])))
|
logger.debug('discard: ' + ', '.join('{}: {}'.format(i, deck[i]) for i in range(cur_game_state.instance.deck_size) if model.get_py_value(ls.discard[m][i])))
|
||||||
|
logger.debug('pace: {}'.format(model.get_py_value(ls.pace[m])))
|
||||||
for s in range(0, cur_game_state.instance.num_suits):
|
for s in range(0, cur_game_state.instance.num_suits):
|
||||||
logger.debug('progress {}: '.format(COLOR_INITIALS[s]) + ''.join(str(r) for r in range(1, 6) if model.get_py_value(ls.progress[m][s, r])))
|
logger.debug('progress {}: '.format(COLOR_INITIALS[s]) + ''.join(str(r) for r in range(1, 6) if model.get_py_value(ls.progress[m][s, r])))
|
||||||
flags = ['discard_any', 'draw_any', 'play', 'play5', 'incr_clues', 'strike', 'extraround', 'dummyturn']
|
flags = ['discard_any', 'draw_any', 'play', 'play5', 'incr_clues', 'strike', 'extraround', 'dummyturn']
|
||||||
|
@ -362,7 +378,7 @@ def run_deck():
|
||||||
for _ in range(17):
|
for _ in range(17):
|
||||||
strat.make_move()
|
strat.make_move()
|
||||||
|
|
||||||
solvable, sol = solve_sat(gs)
|
solvable, sol = solve_sat(gs, 0)
|
||||||
if solvable:
|
if solvable:
|
||||||
print(sol)
|
print(sol)
|
||||||
print(link(sol))
|
print(link(sol))
|
||||||
|
|
Loading…
Reference in a new issue