diff --git a/src/minimize_loss.py b/src/minimize_loss.py index 2213090..6a247d5 100644 --- a/src/minimize_loss.py +++ b/src/minimize_loss.py @@ -164,18 +164,31 @@ def calculate_minloss_ratings( variant_prior_weight=0.5, variant_random_init_stdev: int = 0, ): - player_rating_priors = np.broadcast_to(player_rating_priors, len(rated_ids)) + player_rating_priors = np.broadcast_to( + player_rating_priors, len(global_info.rated_ids) + ) variant_rating_priors = np.broadcast_to( - variant_rating_priors, len(variant_ids) * len(player_counts) + variant_rating_priors, + len(global_info.variant_ids) * len(global_info.player_counts), ) prior_rating_list = np.concatenate( ( player_rating_priors - + np.random.normal(0, player_random_init_stdev, len(player_rating_priors)), + + np.clip( + np.random.normal( + 0, player_random_init_stdev, len(player_rating_priors) + ), + -3 * player_random_init_stdev, + 3 * player_random_init_stdev, + ), variant_rating_priors - + np.random.normal( - 0, variant_random_init_stdev, len(variant_rating_priors) + + np.clip( + np.random.normal( + 0, variant_random_init_stdev, len(variant_rating_priors) + ), + -3 * variant_random_init_stdev, + 3 * variant_random_init_stdev, ), ), dtype=np.float32, @@ -306,7 +319,7 @@ def write_data(random_init_rating_lists, cv_rating_lists): f.write(s) -if __name__ == "__main__": +def read_data(): with open("../data/games.json") as game_data: game_list = json.loads(game_data.read()) @@ -356,6 +369,12 @@ if __name__ == "__main__": ] += 1 global_info.game_counts = game_counts + return game_list, global_info, p_win_lookup_table + + +if __name__ == "__main__": + game_list, global_info, p_win_lookup_table = read_data() + random_init_rating_lists = [] cv_rating_lists = []