import json import psycopg2 import psycopg2.extensions import psycopg2.errors import requests import unidecode import constants from config import read_db_config, read_config from log_setup import logger class DBConnectionManager: def __init__(self): self._conn = None def connect(self): config = read_db_config() logger.debug("Establishing database connection with dbname={}, user={}, password={}".format( config.db_name, config.db_user, config.db_pass )) self._conn = psycopg2.connect("dbname='{}' user='{}' password='{}' host='localhost'".format( config.db_name, config.db_user, config.db_pass )) logger.debug("Established database connection.") def get_connection(self) -> psycopg2.extensions.connection: """ Get the database connection. If not already connected, this reads the database config file and connects to the DB. Otherwise, the already active connection is returned. """ if self._conn is None: self.connect() return self._conn def get_new_cursor(self) -> psycopg2.extensions.cursor: if self._conn is None: self.connect() return self._conn.cursor() # Global instance that will hold our DB connection conn_manager = DBConnectionManager() def get_existing_tables(): conn = conn_manager.get_connection() cur = conn.cursor() table_names = ", ".join("'{}'".format(tablename) for tablename in constants.DB_TABLE_NAMES) cur.execute( "SELECT tablename FROM pg_tables WHERE" " schemaname = 'public' AND" " tablename IN ({})".format( table_names ) ) return [table for (table,) in cur.fetchall()] def init_database(): """ Warning: This drops all existing tables from the database """ conn = conn_manager.get_connection() cur = conn.cursor() with open(constants.DATABASE_SCHEMA_PATH, "r") as f: cur.execute(f.read()) conn.commit() logger.verbose("Initialized DB tables.") def add_player_name(player_name: str): conn = conn_manager.get_connection() cur = conn.cursor() try: cur.execute("INSERT INTO users (player_name) VALUES (%s)", (player_name,)) conn.commit() except psycopg2.errors.UniqueViolation: logger.warn("Player name {} already exists in the database, aborting insertion.".format(player_name)) conn.rollback() def add_user_name_to_player(hanabi_username: str, player_name: str): normalized_username = unidecode.unidecode(hanabi_username).lower() cur = conn_manager.get_new_cursor() cur.execute("SELECT id FROM users WHERE player_name = (%s)", (player_name,)) user_id = cur.fetchone() if user_id is None: logger.error("Display name {} not found in database, cannot add username to it.".format(player_name)) return else: cur.execute("SELECT username, player_name from user_accounts " "INNER JOIN users" " ON user_accounts.user_id = users.id " "WHERE " " normalized_username = (%s)", (normalized_username,) ) res = cur.fetchone() if res is not None: existing_username, existing_player_name = res if existing_player_name == player_name: logger.warn("Hanabi username {} is already registered to player {}, attempted to re-register it.".format( existing_username, existing_player_name )) else: logger.error("Hanabi username {} is already associated to player {}, cannot register it to player {}.".format( res[0], res[1], player_name )) return cur.execute( "INSERT INTO user_accounts (username, normalized_username, user_id) VALUES (%s, %s, %s)", (hanabi_username, normalized_username, user_id) ) conn_manager.get_connection().commit() def add_player(player_name: str, user_name: str): add_player_name(player_name) add_user_name_to_player(user_name, player_name) def fetch_and_initialize_variants(): response = requests.get(constants.VARIANTS_JSON_URL) if not response.status_code == 200: logger.error("Could not download variants.json file from github (tried url {})".format(constants.VARIANTS_JSON_URL)) return variants = json.loads(response.text) config = read_config() for variant in variants: variant_id = variant['id'] name = variant['name'] clue_starved = variant.get('clueStarved', False) num_suits = len(variant['suits']) if config.min_suit_count <= num_suits <= config.max_suit_count: if any(var_name in name.lower() for var_name in config.excluded_variants): continue cur = conn_manager.get_new_cursor() cur.execute( "INSERT INTO variants (id, name, num_suits, clue_starved) VALUES (%s, %s, %s, %s)", (variant_id, name, num_suits, clue_starved) ) conn_manager.get_connection().commit()