2023-11-22 15:31:36 +01:00
|
|
|
import psycopg2
|
2023-11-22 16:00:26 +01:00
|
|
|
import psycopg2.extensions
|
2023-11-22 15:31:36 +01:00
|
|
|
|
2023-11-22 16:00:26 +01:00
|
|
|
import constants
|
2023-11-22 15:31:36 +01:00
|
|
|
from config import read_db_config
|
|
|
|
from log_setup import logger
|
|
|
|
|
|
|
|
|
|
|
|
class DBConnectionManager:
|
|
|
|
def __init__(self):
|
|
|
|
self._conn = None
|
|
|
|
|
|
|
|
def connect(self):
|
|
|
|
config = read_db_config()
|
|
|
|
logger.debug("Establishing database connection.")
|
|
|
|
self._conn = psycopg2.connect("dbname='{}' user='{}' password='{}' host='localhost'".format(
|
|
|
|
config.db_name, config.db_user, config.db_pass
|
|
|
|
))
|
|
|
|
logger.debug("Established database connection.")
|
|
|
|
|
2023-11-22 16:00:26 +01:00
|
|
|
def get_connection(self) -> psycopg2.extensions.connection:
|
2023-11-22 15:31:36 +01:00
|
|
|
"""
|
|
|
|
Get the database connection.
|
|
|
|
If not already connected, this reads the database config file and connects to the DB.
|
|
|
|
Otherwise, the already active connection is returned.
|
|
|
|
"""
|
|
|
|
if self._conn is None:
|
|
|
|
self.connect()
|
|
|
|
return self._conn
|
|
|
|
|
|
|
|
|
|
|
|
# Global instance that will hold our DB connection
|
2023-11-22 16:00:26 +01:00
|
|
|
conn_manager = DBConnectionManager()
|
|
|
|
|
|
|
|
|
|
|
|
def get_existing_tables():
|
|
|
|
conn = conn_manager.get_connection()
|
|
|
|
cur = conn.cursor()
|
|
|
|
table_names = ", ".join("'{}'".format(tablename) for tablename in constants.DB_TABLE_NAMES)
|
|
|
|
cur.execute(
|
|
|
|
"SELECT tablename FROM pg_tables WHERE"
|
|
|
|
" schemaname = 'public' AND"
|
|
|
|
" tablename IN ({})".format(
|
|
|
|
table_names
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return [table for (table,) in cur.fetchall()]
|
|
|
|
|
|
|
|
|
|
|
|
def init_database(erase: bool = False):
|
|
|
|
tables = get_existing_tables()
|
|
|
|
|
|
|
|
if not erase and len(tables) > 0:
|
|
|
|
logger.error("Aborting database initialization: Tables {} already exist".format(", ".join(tables)))
|
|
|
|
return
|
|
|
|
conn = conn_manager.get_connection()
|
|
|
|
cur = conn.cursor()
|
|
|
|
|
|
|
|
with open(constants.DATABASE_SCHEMA_PATH, "r") as f:
|
|
|
|
cur.execute(f.read())
|
|
|
|
conn.commit()
|
|
|
|
logger.verbose("Initialized DB tables.")
|