import json from typing import Optional, List, Dict import psycopg2 import psycopg2.extensions import psycopg2.extras import psycopg2.errors import requests import unidecode import constants import utils from config import config_manager from log_setup import logger class DBConnectionManager: def __init__(self): self._conn = None def connect(self): config = config_manager.get_db_config() logger.debug("Establishing database connection with dbname={}, user={}, password={}".format( config.db_name, config.db_user, config.db_pass )) self._conn = psycopg2.connect("dbname='{}' user='{}' password='{}' host='localhost'".format( config.db_name, config.db_user, config.db_pass )) logger.debug("Established database connection.") def get_connection(self) -> psycopg2.extensions.connection: """ Get the database connection. If not already connected, this reads the database config file and connects to the DB. Otherwise, the already active connection is returned. """ if self._conn is None: self.connect() return self._conn def get_new_cursor(self) -> psycopg2.extensions.cursor: if self._conn is None: self.connect() return self._conn.cursor() # Global instance that will hold our DB connection conn_manager = DBConnectionManager() def get_existing_tables(): conn = conn_manager.get_connection() cur = conn.cursor() table_names = ", ".join("'{}'".format(tablename) for tablename in constants.DB_TABLE_NAMES) cur.execute( "SELECT tablename FROM pg_tables WHERE" " schemaname = 'public' AND" " tablename IN ({})".format( table_names ) ) return [table for (table,) in cur.fetchall()] def init_database(): """ Warning: This drops all existing tables from the database """ conn = conn_manager.get_connection() cur = conn.cursor() with open(constants.DATABASE_SCHEMA_PATH, "r") as f: cur.execute(f.read()) conn.commit() logger.verbose("Initialized DB tables.") def fetch_and_initialize_variants(): response = requests.get(constants.VARIANTS_JSON_URL) if not response.status_code == 200: logger.error( "Could not download variants.json file from github (tried url {})".format(constants.VARIANTS_JSON_URL)) return variants = json.loads(response.text) config = config_manager.get_config() for variant in variants: variant_id = variant['id'] name = variant['name'] clue_starved = variant.get('clueStarved', False) num_suits = len(variant['suits']) if config.min_suit_count <= num_suits <= config.max_suit_count: if any(var_name in name.lower() for var_name in config.excluded_variants): continue cur = conn_manager.get_new_cursor() cur.execute( "INSERT INTO variants (id, name, num_suits, clue_starved) VALUES (%s, %s, %s, %s)", (variant_id, name, num_suits, clue_starved) ) conn_manager.get_connection().commit() def initialize_variant_base_ratings(): config = config_manager.get_config() conn = conn_manager.get_connection() cur = conn.cursor() cur.execute("SELECT id, name from variants") ratings = [] for variant_id, variant_name in cur.fetchall(): for num_players in range(config.min_player_count, config.max_player_count + 1): rating = config.variant_base_rating(variant_name, num_players) ratings.append((variant_id, num_players, rating)) psycopg2.extras.execute_values( cur, "INSERT INTO variant_base_ratings (variant_id, player_count, rating)" "VALUES %s " "ON CONFLICT (variant_id, player_count) " "DO UPDATE SET rating = EXCLUDED.rating", ratings ) conn.commit() def get_user_id(player_name: str) -> Optional[int]: cur = conn_manager.get_new_cursor() cur.execute("SELECT id FROM users WHERE player_name = (%s)", (player_name,)) return cur.fetchone() def get_user_ids_from_normalized_usernames(normalized_usernames: List[str]) -> List[int] | str: """ @rtype: If all users are registered, list of their ids in the same order. Otherwise, name of a user that is not registered. @warning If usernames are not present in the database, there is no corresponding key in the returned dictionary """ cur = conn_manager.get_new_cursor() cur.execute("SELECT normalized_username, user_id " "FROM user_accounts " "WHERE normalized_username IN ({})".format(",".join("%s" for _ in normalized_usernames)), normalized_usernames ) # Build up dict from the specified user ids user_dict: Dict[str, int] = {} for normalized_username, user_id in cur.fetchall(): user_dict[normalized_username] = user_id user_ids = [] for normalized_username in normalized_usernames: if normalized_username not in user_dict.keys(): return normalized_username else: user_ids.append(user_dict[normalized_username]) return user_ids def get_variant_id(variant_name: str) -> Optional[int]: cur = conn_manager.get_new_cursor() cur.execute("SELECT id FROM variants WHERE name = %s", (variant_name,)) return cur.fetchone() def add_player_name(player_name: str): conn = conn_manager.get_connection() cur = conn.cursor() try: cur.execute("INSERT INTO users (player_name) VALUES (%s)", (player_name,)) conn.commit() except psycopg2.errors.UniqueViolation: logger.warn("Player name {} already exists in the database, aborting insertion.".format(player_name)) conn.rollback() def add_user_name_to_player(hanabi_username: str, player_name: str): normalized_username = utils.normalize_username(hanabi_username) user_id = get_user_id(player_name) if user_id is None: logger.error("Player {} not found in database, cannot add username to it.".format(player_name)) return else: cur = conn_manager.get_new_cursor() cur.execute("SELECT username, player_name from user_accounts " "INNER JOIN users" " ON user_accounts.user_id = users.id " "WHERE " " normalized_username = (%s)", (normalized_username,) ) res = cur.fetchone() if res is not None: existing_username, existing_player_name = res if existing_player_name == player_name: logger.warn( "Hanabi username {} is already registered to player {}, attempted to re-register it.".format( existing_username, existing_player_name )) else: logger.error( "Hanabi username {} is already associated to player {}, cannot register it to player {}.".format( res[0], res[1], player_name )) return cur.execute( "INSERT INTO user_accounts (username, normalized_username, user_id) VALUES (%s, %s, %s)", (hanabi_username, normalized_username, user_id) ) conn_manager.get_connection().commit() def init_player_base_rating(player_name: str, base_rating: Optional[int] = None): config = config_manager.get_config() if base_rating is None: base_rating = config.player_base_rating cur = conn_manager.get_new_cursor() vals = [] user_id = get_user_id(player_name) if user_id is None: err_msg = "Cannot initialise base rating for player {}: No such registered player.".format(player_name) logger.error(err_msg) raise ValueError(err_msg) for rating_type in range(0, 2): vals.append((user_id, rating_type, base_rating)) try: psycopg2.extras.execute_values( cur, "INSERT INTO user_base_ratings (user_id, type, rating) VALUES %s", vals ) conn_manager.get_connection().commit() except psycopg2.errors.UniqueViolation as e: err_msg = "Failed to initialize base ratings for player {}: Ratings already exist".format(player_name) logger.error("{}:\n{}".format(err_msg, e)) raise ValueError(err_msg) from e def add_player(user_name: str, player_name: str, base_rating: Optional[int] = None): """ Convenience function: Adds a player to the database, along with associated username and initializes the rating for this player. """ add_player_name(player_name) add_user_name_to_player(user_name, player_name) init_player_base_rating(player_name, base_rating) def get_variant_ids(): cur = conn_manager.get_new_cursor() cur.execute("SELECT id FROM variants") return [var_id for (var_id,) in cur.fetchall()]