start impl of better greedy solver
This commit is contained in:
parent
a1345a1976
commit
a427a2575e
1 changed files with 100 additions and 5 deletions
105
greedy_solver.py
105
greedy_solver.py
|
@ -3,6 +3,7 @@ import collections
|
|||
import sys
|
||||
from enum import Enum
|
||||
from log_setup import logger
|
||||
from typing import Tuple, List, Optional
|
||||
from time import sleep
|
||||
|
||||
from hanabi import DeckCard, Action, ActionType, GameState, HanabiInstance
|
||||
|
@ -14,7 +15,8 @@ class CardType(Enum):
|
|||
Trash = 0
|
||||
Playable = 1
|
||||
Critical = 2
|
||||
Dispensable = 3
|
||||
DuplicateVisible = 3
|
||||
UniqueVisible = 4
|
||||
|
||||
|
||||
class CardState():
|
||||
|
@ -31,8 +33,10 @@ class CardState():
|
|||
return "Playable ({}) with weight {}".format(self.card, self.weight)
|
||||
case CardType.Critical:
|
||||
return "Critical ({})".format(self.card)
|
||||
case CardType.Dispensable:
|
||||
return "Dispensable ({}) with weight {}".format(self.card, self.weight)
|
||||
case CardType.DuplicateVisible:
|
||||
return "Useful (duplicate visible) ({}) with weight {}".format(self.card, self.weight)
|
||||
case CardType.UniqueVisible:
|
||||
return "Useful (unique visible) ({}) with weight {}".format(self.card, self.weight)
|
||||
|
||||
|
||||
# TODO
|
||||
|
@ -45,7 +49,92 @@ def card_type(game_state, card):
|
|||
elif card.rank == 5 or card in game_state.trash:
|
||||
return CardType.Critical
|
||||
else:
|
||||
return CardType.Dispensable
|
||||
visible_cards = sum((game_state.hands[player] for player in range(game_state.num_players)), [])
|
||||
if visible_cards.count(card) >= 2:
|
||||
return CardType.DuplicateVisible
|
||||
else:
|
||||
return CardType.UniqueVisible
|
||||
|
||||
|
||||
class WeightedCard:
|
||||
def __init__(self, card, weight: Optional[int] = None):
|
||||
self.card = card
|
||||
self.weight = weight
|
||||
|
||||
def __repr__(self):
|
||||
return "{} with weight {}".format(self.card, self.weight)
|
||||
|
||||
|
||||
class HandState:
|
||||
def __init__(self, player: int, game_state: GameState):
|
||||
self.trash = []
|
||||
self.playable = []
|
||||
self.critical = []
|
||||
self.dupes = []
|
||||
self.uniques = []
|
||||
for card in game_state.hands[player]:
|
||||
match card_type(game_state, card):
|
||||
case CardType.Trash:
|
||||
self.trash.append(WeightedCard(card))
|
||||
case CardType.Playable:
|
||||
if card not in map(lambda c: c.card, self.playable):
|
||||
self.playable.append(WeightedCard(card))
|
||||
else:
|
||||
self.trash.append(card)
|
||||
case CardType.Critical:
|
||||
self.critical.append(WeightedCard(card))
|
||||
case CardType.UniqueVisible:
|
||||
self.uniques.append(WeightedCard(card))
|
||||
case CardType.DuplicateVisible:
|
||||
copy = next((w for w in self.dupes if w.card == card), None)
|
||||
if copy is not None:
|
||||
self.dupes.remove(copy)
|
||||
self.critical.append(copy)
|
||||
self.trash.append(card)
|
||||
else:
|
||||
self.dupes.append(WeightedCard(card))
|
||||
self.playable.sort(key=lambda c: c.card.rank)
|
||||
self.dupes.sort(key=lambda c: c.card.rank)
|
||||
self.uniques.sort(key=lambda c: c.card.rank)
|
||||
if len(self.trash) > 0:
|
||||
self.best_discard = self.trash[0]
|
||||
self.discard_badness = 0
|
||||
elif len(self.dupes) > 0:
|
||||
self.best_discard = self.dupes[0]
|
||||
self.discard_badness = 8 - game_state.num_players
|
||||
elif len(self.uniques) > 0:
|
||||
self.best_discard = self.uniques[-1]
|
||||
self.discard_badness = 80 - 10 * self.best_discard.card.rank
|
||||
elif len(self.playable) > 0:
|
||||
self.best_discard = self.playable[-1]
|
||||
self.discard_badness = 80 - 10 * self.best_discard.card.rank
|
||||
else:
|
||||
assert len(self.critical) > 0, "Programming error."
|
||||
self.best_discard = self.critical[-1]
|
||||
self.discard_badness = 600 - 100*self.best_discard.card.rank
|
||||
|
||||
def num_useful_cards(self):
|
||||
return len(self.dupes) + len(self.uniques) + len(self.playable) + len(self.critical)
|
||||
|
||||
|
||||
class CheatingStrategy:
|
||||
def __init__(self, game_state: GameState):
|
||||
self.game_state = game_state
|
||||
|
||||
def make_move(self):
|
||||
hand_states = [HandState(player, self.game_state) for player in range(self.game_state.num_players)]
|
||||
|
||||
modified_pace = self.game_state.pace - sum(
|
||||
1 for state in hand_states if len(state.trash) == self.game_state.hand_size
|
||||
)
|
||||
|
||||
cur_hand = hand_states[self.game_state.turn]
|
||||
|
||||
print([state.__dict__ for state in hand_states])
|
||||
print(self.game_state.pace)
|
||||
exit(0)
|
||||
|
||||
|
||||
|
||||
|
||||
class GreedyStrategy():
|
||||
|
@ -157,7 +246,7 @@ class GreedyStrategy():
|
|||
|
||||
def run_deck(instance: HanabiInstance) -> GameState:
|
||||
gs = GameState(instance)
|
||||
strat = GreedyStrategy(gs)
|
||||
strat = CheatingStrategy(gs)
|
||||
while not gs.is_over():
|
||||
strat.make_move()
|
||||
return gs
|
||||
|
@ -190,3 +279,9 @@ def run_samples(num_players, sample_size):
|
|||
logger.info("Won {} ({}%) and lost {} ({}%) from sample of {} test games using greedy strategy.".format(
|
||||
won, round(100 * won / sample_size, 2), lost, round(100 * lost / sample_size, 2), sample_size
|
||||
))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for p in range(2, 6):
|
||||
run_samples(p, int(sys.argv[1]))
|
||||
print()
|
||||
|
|
Loading…
Reference in a new issue