From 967daf1914c2cbb7bac601d9ed3526ec14acca40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20Ke=C3=9Fler?= Date: Wed, 5 Jul 2023 20:50:40 +0200 Subject: [PATCH] Improve DB connection handling: Lazy init + config Introduces a proper config file for db connection information Also, connection now has to be explicitly initialized instead of this being done on import: This is now done by the CLI function automatically --- hanabi/cli.py | 30 +++++-- hanabi/constants.py | 4 + hanabi/database/__init__.py | 7 +- hanabi/database/database.py | 139 +++++++++++++++++-------------- hanabi/database/init_database.py | 2 +- hanabi/live/download_data.py | 2 +- hanabi/live/instance_finder.py | 2 +- hanabi/live/variants.py | 34 ++++---- hanabi/solvers/deck_analyzer.py | 2 +- hanabi/solvers/greedy_solver.py | 2 +- requirements.txt | 1 + 11 files changed, 131 insertions(+), 94 deletions(-) diff --git a/hanabi/cli.py b/hanabi/cli.py index ee1406c..8f9a1da 100755 --- a/hanabi/cli.py +++ b/hanabi/cli.py @@ -9,6 +9,7 @@ from hanabi.live import check_game from hanabi.live import download_data from hanabi.live import compress from hanabi.database import init_database +from hanabi.database import global_db_connection_manager """ Commands supported: @@ -76,6 +77,10 @@ def subcommand_download( logger.info("Successfully exported games for all variants") +def subcommand_gen_config(): + global_db_connection_manager.create_config_file() + + def add_init_subparser(subparsers): parser = subparsers.add_parser( 'init', @@ -120,6 +125,10 @@ def add_analyze_subparser(subparsers): parser.add_argument('--download', '-d', help='Download game if not in database', action='store_true') +def add_config_gen_subparser(subparsers): + parser = subparsers.add_parser('gen-config', help='Generate config file at default location') + + def main_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( prog='hanabi_suite', @@ -131,20 +140,27 @@ def main_parser() -> argparse.ArgumentParser: add_init_subparser(subparsers) add_analyze_subparser(subparsers) add_download_subparser(subparsers) + add_config_gen_subparser(subparsers) return parser def hanabi_cli(): args = main_parser().parse_args() - switcher = { + subcommand_func = { 'analyze': subcommand_analyze, 'init': subcommand_init, - 'download': subcommand_download - } + 'download': subcommand_download, + 'gen-config': subcommand_gen_config + }[args.command] + + if args.command != 'gen-config': + global_db_connection_manager.read_config() + global_db_connection_manager.connect() + if args.verbose: logger_manager.set_console_level(verboselogs.VERBOSE) - method_args = dict(vars(args)) - method_args.pop('command') - method_args.pop('verbose') - switcher[args.command](**method_args) + del args.command + del args.verbose + + subcommand_func(**vars(args)) diff --git a/hanabi/constants.py b/hanabi/constants.py index 0588bcf..93d7a1a 100644 --- a/hanabi/constants.py +++ b/hanabi/constants.py @@ -8,6 +8,10 @@ NUM_STRIKES = 3 COLOR_INITIALS = 'rygbpt' PLAYER_NAMES = ["Alice", "Bob", "Cathy", "Donald", "Emily", "Frank"] +# DB connection parameters +DEFAULT_DB_NAME = 'hanabi-live' +DEFAULT_DB_USER = 'hanabi' + # hanab.live stuff diff --git a/hanabi/database/__init__.py b/hanabi/database/__init__.py index 230c331..371a971 100644 --- a/hanabi/database/__init__.py +++ b/hanabi/database/__init__.py @@ -1 +1,6 @@ -from .database import cur, conn +from .database import DBConnectionManager + +global_db_connection_manager = DBConnectionManager() + +conn = global_db_connection_manager.lazy_conn +cur = global_db_connection_manager.lazy_cur diff --git a/hanabi/database/database.py b/hanabi/database/database.py index 6392373..e75abe8 100644 --- a/hanabi/database/database.py +++ b/hanabi/database/database.py @@ -1,79 +1,90 @@ from typing import Optional +from pathlib import Path +import yaml + import psycopg2 +import platformdirs -# global connection -conn = psycopg2.connect("dbname=hanab-live-2 user=postgres") - -# cursor -cur = conn.cursor() +from hanabi import constants +from hanabi import logger -# init_database_tables() -# populate_static_tables() +class LazyDBCursor: + def __init__(self): + self.__cur: Optional[psycopg2.cursor] = None + + def __getattr__(self, item): + if self.__cur is None: + raise ValueError( + "DB cursor used in uninitialized state. Did you forget to initialize the DB connection?" + ) + return getattr(self.__cur, item) + + def set_cur(self, cur): + self.__cur = cur -class Game: - def __init__(self, info=None): - self.id = -1 - self.num_players = -1 - self.score = -1 - self.seed = "" - self.variant_id = -1 - self.deck_plays = None - self.one_extra_card = None - self.one_less_card = None - self.all_or_nothing = None - self.num_turns = None - if type(info) == dict: - self.__dict__.update(info) +class LazyDBConnection: + def __init__(self): + self.__conn: Optional[psycopg2.connection] = None - @staticmethod - def from_tuple(t): - g = Game() - g.id = t[0] - g.num_players = t[1] - g.score = t[2] - g.seed = t[3] - g.variant_id = t[4] - g.deck_plays = t[5] - g.one_extra_card = t[6] - g.one_less_card = t[7] - g.all_or_nothing = t[8] - g.num_turns = t[9] - return g + def __getattr__(self, item): + if self.__conn is None: + raise ValueError( + "DB connection used in uninitialized state. Did you forget to initialize the DB connection?" + ) + return getattr(self.__conn, item) - def __eq__(self, other): - return self.__dict__ == other.__dict__ + def set_conn(self, conn): + self.__conn = conn -def load(game_id: int) -> Optional[Game]: - cur.execute("SELECT * from games WHERE id = {};".format(game_id)) - a = cur.fetchone() - if a is None: - return None - else: - return Game.from_tuple(a) +class DBConnectionManager: + def __init__(self): + self.lazy_conn: LazyDBConnection = LazyDBConnection() + self.lazy_cur: LazyDBCursor = LazyDBCursor() + self.config_file = Path(platformdirs.user_config_dir(constants.APP_NAME, ensure_exists=True)) / 'config.yaml' + self.db_name: str = constants.DEFAULT_DB_NAME + self.db_user: str = constants.DEFAULT_DB_USER + def read_config(self): + logger.debug("DB connection configuration read from {}".format(self.config_file)) + if self.config_file.exists(): + with open(self.config_file, "r") as f: + config = yaml.safe_load(f) + self.db_name = config.get('dbname', None) + self.db_user = config.get('dbuser', None) + if self.db_name is None: + logger.verbose("Falling back to default database name {}".format(constants.DEFAULT_DB_NAME)) + self.db_name = constants.DEFAULT_DB_NAME + if self.db_user is None: + logger.verbose("Falling back to default database user {}".format(constants.DEFAULT_DB_USER)) + self.db_user = constants.DEFAULT_DB_USER + else: + logger.info( + "No configuration file for database connection found, falling back to default values " + "(dbname={}, dbuser={}).".format( + constants.DEFAULT_DB_NAME, constants.DEFAULT_DB_USER + ) + ) + logger.info( + "Note: To turn off this message, create a config file at {}".format(self.config_file) + ) -def store(game: Game): - stored = load(game.id) - if stored is None: - # print("inserting game with id {} into DB".format(game.id)) - cur.execute( - "INSERT INTO games" - "(id, num_players, score, seed, variant_id)" - "VALUES" - "(%s, %s, %s, %s, %s);", - (game.id, game.num_players, game.score, game.seed, game.variant_id) + def create_config_file(self): + if self.config_file.exists(): + raise FileExistsError("Configuration file already exists, not overriding.") + self.config_file.write_text( + "dbname: {}\n" + "dbuser: {}".format( + constants.DEFAULT_DB_NAME, + constants.DEFAULT_DB_USER + ) ) - print("Inserted game with id {}".format(game.id)) - else: - pass -# if not stored == game: -# print("Already stored game with id {}, aborting".format(game.id)) -# print("Stored game is: {}".format(stored.__dict__)) -# print("New game is: {}".format(game.__dict__)) + logger.info("Initialised default config file {}".format(self.config_file)) - -def commit(): - conn.commit() + def connect(self): + conn = psycopg2.connect("dbname={} user={}".format(self.db_name, self.db_user)) + cur = conn.cursor() + self.lazy_conn.set_conn(conn) + self.lazy_cur.set_cur(cur) diff --git a/hanabi/database/init_database.py b/hanabi/database/init_database.py index 891ec5b..654a15e 100644 --- a/hanabi/database/init_database.py +++ b/hanabi/database/init_database.py @@ -6,7 +6,7 @@ import platformdirs from hanabi import logger from hanabi import constants -from .database import cur, conn +from hanabi.database import cur, conn def get_existing_tables(): diff --git a/hanabi/live/download_data.py b/hanabi/live/download_data.py index 6877378..fc592d9 100644 --- a/hanabi/live/download_data.py +++ b/hanabi/live/download_data.py @@ -7,7 +7,7 @@ import platformdirs from hanabi import hanab_game from hanabi import constants from hanabi import logger -from hanabi.database import database +from hanabi import database from hanabi.live import site_api from hanabi.live import compress from hanabi.live import variants diff --git a/hanabi/live/instance_finder.py b/hanabi/live/instance_finder.py index e4ecd27..4b0d09e 100644 --- a/hanabi/live/instance_finder.py +++ b/hanabi/live/instance_finder.py @@ -9,7 +9,7 @@ import time from hanabi import logger from hanabi.solvers.sat import solve_sat -from hanabi.database import database +from hanabi import database from hanabi.live import download_data from hanabi.live import compress from hanabi import hanab_game diff --git a/hanabi/live/variants.py b/hanabi/live/variants.py index 770c0b0..0ba1c89 100644 --- a/hanabi/live/variants.py +++ b/hanabi/live/variants.py @@ -2,43 +2,43 @@ import enum from typing import List, Optional from hanabi import hanab_game -from hanabi.database.database import cur +from hanabi import database def variant_id(name) -> Optional[int]: - cur.execute( + database.cur.execute( "SELECT id FROM variants WHERE name = %s", (name,) ) - var_id = cur.fetchone() + var_id = database.cur.fetchone() if var_id is not None: return var_id[0] def get_all_variant_ids() -> List[int]: - cur.execute( + database.cur.execute( "SELECT id FROM variants " "ORDER BY id" ) - return [var_id for (var_id,) in cur.fetchall()] + return [var_id for (var_id,) in database.cur.fetchall()] def variant_name(var_id) -> Optional[int]: - cur.execute( + database.cur.execute( "SELECT name FROM variants WHERE id = %s", (var_id,) ) - name = cur.fetchone() + name = database.cur.fetchone() if name is not None: return name[0] def num_suits(var_id) -> Optional[int]: - cur.execute( + database.cur.execute( "SELECT num_suits FROM variants WHERE id = %s", (var_id,) ) - num = cur.fetchone() + num = database.cur.fetchone() if num is not None: return num @@ -90,19 +90,19 @@ class Suit: @staticmethod def from_db(suit_id): - cur.execute( + database.cur.execute( "SELECT name, display_name, abbreviation, rank_clues, color_clues, prism, dark, reversed " "FROM suits " "WHERE id = %s", (suit_id,) ) - suit_properties = cur.fetchone() + suit_properties = database.cur.fetchone() - cur.execute( + database.cur.execute( "SELECT color_id FROM suit_colors WHERE suit_id = %s", (suit_id,) ) - colors = list(map(lambda t: t[0], cur.fetchall())) + colors = list(map(lambda t: t[0], database.cur.fetchall())) return Suit(*suit_properties, colors) @@ -232,7 +232,7 @@ class Variant: @staticmethod def from_db(var_id): - cur.execute( + database.cur.execute( "SELECT " "name, clue_starved, throw_it_in_a_hole, alternating_clues, synesthesia, chimneys, funnels, " "no_color_clues, no_rank_clues, empty_color_clues, empty_rank_clues, odds_and_evens, up_or_down," @@ -240,14 +240,14 @@ class Variant: "FROM variants WHERE id = %s", (var_id,) ) - var_properties = cur.fetchone() + var_properties = database.cur.fetchone() - cur.execute( + database.cur.execute( "SELECT suit_id FROM variant_suits " "WHERE variant_id = %s " "ORDER BY index", (var_id,) ) - var_suits = [Suit.from_db(*s) for s in cur.fetchall()] + var_suits = [Suit.from_db(*s) for s in database.cur.fetchall()] return Variant(*var_properties, var_suits) diff --git a/hanabi/solvers/deck_analyzer.py b/hanabi/solvers/deck_analyzer.py index 93bd9ff..7994c32 100644 --- a/hanabi/solvers/deck_analyzer.py +++ b/hanabi/solvers/deck_analyzer.py @@ -1,7 +1,7 @@ from hanabi.live import compress from enum import Enum -from hanabi.database import database +from hanabi import database from hanabi import hanab_game from hanabi.live import compress diff --git a/hanabi/solvers/greedy_solver.py b/hanabi/solvers/greedy_solver.py index c78fff8..cb6c4fe 100755 --- a/hanabi/solvers/greedy_solver.py +++ b/hanabi/solvers/greedy_solver.py @@ -8,7 +8,7 @@ from typing import Optional from hanabi import logger from hanabi import hanab_game from hanabi.live import compress -from hanabi.database import database +from hanabi import database class CardType(Enum): diff --git a/requirements.txt b/requirements.txt index 6bd5085..cf70c11 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ argparse verboselogs pebble platformdirs +PyYAML