diff --git a/sat.py b/sat.py index 72c123c..c15530b 100644 --- a/sat.py +++ b/sat.py @@ -1,7 +1,8 @@ +import copy from pysmt.shortcuts import Symbol, Bool, Not, Implies, Iff, And, Or, AtMostOne, ExactlyOne, get_model, get_atoms, get_formula_size, get_unsat_core from pysmt.rewritings import conjunctive_partition import json -from typing import List +from typing import List, Optional, Tuple from concurrent.futures import ProcessPoolExecutor from hanabi import DeckCard, Action, ActionType, GameState, HanabiInstance @@ -117,7 +118,7 @@ class Literals(): self.incr_clues = {m: Symbol('m{}c+'.format(m)) for m in range(instance.max_winning_moves)} -def solve_sat(starting_state: GameState | HanabiInstance): +def solve_sat(starting_state: GameState | HanabiInstance) -> Tuple[bool, Optional[GameState]]: if isinstance(starting_state, HanabiInstance): instance = starting_state game_state = GameState(instance) @@ -283,7 +284,7 @@ def solve_sat(starting_state: GameState | HanabiInstance): model = get_model(constraints) if model: # print_model(model, game_state, ls) - solution = evaluate_model(model, game_state, ls) + solution = evaluate_model(model, copy.deepcopy(game_state), ls) return True, solution else: #conj = list(conjunctive_partition(constraints))