hanabi-league/config.py

204 lines
7.1 KiB
Python
Raw Normal View History

2023-11-22 16:21:30 +01:00
import shutil
2023-11-22 18:56:08 +01:00
from typing import Dict, List
2023-11-22 16:21:30 +01:00
2023-11-22 15:31:36 +01:00
import yaml
import platformdirs
2023-11-22 17:26:15 +01:00
import datetime
2023-11-22 17:13:44 +01:00
import dateutil.parser
2023-11-22 17:26:15 +01:00
from pathlib import Path
2023-11-22 15:31:36 +01:00
import constants
from log_setup import logger
class DBConfig:
def __init__(self, db_name: str, db_user: str, db_pass: str):
self.db_name = db_name
self.db_user = db_user
self.db_pass = db_pass
2023-11-22 16:21:30 +01:00
def get_db_config_path() -> Path:
config_dir = Path(platformdirs.user_config_dir(constants.APP_NAME, ensure_exists=True))
config_path = config_dir / constants.DB_CONFIG_FILE_NAME
return config_path
2023-11-22 15:31:36 +01:00
def read_db_config() -> DBConfig:
2023-11-22 15:34:24 +01:00
"""
Reads the DB connection parameters from the config file.
"""
2023-11-22 16:21:30 +01:00
config_path = get_db_config_path()
2023-11-22 15:31:36 +01:00
logger.verbose("DB Configuration read from file {}".format(config_path))
if config_path.exists():
with open(config_path, "r") as f:
config = yaml.safe_load(f)
db_name = config.get('dbname', None)
db_user = config.get('dbuser', None)
db_pass = config.get('dbpass', None)
if db_name is None:
logger.debug("Falling back to default DB name {}".format(constants.DEFAULT_DB_NAME))
db_name = constants.DEFAULT_DB_NAME
if db_user is None:
logger.debug("Falling back to default DB user {}".format(constants.DEFAULT_DB_USER))
db_user = constants.DEFAULT_DB_USER
if db_pass is None:
logger.debug("Falling back to default DB pass {}".format(constants.DEFAULT_DB_PASS))
db_pass = constants.DEFAULT_DB_PASS
logger.debug("Read config values (dbname={}, dbuser={}, dbpass={})".format(db_name, db_user, db_pass))
return DBConfig(db_name, db_user, db_pass)
else:
logger.info(
"No configuration file for database connection found, falling back to default values "
"(dbname={}, dbuser={}, dbpass={}).".format(
constants.DEFAULT_DB_NAME, constants.DEFAULT_DB_USER, constants.DEFAULT_DB_PASS
)
)
logger.info(
"Note: To turn off this message, create a config file at {}".format(config_path)
)
return DBConfig(constants.DEFAULT_DB_NAME, constants.DEFAULT_DB_USER, constants.DEFAULT_DB_PASS)
2023-11-22 16:21:30 +01:00
def create_db_config() -> None:
"""
Creates a default DB config file at the config location
"""
config_path = get_db_config_path()
if not config_path.exists():
shutil.copy(constants.DEFAULT_DB_CONFIG_PATH, config_path)
logger.info("Created default DB config file at {}".format(config_path))
else:
logger.info("DB config file at {} already exists".format(config_path))
2023-11-22 17:13:44 +01:00
def check_config_attr(func):
2023-11-22 17:26:15 +01:00
def wrapper(*args, **kwargs):
2023-11-22 17:13:44 +01:00
try:
2023-11-22 17:26:15 +01:00
return func(*args, **kwargs)
2023-11-22 17:13:44 +01:00
except KeyError as e:
logger.error("Missing config attribute:\n{}".format(e))
2023-11-22 17:26:15 +01:00
return wrapper
2023-11-22 17:13:44 +01:00
class Config:
def __init__(self, config: Dict):
self._config = config
2023-11-22 18:56:08 +01:00
@property
2023-11-22 17:13:44 +01:00
@check_config_attr
def player_base_rating(self) -> int:
return self._config["player_base_rating"]
2023-11-22 18:56:08 +01:00
@property
2023-11-22 17:13:44 +01:00
def min_player_count(self) -> int:
return self._config["min_player_count"]
2023-11-22 18:56:08 +01:00
@property
2023-11-22 17:13:44 +01:00
@check_config_attr
def max_player_count(self) -> int:
return self._config["max_player_count"]
2023-11-22 18:56:08 +01:00
@property
2023-11-22 17:13:44 +01:00
@check_config_attr
def min_suit_count(self) -> int:
2023-11-22 18:56:08 +01:00
return self._config["min_suits"]
2023-11-22 17:13:44 +01:00
2023-11-22 18:56:08 +01:00
@property
2023-11-22 17:13:44 +01:00
@check_config_attr
def max_suit_count(self) -> int:
2023-11-22 18:56:08 +01:00
return self._config["max_suits"]
2023-11-22 17:13:44 +01:00
2023-11-22 18:56:08 +01:00
@property
2023-11-22 17:13:44 +01:00
@check_config_attr
def starting_game_id(self) -> int:
return self._config["starting_game_id"]
2023-11-22 18:56:08 +01:00
@property
2023-11-22 17:13:44 +01:00
@check_config_attr
def ending_game_id(self) -> int:
return self._config["ending_game_id"]
2023-11-22 18:56:08 +01:00
@property
2023-11-22 17:13:44 +01:00
@check_config_attr
2023-11-22 17:26:15 +01:00
def starting_time(self) -> datetime.datetime:
2023-11-22 17:13:44 +01:00
time = self._config["starting_time"]
return dateutil.parser(time, tzinfos={'EST': 'US/Eastern'})
2023-11-22 18:56:08 +01:00
@property
2023-11-22 17:13:44 +01:00
@check_config_attr
2023-11-22 17:26:15 +01:00
def ending_time(self) -> datetime.datetime:
2023-11-22 17:13:44 +01:00
time = self._config["ending_time"]
2023-11-22 17:26:15 +01:00
return dateutil.parser.parse(time, tzinfos={'EST': 'US/Eastern'})
2023-11-22 17:13:44 +01:00
2023-11-22 18:56:08 +01:00
@property
@check_config_attr
def excluded_variants(self) -> List[str]:
return [var.lower() for var in self._config["excluded_variants"]]
2023-11-22 17:13:44 +01:00
@check_config_attr
def variant_base_rating(self, variant_name: str, player_count: int) -> int:
global_base_rating = self._config["variant_base_rating"]
# We use different ways of specifying base ratings here:
# First, there is a (required) config setting for the variant base rating, which will be used as default.
# Then, for each variant, it is possible to either specify some base rating directly,
# or further specify a base rating for each player count.
# Parsing this is now quite easy: We just check if there is an entry for the specific variant and if so,
# read the base rating from there, where we will have to distinguish between a player-independent value
# and possibly player-specific entries. Whenever we don't find an explicit entry, we use the global fallback.
# This makes it possible to just specify the variant + player combinations that differ in their base rating
# from the globally specified one.
var_rating = self._config.get("variant_base_ratings", {}).get(variant_name, None)
if type(var_rating) == int:
return var_rating
elif type(var_rating) == dict:
return var_rating.get("{}p".format(player_count), global_base_rating)
elif var_rating is None:
return global_base_rating
logger.error("Unexpected config format for entry {} in 'variant_base_ratings'".format(variant_name))
def get_config_path():
config_dir = Path(platformdirs.user_config_dir(constants.APP_NAME, ensure_exists=True))
config_path = config_dir / constants.CONFIG_FILE_NAME
return config_path
def read_config() -> Config:
config_path = get_config_path()
logger.verbose("Hanabi League configuration read from file {}".format(config_path))
if config_path.exists():
with open(config_path, "r") as f:
config = yaml.safe_load(f)
return Config(config)
else:
logger.info("No hanabi league configuration found. Falling back to default file {}".format(
constants.DEFAULT_CONFIG_PATH))
logger.info(
"Note: To turn off this message, create a config file at {}".format(config_path)
)
with open(constants.DEFAULT_CONFIG_PATH, "r") as f:
config = yaml.safe_load(f)
return Config(config)
def create_config() -> None:
"""
Creates a default config file for the league at the config location
"""
config_path = get_config_path()
if not config_path.exists():
shutil.copy(constants.DEFAULT_CONFIG_PATH, config_path)
logger.info("Created default hanabi league config file at {}".format(config_path))
else:
logger.info("Hanabi league config file at {} already exists".format(config_path))