Refactor database configuration to support PostgreSQL, add GIN index for operation_names, and enhance backup functionality

This commit is contained in:
Joakim Hellsén 2026-02-13 23:27:18 +01:00
commit c41524e517
Signed by: Joakim Hellsén
SSH key fingerprint: SHA256:/9h/CsExpFp+PRhsfA0xznFx2CGfTT5R/kpuFfUgEQk
11 changed files with 250 additions and 74 deletions

View file

@ -1,6 +1,9 @@
from __future__ import annotations
import io
import os
import shutil
import subprocess # noqa: S404
from compression import zstd
from datetime import datetime
from pathlib import Path
@ -8,6 +11,7 @@ from typing import TYPE_CHECKING
from django.conf import settings
from django.core.management.base import BaseCommand
from django.core.management.base import CommandError
from django.db import connection as django_connection
from django.utils import timezone
@ -35,32 +39,45 @@ class Command(BaseCommand):
)
def handle(self, **options: str) -> None:
"""Run the backup command and write a zstd SQL dump."""
"""Run the backup command and write a zstd SQL dump.
Args:
**options: Command-line options for output directory and filename prefix.
Raises:
CommandError: When the database connection fails or pg_dump is not available.
"""
output_dir: Path = Path(options["output_dir"]).expanduser()
prefix: str = str(options["prefix"]).strip() or "ttvdrops"
output_dir.mkdir(parents=True, exist_ok=True)
# Use Django's database connection to ensure we connect to the test DB during tests
django_connection.ensure_connection()
connection = django_connection.connection
if connection is None:
# Force connection if not already established
django_connection.ensure_connection()
connection = django_connection.connection
msg = "Database connection could not be established."
raise CommandError(msg)
timestamp: str = timezone.localtime(timezone.now()).strftime("%Y%m%d-%H%M%S")
output_path: Path = output_dir / f"{prefix}-{timestamp}.sql.zst"
allowed_tables = _get_allowed_tables(connection, "twitch_")
allowed_tables = _get_allowed_tables("twitch_")
if not allowed_tables:
self.stdout.write(self.style.WARNING("No twitch tables found to back up."))
return
with (
output_path.open("wb") as raw_handle,
zstd.open(raw_handle, "w") as compressed,
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
):
_write_dump(handle, connection, allowed_tables)
if django_connection.vendor == "postgresql":
_write_postgres_dump(output_path, allowed_tables)
elif django_connection.vendor == "sqlite":
with (
output_path.open("wb") as raw_handle,
zstd.open(raw_handle, "w") as compressed,
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
):
_write_sqlite_dump(handle, connection, allowed_tables)
else:
msg = f"Unsupported database backend: {django_connection.vendor}"
raise CommandError(msg)
created_at: datetime = datetime.fromtimestamp(output_path.stat().st_mtime, tz=timezone.get_current_timezone())
self.stdout.write(
@ -71,24 +88,30 @@ class Command(BaseCommand):
self.stdout.write(self.style.SUCCESS(f"Included tables: {len(allowed_tables)}"))
def _get_allowed_tables(connection: sqlite3.Connection, prefix: str) -> list[str]:
def _get_allowed_tables(prefix: str) -> list[str]:
"""Fetch table names that match the allowed prefix.
Args:
connection: SQLite connection.
prefix: Table name prefix to include.
Returns:
List of table names.
"""
cursor = connection.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE ? ORDER BY name",
(f"{prefix}%",),
)
return [row[0] for row in cursor.fetchall()]
with django_connection.cursor() as cursor:
if django_connection.vendor == "postgresql":
cursor.execute(
"SELECT tablename FROM pg_tables WHERE schemaname = 'public' AND tablename LIKE %s ORDER BY tablename",
[f"{prefix}%"],
)
else:
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE ? ORDER BY name",
(f"{prefix}%",),
)
return [row[0] for row in cursor.fetchall()]
def _write_dump(handle: io.TextIOBase, connection: sqlite3.Connection, tables: list[str]) -> None:
def _write_sqlite_dump(handle: io.TextIOBase, connection: sqlite3.Connection, tables: list[str]) -> None:
"""Write a SQL dump containing schema and data for the requested tables.
Args:
@ -163,6 +186,85 @@ def _write_indexes(handle: io.TextIOBase, connection: sqlite3.Connection, tables
handle.write(f"{sql};\n")
def _write_postgres_dump(output_path: Path, tables: list[str]) -> None:
"""Write a PostgreSQL dump using pg_dump into a zstd-compressed file.
Args:
output_path: Destination path for the zstd file.
tables: Table names to include.
Raises:
CommandError: When pg_dump fails or is not found.
"""
pg_dump_path = shutil.which("pg_dump")
if not pg_dump_path:
msg = "pg_dump was not found. Install PostgreSQL client tools and retry."
raise CommandError(msg)
settings_dict = django_connection.settings_dict
env = os.environ.copy()
password = settings_dict.get("PASSWORD")
if password:
env["PGPASSWORD"] = str(password)
cmd = [
pg_dump_path,
"--format=plain",
"--no-owner",
"--no-privileges",
"--clean",
"--if-exists",
"--column-inserts",
"--quote-all-identifiers",
"--encoding=UTF8",
"--dbname",
str(settings_dict.get("NAME", "")),
]
host = settings_dict.get("HOST")
port = settings_dict.get("PORT")
user = settings_dict.get("USER")
if host:
cmd.extend(["--host", str(host)])
if port:
cmd.extend(["--port", str(port)])
if user:
cmd.extend(["--username", str(user)])
for table in tables:
cmd.extend(["-t", f"public.{table}"])
process = subprocess.Popen( # noqa: S603
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
)
if process.stdout is None or process.stderr is None:
process.kill()
msg = "Failed to start pg_dump process."
raise CommandError(msg)
if process.stdout is None or process.stderr is None:
process.kill()
msg = "pg_dump process did not provide stdout or stderr."
raise CommandError(msg)
with (
output_path.open("wb") as raw_handle,
zstd.open(raw_handle, "w") as compressed,
):
for chunk in iter(lambda: process.stdout.read(64 * 1024), b""): # pyright: ignore[reportOptionalMemberAccess]
compressed.write(chunk)
stderr_output = process.stderr.read().decode("utf-8", errors="replace")
return_code = process.wait()
if return_code != 0:
msg = f"pg_dump failed with exit code {return_code}: {stderr_output.strip()}"
raise CommandError(msg)
def _sql_literal(value: object) -> str:
"""Convert a Python value to a SQL literal.