start impl of better greedy solver
This commit is contained in:
parent
a1345a1976
commit
a427a2575e
1 changed files with 100 additions and 5 deletions
105
greedy_solver.py
105
greedy_solver.py
|
@ -3,6 +3,7 @@ import collections
|
||||||
import sys
|
import sys
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from log_setup import logger
|
from log_setup import logger
|
||||||
|
from typing import Tuple, List, Optional
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
from hanabi import DeckCard, Action, ActionType, GameState, HanabiInstance
|
from hanabi import DeckCard, Action, ActionType, GameState, HanabiInstance
|
||||||
|
@ -14,7 +15,8 @@ class CardType(Enum):
|
||||||
Trash = 0
|
Trash = 0
|
||||||
Playable = 1
|
Playable = 1
|
||||||
Critical = 2
|
Critical = 2
|
||||||
Dispensable = 3
|
DuplicateVisible = 3
|
||||||
|
UniqueVisible = 4
|
||||||
|
|
||||||
|
|
||||||
class CardState():
|
class CardState():
|
||||||
|
@ -31,8 +33,10 @@ class CardState():
|
||||||
return "Playable ({}) with weight {}".format(self.card, self.weight)
|
return "Playable ({}) with weight {}".format(self.card, self.weight)
|
||||||
case CardType.Critical:
|
case CardType.Critical:
|
||||||
return "Critical ({})".format(self.card)
|
return "Critical ({})".format(self.card)
|
||||||
case CardType.Dispensable:
|
case CardType.DuplicateVisible:
|
||||||
return "Dispensable ({}) with weight {}".format(self.card, self.weight)
|
return "Useful (duplicate visible) ({}) with weight {}".format(self.card, self.weight)
|
||||||
|
case CardType.UniqueVisible:
|
||||||
|
return "Useful (unique visible) ({}) with weight {}".format(self.card, self.weight)
|
||||||
|
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
|
@ -45,7 +49,92 @@ def card_type(game_state, card):
|
||||||
elif card.rank == 5 or card in game_state.trash:
|
elif card.rank == 5 or card in game_state.trash:
|
||||||
return CardType.Critical
|
return CardType.Critical
|
||||||
else:
|
else:
|
||||||
return CardType.Dispensable
|
visible_cards = sum((game_state.hands[player] for player in range(game_state.num_players)), [])
|
||||||
|
if visible_cards.count(card) >= 2:
|
||||||
|
return CardType.DuplicateVisible
|
||||||
|
else:
|
||||||
|
return CardType.UniqueVisible
|
||||||
|
|
||||||
|
|
||||||
|
class WeightedCard:
|
||||||
|
def __init__(self, card, weight: Optional[int] = None):
|
||||||
|
self.card = card
|
||||||
|
self.weight = weight
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{} with weight {}".format(self.card, self.weight)
|
||||||
|
|
||||||
|
|
||||||
|
class HandState:
|
||||||
|
def __init__(self, player: int, game_state: GameState):
|
||||||
|
self.trash = []
|
||||||
|
self.playable = []
|
||||||
|
self.critical = []
|
||||||
|
self.dupes = []
|
||||||
|
self.uniques = []
|
||||||
|
for card in game_state.hands[player]:
|
||||||
|
match card_type(game_state, card):
|
||||||
|
case CardType.Trash:
|
||||||
|
self.trash.append(WeightedCard(card))
|
||||||
|
case CardType.Playable:
|
||||||
|
if card not in map(lambda c: c.card, self.playable):
|
||||||
|
self.playable.append(WeightedCard(card))
|
||||||
|
else:
|
||||||
|
self.trash.append(card)
|
||||||
|
case CardType.Critical:
|
||||||
|
self.critical.append(WeightedCard(card))
|
||||||
|
case CardType.UniqueVisible:
|
||||||
|
self.uniques.append(WeightedCard(card))
|
||||||
|
case CardType.DuplicateVisible:
|
||||||
|
copy = next((w for w in self.dupes if w.card == card), None)
|
||||||
|
if copy is not None:
|
||||||
|
self.dupes.remove(copy)
|
||||||
|
self.critical.append(copy)
|
||||||
|
self.trash.append(card)
|
||||||
|
else:
|
||||||
|
self.dupes.append(WeightedCard(card))
|
||||||
|
self.playable.sort(key=lambda c: c.card.rank)
|
||||||
|
self.dupes.sort(key=lambda c: c.card.rank)
|
||||||
|
self.uniques.sort(key=lambda c: c.card.rank)
|
||||||
|
if len(self.trash) > 0:
|
||||||
|
self.best_discard = self.trash[0]
|
||||||
|
self.discard_badness = 0
|
||||||
|
elif len(self.dupes) > 0:
|
||||||
|
self.best_discard = self.dupes[0]
|
||||||
|
self.discard_badness = 8 - game_state.num_players
|
||||||
|
elif len(self.uniques) > 0:
|
||||||
|
self.best_discard = self.uniques[-1]
|
||||||
|
self.discard_badness = 80 - 10 * self.best_discard.card.rank
|
||||||
|
elif len(self.playable) > 0:
|
||||||
|
self.best_discard = self.playable[-1]
|
||||||
|
self.discard_badness = 80 - 10 * self.best_discard.card.rank
|
||||||
|
else:
|
||||||
|
assert len(self.critical) > 0, "Programming error."
|
||||||
|
self.best_discard = self.critical[-1]
|
||||||
|
self.discard_badness = 600 - 100*self.best_discard.card.rank
|
||||||
|
|
||||||
|
def num_useful_cards(self):
|
||||||
|
return len(self.dupes) + len(self.uniques) + len(self.playable) + len(self.critical)
|
||||||
|
|
||||||
|
|
||||||
|
class CheatingStrategy:
|
||||||
|
def __init__(self, game_state: GameState):
|
||||||
|
self.game_state = game_state
|
||||||
|
|
||||||
|
def make_move(self):
|
||||||
|
hand_states = [HandState(player, self.game_state) for player in range(self.game_state.num_players)]
|
||||||
|
|
||||||
|
modified_pace = self.game_state.pace - sum(
|
||||||
|
1 for state in hand_states if len(state.trash) == self.game_state.hand_size
|
||||||
|
)
|
||||||
|
|
||||||
|
cur_hand = hand_states[self.game_state.turn]
|
||||||
|
|
||||||
|
print([state.__dict__ for state in hand_states])
|
||||||
|
print(self.game_state.pace)
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GreedyStrategy():
|
class GreedyStrategy():
|
||||||
|
@ -157,7 +246,7 @@ class GreedyStrategy():
|
||||||
|
|
||||||
def run_deck(instance: HanabiInstance) -> GameState:
|
def run_deck(instance: HanabiInstance) -> GameState:
|
||||||
gs = GameState(instance)
|
gs = GameState(instance)
|
||||||
strat = GreedyStrategy(gs)
|
strat = CheatingStrategy(gs)
|
||||||
while not gs.is_over():
|
while not gs.is_over():
|
||||||
strat.make_move()
|
strat.make_move()
|
||||||
return gs
|
return gs
|
||||||
|
@ -190,3 +279,9 @@ def run_samples(num_players, sample_size):
|
||||||
logger.info("Won {} ({}%) and lost {} ({}%) from sample of {} test games using greedy strategy.".format(
|
logger.info("Won {} ({}%) and lost {} ({}%) from sample of {} test games using greedy strategy.".format(
|
||||||
won, round(100 * won / sample_size, 2), lost, round(100 * lost / sample_size, 2), sample_size
|
won, round(100 * won / sample_size, 2), lost, round(100 * lost / sample_size, 2), sample_size
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
for p in range(2, 6):
|
||||||
|
run_samples(p, int(sys.argv[1]))
|
||||||
|
print()
|
||||||
|
|
Loading…
Reference in a new issue