hanabi-league/database.py

251 lines
8.9 KiB
Python

import json
from typing import Optional, List, Dict
import psycopg2
import psycopg2.extensions
import psycopg2.extras
import psycopg2.errors
import requests
import unidecode
import constants
import utils
from config import config_manager
from log_setup import logger
class DBConnectionManager:
def __init__(self):
self._conn = None
def connect(self):
config = config_manager.get_db_config()
logger.debug("Establishing database connection with dbname={}, user={}, password={}".format(
config.db_name, config.db_user, config.db_pass
))
self._conn = psycopg2.connect("dbname='{}' user='{}' password='{}' host='localhost'".format(
config.db_name, config.db_user, config.db_pass
))
logger.debug("Established database connection.")
def get_connection(self) -> psycopg2.extensions.connection:
"""
Get the database connection.
If not already connected, this reads the database config file and connects to the DB.
Otherwise, the already active connection is returned.
"""
if self._conn is None:
self.connect()
return self._conn
def get_new_cursor(self) -> psycopg2.extensions.cursor:
if self._conn is None:
self.connect()
return self._conn.cursor()
# Global instance that will hold our DB connection
conn_manager = DBConnectionManager()
def get_existing_tables():
conn = conn_manager.get_connection()
cur = conn.cursor()
table_names = ", ".join("'{}'".format(tablename) for tablename in constants.DB_TABLE_NAMES)
cur.execute(
"SELECT tablename FROM pg_tables WHERE"
" schemaname = 'public' AND"
" tablename IN ({})".format(
table_names
)
)
return [table for (table,) in cur.fetchall()]
def init_database():
"""
Warning: This drops all existing tables from the database
"""
conn = conn_manager.get_connection()
cur = conn.cursor()
with open(constants.DATABASE_SCHEMA_PATH, "r") as f:
cur.execute(f.read())
conn.commit()
logger.verbose("Initialized DB tables.")
def fetch_and_initialize_variants():
response = requests.get(constants.VARIANTS_JSON_URL)
if not response.status_code == 200:
logger.error(
"Could not download variants.json file from github (tried url {})".format(constants.VARIANTS_JSON_URL))
return
variants = json.loads(response.text)
config = config_manager.get_config()
for variant in variants:
variant_id = variant['id']
name = variant['name']
clue_starved = variant.get('clueStarved', False)
num_suits = len(variant['suits'])
if config.min_suit_count <= num_suits <= config.max_suit_count:
if any(var_name in name.lower() for var_name in config.excluded_variants):
continue
cur = conn_manager.get_new_cursor()
cur.execute(
"INSERT INTO variants (id, name, num_suits, clue_starved) VALUES (%s, %s, %s, %s)",
(variant_id, name, num_suits, clue_starved)
)
conn_manager.get_connection().commit()
def initialize_variant_base_ratings():
config = config_manager.get_config()
conn = conn_manager.get_connection()
cur = conn.cursor()
cur.execute("SELECT id, name from variants")
ratings = []
for variant_id, variant_name in cur.fetchall():
for num_players in range(config.min_player_count, config.max_player_count + 1):
rating = config.variant_base_rating(variant_name, num_players)
ratings.append((variant_id, num_players, rating))
psycopg2.extras.execute_values(
cur,
"INSERT INTO variant_base_ratings (variant_id, player_count, rating)"
"VALUES %s "
"ON CONFLICT (variant_id, player_count) "
"DO UPDATE SET rating = EXCLUDED.rating",
ratings
)
conn.commit()
def get_user_id(player_name: str) -> Optional[int]:
cur = conn_manager.get_new_cursor()
cur.execute("SELECT id FROM users WHERE player_name = (%s)", (player_name,))
return cur.fetchone()
def get_user_ids_from_normalized_usernames(normalized_usernames: List[str]) -> List[int] | str:
"""
@rtype: If all users are registered, list of their ids in the same order.
Otherwise, name of a user that is not registered.
@warning If usernames are not present in the database, there is no corresponding key in the returned dictionary
"""
cur = conn_manager.get_new_cursor()
cur.execute("SELECT normalized_username, user_id "
"FROM user_accounts "
"WHERE normalized_username IN ({})".format(",".join("%s" for _ in normalized_usernames)),
normalized_usernames
)
# Build up dict from the specified user ids
user_dict: Dict[str, int] = {}
for normalized_username, user_id in cur.fetchall():
user_dict[normalized_username] = user_id
user_ids = []
for normalized_username in normalized_usernames:
if normalized_username not in user_dict.keys():
return normalized_username
else:
user_ids.append(user_dict[normalized_username])
return user_ids
def get_variant_id(variant_name: str) -> Optional[int]:
cur = conn_manager.get_new_cursor()
cur.execute("SELECT id FROM variants WHERE name = %s", (variant_name,))
return cur.fetchone()
def add_player_name(player_name: str):
conn = conn_manager.get_connection()
cur = conn.cursor()
try:
cur.execute("INSERT INTO users (player_name) VALUES (%s)", (player_name,))
conn.commit()
except psycopg2.errors.UniqueViolation:
logger.warn("Player name {} already exists in the database, aborting insertion.".format(player_name))
conn.rollback()
def add_user_name_to_player(hanabi_username: str, player_name: str):
normalized_username = utils.normalize_username(hanabi_username)
user_id = get_user_id(player_name)
if user_id is None:
logger.error("Player {} not found in database, cannot add username to it.".format(player_name))
return
else:
cur = conn_manager.get_new_cursor()
cur.execute("SELECT username, player_name from user_accounts "
"INNER JOIN users"
" ON user_accounts.user_id = users.id "
"WHERE "
" normalized_username = (%s)",
(normalized_username,)
)
res = cur.fetchone()
if res is not None:
existing_username, existing_player_name = res
if existing_player_name == player_name:
logger.warn(
"Hanabi username {} is already registered to player {}, attempted to re-register it.".format(
existing_username, existing_player_name
))
else:
logger.error(
"Hanabi username {} is already associated to player {}, cannot register it to player {}.".format(
res[0], res[1], player_name
))
return
cur.execute(
"INSERT INTO user_accounts (username, normalized_username, user_id) VALUES (%s, %s, %s)",
(hanabi_username, normalized_username, user_id)
)
conn_manager.get_connection().commit()
def init_player_base_rating(player_name: str, base_rating: Optional[int] = None):
config = config_manager.get_config()
if base_rating is None:
base_rating = config.player_base_rating
cur = conn_manager.get_new_cursor()
vals = []
user_id = get_user_id(player_name)
if user_id is None:
err_msg = "Cannot initialise base rating for player {}: No such registered player.".format(player_name)
logger.error(err_msg)
raise ValueError(err_msg)
for rating_type in range(0, 2):
vals.append((user_id, rating_type, base_rating))
try:
psycopg2.extras.execute_values(
cur,
"INSERT INTO user_base_ratings (user_id, type, rating) VALUES %s",
vals
)
conn_manager.get_connection().commit()
except psycopg2.errors.UniqueViolation as e:
err_msg = "Failed to initialize base ratings for player {}: Ratings already exist".format(player_name)
logger.error("{}:\n{}".format(err_msg, e))
raise ValueError(err_msg) from e
def add_player(user_name: str, player_name: str, base_rating: Optional[int] = None):
"""
Convenience function: Adds a player to the database, along with associated username and
initializes the rating for this player.
"""
add_player_name(player_name)
add_user_name_to_player(user_name, player_name)
init_player_base_rating(player_name, base_rating)
def get_variant_ids():
cur = conn_manager.get_new_cursor()
cur.execute("SELECT id FROM variants")
return [var_id for (var_id,) in cur.fetchall()]