Improve DB connection handling: Lazy init + config
Introduces a proper config file for db connection information Also, connection now has to be explicitly initialized instead of this being done on import: This is now done by the CLI function automatically
This commit is contained in:
parent
a014dee0da
commit
967daf1914
11 changed files with 131 additions and 94 deletions
|
@ -9,6 +9,7 @@ from hanabi.live import check_game
|
||||||
from hanabi.live import download_data
|
from hanabi.live import download_data
|
||||||
from hanabi.live import compress
|
from hanabi.live import compress
|
||||||
from hanabi.database import init_database
|
from hanabi.database import init_database
|
||||||
|
from hanabi.database import global_db_connection_manager
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Commands supported:
|
Commands supported:
|
||||||
|
@ -76,6 +77,10 @@ def subcommand_download(
|
||||||
logger.info("Successfully exported games for all variants")
|
logger.info("Successfully exported games for all variants")
|
||||||
|
|
||||||
|
|
||||||
|
def subcommand_gen_config():
|
||||||
|
global_db_connection_manager.create_config_file()
|
||||||
|
|
||||||
|
|
||||||
def add_init_subparser(subparsers):
|
def add_init_subparser(subparsers):
|
||||||
parser = subparsers.add_parser(
|
parser = subparsers.add_parser(
|
||||||
'init',
|
'init',
|
||||||
|
@ -120,6 +125,10 @@ def add_analyze_subparser(subparsers):
|
||||||
parser.add_argument('--download', '-d', help='Download game if not in database', action='store_true')
|
parser.add_argument('--download', '-d', help='Download game if not in database', action='store_true')
|
||||||
|
|
||||||
|
|
||||||
|
def add_config_gen_subparser(subparsers):
|
||||||
|
parser = subparsers.add_parser('gen-config', help='Generate config file at default location')
|
||||||
|
|
||||||
|
|
||||||
def main_parser() -> argparse.ArgumentParser:
|
def main_parser() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
prog='hanabi_suite',
|
prog='hanabi_suite',
|
||||||
|
@ -131,20 +140,27 @@ def main_parser() -> argparse.ArgumentParser:
|
||||||
add_init_subparser(subparsers)
|
add_init_subparser(subparsers)
|
||||||
add_analyze_subparser(subparsers)
|
add_analyze_subparser(subparsers)
|
||||||
add_download_subparser(subparsers)
|
add_download_subparser(subparsers)
|
||||||
|
add_config_gen_subparser(subparsers)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def hanabi_cli():
|
def hanabi_cli():
|
||||||
args = main_parser().parse_args()
|
args = main_parser().parse_args()
|
||||||
switcher = {
|
subcommand_func = {
|
||||||
'analyze': subcommand_analyze,
|
'analyze': subcommand_analyze,
|
||||||
'init': subcommand_init,
|
'init': subcommand_init,
|
||||||
'download': subcommand_download
|
'download': subcommand_download,
|
||||||
}
|
'gen-config': subcommand_gen_config
|
||||||
|
}[args.command]
|
||||||
|
|
||||||
|
if args.command != 'gen-config':
|
||||||
|
global_db_connection_manager.read_config()
|
||||||
|
global_db_connection_manager.connect()
|
||||||
|
|
||||||
if args.verbose:
|
if args.verbose:
|
||||||
logger_manager.set_console_level(verboselogs.VERBOSE)
|
logger_manager.set_console_level(verboselogs.VERBOSE)
|
||||||
method_args = dict(vars(args))
|
del args.command
|
||||||
method_args.pop('command')
|
del args.verbose
|
||||||
method_args.pop('verbose')
|
|
||||||
switcher[args.command](**method_args)
|
subcommand_func(**vars(args))
|
||||||
|
|
|
@ -8,6 +8,10 @@ NUM_STRIKES = 3
|
||||||
COLOR_INITIALS = 'rygbpt'
|
COLOR_INITIALS = 'rygbpt'
|
||||||
PLAYER_NAMES = ["Alice", "Bob", "Cathy", "Donald", "Emily", "Frank"]
|
PLAYER_NAMES = ["Alice", "Bob", "Cathy", "Donald", "Emily", "Frank"]
|
||||||
|
|
||||||
|
# DB connection parameters
|
||||||
|
DEFAULT_DB_NAME = 'hanabi-live'
|
||||||
|
DEFAULT_DB_USER = 'hanabi'
|
||||||
|
|
||||||
|
|
||||||
# hanab.live stuff
|
# hanab.live stuff
|
||||||
|
|
||||||
|
|
|
@ -1 +1,6 @@
|
||||||
from .database import cur, conn
|
from .database import DBConnectionManager
|
||||||
|
|
||||||
|
global_db_connection_manager = DBConnectionManager()
|
||||||
|
|
||||||
|
conn = global_db_connection_manager.lazy_conn
|
||||||
|
cur = global_db_connection_manager.lazy_cur
|
||||||
|
|
|
@ -1,79 +1,90 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from pathlib import Path
|
||||||
|
import yaml
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
|
import platformdirs
|
||||||
|
|
||||||
# global connection
|
from hanabi import constants
|
||||||
conn = psycopg2.connect("dbname=hanab-live-2 user=postgres")
|
from hanabi import logger
|
||||||
|
|
||||||
# cursor
|
|
||||||
cur = conn.cursor()
|
|
||||||
|
|
||||||
|
|
||||||
# init_database_tables()
|
class LazyDBCursor:
|
||||||
# populate_static_tables()
|
def __init__(self):
|
||||||
|
self.__cur: Optional[psycopg2.cursor] = None
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
if self.__cur is None:
|
||||||
|
raise ValueError(
|
||||||
|
"DB cursor used in uninitialized state. Did you forget to initialize the DB connection?"
|
||||||
|
)
|
||||||
|
return getattr(self.__cur, item)
|
||||||
|
|
||||||
|
def set_cur(self, cur):
|
||||||
|
self.__cur = cur
|
||||||
|
|
||||||
|
|
||||||
class Game:
|
class LazyDBConnection:
|
||||||
def __init__(self, info=None):
|
def __init__(self):
|
||||||
self.id = -1
|
self.__conn: Optional[psycopg2.connection] = None
|
||||||
self.num_players = -1
|
|
||||||
self.score = -1
|
|
||||||
self.seed = ""
|
|
||||||
self.variant_id = -1
|
|
||||||
self.deck_plays = None
|
|
||||||
self.one_extra_card = None
|
|
||||||
self.one_less_card = None
|
|
||||||
self.all_or_nothing = None
|
|
||||||
self.num_turns = None
|
|
||||||
if type(info) == dict:
|
|
||||||
self.__dict__.update(info)
|
|
||||||
|
|
||||||
@staticmethod
|
def __getattr__(self, item):
|
||||||
def from_tuple(t):
|
if self.__conn is None:
|
||||||
g = Game()
|
raise ValueError(
|
||||||
g.id = t[0]
|
"DB connection used in uninitialized state. Did you forget to initialize the DB connection?"
|
||||||
g.num_players = t[1]
|
)
|
||||||
g.score = t[2]
|
return getattr(self.__conn, item)
|
||||||
g.seed = t[3]
|
|
||||||
g.variant_id = t[4]
|
|
||||||
g.deck_plays = t[5]
|
|
||||||
g.one_extra_card = t[6]
|
|
||||||
g.one_less_card = t[7]
|
|
||||||
g.all_or_nothing = t[8]
|
|
||||||
g.num_turns = t[9]
|
|
||||||
return g
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
def set_conn(self, conn):
|
||||||
return self.__dict__ == other.__dict__
|
self.__conn = conn
|
||||||
|
|
||||||
|
|
||||||
def load(game_id: int) -> Optional[Game]:
|
class DBConnectionManager:
|
||||||
cur.execute("SELECT * from games WHERE id = {};".format(game_id))
|
def __init__(self):
|
||||||
a = cur.fetchone()
|
self.lazy_conn: LazyDBConnection = LazyDBConnection()
|
||||||
if a is None:
|
self.lazy_cur: LazyDBCursor = LazyDBCursor()
|
||||||
return None
|
self.config_file = Path(platformdirs.user_config_dir(constants.APP_NAME, ensure_exists=True)) / 'config.yaml'
|
||||||
else:
|
self.db_name: str = constants.DEFAULT_DB_NAME
|
||||||
return Game.from_tuple(a)
|
self.db_user: str = constants.DEFAULT_DB_USER
|
||||||
|
|
||||||
|
def read_config(self):
|
||||||
|
logger.debug("DB connection configuration read from {}".format(self.config_file))
|
||||||
|
if self.config_file.exists():
|
||||||
|
with open(self.config_file, "r") as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
self.db_name = config.get('dbname', None)
|
||||||
|
self.db_user = config.get('dbuser', None)
|
||||||
|
if self.db_name is None:
|
||||||
|
logger.verbose("Falling back to default database name {}".format(constants.DEFAULT_DB_NAME))
|
||||||
|
self.db_name = constants.DEFAULT_DB_NAME
|
||||||
|
if self.db_user is None:
|
||||||
|
logger.verbose("Falling back to default database user {}".format(constants.DEFAULT_DB_USER))
|
||||||
|
self.db_user = constants.DEFAULT_DB_USER
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"No configuration file for database connection found, falling back to default values "
|
||||||
|
"(dbname={}, dbuser={}).".format(
|
||||||
|
constants.DEFAULT_DB_NAME, constants.DEFAULT_DB_USER
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Note: To turn off this message, create a config file at {}".format(self.config_file)
|
||||||
|
)
|
||||||
|
|
||||||
def store(game: Game):
|
def create_config_file(self):
|
||||||
stored = load(game.id)
|
if self.config_file.exists():
|
||||||
if stored is None:
|
raise FileExistsError("Configuration file already exists, not overriding.")
|
||||||
# print("inserting game with id {} into DB".format(game.id))
|
self.config_file.write_text(
|
||||||
cur.execute(
|
"dbname: {}\n"
|
||||||
"INSERT INTO games"
|
"dbuser: {}".format(
|
||||||
"(id, num_players, score, seed, variant_id)"
|
constants.DEFAULT_DB_NAME,
|
||||||
"VALUES"
|
constants.DEFAULT_DB_USER
|
||||||
"(%s, %s, %s, %s, %s);",
|
)
|
||||||
(game.id, game.num_players, game.score, game.seed, game.variant_id)
|
|
||||||
)
|
)
|
||||||
print("Inserted game with id {}".format(game.id))
|
logger.info("Initialised default config file {}".format(self.config_file))
|
||||||
else:
|
|
||||||
pass
|
|
||||||
# if not stored == game:
|
|
||||||
# print("Already stored game with id {}, aborting".format(game.id))
|
|
||||||
# print("Stored game is: {}".format(stored.__dict__))
|
|
||||||
# print("New game is: {}".format(game.__dict__))
|
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
def commit():
|
conn = psycopg2.connect("dbname={} user={}".format(self.db_name, self.db_user))
|
||||||
conn.commit()
|
cur = conn.cursor()
|
||||||
|
self.lazy_conn.set_conn(conn)
|
||||||
|
self.lazy_cur.set_cur(cur)
|
||||||
|
|
|
@ -6,7 +6,7 @@ import platformdirs
|
||||||
|
|
||||||
from hanabi import logger
|
from hanabi import logger
|
||||||
from hanabi import constants
|
from hanabi import constants
|
||||||
from .database import cur, conn
|
from hanabi.database import cur, conn
|
||||||
|
|
||||||
|
|
||||||
def get_existing_tables():
|
def get_existing_tables():
|
||||||
|
|
|
@ -7,7 +7,7 @@ import platformdirs
|
||||||
from hanabi import hanab_game
|
from hanabi import hanab_game
|
||||||
from hanabi import constants
|
from hanabi import constants
|
||||||
from hanabi import logger
|
from hanabi import logger
|
||||||
from hanabi.database import database
|
from hanabi import database
|
||||||
from hanabi.live import site_api
|
from hanabi.live import site_api
|
||||||
from hanabi.live import compress
|
from hanabi.live import compress
|
||||||
from hanabi.live import variants
|
from hanabi.live import variants
|
||||||
|
|
|
@ -9,7 +9,7 @@ import time
|
||||||
|
|
||||||
from hanabi import logger
|
from hanabi import logger
|
||||||
from hanabi.solvers.sat import solve_sat
|
from hanabi.solvers.sat import solve_sat
|
||||||
from hanabi.database import database
|
from hanabi import database
|
||||||
from hanabi.live import download_data
|
from hanabi.live import download_data
|
||||||
from hanabi.live import compress
|
from hanabi.live import compress
|
||||||
from hanabi import hanab_game
|
from hanabi import hanab_game
|
||||||
|
|
|
@ -2,43 +2,43 @@ import enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from hanabi import hanab_game
|
from hanabi import hanab_game
|
||||||
|
|
||||||
from hanabi.database.database import cur
|
from hanabi import database
|
||||||
|
|
||||||
|
|
||||||
def variant_id(name) -> Optional[int]:
|
def variant_id(name) -> Optional[int]:
|
||||||
cur.execute(
|
database.cur.execute(
|
||||||
"SELECT id FROM variants WHERE name = %s",
|
"SELECT id FROM variants WHERE name = %s",
|
||||||
(name,)
|
(name,)
|
||||||
)
|
)
|
||||||
var_id = cur.fetchone()
|
var_id = database.cur.fetchone()
|
||||||
if var_id is not None:
|
if var_id is not None:
|
||||||
return var_id[0]
|
return var_id[0]
|
||||||
|
|
||||||
|
|
||||||
def get_all_variant_ids() -> List[int]:
|
def get_all_variant_ids() -> List[int]:
|
||||||
cur.execute(
|
database.cur.execute(
|
||||||
"SELECT id FROM variants "
|
"SELECT id FROM variants "
|
||||||
"ORDER BY id"
|
"ORDER BY id"
|
||||||
)
|
)
|
||||||
return [var_id for (var_id,) in cur.fetchall()]
|
return [var_id for (var_id,) in database.cur.fetchall()]
|
||||||
|
|
||||||
|
|
||||||
def variant_name(var_id) -> Optional[int]:
|
def variant_name(var_id) -> Optional[int]:
|
||||||
cur.execute(
|
database.cur.execute(
|
||||||
"SELECT name FROM variants WHERE id = %s",
|
"SELECT name FROM variants WHERE id = %s",
|
||||||
(var_id,)
|
(var_id,)
|
||||||
)
|
)
|
||||||
name = cur.fetchone()
|
name = database.cur.fetchone()
|
||||||
if name is not None:
|
if name is not None:
|
||||||
return name[0]
|
return name[0]
|
||||||
|
|
||||||
|
|
||||||
def num_suits(var_id) -> Optional[int]:
|
def num_suits(var_id) -> Optional[int]:
|
||||||
cur.execute(
|
database.cur.execute(
|
||||||
"SELECT num_suits FROM variants WHERE id = %s",
|
"SELECT num_suits FROM variants WHERE id = %s",
|
||||||
(var_id,)
|
(var_id,)
|
||||||
)
|
)
|
||||||
num = cur.fetchone()
|
num = database.cur.fetchone()
|
||||||
if num is not None:
|
if num is not None:
|
||||||
return num
|
return num
|
||||||
|
|
||||||
|
@ -90,19 +90,19 @@ class Suit:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_db(suit_id):
|
def from_db(suit_id):
|
||||||
cur.execute(
|
database.cur.execute(
|
||||||
"SELECT name, display_name, abbreviation, rank_clues, color_clues, prism, dark, reversed "
|
"SELECT name, display_name, abbreviation, rank_clues, color_clues, prism, dark, reversed "
|
||||||
"FROM suits "
|
"FROM suits "
|
||||||
"WHERE id = %s",
|
"WHERE id = %s",
|
||||||
(suit_id,)
|
(suit_id,)
|
||||||
)
|
)
|
||||||
suit_properties = cur.fetchone()
|
suit_properties = database.cur.fetchone()
|
||||||
|
|
||||||
cur.execute(
|
database.cur.execute(
|
||||||
"SELECT color_id FROM suit_colors WHERE suit_id = %s",
|
"SELECT color_id FROM suit_colors WHERE suit_id = %s",
|
||||||
(suit_id,)
|
(suit_id,)
|
||||||
)
|
)
|
||||||
colors = list(map(lambda t: t[0], cur.fetchall()))
|
colors = list(map(lambda t: t[0], database.cur.fetchall()))
|
||||||
return Suit(*suit_properties, colors)
|
return Suit(*suit_properties, colors)
|
||||||
|
|
||||||
|
|
||||||
|
@ -232,7 +232,7 @@ class Variant:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_db(var_id):
|
def from_db(var_id):
|
||||||
cur.execute(
|
database.cur.execute(
|
||||||
"SELECT "
|
"SELECT "
|
||||||
"name, clue_starved, throw_it_in_a_hole, alternating_clues, synesthesia, chimneys, funnels, "
|
"name, clue_starved, throw_it_in_a_hole, alternating_clues, synesthesia, chimneys, funnels, "
|
||||||
"no_color_clues, no_rank_clues, empty_color_clues, empty_rank_clues, odds_and_evens, up_or_down,"
|
"no_color_clues, no_rank_clues, empty_color_clues, empty_rank_clues, odds_and_evens, up_or_down,"
|
||||||
|
@ -240,14 +240,14 @@ class Variant:
|
||||||
"FROM variants WHERE id = %s",
|
"FROM variants WHERE id = %s",
|
||||||
(var_id,)
|
(var_id,)
|
||||||
)
|
)
|
||||||
var_properties = cur.fetchone()
|
var_properties = database.cur.fetchone()
|
||||||
|
|
||||||
cur.execute(
|
database.cur.execute(
|
||||||
"SELECT suit_id FROM variant_suits "
|
"SELECT suit_id FROM variant_suits "
|
||||||
"WHERE variant_id = %s "
|
"WHERE variant_id = %s "
|
||||||
"ORDER BY index",
|
"ORDER BY index",
|
||||||
(var_id,)
|
(var_id,)
|
||||||
)
|
)
|
||||||
var_suits = [Suit.from_db(*s) for s in cur.fetchall()]
|
var_suits = [Suit.from_db(*s) for s in database.cur.fetchall()]
|
||||||
|
|
||||||
return Variant(*var_properties, var_suits)
|
return Variant(*var_properties, var_suits)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from hanabi.live import compress
|
from hanabi.live import compress
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from hanabi.database import database
|
from hanabi import database
|
||||||
from hanabi import hanab_game
|
from hanabi import hanab_game
|
||||||
from hanabi.live import compress
|
from hanabi.live import compress
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import Optional
|
||||||
from hanabi import logger
|
from hanabi import logger
|
||||||
from hanabi import hanab_game
|
from hanabi import hanab_game
|
||||||
from hanabi.live import compress
|
from hanabi.live import compress
|
||||||
from hanabi.database import database
|
from hanabi import database
|
||||||
|
|
||||||
|
|
||||||
class CardType(Enum):
|
class CardType(Enum):
|
||||||
|
|
|
@ -9,3 +9,4 @@ argparse
|
||||||
verboselogs
|
verboselogs
|
||||||
pebble
|
pebble
|
||||||
platformdirs
|
platformdirs
|
||||||
|
PyYAML
|
||||||
|
|
Loading…
Reference in a new issue