start impl of better greedy solver

This commit is contained in:
Maximilian Keßler 2023-05-25 17:00:18 +02:00
parent a1345a1976
commit a427a2575e
Signed by: max
GPG key ID: BCC5A619923C0BA5

View file

@ -3,6 +3,7 @@ import collections
import sys import sys
from enum import Enum from enum import Enum
from log_setup import logger from log_setup import logger
from typing import Tuple, List, Optional
from time import sleep from time import sleep
from hanabi import DeckCard, Action, ActionType, GameState, HanabiInstance from hanabi import DeckCard, Action, ActionType, GameState, HanabiInstance
@ -14,7 +15,8 @@ class CardType(Enum):
Trash = 0 Trash = 0
Playable = 1 Playable = 1
Critical = 2 Critical = 2
Dispensable = 3 DuplicateVisible = 3
UniqueVisible = 4
class CardState(): class CardState():
@ -31,8 +33,10 @@ class CardState():
return "Playable ({}) with weight {}".format(self.card, self.weight) return "Playable ({}) with weight {}".format(self.card, self.weight)
case CardType.Critical: case CardType.Critical:
return "Critical ({})".format(self.card) return "Critical ({})".format(self.card)
case CardType.Dispensable: case CardType.DuplicateVisible:
return "Dispensable ({}) with weight {}".format(self.card, self.weight) return "Useful (duplicate visible) ({}) with weight {}".format(self.card, self.weight)
case CardType.UniqueVisible:
return "Useful (unique visible) ({}) with weight {}".format(self.card, self.weight)
# TODO # TODO
@ -45,7 +49,92 @@ def card_type(game_state, card):
elif card.rank == 5 or card in game_state.trash: elif card.rank == 5 or card in game_state.trash:
return CardType.Critical return CardType.Critical
else: else:
return CardType.Dispensable visible_cards = sum((game_state.hands[player] for player in range(game_state.num_players)), [])
if visible_cards.count(card) >= 2:
return CardType.DuplicateVisible
else:
return CardType.UniqueVisible
class WeightedCard:
def __init__(self, card, weight: Optional[int] = None):
self.card = card
self.weight = weight
def __repr__(self):
return "{} with weight {}".format(self.card, self.weight)
class HandState:
def __init__(self, player: int, game_state: GameState):
self.trash = []
self.playable = []
self.critical = []
self.dupes = []
self.uniques = []
for card in game_state.hands[player]:
match card_type(game_state, card):
case CardType.Trash:
self.trash.append(WeightedCard(card))
case CardType.Playable:
if card not in map(lambda c: c.card, self.playable):
self.playable.append(WeightedCard(card))
else:
self.trash.append(card)
case CardType.Critical:
self.critical.append(WeightedCard(card))
case CardType.UniqueVisible:
self.uniques.append(WeightedCard(card))
case CardType.DuplicateVisible:
copy = next((w for w in self.dupes if w.card == card), None)
if copy is not None:
self.dupes.remove(copy)
self.critical.append(copy)
self.trash.append(card)
else:
self.dupes.append(WeightedCard(card))
self.playable.sort(key=lambda c: c.card.rank)
self.dupes.sort(key=lambda c: c.card.rank)
self.uniques.sort(key=lambda c: c.card.rank)
if len(self.trash) > 0:
self.best_discard = self.trash[0]
self.discard_badness = 0
elif len(self.dupes) > 0:
self.best_discard = self.dupes[0]
self.discard_badness = 8 - game_state.num_players
elif len(self.uniques) > 0:
self.best_discard = self.uniques[-1]
self.discard_badness = 80 - 10 * self.best_discard.card.rank
elif len(self.playable) > 0:
self.best_discard = self.playable[-1]
self.discard_badness = 80 - 10 * self.best_discard.card.rank
else:
assert len(self.critical) > 0, "Programming error."
self.best_discard = self.critical[-1]
self.discard_badness = 600 - 100*self.best_discard.card.rank
def num_useful_cards(self):
return len(self.dupes) + len(self.uniques) + len(self.playable) + len(self.critical)
class CheatingStrategy:
def __init__(self, game_state: GameState):
self.game_state = game_state
def make_move(self):
hand_states = [HandState(player, self.game_state) for player in range(self.game_state.num_players)]
modified_pace = self.game_state.pace - sum(
1 for state in hand_states if len(state.trash) == self.game_state.hand_size
)
cur_hand = hand_states[self.game_state.turn]
print([state.__dict__ for state in hand_states])
print(self.game_state.pace)
exit(0)
class GreedyStrategy(): class GreedyStrategy():
@ -157,7 +246,7 @@ class GreedyStrategy():
def run_deck(instance: HanabiInstance) -> GameState: def run_deck(instance: HanabiInstance) -> GameState:
gs = GameState(instance) gs = GameState(instance)
strat = GreedyStrategy(gs) strat = CheatingStrategy(gs)
while not gs.is_over(): while not gs.is_over():
strat.make_move() strat.make_move()
return gs return gs
@ -190,3 +279,9 @@ def run_samples(num_players, sample_size):
logger.info("Won {} ({}%) and lost {} ({}%) from sample of {} test games using greedy strategy.".format( logger.info("Won {} ({}%) and lost {} ({}%) from sample of {} test games using greedy strategy.".format(
won, round(100 * won / sample_size, 2), lost, round(100 * lost / sample_size, 2), sample_size won, round(100 * won / sample_size, 2), lost, round(100 * lost / sample_size, 2), sample_size
)) ))
if __name__ == "__main__":
for p in range(2, 6):
run_samples(p, int(sys.argv[1]))
print()