Refactor database configuration to support PostgreSQL, add GIN index for operation_names, and enhance backup functionality

This commit is contained in:
Joakim Hellsén 2026-02-13 23:27:18 +01:00
commit c41524e517
Signed by: Joakim Hellsén
SSH key fingerprint: SHA256:/9h/CsExpFp+PRhsfA0xznFx2CGfTT5R/kpuFfUgEQk
11 changed files with 250 additions and 74 deletions

View file

@ -1,6 +1,9 @@
from __future__ import annotations
import io
import os
import shutil
import subprocess # noqa: S404
from compression import zstd
from datetime import datetime
from pathlib import Path
@ -8,6 +11,7 @@ from typing import TYPE_CHECKING
from django.conf import settings
from django.core.management.base import BaseCommand
from django.core.management.base import CommandError
from django.db import connection as django_connection
from django.utils import timezone
@ -35,32 +39,45 @@ class Command(BaseCommand):
)
def handle(self, **options: str) -> None:
"""Run the backup command and write a zstd SQL dump."""
"""Run the backup command and write a zstd SQL dump.
Args:
**options: Command-line options for output directory and filename prefix.
Raises:
CommandError: When the database connection fails or pg_dump is not available.
"""
output_dir: Path = Path(options["output_dir"]).expanduser()
prefix: str = str(options["prefix"]).strip() or "ttvdrops"
output_dir.mkdir(parents=True, exist_ok=True)
# Use Django's database connection to ensure we connect to the test DB during tests
django_connection.ensure_connection()
connection = django_connection.connection
if connection is None:
# Force connection if not already established
django_connection.ensure_connection()
connection = django_connection.connection
msg = "Database connection could not be established."
raise CommandError(msg)
timestamp: str = timezone.localtime(timezone.now()).strftime("%Y%m%d-%H%M%S")
output_path: Path = output_dir / f"{prefix}-{timestamp}.sql.zst"
allowed_tables = _get_allowed_tables(connection, "twitch_")
allowed_tables = _get_allowed_tables("twitch_")
if not allowed_tables:
self.stdout.write(self.style.WARNING("No twitch tables found to back up."))
return
with (
output_path.open("wb") as raw_handle,
zstd.open(raw_handle, "w") as compressed,
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
):
_write_dump(handle, connection, allowed_tables)
if django_connection.vendor == "postgresql":
_write_postgres_dump(output_path, allowed_tables)
elif django_connection.vendor == "sqlite":
with (
output_path.open("wb") as raw_handle,
zstd.open(raw_handle, "w") as compressed,
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
):
_write_sqlite_dump(handle, connection, allowed_tables)
else:
msg = f"Unsupported database backend: {django_connection.vendor}"
raise CommandError(msg)
created_at: datetime = datetime.fromtimestamp(output_path.stat().st_mtime, tz=timezone.get_current_timezone())
self.stdout.write(
@ -71,24 +88,30 @@ class Command(BaseCommand):
self.stdout.write(self.style.SUCCESS(f"Included tables: {len(allowed_tables)}"))
def _get_allowed_tables(connection: sqlite3.Connection, prefix: str) -> list[str]:
def _get_allowed_tables(prefix: str) -> list[str]:
"""Fetch table names that match the allowed prefix.
Args:
connection: SQLite connection.
prefix: Table name prefix to include.
Returns:
List of table names.
"""
cursor = connection.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE ? ORDER BY name",
(f"{prefix}%",),
)
return [row[0] for row in cursor.fetchall()]
with django_connection.cursor() as cursor:
if django_connection.vendor == "postgresql":
cursor.execute(
"SELECT tablename FROM pg_tables WHERE schemaname = 'public' AND tablename LIKE %s ORDER BY tablename",
[f"{prefix}%"],
)
else:
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE ? ORDER BY name",
(f"{prefix}%",),
)
return [row[0] for row in cursor.fetchall()]
def _write_dump(handle: io.TextIOBase, connection: sqlite3.Connection, tables: list[str]) -> None:
def _write_sqlite_dump(handle: io.TextIOBase, connection: sqlite3.Connection, tables: list[str]) -> None:
"""Write a SQL dump containing schema and data for the requested tables.
Args:
@ -163,6 +186,85 @@ def _write_indexes(handle: io.TextIOBase, connection: sqlite3.Connection, tables
handle.write(f"{sql};\n")
def _write_postgres_dump(output_path: Path, tables: list[str]) -> None:
"""Write a PostgreSQL dump using pg_dump into a zstd-compressed file.
Args:
output_path: Destination path for the zstd file.
tables: Table names to include.
Raises:
CommandError: When pg_dump fails or is not found.
"""
pg_dump_path = shutil.which("pg_dump")
if not pg_dump_path:
msg = "pg_dump was not found. Install PostgreSQL client tools and retry."
raise CommandError(msg)
settings_dict = django_connection.settings_dict
env = os.environ.copy()
password = settings_dict.get("PASSWORD")
if password:
env["PGPASSWORD"] = str(password)
cmd = [
pg_dump_path,
"--format=plain",
"--no-owner",
"--no-privileges",
"--clean",
"--if-exists",
"--column-inserts",
"--quote-all-identifiers",
"--encoding=UTF8",
"--dbname",
str(settings_dict.get("NAME", "")),
]
host = settings_dict.get("HOST")
port = settings_dict.get("PORT")
user = settings_dict.get("USER")
if host:
cmd.extend(["--host", str(host)])
if port:
cmd.extend(["--port", str(port)])
if user:
cmd.extend(["--username", str(user)])
for table in tables:
cmd.extend(["-t", f"public.{table}"])
process = subprocess.Popen( # noqa: S603
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
)
if process.stdout is None or process.stderr is None:
process.kill()
msg = "Failed to start pg_dump process."
raise CommandError(msg)
if process.stdout is None or process.stderr is None:
process.kill()
msg = "pg_dump process did not provide stdout or stderr."
raise CommandError(msg)
with (
output_path.open("wb") as raw_handle,
zstd.open(raw_handle, "w") as compressed,
):
for chunk in iter(lambda: process.stdout.read(64 * 1024), b""): # pyright: ignore[reportOptionalMemberAccess]
compressed.write(chunk)
stderr_output = process.stderr.read().decode("utf-8", errors="replace")
return_code = process.wait()
if return_code != 0:
msg = f"pg_dump failed with exit code {return_code}: {stderr_output.strip()}"
raise CommandError(msg)
def _sql_literal(value: object) -> str:
"""Convert a Python value to a SQL literal.

