hanabi-league/database.py

161 lines
5.4 KiB
Python

import json
import psycopg2
import psycopg2.extensions
import psycopg2.errors
import requests
import unidecode
import constants
from config import config_manager
from log_setup import logger
class DBConnectionManager:
def __init__(self):
self._conn = None
def connect(self):
config = config_manager.get_db_config()
logger.debug("Establishing database connection with dbname={}, user={}, password={}".format(
config.db_name, config.db_user, config.db_pass
))
self._conn = psycopg2.connect("dbname='{}' user='{}' password='{}' host='localhost'".format(
config.db_name, config.db_user, config.db_pass
))
logger.debug("Established database connection.")
def get_connection(self) -> psycopg2.extensions.connection:
"""
Get the database connection.
If not already connected, this reads the database config file and connects to the DB.
Otherwise, the already active connection is returned.
"""
if self._conn is None:
self.connect()
return self._conn
def get_new_cursor(self) -> psycopg2.extensions.cursor:
if self._conn is None:
self.connect()
return self._conn.cursor()
# Global instance that will hold our DB connection
conn_manager = DBConnectionManager()
def get_existing_tables():
conn = conn_manager.get_connection()
cur = conn.cursor()
table_names = ", ".join("'{}'".format(tablename) for tablename in constants.DB_TABLE_NAMES)
cur.execute(
"SELECT tablename FROM pg_tables WHERE"
" schemaname = 'public' AND"
" tablename IN ({})".format(
table_names
)
)
return [table for (table,) in cur.fetchall()]
def init_database():
"""
Warning: This drops all existing tables from the database
"""
conn = conn_manager.get_connection()
cur = conn.cursor()
with open(constants.DATABASE_SCHEMA_PATH, "r") as f:
cur.execute(f.read())
conn.commit()
logger.verbose("Initialized DB tables.")
def fetch_and_initialize_variants():
response = requests.get(constants.VARIANTS_JSON_URL)
if not response.status_code == 200:
logger.error("Could not download variants.json file from github (tried url {})".format(constants.VARIANTS_JSON_URL))
return
variants = json.loads(response.text)
config = config_manager.get_config()
for variant in variants:
variant_id = variant['id']
name = variant['name']
clue_starved = variant.get('clueStarved', False)
num_suits = len(variant['suits'])
if config.min_suit_count <= num_suits <= config.max_suit_count:
if any(var_name in name.lower() for var_name in config.excluded_variants):
continue
cur = conn_manager.get_new_cursor()
cur.execute(
"INSERT INTO variants (id, name, num_suits, clue_starved) VALUES (%s, %s, %s, %s)",
(variant_id, name, num_suits, clue_starved)
)
conn_manager.get_connection().commit()
def normalize_username(username: str) -> str:
decoded = unidecode.unidecode(username)
return decoded.lower()
def add_player_name(player_name: str):
conn = conn_manager.get_connection()
cur = conn.cursor()
try:
cur.execute("INSERT INTO users (player_name) VALUES (%s)", (player_name,))
conn.commit()
except psycopg2.errors.UniqueViolation:
logger.warn("Player name {} already exists in the database, aborting insertion.".format(player_name))
conn.rollback()
def add_user_name_to_player(hanabi_username: str, player_name: str):
normalized_username = normalize_username(hanabi_username)
cur = conn_manager.get_new_cursor()
cur.execute("SELECT id FROM users WHERE player_name = (%s)", (player_name,))
user_id = cur.fetchone()
if user_id is None:
logger.error("Display name {} not found in database, cannot add username to it.".format(player_name))
return
else:
cur.execute("SELECT username, player_name from user_accounts "
"INNER JOIN users"
" ON user_accounts.user_id = users.id "
"WHERE "
" normalized_username = (%s)",
(normalized_username,)
)
res = cur.fetchone()
if res is not None:
existing_username, existing_player_name = res
if existing_player_name == player_name:
logger.warn("Hanabi username {} is already registered to player {}, attempted to re-register it.".format(
existing_username, existing_player_name
))
else:
logger.error("Hanabi username {} is already associated to player {}, cannot register it to player {}.".format(
res[0], res[1], player_name
))
return
cur.execute(
"INSERT INTO user_accounts (username, normalized_username, user_id) VALUES (%s, %s, %s)",
(hanabi_username, normalized_username, user_id)
)
conn_manager.get_connection().commit()
def add_player(player_name: str, user_name: str):
add_player_name(player_name)
add_user_name_to_player(user_name, player_name)
def get_variant_ids():
cur = conn_manager.get_new_cursor()
cur.execute("SELECT id FROM variants")
return [var_id for (var_id,) in cur.fetchall()]