Improve DB connection handling: Lazy init + config

Introduces a proper config file for db connection information
Also, connection now has to be explicitly initialized instead of this
being done on import: This is now done by the CLI function automatically
This commit is contained in:
Maximilian Keßler 2023-07-05 20:50:40 +02:00
parent a014dee0da
commit 967daf1914
Signed by: max
GPG key ID: BCC5A619923C0BA5
11 changed files with 131 additions and 94 deletions

View file

@ -9,6 +9,7 @@ from hanabi.live import check_game
from hanabi.live import download_data from hanabi.live import download_data
from hanabi.live import compress from hanabi.live import compress
from hanabi.database import init_database from hanabi.database import init_database
from hanabi.database import global_db_connection_manager
""" """
Commands supported: Commands supported:
@ -76,6 +77,10 @@ def subcommand_download(
logger.info("Successfully exported games for all variants") logger.info("Successfully exported games for all variants")
def subcommand_gen_config():
global_db_connection_manager.create_config_file()
def add_init_subparser(subparsers): def add_init_subparser(subparsers):
parser = subparsers.add_parser( parser = subparsers.add_parser(
'init', 'init',
@ -120,6 +125,10 @@ def add_analyze_subparser(subparsers):
parser.add_argument('--download', '-d', help='Download game if not in database', action='store_true') parser.add_argument('--download', '-d', help='Download game if not in database', action='store_true')
def add_config_gen_subparser(subparsers):
parser = subparsers.add_parser('gen-config', help='Generate config file at default location')
def main_parser() -> argparse.ArgumentParser: def main_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog='hanabi_suite', prog='hanabi_suite',
@ -131,20 +140,27 @@ def main_parser() -> argparse.ArgumentParser:
add_init_subparser(subparsers) add_init_subparser(subparsers)
add_analyze_subparser(subparsers) add_analyze_subparser(subparsers)
add_download_subparser(subparsers) add_download_subparser(subparsers)
add_config_gen_subparser(subparsers)
return parser return parser
def hanabi_cli(): def hanabi_cli():
args = main_parser().parse_args() args = main_parser().parse_args()
switcher = { subcommand_func = {
'analyze': subcommand_analyze, 'analyze': subcommand_analyze,
'init': subcommand_init, 'init': subcommand_init,
'download': subcommand_download 'download': subcommand_download,
} 'gen-config': subcommand_gen_config
}[args.command]
if args.command != 'gen-config':
global_db_connection_manager.read_config()
global_db_connection_manager.connect()
if args.verbose: if args.verbose:
logger_manager.set_console_level(verboselogs.VERBOSE) logger_manager.set_console_level(verboselogs.VERBOSE)
method_args = dict(vars(args)) del args.command
method_args.pop('command') del args.verbose
method_args.pop('verbose')
switcher[args.command](**method_args) subcommand_func(**vars(args))

View file

@ -8,6 +8,10 @@ NUM_STRIKES = 3
COLOR_INITIALS = 'rygbpt' COLOR_INITIALS = 'rygbpt'
PLAYER_NAMES = ["Alice", "Bob", "Cathy", "Donald", "Emily", "Frank"] PLAYER_NAMES = ["Alice", "Bob", "Cathy", "Donald", "Emily", "Frank"]
# DB connection parameters
DEFAULT_DB_NAME = 'hanabi-live'
DEFAULT_DB_USER = 'hanabi'
# hanab.live stuff # hanab.live stuff

View file

@ -1 +1,6 @@
from .database import cur, conn from .database import DBConnectionManager
global_db_connection_manager = DBConnectionManager()
conn = global_db_connection_manager.lazy_conn
cur = global_db_connection_manager.lazy_cur

View file

@ -1,79 +1,90 @@
from typing import Optional from typing import Optional
from pathlib import Path
import yaml
import psycopg2 import psycopg2
import platformdirs
# global connection from hanabi import constants
conn = psycopg2.connect("dbname=hanab-live-2 user=postgres") from hanabi import logger
# cursor
cur = conn.cursor()
# init_database_tables() class LazyDBCursor:
# populate_static_tables() def __init__(self):
self.__cur: Optional[psycopg2.cursor] = None
def __getattr__(self, item):
class Game: if self.__cur is None:
def __init__(self, info=None): raise ValueError(
self.id = -1 "DB cursor used in uninitialized state. Did you forget to initialize the DB connection?"
self.num_players = -1
self.score = -1
self.seed = ""
self.variant_id = -1
self.deck_plays = None
self.one_extra_card = None
self.one_less_card = None
self.all_or_nothing = None
self.num_turns = None
if type(info) == dict:
self.__dict__.update(info)
@staticmethod
def from_tuple(t):
g = Game()
g.id = t[0]
g.num_players = t[1]
g.score = t[2]
g.seed = t[3]
g.variant_id = t[4]
g.deck_plays = t[5]
g.one_extra_card = t[6]
g.one_less_card = t[7]
g.all_or_nothing = t[8]
g.num_turns = t[9]
return g
def __eq__(self, other):
return self.__dict__ == other.__dict__
def load(game_id: int) -> Optional[Game]:
cur.execute("SELECT * from games WHERE id = {};".format(game_id))
a = cur.fetchone()
if a is None:
return None
else:
return Game.from_tuple(a)
def store(game: Game):
stored = load(game.id)
if stored is None:
# print("inserting game with id {} into DB".format(game.id))
cur.execute(
"INSERT INTO games"
"(id, num_players, score, seed, variant_id)"
"VALUES"
"(%s, %s, %s, %s, %s);",
(game.id, game.num_players, game.score, game.seed, game.variant_id)
) )
print("Inserted game with id {}".format(game.id)) return getattr(self.__cur, item)
def set_cur(self, cur):
self.__cur = cur
class LazyDBConnection:
def __init__(self):
self.__conn: Optional[psycopg2.connection] = None
def __getattr__(self, item):
if self.__conn is None:
raise ValueError(
"DB connection used in uninitialized state. Did you forget to initialize the DB connection?"
)
return getattr(self.__conn, item)
def set_conn(self, conn):
self.__conn = conn
class DBConnectionManager:
def __init__(self):
self.lazy_conn: LazyDBConnection = LazyDBConnection()
self.lazy_cur: LazyDBCursor = LazyDBCursor()
self.config_file = Path(platformdirs.user_config_dir(constants.APP_NAME, ensure_exists=True)) / 'config.yaml'
self.db_name: str = constants.DEFAULT_DB_NAME
self.db_user: str = constants.DEFAULT_DB_USER
def read_config(self):
logger.debug("DB connection configuration read from {}".format(self.config_file))
if self.config_file.exists():
with open(self.config_file, "r") as f:
config = yaml.safe_load(f)
self.db_name = config.get('dbname', None)
self.db_user = config.get('dbuser', None)
if self.db_name is None:
logger.verbose("Falling back to default database name {}".format(constants.DEFAULT_DB_NAME))
self.db_name = constants.DEFAULT_DB_NAME
if self.db_user is None:
logger.verbose("Falling back to default database user {}".format(constants.DEFAULT_DB_USER))
self.db_user = constants.DEFAULT_DB_USER
else: else:
pass logger.info(
# if not stored == game: "No configuration file for database connection found, falling back to default values "
# print("Already stored game with id {}, aborting".format(game.id)) "(dbname={}, dbuser={}).".format(
# print("Stored game is: {}".format(stored.__dict__)) constants.DEFAULT_DB_NAME, constants.DEFAULT_DB_USER
# print("New game is: {}".format(game.__dict__)) )
)
logger.info(
"Note: To turn off this message, create a config file at {}".format(self.config_file)
)
def create_config_file(self):
if self.config_file.exists():
raise FileExistsError("Configuration file already exists, not overriding.")
self.config_file.write_text(
"dbname: {}\n"
"dbuser: {}".format(
constants.DEFAULT_DB_NAME,
constants.DEFAULT_DB_USER
)
)
logger.info("Initialised default config file {}".format(self.config_file))
def commit(): def connect(self):
conn.commit() conn = psycopg2.connect("dbname={} user={}".format(self.db_name, self.db_user))
cur = conn.cursor()
self.lazy_conn.set_conn(conn)
self.lazy_cur.set_cur(cur)

View file

@ -6,7 +6,7 @@ import platformdirs
from hanabi import logger from hanabi import logger
from hanabi import constants from hanabi import constants
from .database import cur, conn from hanabi.database import cur, conn
def get_existing_tables(): def get_existing_tables():

View file

@ -7,7 +7,7 @@ import platformdirs
from hanabi import hanab_game from hanabi import hanab_game
from hanabi import constants from hanabi import constants
from hanabi import logger from hanabi import logger
from hanabi.database import database from hanabi import database
from hanabi.live import site_api from hanabi.live import site_api
from hanabi.live import compress from hanabi.live import compress
from hanabi.live import variants from hanabi.live import variants

View file

@ -9,7 +9,7 @@ import time
from hanabi import logger from hanabi import logger
from hanabi.solvers.sat import solve_sat from hanabi.solvers.sat import solve_sat
from hanabi.database import database from hanabi import database
from hanabi.live import download_data from hanabi.live import download_data
from hanabi.live import compress from hanabi.live import compress
from hanabi import hanab_game from hanabi import hanab_game

View file

@ -2,43 +2,43 @@ import enum
from typing import List, Optional from typing import List, Optional
from hanabi import hanab_game from hanabi import hanab_game
from hanabi.database.database import cur from hanabi import database
def variant_id(name) -> Optional[int]: def variant_id(name) -> Optional[int]:
cur.execute( database.cur.execute(
"SELECT id FROM variants WHERE name = %s", "SELECT id FROM variants WHERE name = %s",
(name,) (name,)
) )
var_id = cur.fetchone() var_id = database.cur.fetchone()
if var_id is not None: if var_id is not None:
return var_id[0] return var_id[0]
def get_all_variant_ids() -> List[int]: def get_all_variant_ids() -> List[int]:
cur.execute( database.cur.execute(
"SELECT id FROM variants " "SELECT id FROM variants "
"ORDER BY id" "ORDER BY id"
) )
return [var_id for (var_id,) in cur.fetchall()] return [var_id for (var_id,) in database.cur.fetchall()]
def variant_name(var_id) -> Optional[int]: def variant_name(var_id) -> Optional[int]:
cur.execute( database.cur.execute(
"SELECT name FROM variants WHERE id = %s", "SELECT name FROM variants WHERE id = %s",
(var_id,) (var_id,)
) )
name = cur.fetchone() name = database.cur.fetchone()
if name is not None: if name is not None:
return name[0] return name[0]
def num_suits(var_id) -> Optional[int]: def num_suits(var_id) -> Optional[int]:
cur.execute( database.cur.execute(
"SELECT num_suits FROM variants WHERE id = %s", "SELECT num_suits FROM variants WHERE id = %s",
(var_id,) (var_id,)
) )
num = cur.fetchone() num = database.cur.fetchone()
if num is not None: if num is not None:
return num return num
@ -90,19 +90,19 @@ class Suit:
@staticmethod @staticmethod
def from_db(suit_id): def from_db(suit_id):
cur.execute( database.cur.execute(
"SELECT name, display_name, abbreviation, rank_clues, color_clues, prism, dark, reversed " "SELECT name, display_name, abbreviation, rank_clues, color_clues, prism, dark, reversed "
"FROM suits " "FROM suits "
"WHERE id = %s", "WHERE id = %s",
(suit_id,) (suit_id,)
) )
suit_properties = cur.fetchone() suit_properties = database.cur.fetchone()
cur.execute( database.cur.execute(
"SELECT color_id FROM suit_colors WHERE suit_id = %s", "SELECT color_id FROM suit_colors WHERE suit_id = %s",
(suit_id,) (suit_id,)
) )
colors = list(map(lambda t: t[0], cur.fetchall())) colors = list(map(lambda t: t[0], database.cur.fetchall()))
return Suit(*suit_properties, colors) return Suit(*suit_properties, colors)
@ -232,7 +232,7 @@ class Variant:
@staticmethod @staticmethod
def from_db(var_id): def from_db(var_id):
cur.execute( database.cur.execute(
"SELECT " "SELECT "
"name, clue_starved, throw_it_in_a_hole, alternating_clues, synesthesia, chimneys, funnels, " "name, clue_starved, throw_it_in_a_hole, alternating_clues, synesthesia, chimneys, funnels, "
"no_color_clues, no_rank_clues, empty_color_clues, empty_rank_clues, odds_and_evens, up_or_down," "no_color_clues, no_rank_clues, empty_color_clues, empty_rank_clues, odds_and_evens, up_or_down,"
@ -240,14 +240,14 @@ class Variant:
"FROM variants WHERE id = %s", "FROM variants WHERE id = %s",
(var_id,) (var_id,)
) )
var_properties = cur.fetchone() var_properties = database.cur.fetchone()
cur.execute( database.cur.execute(
"SELECT suit_id FROM variant_suits " "SELECT suit_id FROM variant_suits "
"WHERE variant_id = %s " "WHERE variant_id = %s "
"ORDER BY index", "ORDER BY index",
(var_id,) (var_id,)
) )
var_suits = [Suit.from_db(*s) for s in cur.fetchall()] var_suits = [Suit.from_db(*s) for s in database.cur.fetchall()]
return Variant(*var_properties, var_suits) return Variant(*var_properties, var_suits)

View file

@ -1,7 +1,7 @@
from hanabi.live import compress from hanabi.live import compress
from enum import Enum from enum import Enum
from hanabi.database import database from hanabi import database
from hanabi import hanab_game from hanabi import hanab_game
from hanabi.live import compress from hanabi.live import compress

View file

@ -8,7 +8,7 @@ from typing import Optional
from hanabi import logger from hanabi import logger
from hanabi import hanab_game from hanabi import hanab_game
from hanabi.live import compress from hanabi.live import compress
from hanabi.database import database from hanabi import database
class CardType(Enum): class CardType(Enum):

View file

@ -9,3 +9,4 @@ argparse
verboselogs verboselogs
pebble pebble
platformdirs platformdirs
PyYAML