Refactor database configuration to support PostgreSQL, add GIN index for operation_names, and enhance backup functionality
This commit is contained in:
parent
477bb753ae
commit
c41524e517
11 changed files with 250 additions and 74 deletions
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
import io
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from compression import zstd
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
|
@ -14,7 +15,8 @@ from django.urls import reverse
|
|||
|
||||
from twitch.management.commands.backup_db import _get_allowed_tables
|
||||
from twitch.management.commands.backup_db import _sql_literal
|
||||
from twitch.management.commands.backup_db import _write_dump
|
||||
from twitch.management.commands.backup_db import _write_postgres_dump
|
||||
from twitch.management.commands.backup_db import _write_sqlite_dump
|
||||
from twitch.models import Game
|
||||
from twitch.models import Organization
|
||||
|
||||
|
|
@ -25,12 +27,18 @@ if TYPE_CHECKING:
|
|||
from django.test.client import _MonkeyPatchedWSGIResponse
|
||||
|
||||
|
||||
def _skip_if_pg_dump_missing() -> None:
|
||||
if connection.vendor == "postgresql" and not shutil.which("pg_dump"):
|
||||
pytest.skip("pg_dump is not available")
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestBackupCommand:
|
||||
"""Tests for the backup_db management command."""
|
||||
|
||||
def test_backup_creates_file(self, tmp_path: Path) -> None:
|
||||
"""Test that backup command creates a zstd compressed file."""
|
||||
_skip_if_pg_dump_missing()
|
||||
# Create test data so tables exist
|
||||
Organization.objects.create(twitch_id="test000", name="Test Org")
|
||||
|
||||
|
|
@ -46,6 +54,7 @@ class TestBackupCommand:
|
|||
|
||||
def test_backup_contains_sql_content(self, tmp_path: Path) -> None:
|
||||
"""Test that backup file contains valid SQL content."""
|
||||
_skip_if_pg_dump_missing()
|
||||
output_dir = tmp_path / "backups"
|
||||
output_dir.mkdir()
|
||||
|
||||
|
|
@ -66,15 +75,20 @@ class TestBackupCommand:
|
|||
):
|
||||
content = handle.read()
|
||||
|
||||
assert "PRAGMA foreign_keys=OFF;" in content
|
||||
assert "BEGIN TRANSACTION;" in content
|
||||
assert "COMMIT;" in content
|
||||
if connection.vendor == "postgresql":
|
||||
assert "CREATE TABLE" in content
|
||||
assert "INSERT INTO" in content
|
||||
else:
|
||||
assert "PRAGMA foreign_keys=OFF;" in content
|
||||
assert "BEGIN TRANSACTION;" in content
|
||||
assert "COMMIT;" in content
|
||||
assert "twitch_organization" in content
|
||||
assert "twitch_game" in content
|
||||
assert "Test Org" in content
|
||||
|
||||
def test_backup_excludes_non_twitch_tables(self, tmp_path: Path) -> None:
|
||||
"""Test that backup only includes twitch_ prefixed tables."""
|
||||
_skip_if_pg_dump_missing()
|
||||
# Create test data so tables exist
|
||||
Organization.objects.create(twitch_id="test001", name="Test Org")
|
||||
|
||||
|
|
@ -103,6 +117,7 @@ class TestBackupCommand:
|
|||
|
||||
def test_backup_with_custom_prefix(self, tmp_path: Path) -> None:
|
||||
"""Test that custom prefix is used in filename."""
|
||||
_skip_if_pg_dump_missing()
|
||||
# Create test data so tables exist
|
||||
Organization.objects.create(twitch_id="test002", name="Test Org")
|
||||
|
||||
|
|
@ -116,6 +131,7 @@ class TestBackupCommand:
|
|||
|
||||
def test_backup_creates_output_directory(self, tmp_path: Path) -> None:
|
||||
"""Test that backup command creates output directory if missing."""
|
||||
_skip_if_pg_dump_missing()
|
||||
# Create test data so tables exist
|
||||
Organization.objects.create(twitch_id="test003", name="Test Org")
|
||||
|
||||
|
|
@ -128,6 +144,7 @@ class TestBackupCommand:
|
|||
|
||||
def test_backup_uses_default_directory(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that backup uses DATA_DIR/datasets by default."""
|
||||
_skip_if_pg_dump_missing()
|
||||
# Create test data so tables exist
|
||||
Organization.objects.create(twitch_id="test004", name="Test Org")
|
||||
|
||||
|
|
@ -148,8 +165,7 @@ class TestBackupHelperFunctions:
|
|||
def test_get_allowed_tables_filters_by_prefix(self) -> None:
|
||||
"""Test that _get_allowed_tables returns only matching tables."""
|
||||
# Use Django's connection to access the test database
|
||||
db_connection = connection.connection
|
||||
tables = _get_allowed_tables(db_connection, "twitch_")
|
||||
tables = _get_allowed_tables("twitch_")
|
||||
|
||||
assert len(tables) > 0
|
||||
assert all(table.startswith("twitch_") for table in tables)
|
||||
|
|
@ -159,8 +175,7 @@ class TestBackupHelperFunctions:
|
|||
def test_get_allowed_tables_excludes_non_matching(self) -> None:
|
||||
"""Test that _get_allowed_tables excludes non-matching tables."""
|
||||
# Use Django's connection to access the test database
|
||||
db_connection = connection.connection
|
||||
tables = _get_allowed_tables(db_connection, "twitch_")
|
||||
tables = _get_allowed_tables("twitch_")
|
||||
|
||||
# Should not include django, silk, or debug toolbar tables
|
||||
assert not any(table.startswith("django_") for table in tables)
|
||||
|
|
@ -192,33 +207,41 @@ class TestBackupHelperFunctions:
|
|||
assert _sql_literal(b"\x00\x01\x02") == "X'000102'"
|
||||
assert _sql_literal(b"hello") == "X'68656c6c6f'"
|
||||
|
||||
def test_write_dump_includes_schema_and_data(self) -> None:
|
||||
def test_write_dump_includes_schema_and_data(self, tmp_path: Path) -> None:
|
||||
"""Test _write_dump writes complete SQL dump."""
|
||||
# Create test data
|
||||
Organization.objects.create(twitch_id="test789", name="Write Test Org")
|
||||
|
||||
# Use Django's connection to access the test database
|
||||
db_connection = connection.connection
|
||||
output = io.StringIO()
|
||||
tables = _get_allowed_tables("twitch_")
|
||||
|
||||
tables = _get_allowed_tables(db_connection, "twitch_")
|
||||
_write_dump(output, db_connection, tables)
|
||||
|
||||
content = output.getvalue()
|
||||
|
||||
# Check for SQL structure
|
||||
assert "PRAGMA foreign_keys=OFF;" in content
|
||||
assert "BEGIN TRANSACTION;" in content
|
||||
assert "COMMIT;" in content
|
||||
assert "PRAGMA foreign_keys=ON;" in content
|
||||
|
||||
# Check for schema
|
||||
assert "CREATE TABLE" in content
|
||||
assert "twitch_organization" in content
|
||||
|
||||
# Check for data
|
||||
assert "INSERT INTO" in content
|
||||
assert "Write Test Org" in content
|
||||
if connection.vendor == "postgresql":
|
||||
if not shutil.which("pg_dump"):
|
||||
pytest.skip("pg_dump is not available")
|
||||
output_path = tmp_path / "backup.sql.zst"
|
||||
_write_postgres_dump(output_path, tables)
|
||||
with (
|
||||
output_path.open("rb") as raw_handle,
|
||||
zstd.open(raw_handle, "r") as compressed,
|
||||
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
|
||||
):
|
||||
content = handle.read()
|
||||
assert "CREATE TABLE" in content
|
||||
assert "INSERT INTO" in content
|
||||
assert "twitch_organization" in content
|
||||
assert "Write Test Org" in content
|
||||
else:
|
||||
db_connection = connection.connection
|
||||
output = io.StringIO()
|
||||
_write_sqlite_dump(output, db_connection, tables)
|
||||
content = output.getvalue()
|
||||
assert "PRAGMA foreign_keys=OFF;" in content
|
||||
assert "BEGIN TRANSACTION;" in content
|
||||
assert "COMMIT;" in content
|
||||
assert "PRAGMA foreign_keys=ON;" in content
|
||||
assert "CREATE TABLE" in content
|
||||
assert "twitch_organization" in content
|
||||
assert "INSERT INTO" in content
|
||||
assert "Write Test Org" in content
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
|
|
@ -16,9 +15,6 @@ from twitch.models import Organization
|
|||
from twitch.models import TimeBasedDrop
|
||||
from twitch.schemas import DropBenefitSchema
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from debug_toolbar.panels.templates.panel import QuerySet
|
||||
|
||||
|
||||
class GetOrUpdateBenefitTests(TestCase):
|
||||
"""Tests for the _get_or_update_benefit method in better_import_drops.Command."""
|
||||
|
|
@ -467,11 +463,9 @@ class OperationNameFilteringTests(TestCase):
|
|||
command.process_responses([viewer_drops_payload], Path("viewer.json"), {})
|
||||
command.process_responses([inventory_payload], Path("inventory.json"), {})
|
||||
|
||||
# Verify we can filter by operation_names
|
||||
# SQLite doesn't support JSON contains, so we filter in Python
|
||||
all_campaigns: QuerySet[DropCampaign, DropCampaign] = DropCampaign.objects.all()
|
||||
viewer_campaigns: list[DropCampaign] = [c for c in all_campaigns if "ViewerDropsDashboard" in c.operation_names]
|
||||
inventory_campaigns: list[DropCampaign] = [c for c in all_campaigns if "Inventory" in c.operation_names]
|
||||
# Verify we can filter by operation_names with JSON containment
|
||||
viewer_campaigns = DropCampaign.objects.filter(operation_names__contains=["ViewerDropsDashboard"])
|
||||
inventory_campaigns = DropCampaign.objects.filter(operation_names__contains=["Inventory"])
|
||||
|
||||
assert len(viewer_campaigns) >= 1
|
||||
assert len(inventory_campaigns) >= 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue