sat: annotate function return type. use deep copy for returning solution to avoid modifying passed game state

This commit is contained in:
Maximilian Keßler 2023-05-06 19:41:50 +02:00
parent bdefe7aa34
commit 303158bc25
Signed by: max
GPG Key ID: BCC5A619923C0BA5

7
sat.py
View File

@ -1,7 +1,8 @@
import copy
from pysmt.shortcuts import Symbol, Bool, Not, Implies, Iff, And, Or, AtMostOne, ExactlyOne, get_model, get_atoms, get_formula_size, get_unsat_core from pysmt.shortcuts import Symbol, Bool, Not, Implies, Iff, And, Or, AtMostOne, ExactlyOne, get_model, get_atoms, get_formula_size, get_unsat_core
from pysmt.rewritings import conjunctive_partition from pysmt.rewritings import conjunctive_partition
import json import json
from typing import List from typing import List, Optional, Tuple
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from hanabi import DeckCard, Action, ActionType, GameState, HanabiInstance from hanabi import DeckCard, Action, ActionType, GameState, HanabiInstance
@ -117,7 +118,7 @@ class Literals():
self.incr_clues = {m: Symbol('m{}c+'.format(m)) for m in range(instance.max_winning_moves)} self.incr_clues = {m: Symbol('m{}c+'.format(m)) for m in range(instance.max_winning_moves)}
def solve_sat(starting_state: GameState | HanabiInstance): def solve_sat(starting_state: GameState | HanabiInstance) -> Tuple[bool, Optional[GameState]]:
if isinstance(starting_state, HanabiInstance): if isinstance(starting_state, HanabiInstance):
instance = starting_state instance = starting_state
game_state = GameState(instance) game_state = GameState(instance)
@ -283,7 +284,7 @@ def solve_sat(starting_state: GameState | HanabiInstance):
model = get_model(constraints) model = get_model(constraints)
if model: if model:
# print_model(model, game_state, ls) # print_model(model, game_state, ls)
solution = evaluate_model(model, game_state, ls) solution = evaluate_model(model, copy.deepcopy(game_state), ls)
return True, solution return True, solution
else: else:
#conj = list(conjunctive_partition(constraints)) #conj = list(conjunctive_partition(constraints))