View file

@ -0,0 +1,24 @@
# Generated by Django 6.0.2 on 2026-02-12 12:00
from __future__ import annotations
from django.contrib.postgres.indexes import GinIndex
from django.db import migrations
class Migration(migrations.Migration):
"""Replace the JSONField btree index with a GIN index for Postgres."""
dependencies = [
("twitch", "0011_dropbenefit_image_height_dropbenefit_image_width_and_more"),
]
operations = [
migrations.RemoveIndex(
model_name="dropcampaign",
name="twitch_drop_operati_fe3bc8_idx",
),
migrations.AddIndex(
model_name="dropcampaign",
index=GinIndex(fields=["operation_names"], name="twitch_drop_operati_gin_idx"),
),
]

View file

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
import auto_prefetch
from django.contrib.humanize.templatetags.humanize import naturaltime
from django.contrib.postgres.indexes import GinIndex
from django.db import models
from django.urls import reverse
from django.utils import timezone
@ -420,7 +421,7 @@ class DropCampaign(auto_prefetch.Model):
models.Index(fields=["name"]),
models.Index(fields=["description"]),
models.Index(fields=["allow_is_enabled"]),
models.Index(fields=["operation_names"]),
GinIndex(fields=["operation_names"], name="twitch_drop_operati_gin_idx"),
models.Index(fields=["added_at"]),
models.Index(fields=["updated_at"]),
# Composite indexes for common queries

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import io
import math
import os
import shutil
from compression import zstd
from typing import TYPE_CHECKING
@ -14,7 +15,8 @@ from django.urls import reverse
from twitch.management.commands.backup_db import _get_allowed_tables
from twitch.management.commands.backup_db import _sql_literal
from twitch.management.commands.backup_db import _write_dump
from twitch.management.commands.backup_db import _write_postgres_dump
from twitch.management.commands.backup_db import _write_sqlite_dump
from twitch.models import Game
from twitch.models import Organization
@ -25,12 +27,18 @@ if TYPE_CHECKING:
from django.test.client import _MonkeyPatchedWSGIResponse
def _skip_if_pg_dump_missing() -> None:
if connection.vendor == "postgresql" and not shutil.which("pg_dump"):
pytest.skip("pg_dump is not available")
@pytest.mark.django_db
class TestBackupCommand:
"""Tests for the backup_db management command."""
def test_backup_creates_file(self, tmp_path: Path) -> None:
"""Test that backup command creates a zstd compressed file."""
_skip_if_pg_dump_missing()
# Create test data so tables exist
Organization.objects.create(twitch_id="test000", name="Test Org")
@ -46,6 +54,7 @@ class TestBackupCommand:
def test_backup_contains_sql_content(self, tmp_path: Path) -> None:
"""Test that backup file contains valid SQL content."""
_skip_if_pg_dump_missing()
output_dir = tmp_path / "backups"
output_dir.mkdir()
@ -66,15 +75,20 @@ class TestBackupCommand:
):
content = handle.read()
assert "PRAGMA foreign_keys=OFF;" in content
assert "BEGIN TRANSACTION;" in content
assert "COMMIT;" in content
if connection.vendor == "postgresql":
assert "CREATE TABLE" in content
assert "INSERT INTO" in content
else:
assert "PRAGMA foreign_keys=OFF;" in content
assert "BEGIN TRANSACTION;" in content
assert "COMMIT;" in content
assert "twitch_organization" in content
assert "twitch_game" in content
assert "Test Org" in content
def test_backup_excludes_non_twitch_tables(self, tmp_path: Path) -> None:
"""Test that backup only includes twitch_ prefixed tables."""
_skip_if_pg_dump_missing()
# Create test data so tables exist
Organization.objects.create(twitch_id="test001", name="Test Org")
@ -103,6 +117,7 @@ class TestBackupCommand:
def test_backup_with_custom_prefix(self, tmp_path: Path) -> None:
"""Test that custom prefix is used in filename."""
_skip_if_pg_dump_missing()
# Create test data so tables exist
Organization.objects.create(twitch_id="test002", name="Test Org")
@ -116,6 +131,7 @@ class TestBackupCommand:
def test_backup_creates_output_directory(self, tmp_path: Path) -> None:
"""Test that backup command creates output directory if missing."""
_skip_if_pg_dump_missing()
# Create test data so tables exist
Organization.objects.create(twitch_id="test003", name="Test Org")
@ -128,6 +144,7 @@ class TestBackupCommand:
def test_backup_uses_default_directory(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that backup uses DATA_DIR/datasets by default."""
_skip_if_pg_dump_missing()
# Create test data so tables exist
Organization.objects.create(twitch_id="test004", name="Test Org")
@ -148,8 +165,7 @@ class TestBackupHelperFunctions:
def test_get_allowed_tables_filters_by_prefix(self) -> None:
"""Test that _get_allowed_tables returns only matching tables."""
# Use Django's connection to access the test database
db_connection = connection.connection
tables = _get_allowed_tables(db_connection, "twitch_")
tables = _get_allowed_tables("twitch_")
assert len(tables) > 0
assert all(table.startswith("twitch_") for table in tables)
@ -159,8 +175,7 @@ class TestBackupHelperFunctions:
def test_get_allowed_tables_excludes_non_matching(self) -> None:
"""Test that _get_allowed_tables excludes non-matching tables."""
# Use Django's connection to access the test database
db_connection = connection.connection
tables = _get_allowed_tables(db_connection, "twitch_")
tables = _get_allowed_tables("twitch_")
# Should not include django, silk, or debug toolbar tables
assert not any(table.startswith("django_") for table in tables)
@ -192,33 +207,41 @@ class TestBackupHelperFunctions:
assert _sql_literal(b"\x00\x01\x02") == "X'000102'"
assert _sql_literal(b"hello") == "X'68656c6c6f'"
def test_write_dump_includes_schema_and_data(self) -> None:
def test_write_dump_includes_schema_and_data(self, tmp_path: Path) -> None:
"""Test _write_dump writes complete SQL dump."""
# Create test data
Organization.objects.create(twitch_id="test789", name="Write Test Org")
# Use Django's connection to access the test database
db_connection = connection.connection
output = io.StringIO()
tables = _get_allowed_tables("twitch_")
tables = _get_allowed_tables(db_connection, "twitch_")
_write_dump(output, db_connection, tables)
content = output.getvalue()
# Check for SQL structure
assert "PRAGMA foreign_keys=OFF;" in content
assert "BEGIN TRANSACTION;" in content
assert "COMMIT;" in content
assert "PRAGMA foreign_keys=ON;" in content
# Check for schema
assert "CREATE TABLE" in content
assert "twitch_organization" in content
# Check for data
assert "INSERT INTO" in content
assert "Write Test Org" in content
if connection.vendor == "postgresql":
if not shutil.which("pg_dump"):
pytest.skip("pg_dump is not available")
output_path = tmp_path / "backup.sql.zst"
_write_postgres_dump(output_path, tables)
with (
output_path.open("rb") as raw_handle,
zstd.open(raw_handle, "r") as compressed,
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
):
content = handle.read()
assert "CREATE TABLE" in content
assert "INSERT INTO" in content
assert "twitch_organization" in content
assert "Write Test Org" in content
else:
db_connection = connection.connection
output = io.StringIO()
_write_sqlite_dump(output, db_connection, tables)
content = output.getvalue()
assert "PRAGMA foreign_keys=OFF;" in content
assert "BEGIN TRANSACTION;" in content
assert "COMMIT;" in content
assert "PRAGMA foreign_keys=ON;" in content
assert "CREATE TABLE" in content
assert "twitch_organization" in content
assert "INSERT INTO" in content
assert "Write Test Org" in content
@pytest.mark.django_db

View file

@ -2,7 +2,6 @@ from __future__ import annotations
import json
from pathlib import Path
from typing import TYPE_CHECKING
from django.test import TestCase
@ -16,9 +15,6 @@ from twitch.models import Organization
from twitch.models import TimeBasedDrop
from twitch.schemas import DropBenefitSchema
if TYPE_CHECKING:
from debug_toolbar.panels.templates.panel import QuerySet
class GetOrUpdateBenefitTests(TestCase):
"""Tests for the _get_or_update_benefit method in better_import_drops.Command."""
@ -467,11 +463,9 @@ class OperationNameFilteringTests(TestCase):
command.process_responses([viewer_drops_payload], Path("viewer.json"), {})
command.process_responses([inventory_payload], Path("inventory.json"), {})
# Verify we can filter by operation_names
# SQLite doesn't support JSON contains, so we filter in Python
all_campaigns: QuerySet[DropCampaign, DropCampaign] = DropCampaign.objects.all()
viewer_campaigns: list[DropCampaign] = [c for c in all_campaigns if "ViewerDropsDashboard" in c.operation_names]
inventory_campaigns: list[DropCampaign] = [c for c in all_campaigns if "Inventory" in c.operation_names]
# Verify we can filter by operation_names with JSON containment
viewer_campaigns = DropCampaign.objects.filter(operation_names__contains=["ViewerDropsDashboard"])
inventory_campaigns = DropCampaign.objects.filter(operation_names__contains=["Inventory"])
assert len(viewer_campaigns) >= 1
assert len(inventory_campaigns) >= 1

View file

@ -1617,7 +1617,7 @@ def debug_view(request: HttpRequest) -> HttpResponse:
campaigns_missing_dropcampaigndetails: QuerySet[DropCampaign] = (
DropCampaign.objects
.filter(
Q(operation_names__isnull=True) | ~Q(operation_names__icontains="DropCampaignDetails"),
Q(operation_names__isnull=True) | ~Q(operation_names__contains=["DropCampaignDetails"]),
)
.select_related("game")
.order_by("game__display_name", "name")