start impl of better greedy solver

This commit is contained in:
Maximilian Keßler 2023-05-25 17:00:18 +02:00
parent a1345a1976
commit a427a2575e
Signed by: max
GPG Key ID: BCC5A619923C0BA5

View File

@ -3,6 +3,7 @@ import collections
import sys
from enum import Enum
from log_setup import logger
from typing import Tuple, List, Optional
from time import sleep
from hanabi import DeckCard, Action, ActionType, GameState, HanabiInstance
@ -14,7 +15,8 @@ class CardType(Enum):
Trash = 0
Playable = 1
Critical = 2
Dispensable = 3
DuplicateVisible = 3
UniqueVisible = 4
class CardState():
@ -31,8 +33,10 @@ class CardState():
return "Playable ({}) with weight {}".format(self.card, self.weight)
case CardType.Critical:
return "Critical ({})".format(self.card)
case CardType.Dispensable:
return "Dispensable ({}) with weight {}".format(self.card, self.weight)
case CardType.DuplicateVisible:
return "Useful (duplicate visible) ({}) with weight {}".format(self.card, self.weight)
case CardType.UniqueVisible:
return "Useful (unique visible) ({}) with weight {}".format(self.card, self.weight)
# TODO
@ -45,7 +49,92 @@ def card_type(game_state, card):
elif card.rank == 5 or card in game_state.trash:
return CardType.Critical
else:
return CardType.Dispensable
visible_cards = sum((game_state.hands[player] for player in range(game_state.num_players)), [])
if visible_cards.count(card) >= 2:
return CardType.DuplicateVisible
else:
return CardType.UniqueVisible
class WeightedCard:
def __init__(self, card, weight: Optional[int] = None):
self.card = card
self.weight = weight
def __repr__(self):
return "{} with weight {}".format(self.card, self.weight)
class HandState:
def __init__(self, player: int, game_state: GameState):
self.trash = []
self.playable = []
self.critical = []
self.dupes = []
self.uniques = []
for card in game_state.hands[player]:
match card_type(game_state, card):
case CardType.Trash:
self.trash.append(WeightedCard(card))
case CardType.Playable:
if card not in map(lambda c: c.card, self.playable):
self.playable.append(WeightedCard(card))
else:
self.trash.append(card)
case CardType.Critical:
self.critical.append(WeightedCard(card))
case CardType.UniqueVisible:
self.uniques.append(WeightedCard(card))
case CardType.DuplicateVisible:
copy = next((w for w in self.dupes if w.card == card), None)
if copy is not None:
self.dupes.remove(copy)
self.critical.append(copy)
self.trash.append(card)
else:
self.dupes.append(WeightedCard(card))
self.playable.sort(key=lambda c: c.card.rank)
self.dupes.sort(key=lambda c: c.card.rank)
self.uniques.sort(key=lambda c: c.card.rank)
if len(self.trash) > 0:
self.best_discard = self.trash[0]
self.discard_badness = 0
elif len(self.dupes) > 0:
self.best_discard = self.dupes[0]
self.discard_badness = 8 - game_state.num_players
elif len(self.uniques) > 0:
self.best_discard = self.uniques[-1]
self.discard_badness = 80 - 10 * self.best_discard.card.rank
elif len(self.playable) > 0:
self.best_discard = self.playable[-1]
self.discard_badness = 80 - 10 * self.best_discard.card.rank
else:
assert len(self.critical) > 0, "Programming error."
self.best_discard = self.critical[-1]
self.discard_badness = 600 - 100*self.best_discard.card.rank
def num_useful_cards(self):
return len(self.dupes) + len(self.uniques) + len(self.playable) + len(self.critical)
class CheatingStrategy:
def __init__(self, game_state: GameState):
self.game_state = game_state
def make_move(self):
hand_states = [HandState(player, self.game_state) for player in range(self.game_state.num_players)]
modified_pace = self.game_state.pace - sum(
1 for state in hand_states if len(state.trash) == self.game_state.hand_size
)
cur_hand = hand_states[self.game_state.turn]
print([state.__dict__ for state in hand_states])
print(self.game_state.pace)
exit(0)
class GreedyStrategy():
@ -157,7 +246,7 @@ class GreedyStrategy():
def run_deck(instance: HanabiInstance) -> GameState:
gs = GameState(instance)
strat = GreedyStrategy(gs)
strat = CheatingStrategy(gs)
while not gs.is_over():
strat.make_move()
return gs
@ -190,3 +279,9 @@ def run_samples(num_players, sample_size):
logger.info("Won {} ({}%) and lost {} ({}%) from sample of {} test games using greedy strategy.".format(
won, round(100 * won / sample_size, 2), lost, round(100 * lost / sample_size, 2), sample_size
))
if __name__ == "__main__":
for p in range(2, 6):
run_samples(p, int(sys.argv[1]))
print()