2023-11-22 18:56:08 +01:00
|
|
|
import json
|
2023-11-23 12:52:56 +01:00
|
|
|
from typing import Optional, List, Dict
|
2023-11-22 18:56:08 +01:00
|
|
|
|
2023-11-22 15:31:36 +01:00
|
|
|
import psycopg2
|
2023-11-22 16:00:26 +01:00
|
|
|
import psycopg2.extensions
|
2023-11-23 01:18:26 +01:00
|
|
|
import psycopg2.extras
|
2023-11-22 18:22:53 +01:00
|
|
|
import psycopg2.errors
|
2023-11-22 18:56:08 +01:00
|
|
|
import requests
|
2023-11-22 18:22:53 +01:00
|
|
|
import unidecode
|
2023-11-22 15:31:36 +01:00
|
|
|
|
2023-11-22 16:00:26 +01:00
|
|
|
import constants
|
2023-11-23 12:52:56 +01:00
|
|
|
import utils
|
2023-11-22 23:50:20 +01:00
|
|
|
from config import config_manager
|
2023-11-22 15:31:36 +01:00
|
|
|
from log_setup import logger
|
|
|
|
|
|
|
|
|
|
|
|
class DBConnectionManager:
|
|
|
|
def __init__(self):
|
|
|
|
self._conn = None
|
|
|
|
|
|
|
|
def connect(self):
|
2023-11-22 23:50:20 +01:00
|
|
|
config = config_manager.get_db_config()
|
2023-11-22 16:27:04 +01:00
|
|
|
logger.debug("Establishing database connection with dbname={}, user={}, password={}".format(
|
|
|
|
config.db_name, config.db_user, config.db_pass
|
|
|
|
))
|
2023-11-22 15:31:36 +01:00
|
|
|
self._conn = psycopg2.connect("dbname='{}' user='{}' password='{}' host='localhost'".format(
|
|
|
|
config.db_name, config.db_user, config.db_pass
|
|
|
|
))
|
|
|
|
logger.debug("Established database connection.")
|
|
|
|
|
2023-11-22 16:00:26 +01:00
|
|
|
def get_connection(self) -> psycopg2.extensions.connection:
|
2023-11-22 15:31:36 +01:00
|
|
|
"""
|
|
|
|
Get the database connection.
|
|
|
|
If not already connected, this reads the database config file and connects to the DB.
|
|
|
|
Otherwise, the already active connection is returned.
|
|
|
|
"""
|
|
|
|
if self._conn is None:
|
|
|
|
self.connect()
|
|
|
|
return self._conn
|
|
|
|
|
2023-11-22 18:22:53 +01:00
|
|
|
def get_new_cursor(self) -> psycopg2.extensions.cursor:
|
|
|
|
if self._conn is None:
|
|
|
|
self.connect()
|
|
|
|
return self._conn.cursor()
|
|
|
|
|
2023-11-22 15:31:36 +01:00
|
|
|
|
|
|
|
# Global instance that will hold our DB connection
|
2023-11-22 16:00:26 +01:00
|
|
|
conn_manager = DBConnectionManager()
|
|
|
|
|
|
|
|
|
|
|
|
def get_existing_tables():
|
|
|
|
conn = conn_manager.get_connection()
|
|
|
|
cur = conn.cursor()
|
|
|
|
table_names = ", ".join("'{}'".format(tablename) for tablename in constants.DB_TABLE_NAMES)
|
|
|
|
cur.execute(
|
|
|
|
"SELECT tablename FROM pg_tables WHERE"
|
|
|
|
" schemaname = 'public' AND"
|
|
|
|
" tablename IN ({})".format(
|
|
|
|
table_names
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return [table for (table,) in cur.fetchall()]
|
|
|
|
|
|
|
|
|
2023-11-22 16:21:30 +01:00
|
|
|
def init_database():
|
|
|
|
"""
|
|
|
|
Warning: This drops all existing tables from the database
|
|
|
|
"""
|
2023-11-22 16:00:26 +01:00
|
|
|
conn = conn_manager.get_connection()
|
|
|
|
cur = conn.cursor()
|
|
|
|
|
|
|
|
with open(constants.DATABASE_SCHEMA_PATH, "r") as f:
|
|
|
|
cur.execute(f.read())
|
|
|
|
conn.commit()
|
|
|
|
logger.verbose("Initialized DB tables.")
|
2023-11-22 18:22:53 +01:00
|
|
|
|
|
|
|
|
2023-11-22 18:58:55 +01:00
|
|
|
def fetch_and_initialize_variants():
|
|
|
|
response = requests.get(constants.VARIANTS_JSON_URL)
|
|
|
|
if not response.status_code == 200:
|
2023-11-23 01:36:23 +01:00
|
|
|
logger.error(
|
|
|
|
"Could not download variants.json file from github (tried url {})".format(constants.VARIANTS_JSON_URL))
|
2023-11-22 18:58:55 +01:00
|
|
|
return
|
|
|
|
variants = json.loads(response.text)
|
|
|
|
|
2023-11-22 23:50:20 +01:00
|
|
|
config = config_manager.get_config()
|
2023-11-22 18:58:55 +01:00
|
|
|
|
|
|
|
for variant in variants:
|
|
|
|
variant_id = variant['id']
|
|
|
|
name = variant['name']
|
|
|
|
clue_starved = variant.get('clueStarved', False)
|
|
|
|
num_suits = len(variant['suits'])
|
|
|
|
|
|
|
|
if config.min_suit_count <= num_suits <= config.max_suit_count:
|
|
|
|
if any(var_name in name.lower() for var_name in config.excluded_variants):
|
|
|
|
continue
|
|
|
|
cur = conn_manager.get_new_cursor()
|
|
|
|
cur.execute(
|
|
|
|
"INSERT INTO variants (id, name, num_suits, clue_starved) VALUES (%s, %s, %s, %s)",
|
|
|
|
(variant_id, name, num_suits, clue_starved)
|
|
|
|
)
|
|
|
|
conn_manager.get_connection().commit()
|
|
|
|
|
|
|
|
|
2023-11-23 01:18:26 +01:00
|
|
|
def initialize_variant_base_ratings():
|
|
|
|
config = config_manager.get_config()
|
|
|
|
conn = conn_manager.get_connection()
|
|
|
|
cur = conn.cursor()
|
|
|
|
cur.execute("SELECT id, name from variants")
|
|
|
|
ratings = []
|
|
|
|
for variant_id, variant_name in cur.fetchall():
|
|
|
|
for num_players in range(config.min_player_count, config.max_player_count + 1):
|
|
|
|
rating = config.variant_base_rating(variant_name, num_players)
|
|
|
|
ratings.append((variant_id, num_players, rating))
|
|
|
|
|
|
|
|
psycopg2.extras.execute_values(
|
|
|
|
cur,
|
|
|
|
"INSERT INTO variant_base_ratings (variant_id, player_count, rating)"
|
|
|
|
"VALUES %s "
|
|
|
|
"ON CONFLICT (variant_id, player_count) "
|
|
|
|
"DO UPDATE SET rating = EXCLUDED.rating",
|
|
|
|
ratings
|
|
|
|
)
|
|
|
|
conn.commit()
|
|
|
|
|
2023-11-23 01:36:23 +01:00
|
|
|
|
2023-11-23 12:52:56 +01:00
|
|
|
def get_user_id(player_name: str) -> Optional[int]:
|
|
|
|
cur = conn_manager.get_new_cursor()
|
|
|
|
cur.execute("SELECT id FROM users WHERE player_name = (%s)", (player_name,))
|
|
|
|
return cur.fetchone()
|
2023-11-22 23:35:47 +01:00
|
|
|
|
|
|
|
|
2023-11-23 12:52:56 +01:00
|
|
|
def get_user_ids_from_normalized_usernames(normalized_usernames: List[str]) -> List[int] | str:
|
|
|
|
"""
|
|
|
|
@rtype: If all users are registered, list of their ids in the same order.
|
|
|
|
Otherwise, name of a user that is not registered.
|
|
|
|
@warning If usernames are not present in the database, there is no corresponding key in the returned dictionary
|
|
|
|
"""
|
2023-11-23 01:36:23 +01:00
|
|
|
cur = conn_manager.get_new_cursor()
|
2023-11-23 12:52:56 +01:00
|
|
|
cur.execute("SELECT normalized_username, user_id "
|
|
|
|
"FROM user_accounts "
|
|
|
|
"WHERE normalized_username IN ({})".format(",".join("%s" for _ in normalized_usernames)),
|
|
|
|
normalized_usernames
|
|
|
|
)
|
|
|
|
# Build up dict from the specified user ids
|
|
|
|
user_dict: Dict[str, int] = {}
|
|
|
|
for normalized_username, user_id in cur.fetchall():
|
|
|
|
user_dict[normalized_username] = user_id
|
|
|
|
|
|
|
|
user_ids = []
|
|
|
|
for normalized_username in normalized_usernames:
|
|
|
|
if normalized_username not in user_dict.keys():
|
|
|
|
return normalized_username
|
|
|
|
else:
|
|
|
|
user_ids.append(user_dict[normalized_username])
|
|
|
|
return user_ids
|
|
|
|
|
|
|
|
|
|
|
|
def get_variant_id(variant_name: str) -> Optional[int]:
|
|
|
|
cur = conn_manager.get_new_cursor()
|
|
|
|
cur.execute("SELECT id FROM variants WHERE name = %s", (variant_name,))
|
2023-11-23 01:36:23 +01:00
|
|
|
return cur.fetchone()
|
|
|
|
|
|
|
|
|
2023-11-22 18:22:53 +01:00
|
|
|
def add_player_name(player_name: str):
|
|
|
|
conn = conn_manager.get_connection()
|
|
|
|
cur = conn.cursor()
|
|
|
|
try:
|
|
|
|
cur.execute("INSERT INTO users (player_name) VALUES (%s)", (player_name,))
|
|
|
|
conn.commit()
|
|
|
|
except psycopg2.errors.UniqueViolation:
|
|
|
|
logger.warn("Player name {} already exists in the database, aborting insertion.".format(player_name))
|
|
|
|
conn.rollback()
|
|
|
|
|
|
|
|
|
|
|
|
def add_user_name_to_player(hanabi_username: str, player_name: str):
|
2023-11-23 12:52:56 +01:00
|
|
|
normalized_username = utils.normalize_username(hanabi_username)
|
2023-11-23 01:36:23 +01:00
|
|
|
user_id = get_user_id(player_name)
|
2023-11-22 18:22:53 +01:00
|
|
|
if user_id is None:
|
2023-11-23 01:36:23 +01:00
|
|
|
logger.error("Player {} not found in database, cannot add username to it.".format(player_name))
|
2023-11-22 18:22:53 +01:00
|
|
|
return
|
|
|
|
else:
|
2023-11-23 01:36:23 +01:00
|
|
|
cur = conn_manager.get_new_cursor()
|
2023-11-22 18:22:53 +01:00
|
|
|
cur.execute("SELECT username, player_name from user_accounts "
|
|
|
|
"INNER JOIN users"
|
|
|
|
" ON user_accounts.user_id = users.id "
|
|
|
|
"WHERE "
|
|
|
|
" normalized_username = (%s)",
|
|
|
|
(normalized_username,)
|
|
|
|
)
|
|
|
|
res = cur.fetchone()
|
|
|
|
if res is not None:
|
|
|
|
existing_username, existing_player_name = res
|
|
|
|
if existing_player_name == player_name:
|
2023-11-23 01:36:23 +01:00
|
|
|
logger.warn(
|
|
|
|
"Hanabi username {} is already registered to player {}, attempted to re-register it.".format(
|
|
|
|
existing_username, existing_player_name
|
|
|
|
))
|
2023-11-22 18:22:53 +01:00
|
|
|
else:
|
2023-11-23 01:36:23 +01:00
|
|
|
logger.error(
|
|
|
|
"Hanabi username {} is already associated to player {}, cannot register it to player {}.".format(
|
|
|
|
res[0], res[1], player_name
|
|
|
|
))
|
2023-11-22 18:22:53 +01:00
|
|
|
return
|
|
|
|
cur.execute(
|
|
|
|
"INSERT INTO user_accounts (username, normalized_username, user_id) VALUES (%s, %s, %s)",
|
|
|
|
(hanabi_username, normalized_username, user_id)
|
|
|
|
)
|
|
|
|
conn_manager.get_connection().commit()
|
|
|
|
|
|
|
|
|
2023-11-23 01:36:23 +01:00
|
|
|
def init_player_base_rating(player_name: str, base_rating: Optional[int] = None):
|
|
|
|
config = config_manager.get_config()
|
|
|
|
if base_rating is None:
|
|
|
|
base_rating = config.player_base_rating
|
|
|
|
cur = conn_manager.get_new_cursor()
|
|
|
|
vals = []
|
|
|
|
user_id = get_user_id(player_name)
|
|
|
|
if user_id is None:
|
|
|
|
err_msg = "Cannot initialise base rating for player {}: No such registered player.".format(player_name)
|
|
|
|
logger.error(err_msg)
|
|
|
|
raise ValueError(err_msg)
|
|
|
|
for rating_type in range(0, 2):
|
|
|
|
vals.append((user_id, rating_type, base_rating))
|
|
|
|
try:
|
|
|
|
psycopg2.extras.execute_values(
|
|
|
|
cur,
|
|
|
|
"INSERT INTO user_base_ratings (user_id, type, rating) VALUES %s",
|
|
|
|
vals
|
|
|
|
)
|
|
|
|
conn_manager.get_connection().commit()
|
|
|
|
except psycopg2.errors.UniqueViolation as e:
|
|
|
|
err_msg = "Failed to initialize base ratings for player {}: Ratings already exist".format(player_name)
|
|
|
|
logger.error("{}:\n{}".format(err_msg, e))
|
|
|
|
raise ValueError(err_msg) from e
|
|
|
|
|
|
|
|
|
|
|
|
def add_player(user_name: str, player_name: str, base_rating: Optional[int] = None):
|
|
|
|
"""
|
|
|
|
Convenience function: Adds a player to the database, along with associated username and
|
|
|
|
initializes the rating for this player.
|
|
|
|
"""
|
2023-11-22 18:22:53 +01:00
|
|
|
add_player_name(player_name)
|
|
|
|
add_user_name_to_player(user_name, player_name)
|
2023-11-23 01:36:23 +01:00
|
|
|
init_player_base_rating(player_name, base_rating)
|
2023-11-22 23:35:47 +01:00
|
|
|
|
|
|
|
|
|
|
|
def get_variant_ids():
|
|
|
|
cur = conn_manager.get_new_cursor()
|
|
|
|
cur.execute("SELECT id FROM variants")
|
|
|
|
return [var_id for (var_id,) in cur.fetchall()]
|