Clip initializations

This commit is contained in:
posij118 2024-02-05 21:58:49 +08:00
parent 1de9bd7fff
commit 333a3726fd

View file

@ -164,18 +164,31 @@ def calculate_minloss_ratings(
variant_prior_weight=0.5, variant_prior_weight=0.5,
variant_random_init_stdev: int = 0, variant_random_init_stdev: int = 0,
): ):
player_rating_priors = np.broadcast_to(player_rating_priors, len(rated_ids)) player_rating_priors = np.broadcast_to(
player_rating_priors, len(global_info.rated_ids)
)
variant_rating_priors = np.broadcast_to( variant_rating_priors = np.broadcast_to(
variant_rating_priors, len(variant_ids) * len(player_counts) variant_rating_priors,
len(global_info.variant_ids) * len(global_info.player_counts),
) )
prior_rating_list = np.concatenate( prior_rating_list = np.concatenate(
( (
player_rating_priors player_rating_priors
+ np.random.normal(0, player_random_init_stdev, len(player_rating_priors)), + np.clip(
np.random.normal(
0, player_random_init_stdev, len(player_rating_priors)
),
-3 * player_random_init_stdev,
3 * player_random_init_stdev,
),
variant_rating_priors variant_rating_priors
+ np.random.normal( + np.clip(
0, variant_random_init_stdev, len(variant_rating_priors) np.random.normal(
0, variant_random_init_stdev, len(variant_rating_priors)
),
-3 * variant_random_init_stdev,
3 * variant_random_init_stdev,
), ),
), ),
dtype=np.float32, dtype=np.float32,
@ -306,7 +319,7 @@ def write_data(random_init_rating_lists, cv_rating_lists):
f.write(s) f.write(s)
if __name__ == "__main__": def read_data():
with open("../data/games.json") as game_data: with open("../data/games.json") as game_data:
game_list = json.loads(game_data.read()) game_list = json.loads(game_data.read())
@ -356,6 +369,12 @@ if __name__ == "__main__":
] += 1 ] += 1
global_info.game_counts = game_counts global_info.game_counts = game_counts
return game_list, global_info, p_win_lookup_table
if __name__ == "__main__":
game_list, global_info, p_win_lookup_table = read_data()
random_init_rating_lists = [] random_init_rating_lists = []
cv_rating_lists = [] cv_rating_lists = []