Clip initializations

This commit is contained in:
posij118 2024-02-05 21:58:49 +08:00
parent 1de9bd7fff
commit 333a3726fd

View file

@ -164,19 +164,32 @@ def calculate_minloss_ratings(
variant_prior_weight=0.5,
variant_random_init_stdev: int = 0,
):
player_rating_priors = np.broadcast_to(player_rating_priors, len(rated_ids))
player_rating_priors = np.broadcast_to(
player_rating_priors, len(global_info.rated_ids)
)
variant_rating_priors = np.broadcast_to(
variant_rating_priors, len(variant_ids) * len(player_counts)
variant_rating_priors,
len(global_info.variant_ids) * len(global_info.player_counts),
)
prior_rating_list = np.concatenate(
(
player_rating_priors
+ np.random.normal(0, player_random_init_stdev, len(player_rating_priors)),
+ np.clip(
np.random.normal(
0, player_random_init_stdev, len(player_rating_priors)
),
-3 * player_random_init_stdev,
3 * player_random_init_stdev,
),
variant_rating_priors
+ np.random.normal(
+ np.clip(
np.random.normal(
0, variant_random_init_stdev, len(variant_rating_priors)
),
-3 * variant_random_init_stdev,
3 * variant_random_init_stdev,
),
),
dtype=np.float32,
)
@ -306,7 +319,7 @@ def write_data(random_init_rating_lists, cv_rating_lists):
f.write(s)
if __name__ == "__main__":
def read_data():
with open("../data/games.json") as game_data:
game_list = json.loads(game_data.read())
@ -356,6 +369,12 @@ if __name__ == "__main__":
] += 1
global_info.game_counts = game_counts
return game_list, global_info, p_win_lookup_table
if __name__ == "__main__":
game_list, global_info, p_win_lookup_table = read_data()
random_init_rating_lists = []
cv_rating_lists = []