Refactor database configuration to support PostgreSQL, add GIN index for operation_names, and enhance backup functionality
This commit is contained in:
parent
477bb753ae
commit
c41524e517
11 changed files with 250 additions and 74 deletions
|
|
@ -1,6 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
import shutil
|
||||
import subprocess # noqa: S404
|
||||
from compression import zstd
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
|
@ -8,6 +11,7 @@ from typing import TYPE_CHECKING
|
|||
|
||||
from django.conf import settings
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.core.management.base import CommandError
|
||||
from django.db import connection as django_connection
|
||||
from django.utils import timezone
|
||||
|
||||
|
|
@ -35,32 +39,45 @@ class Command(BaseCommand):
|
|||
)
|
||||
|
||||
def handle(self, **options: str) -> None:
|
||||
"""Run the backup command and write a zstd SQL dump."""
|
||||
"""Run the backup command and write a zstd SQL dump.
|
||||
|
||||
Args:
|
||||
**options: Command-line options for output directory and filename prefix.
|
||||
|
||||
Raises:
|
||||
CommandError: When the database connection fails or pg_dump is not available.
|
||||
"""
|
||||
output_dir: Path = Path(options["output_dir"]).expanduser()
|
||||
prefix: str = str(options["prefix"]).strip() or "ttvdrops"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Use Django's database connection to ensure we connect to the test DB during tests
|
||||
django_connection.ensure_connection()
|
||||
connection = django_connection.connection
|
||||
if connection is None:
|
||||
# Force connection if not already established
|
||||
django_connection.ensure_connection()
|
||||
connection = django_connection.connection
|
||||
msg = "Database connection could not be established."
|
||||
raise CommandError(msg)
|
||||
|
||||
timestamp: str = timezone.localtime(timezone.now()).strftime("%Y%m%d-%H%M%S")
|
||||
output_path: Path = output_dir / f"{prefix}-{timestamp}.sql.zst"
|
||||
|
||||
allowed_tables = _get_allowed_tables(connection, "twitch_")
|
||||
allowed_tables = _get_allowed_tables("twitch_")
|
||||
if not allowed_tables:
|
||||
self.stdout.write(self.style.WARNING("No twitch tables found to back up."))
|
||||
return
|
||||
|
||||
with (
|
||||
output_path.open("wb") as raw_handle,
|
||||
zstd.open(raw_handle, "w") as compressed,
|
||||
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
|
||||
):
|
||||
_write_dump(handle, connection, allowed_tables)
|
||||
if django_connection.vendor == "postgresql":
|
||||
_write_postgres_dump(output_path, allowed_tables)
|
||||
elif django_connection.vendor == "sqlite":
|
||||
with (
|
||||
output_path.open("wb") as raw_handle,
|
||||
zstd.open(raw_handle, "w") as compressed,
|
||||
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
|
||||
):
|
||||
_write_sqlite_dump(handle, connection, allowed_tables)
|
||||
else:
|
||||
msg = f"Unsupported database backend: {django_connection.vendor}"
|
||||
raise CommandError(msg)
|
||||
|
||||
created_at: datetime = datetime.fromtimestamp(output_path.stat().st_mtime, tz=timezone.get_current_timezone())
|
||||
self.stdout.write(
|
||||
|
|
@ -71,24 +88,30 @@ class Command(BaseCommand):
|
|||
self.stdout.write(self.style.SUCCESS(f"Included tables: {len(allowed_tables)}"))
|
||||
|
||||
|
||||
def _get_allowed_tables(connection: sqlite3.Connection, prefix: str) -> list[str]:
|
||||
def _get_allowed_tables(prefix: str) -> list[str]:
|
||||
"""Fetch table names that match the allowed prefix.
|
||||
|
||||
Args:
|
||||
connection: SQLite connection.
|
||||
prefix: Table name prefix to include.
|
||||
|
||||
Returns:
|
||||
List of table names.
|
||||
"""
|
||||
cursor = connection.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE ? ORDER BY name",
|
||||
(f"{prefix}%",),
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
with django_connection.cursor() as cursor:
|
||||
if django_connection.vendor == "postgresql":
|
||||
cursor.execute(
|
||||
"SELECT tablename FROM pg_tables WHERE schemaname = 'public' AND tablename LIKE %s ORDER BY tablename",
|
||||
[f"{prefix}%"],
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE ? ORDER BY name",
|
||||
(f"{prefix}%",),
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
||||
|
||||
def _write_dump(handle: io.TextIOBase, connection: sqlite3.Connection, tables: list[str]) -> None:
|
||||
def _write_sqlite_dump(handle: io.TextIOBase, connection: sqlite3.Connection, tables: list[str]) -> None:
|
||||
"""Write a SQL dump containing schema and data for the requested tables.
|
||||
|
||||
Args:
|
||||
|
|
@ -163,6 +186,85 @@ def _write_indexes(handle: io.TextIOBase, connection: sqlite3.Connection, tables
|
|||
handle.write(f"{sql};\n")
|
||||
|
||||
|
||||
def _write_postgres_dump(output_path: Path, tables: list[str]) -> None:
|
||||
"""Write a PostgreSQL dump using pg_dump into a zstd-compressed file.
|
||||
|
||||
Args:
|
||||
output_path: Destination path for the zstd file.
|
||||
tables: Table names to include.
|
||||
|
||||
Raises:
|
||||
CommandError: When pg_dump fails or is not found.
|
||||
"""
|
||||
pg_dump_path = shutil.which("pg_dump")
|
||||
if not pg_dump_path:
|
||||
msg = "pg_dump was not found. Install PostgreSQL client tools and retry."
|
||||
raise CommandError(msg)
|
||||
|
||||
settings_dict = django_connection.settings_dict
|
||||
env = os.environ.copy()
|
||||
password = settings_dict.get("PASSWORD")
|
||||
if password:
|
||||
env["PGPASSWORD"] = str(password)
|
||||
|
||||
cmd = [
|
||||
pg_dump_path,
|
||||
"--format=plain",
|
||||
"--no-owner",
|
||||
"--no-privileges",
|
||||
"--clean",
|
||||
"--if-exists",
|
||||
"--column-inserts",
|
||||
"--quote-all-identifiers",
|
||||
"--encoding=UTF8",
|
||||
"--dbname",
|
||||
str(settings_dict.get("NAME", "")),
|
||||
]
|
||||
|
||||
host = settings_dict.get("HOST")
|
||||
port = settings_dict.get("PORT")
|
||||
user = settings_dict.get("USER")
|
||||
|
||||
if host:
|
||||
cmd.extend(["--host", str(host)])
|
||||
if port:
|
||||
cmd.extend(["--port", str(port)])
|
||||
if user:
|
||||
cmd.extend(["--username", str(user)])
|
||||
|
||||
for table in tables:
|
||||
cmd.extend(["-t", f"public.{table}"])
|
||||
|
||||
process = subprocess.Popen( # noqa: S603
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
if process.stdout is None or process.stderr is None:
|
||||
process.kill()
|
||||
msg = "Failed to start pg_dump process."
|
||||
raise CommandError(msg)
|
||||
|
||||
if process.stdout is None or process.stderr is None:
|
||||
process.kill()
|
||||
msg = "pg_dump process did not provide stdout or stderr."
|
||||
raise CommandError(msg)
|
||||
|
||||
with (
|
||||
output_path.open("wb") as raw_handle,
|
||||
zstd.open(raw_handle, "w") as compressed,
|
||||
):
|
||||
for chunk in iter(lambda: process.stdout.read(64 * 1024), b""): # pyright: ignore[reportOptionalMemberAccess]
|
||||
compressed.write(chunk)
|
||||
|
||||
stderr_output = process.stderr.read().decode("utf-8", errors="replace")
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
msg = f"pg_dump failed with exit code {return_code}: {stderr_output.strip()}"
|
||||
raise CommandError(msg)
|
||||
|
||||
|
||||
def _sql_literal(value: object) -> str:
|
||||
"""Convert a Python value to a SQL literal.
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue