From c41524e5172725a520ea6f7396047027180631c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Hells=C3=A9n?= Date: Fri, 13 Feb 2026 23:27:18 +0100 Subject: [PATCH] Refactor database configuration to support PostgreSQL, add GIN index for operation_names, and enhance backup functionality --- Dockerfile | 1 + config/settings.py | 22 ++- docker-compose.yml | 29 ++++ pyproject.toml | 1 + start.sh | 3 + twitch/management/commands/backup_db.py | 140 +++++++++++++++--- ..._dropcampaign_operation_names_gin_index.py | 24 +++ twitch/models.py | 3 +- twitch/tests/test_backup.py | 83 +++++++---- twitch/tests/test_better_import_drops.py | 12 +- twitch/views.py | 2 +- 11 files changed, 248 insertions(+), 72 deletions(-) create mode 100644 twitch/migrations/0012_dropcampaign_operation_names_gin_index.py diff --git a/Dockerfile b/Dockerfile index 78e1c32..73fbd15 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,6 +5,7 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy ENV UV_NO_DEV=1 ENV UV_PYTHON_DOWNLOADS=0 +ENV UV_NO_CACHE=1 WORKDIR /app COPY . /app/ diff --git a/config/settings.py b/config/settings.py index 3b99501..709434f 100644 --- a/config/settings.py +++ b/config/settings.py @@ -53,7 +53,7 @@ def get_data_dir() -> Path: For example, on Windows, it might be: `C:\Users\lovinator\AppData\Roaming\TheLovinator\TTVDrops` - In this directory, the SQLite database file will be stored as `db.sqlite3`. + In this directory, application data such as media and static files will be stored. """ data_dir: str = user_data_dir( appname="TTVDrops", @@ -177,18 +177,16 @@ TEMPLATES: list[dict[str, Any]] = [ }, ] - -# https://blog.pecar.me/django-sqlite-benchmark -DATABASES: dict[str, dict[str, str | Path | dict[str, str]]] = { +DATABASES: dict[str, dict[str, Any]] = { "default": { - "ENGINE": "django.db.backends.sqlite3", - "NAME": DATA_DIR / "ttvdrops.sqlite3", - "OPTIONS": { - "init_command": ( - "PRAGMA foreign_keys = ON; PRAGMA journal_mode=WAL; PRAGMA synchronous=NORMAL; PRAGMA mmap_size = 134217728; PRAGMA journal_size_limit = 27103364; PRAGMA cache_size=2000;" # noqa: E501 - ), - "transaction_mode": "IMMEDIATE", - }, + "ENGINE": "django.db.backends.postgresql", + "NAME": os.getenv("POSTGRES_DB", "ttvdrops"), + "USER": os.getenv("POSTGRES_USER", "ttvdrops"), + "PASSWORD": os.getenv("POSTGRES_PASSWORD", ""), + "HOST": os.getenv("POSTGRES_HOST", "localhost"), + "PORT": env_int("POSTGRES_PORT", 5432), + "CONN_MAX_AGE": env_int("CONN_MAX_AGE", 60), + "CONN_HEALTH_CHECKS": env_bool("CONN_HEALTH_CHECKS", default=True), }, } diff --git a/docker-compose.yml b/docker-compose.yml index abb4d38..b1e30a7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,4 +1,22 @@ services: + ttvdrops_postgres: + container_name: ttvdrops_postgres + image: postgres:17 + environment: + - POSTGRES_DB=ttvdrops + - POSTGRES_USER=ttvdrops + - POSTGRES_PASSWORD=changeme + volumes: + - /mnt/Docker/Data/ttvdrops/postgres:/var/lib/postgresql + restart: unless-stopped + networks: + - internal + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ttvdrops"] + interval: 10s + timeout: 5s + retries: 5 + ttvdrops: container_name: ttvdrops image: ghcr.io/thelovinator1/ttvdrops:latest @@ -16,13 +34,24 @@ services: - EMAIL_HOST_PASSWORD= - EMAIL_USE_TLS=True - EMAIL_USE_SSL=False + - POSTGRES_DB=ttvdrops + - POSTGRES_USER=ttvdrops + - POSTGRES_PASSWORD=changeme + - POSTGRES_HOST=ttvdrops_postgres + - POSTGRES_PORT=5432 volumes: # Data is stored in /root/.local/share/TTVDrops" inside the container - /mnt/Docker/Data/ttvdrops/data:/root/.local/share/TTVDrops restart: unless-stopped networks: - web + - internal + depends_on: + ttvdrops_postgres: + condition: service_healthy networks: web: external: true + internal: + driver: bridge diff --git a/pyproject.toml b/pyproject.toml index 4f1dd4d..12438de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ requires-python = ">=3.14" dependencies = [ "dateparser", "django", + "psycopg[binary]", "json-repair", "pillow", "platformdirs", diff --git a/start.sh b/start.sh index 9361711..a6fdf33 100755 --- a/start.sh +++ b/start.sh @@ -1,6 +1,9 @@ #!/bin/sh set -e +echo "Running database migrations..." +uv run python manage.py migrate --noinput + echo "Collecting static files..." uv run python manage.py collectstatic --noinput diff --git a/twitch/management/commands/backup_db.py b/twitch/management/commands/backup_db.py index 101fe49..53ec0da 100644 --- a/twitch/management/commands/backup_db.py +++ b/twitch/management/commands/backup_db.py @@ -1,6 +1,9 @@ from __future__ import annotations import io +import os +import shutil +import subprocess # noqa: S404 from compression import zstd from datetime import datetime from pathlib import Path @@ -8,6 +11,7 @@ from typing import TYPE_CHECKING from django.conf import settings from django.core.management.base import BaseCommand +from django.core.management.base import CommandError from django.db import connection as django_connection from django.utils import timezone @@ -35,32 +39,45 @@ class Command(BaseCommand): ) def handle(self, **options: str) -> None: - """Run the backup command and write a zstd SQL dump.""" + """Run the backup command and write a zstd SQL dump. + + Args: + **options: Command-line options for output directory and filename prefix. + + Raises: + CommandError: When the database connection fails or pg_dump is not available. + """ output_dir: Path = Path(options["output_dir"]).expanduser() prefix: str = str(options["prefix"]).strip() or "ttvdrops" output_dir.mkdir(parents=True, exist_ok=True) # Use Django's database connection to ensure we connect to the test DB during tests + django_connection.ensure_connection() connection = django_connection.connection if connection is None: - # Force connection if not already established - django_connection.ensure_connection() - connection = django_connection.connection + msg = "Database connection could not be established." + raise CommandError(msg) timestamp: str = timezone.localtime(timezone.now()).strftime("%Y%m%d-%H%M%S") output_path: Path = output_dir / f"{prefix}-{timestamp}.sql.zst" - allowed_tables = _get_allowed_tables(connection, "twitch_") + allowed_tables = _get_allowed_tables("twitch_") if not allowed_tables: self.stdout.write(self.style.WARNING("No twitch tables found to back up.")) return - with ( - output_path.open("wb") as raw_handle, - zstd.open(raw_handle, "w") as compressed, - io.TextIOWrapper(compressed, encoding="utf-8") as handle, - ): - _write_dump(handle, connection, allowed_tables) + if django_connection.vendor == "postgresql": + _write_postgres_dump(output_path, allowed_tables) + elif django_connection.vendor == "sqlite": + with ( + output_path.open("wb") as raw_handle, + zstd.open(raw_handle, "w") as compressed, + io.TextIOWrapper(compressed, encoding="utf-8") as handle, + ): + _write_sqlite_dump(handle, connection, allowed_tables) + else: + msg = f"Unsupported database backend: {django_connection.vendor}" + raise CommandError(msg) created_at: datetime = datetime.fromtimestamp(output_path.stat().st_mtime, tz=timezone.get_current_timezone()) self.stdout.write( @@ -71,24 +88,30 @@ class Command(BaseCommand): self.stdout.write(self.style.SUCCESS(f"Included tables: {len(allowed_tables)}")) -def _get_allowed_tables(connection: sqlite3.Connection, prefix: str) -> list[str]: +def _get_allowed_tables(prefix: str) -> list[str]: """Fetch table names that match the allowed prefix. Args: - connection: SQLite connection. prefix: Table name prefix to include. Returns: List of table names. """ - cursor = connection.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE ? ORDER BY name", - (f"{prefix}%",), - ) - return [row[0] for row in cursor.fetchall()] + with django_connection.cursor() as cursor: + if django_connection.vendor == "postgresql": + cursor.execute( + "SELECT tablename FROM pg_tables WHERE schemaname = 'public' AND tablename LIKE %s ORDER BY tablename", + [f"{prefix}%"], + ) + else: + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE ? ORDER BY name", + (f"{prefix}%",), + ) + return [row[0] for row in cursor.fetchall()] -def _write_dump(handle: io.TextIOBase, connection: sqlite3.Connection, tables: list[str]) -> None: +def _write_sqlite_dump(handle: io.TextIOBase, connection: sqlite3.Connection, tables: list[str]) -> None: """Write a SQL dump containing schema and data for the requested tables. Args: @@ -163,6 +186,85 @@ def _write_indexes(handle: io.TextIOBase, connection: sqlite3.Connection, tables handle.write(f"{sql};\n") +def _write_postgres_dump(output_path: Path, tables: list[str]) -> None: + """Write a PostgreSQL dump using pg_dump into a zstd-compressed file. + + Args: + output_path: Destination path for the zstd file. + tables: Table names to include. + + Raises: + CommandError: When pg_dump fails or is not found. + """ + pg_dump_path = shutil.which("pg_dump") + if not pg_dump_path: + msg = "pg_dump was not found. Install PostgreSQL client tools and retry." + raise CommandError(msg) + + settings_dict = django_connection.settings_dict + env = os.environ.copy() + password = settings_dict.get("PASSWORD") + if password: + env["PGPASSWORD"] = str(password) + + cmd = [ + pg_dump_path, + "--format=plain", + "--no-owner", + "--no-privileges", + "--clean", + "--if-exists", + "--column-inserts", + "--quote-all-identifiers", + "--encoding=UTF8", + "--dbname", + str(settings_dict.get("NAME", "")), + ] + + host = settings_dict.get("HOST") + port = settings_dict.get("PORT") + user = settings_dict.get("USER") + + if host: + cmd.extend(["--host", str(host)]) + if port: + cmd.extend(["--port", str(port)]) + if user: + cmd.extend(["--username", str(user)]) + + for table in tables: + cmd.extend(["-t", f"public.{table}"]) + + process = subprocess.Popen( # noqa: S603 + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + ) + if process.stdout is None or process.stderr is None: + process.kill() + msg = "Failed to start pg_dump process." + raise CommandError(msg) + + if process.stdout is None or process.stderr is None: + process.kill() + msg = "pg_dump process did not provide stdout or stderr." + raise CommandError(msg) + + with ( + output_path.open("wb") as raw_handle, + zstd.open(raw_handle, "w") as compressed, + ): + for chunk in iter(lambda: process.stdout.read(64 * 1024), b""): # pyright: ignore[reportOptionalMemberAccess] + compressed.write(chunk) + + stderr_output = process.stderr.read().decode("utf-8", errors="replace") + return_code = process.wait() + if return_code != 0: + msg = f"pg_dump failed with exit code {return_code}: {stderr_output.strip()}" + raise CommandError(msg) + + def _sql_literal(value: object) -> str: """Convert a Python value to a SQL literal. diff --git a/twitch/migrations/0012_dropcampaign_operation_names_gin_index.py b/twitch/migrations/0012_dropcampaign_operation_names_gin_index.py new file mode 100644 index 0000000..30cb566 --- /dev/null +++ b/twitch/migrations/0012_dropcampaign_operation_names_gin_index.py @@ -0,0 +1,24 @@ +# Generated by Django 6.0.2 on 2026-02-12 12:00 +from __future__ import annotations + +from django.contrib.postgres.indexes import GinIndex +from django.db import migrations + + +class Migration(migrations.Migration): + """Replace the JSONField btree index with a GIN index for Postgres.""" + + dependencies = [ + ("twitch", "0011_dropbenefit_image_height_dropbenefit_image_width_and_more"), + ] + + operations = [ + migrations.RemoveIndex( + model_name="dropcampaign", + name="twitch_drop_operati_fe3bc8_idx", + ), + migrations.AddIndex( + model_name="dropcampaign", + index=GinIndex(fields=["operation_names"], name="twitch_drop_operati_gin_idx"), + ), + ] diff --git a/twitch/models.py b/twitch/models.py index 78d2893..7d88c28 100644 --- a/twitch/models.py +++ b/twitch/models.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING import auto_prefetch from django.contrib.humanize.templatetags.humanize import naturaltime +from django.contrib.postgres.indexes import GinIndex from django.db import models from django.urls import reverse from django.utils import timezone @@ -420,7 +421,7 @@ class DropCampaign(auto_prefetch.Model): models.Index(fields=["name"]), models.Index(fields=["description"]), models.Index(fields=["allow_is_enabled"]), - models.Index(fields=["operation_names"]), + GinIndex(fields=["operation_names"], name="twitch_drop_operati_gin_idx"), models.Index(fields=["added_at"]), models.Index(fields=["updated_at"]), # Composite indexes for common queries diff --git a/twitch/tests/test_backup.py b/twitch/tests/test_backup.py index 5d35325..a0dd790 100644 --- a/twitch/tests/test_backup.py +++ b/twitch/tests/test_backup.py @@ -3,6 +3,7 @@ from __future__ import annotations import io import math import os +import shutil from compression import zstd from typing import TYPE_CHECKING @@ -14,7 +15,8 @@ from django.urls import reverse from twitch.management.commands.backup_db import _get_allowed_tables from twitch.management.commands.backup_db import _sql_literal -from twitch.management.commands.backup_db import _write_dump +from twitch.management.commands.backup_db import _write_postgres_dump +from twitch.management.commands.backup_db import _write_sqlite_dump from twitch.models import Game from twitch.models import Organization @@ -25,12 +27,18 @@ if TYPE_CHECKING: from django.test.client import _MonkeyPatchedWSGIResponse +def _skip_if_pg_dump_missing() -> None: + if connection.vendor == "postgresql" and not shutil.which("pg_dump"): + pytest.skip("pg_dump is not available") + + @pytest.mark.django_db class TestBackupCommand: """Tests for the backup_db management command.""" def test_backup_creates_file(self, tmp_path: Path) -> None: """Test that backup command creates a zstd compressed file.""" + _skip_if_pg_dump_missing() # Create test data so tables exist Organization.objects.create(twitch_id="test000", name="Test Org") @@ -46,6 +54,7 @@ class TestBackupCommand: def test_backup_contains_sql_content(self, tmp_path: Path) -> None: """Test that backup file contains valid SQL content.""" + _skip_if_pg_dump_missing() output_dir = tmp_path / "backups" output_dir.mkdir() @@ -66,15 +75,20 @@ class TestBackupCommand: ): content = handle.read() - assert "PRAGMA foreign_keys=OFF;" in content - assert "BEGIN TRANSACTION;" in content - assert "COMMIT;" in content + if connection.vendor == "postgresql": + assert "CREATE TABLE" in content + assert "INSERT INTO" in content + else: + assert "PRAGMA foreign_keys=OFF;" in content + assert "BEGIN TRANSACTION;" in content + assert "COMMIT;" in content assert "twitch_organization" in content assert "twitch_game" in content assert "Test Org" in content def test_backup_excludes_non_twitch_tables(self, tmp_path: Path) -> None: """Test that backup only includes twitch_ prefixed tables.""" + _skip_if_pg_dump_missing() # Create test data so tables exist Organization.objects.create(twitch_id="test001", name="Test Org") @@ -103,6 +117,7 @@ class TestBackupCommand: def test_backup_with_custom_prefix(self, tmp_path: Path) -> None: """Test that custom prefix is used in filename.""" + _skip_if_pg_dump_missing() # Create test data so tables exist Organization.objects.create(twitch_id="test002", name="Test Org") @@ -116,6 +131,7 @@ class TestBackupCommand: def test_backup_creates_output_directory(self, tmp_path: Path) -> None: """Test that backup command creates output directory if missing.""" + _skip_if_pg_dump_missing() # Create test data so tables exist Organization.objects.create(twitch_id="test003", name="Test Org") @@ -128,6 +144,7 @@ class TestBackupCommand: def test_backup_uses_default_directory(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """Test that backup uses DATA_DIR/datasets by default.""" + _skip_if_pg_dump_missing() # Create test data so tables exist Organization.objects.create(twitch_id="test004", name="Test Org") @@ -148,8 +165,7 @@ class TestBackupHelperFunctions: def test_get_allowed_tables_filters_by_prefix(self) -> None: """Test that _get_allowed_tables returns only matching tables.""" # Use Django's connection to access the test database - db_connection = connection.connection - tables = _get_allowed_tables(db_connection, "twitch_") + tables = _get_allowed_tables("twitch_") assert len(tables) > 0 assert all(table.startswith("twitch_") for table in tables) @@ -159,8 +175,7 @@ class TestBackupHelperFunctions: def test_get_allowed_tables_excludes_non_matching(self) -> None: """Test that _get_allowed_tables excludes non-matching tables.""" # Use Django's connection to access the test database - db_connection = connection.connection - tables = _get_allowed_tables(db_connection, "twitch_") + tables = _get_allowed_tables("twitch_") # Should not include django, silk, or debug toolbar tables assert not any(table.startswith("django_") for table in tables) @@ -192,33 +207,41 @@ class TestBackupHelperFunctions: assert _sql_literal(b"\x00\x01\x02") == "X'000102'" assert _sql_literal(b"hello") == "X'68656c6c6f'" - def test_write_dump_includes_schema_and_data(self) -> None: + def test_write_dump_includes_schema_and_data(self, tmp_path: Path) -> None: """Test _write_dump writes complete SQL dump.""" # Create test data Organization.objects.create(twitch_id="test789", name="Write Test Org") - # Use Django's connection to access the test database - db_connection = connection.connection - output = io.StringIO() + tables = _get_allowed_tables("twitch_") - tables = _get_allowed_tables(db_connection, "twitch_") - _write_dump(output, db_connection, tables) - - content = output.getvalue() - - # Check for SQL structure - assert "PRAGMA foreign_keys=OFF;" in content - assert "BEGIN TRANSACTION;" in content - assert "COMMIT;" in content - assert "PRAGMA foreign_keys=ON;" in content - - # Check for schema - assert "CREATE TABLE" in content - assert "twitch_organization" in content - - # Check for data - assert "INSERT INTO" in content - assert "Write Test Org" in content + if connection.vendor == "postgresql": + if not shutil.which("pg_dump"): + pytest.skip("pg_dump is not available") + output_path = tmp_path / "backup.sql.zst" + _write_postgres_dump(output_path, tables) + with ( + output_path.open("rb") as raw_handle, + zstd.open(raw_handle, "r") as compressed, + io.TextIOWrapper(compressed, encoding="utf-8") as handle, + ): + content = handle.read() + assert "CREATE TABLE" in content + assert "INSERT INTO" in content + assert "twitch_organization" in content + assert "Write Test Org" in content + else: + db_connection = connection.connection + output = io.StringIO() + _write_sqlite_dump(output, db_connection, tables) + content = output.getvalue() + assert "PRAGMA foreign_keys=OFF;" in content + assert "BEGIN TRANSACTION;" in content + assert "COMMIT;" in content + assert "PRAGMA foreign_keys=ON;" in content + assert "CREATE TABLE" in content + assert "twitch_organization" in content + assert "INSERT INTO" in content + assert "Write Test Org" in content @pytest.mark.django_db diff --git a/twitch/tests/test_better_import_drops.py b/twitch/tests/test_better_import_drops.py index e7fed52..5d32285 100644 --- a/twitch/tests/test_better_import_drops.py +++ b/twitch/tests/test_better_import_drops.py @@ -2,7 +2,6 @@ from __future__ import annotations import json from pathlib import Path -from typing import TYPE_CHECKING from django.test import TestCase @@ -16,9 +15,6 @@ from twitch.models import Organization from twitch.models import TimeBasedDrop from twitch.schemas import DropBenefitSchema -if TYPE_CHECKING: - from debug_toolbar.panels.templates.panel import QuerySet - class GetOrUpdateBenefitTests(TestCase): """Tests for the _get_or_update_benefit method in better_import_drops.Command.""" @@ -467,11 +463,9 @@ class OperationNameFilteringTests(TestCase): command.process_responses([viewer_drops_payload], Path("viewer.json"), {}) command.process_responses([inventory_payload], Path("inventory.json"), {}) - # Verify we can filter by operation_names - # SQLite doesn't support JSON contains, so we filter in Python - all_campaigns: QuerySet[DropCampaign, DropCampaign] = DropCampaign.objects.all() - viewer_campaigns: list[DropCampaign] = [c for c in all_campaigns if "ViewerDropsDashboard" in c.operation_names] - inventory_campaigns: list[DropCampaign] = [c for c in all_campaigns if "Inventory" in c.operation_names] + # Verify we can filter by operation_names with JSON containment + viewer_campaigns = DropCampaign.objects.filter(operation_names__contains=["ViewerDropsDashboard"]) + inventory_campaigns = DropCampaign.objects.filter(operation_names__contains=["Inventory"]) assert len(viewer_campaigns) >= 1 assert len(inventory_campaigns) >= 1 diff --git a/twitch/views.py b/twitch/views.py index 043fb01..d2d17ff 100644 --- a/twitch/views.py +++ b/twitch/views.py @@ -1617,7 +1617,7 @@ def debug_view(request: HttpRequest) -> HttpResponse: campaigns_missing_dropcampaigndetails: QuerySet[DropCampaign] = ( DropCampaign.objects .filter( - Q(operation_names__isnull=True) | ~Q(operation_names__icontains="DropCampaignDetails"), + Q(operation_names__isnull=True) | ~Q(operation_names__contains=["DropCampaignDetails"]), ) .select_related("game") .order_by("game__display_name", "name")