diff --git a/config.py b/config.py index a0dc0c9..d90826e 100644 --- a/config.py +++ b/config.py @@ -1,3 +1,5 @@ +import shutil + import yaml import platformdirs from pathlib import Path @@ -13,13 +15,17 @@ class DBConfig: self.db_pass = db_pass +def get_db_config_path() -> Path: + config_dir = Path(platformdirs.user_config_dir(constants.APP_NAME, ensure_exists=True)) + config_path = config_dir / constants.DB_CONFIG_FILE_NAME + return config_path + + def read_db_config() -> DBConfig: """ Reads the DB connection parameters from the config file. """ - config_dir = Path(platformdirs.user_config_dir(constants.APP_NAME, ensure_exists=True)) - config_path = config_dir / constants.DB_CONFIG_FILE_NAME - + config_path = get_db_config_path() logger.verbose("DB Configuration read from file {}".format(config_path)) if config_path.exists(): @@ -53,3 +59,15 @@ def read_db_config() -> DBConfig: "Note: To turn off this message, create a config file at {}".format(config_path) ) return DBConfig(constants.DEFAULT_DB_NAME, constants.DEFAULT_DB_USER, constants.DEFAULT_DB_PASS) + + +def create_db_config() -> None: + """ + Creates a default DB config file at the config location + """ + config_path = get_db_config_path() + if not config_path.exists(): + shutil.copy(constants.DEFAULT_DB_CONFIG_PATH, config_path) + logger.info("Created default DB config file at {}".format(config_path)) + else: + logger.info("DB config file at {} already exists".format(config_path)) diff --git a/constants.py b/constants.py index e4bbadb..5d523f0 100644 --- a/constants.py +++ b/constants.py @@ -1,5 +1,6 @@ # This file should only contain constants that we use throughout the program, # i.e. stuff that might be changed at some point, but not changed on user-level +# It's not meant to include all string constants or anything, just the ones that are important for functioning. APP_NAME = 'hanabi-league' DB_CONFIG_FILE_NAME = 'config.yaml' @@ -25,3 +26,4 @@ DB_TABLE_NAMES = [ ] DATABASE_SCHEMA_PATH = 'install/database_schema.sql' +DEFAULT_DB_CONFIG_PATH = 'install/default_db_config.yaml' diff --git a/database.py b/database.py index a7eeb94..645ff1e 100644 --- a/database.py +++ b/database.py @@ -47,12 +47,10 @@ def get_existing_tables(): return [table for (table,) in cur.fetchall()] -def init_database(erase: bool = False): - tables = get_existing_tables() - - if not erase and len(tables) > 0: - logger.error("Aborting database initialization: Tables {} already exist".format(", ".join(tables))) - return +def init_database(): + """ + Warning: This drops all existing tables from the database + """ conn = conn_manager.get_connection() cur = conn.cursor() diff --git a/install/default_db_config.yaml b/install/default_db_config.yaml new file mode 100644 index 0000000..a565d4a --- /dev/null +++ b/install/default_db_config.yaml @@ -0,0 +1,3 @@ +dbname: hanabi-league +dbuser: hanabi-league +dbpass: hanabi-league diff --git a/main.py b/main.py new file mode 100644 index 0000000..83cf5c1 --- /dev/null +++ b/main.py @@ -0,0 +1,72 @@ +import argparse + +import verboselogs + +import config +import constants +import database +import log_setup + +from log_setup import logger + + +def subcommand_init(force: bool): + tables = database.get_existing_tables() + if len(tables) > 0 and not force: + logger.info( + 'Database tables "{}" exist already, aborting. To force re-initialization, use the --force options' + .format(", ".join(tables)) + ) + return + if len(tables) > 0: + logger.info( + "WARNING: This will drop all existing tables from the database and re-initialize them." + ) + response = input("Do you wish to continue? [y/N] ") + if response not in ["y", "Y", "yes"]: + return + database.init_database() + logger.info("Successfully initialized database tables") + + +def subcommand_generate_config(): + config.create_db_config() + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog=constants.APP_NAME, + description="Data collection and analysis tool for the Hanabi League" + ) + parser.add_argument('--verbose', '-v', help='Enable verbose logging to console', action='store_true') + + subparsers = parser.add_subparsers(dest='command', required=True, help='select subcommand') + + init_parser = subparsers.add_parser('init', help='Initialize database.') + init_parser.add_argument('--force', '-f', help='Force initialization (Drops existing tables)', action='store_true') + + subparsers.add_parser('generate-config', help='Generate config file at default location') + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + subcommand_func = { + 'init': subcommand_init, + 'generate-config': subcommand_generate_config + }[args.command] + + if args.verbose: + log_setup.logger_manager.set_console_level(verboselogs.VERBOSE) + + del args.command + del args.verbose + + subcommand_func(**vars(args)) + + +if __name__ == "__main__": + main()