ttvdrops/twitch/tests/test_backup.py
Joakim Helleśen 9fd22ba8a8
All checks were successful
Deploy to Server / deploy (push) Successful in 40s
Add JSON and CSV functionality to database backup command
2026-03-17 00:38:01 +01:00

619 lines
23 KiB
Python

import csv
import io
import json
import math
import os
import shutil
from compression import zstd
from datetime import datetime as dt
from typing import TYPE_CHECKING
import pytest
from django.conf import settings
from django.core.management import call_command
from django.db import connection
from django.urls import reverse
from twitch.management.commands.backup_db import _get_allowed_tables
from twitch.management.commands.backup_db import _json_default
from twitch.management.commands.backup_db import _sql_literal
from twitch.management.commands.backup_db import _write_csv_dumps
from twitch.management.commands.backup_db import _write_json_dump
from twitch.management.commands.backup_db import _write_postgres_dump
from twitch.management.commands.backup_db import _write_sqlite_dump
from twitch.models import Game
from twitch.models import Organization
if TYPE_CHECKING:
from csv import Reader
from datetime import datetime
from pathlib import Path
from django.test import Client
from django.test.client import _MonkeyPatchedWSGIResponse
def _skip_if_pg_dump_missing() -> None:
if connection.vendor == "postgresql" and not shutil.which("pg_dump"):
pytest.skip("pg_dump is not available")
@pytest.mark.django_db
class TestBackupCommand:
"""Tests for the backup_db management command."""
def test_backup_creates_file(self, tmp_path: Path) -> None:
"""Test that backup command creates a zstd compressed file."""
_skip_if_pg_dump_missing()
# Create test data so tables exist
Organization.objects.create(twitch_id="test000", name="Test Org")
output_dir = tmp_path / "backups"
output_dir.mkdir()
call_command("backup_db", output_dir=str(output_dir), prefix="test")
backup_files = list(output_dir.glob("test-*.sql.zst"))
assert len(backup_files) == 1
assert backup_files[0].exists()
assert backup_files[0].stat().st_size > 0
def test_backup_contains_sql_content(self, tmp_path: Path) -> None:
"""Test that backup file contains valid SQL content."""
_skip_if_pg_dump_missing()
output_dir = tmp_path / "backups"
output_dir.mkdir()
# Create some test data
org = Organization.objects.create(twitch_id="test123", name="Test Org")
game = Game.objects.create(twitch_id="game456", display_name="Test Game")
game.owners.add(org)
call_command("backup_db", output_dir=str(output_dir), prefix="test")
backup_file = next(iter(output_dir.glob("test-*.sql.zst")))
# Decompress and read content
with (
backup_file.open("rb") as raw_handle,
zstd.open(raw_handle, "r") as compressed,
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
):
content = handle.read()
if connection.vendor == "postgresql":
assert "CREATE TABLE" in content
assert "INSERT INTO" in content
else:
assert "PRAGMA foreign_keys=OFF;" in content
assert "BEGIN TRANSACTION;" in content
assert "COMMIT;" in content
assert "twitch_organization" in content
assert "twitch_game" in content
assert "Test Org" in content
def test_backup_excludes_non_app_tables(self, tmp_path: Path) -> None:
"""Test that backup includes app tables and excludes non-app tables."""
_skip_if_pg_dump_missing()
# Create test data so tables exist
Organization.objects.create(twitch_id="test001", name="Test Org")
output_dir: Path = tmp_path / "backups"
output_dir.mkdir()
call_command("backup_db", output_dir=str(output_dir), prefix="test")
backup_file: Path = next(iter(output_dir.glob("test-*.sql.zst")))
with (
backup_file.open("rb") as raw_handle,
zstd.open(raw_handle, "r") as compressed,
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
):
content: str = handle.read()
# Should NOT contain django admin, silk, or debug toolbar tables
assert "django_session" not in content
assert "django_migrations" not in content
assert "django_content_type" not in content
assert "silk_" not in content
assert "debug_toolbar_" not in content
assert "django_admin_log" not in content
assert "auth_" not in content
assert "youtube_" not in content
# Should contain twitch and kick tables
assert "twitch_" in content
assert "kick_" in content
def test_backup_with_custom_prefix(self, tmp_path: Path) -> None:
"""Test that custom prefix is used in filename."""
_skip_if_pg_dump_missing()
# Create test data so tables exist
Organization.objects.create(twitch_id="test002", name="Test Org")
output_dir = tmp_path / "backups"
output_dir.mkdir()
call_command("backup_db", output_dir=str(output_dir), prefix="custom")
backup_files = list(output_dir.glob("custom-*.sql.zst"))
assert len(backup_files) == 1
def test_backup_creates_output_directory(self, tmp_path: Path) -> None:
"""Test that backup command creates output directory if missing."""
_skip_if_pg_dump_missing()
# Create test data so tables exist
Organization.objects.create(twitch_id="test003", name="Test Org")
output_dir = tmp_path / "nonexistent" / "backups"
call_command("backup_db", output_dir=str(output_dir), prefix="test")
assert output_dir.exists()
assert len(list(output_dir.glob("test-*.sql.zst"))) == 1
def test_backup_uses_default_directory(
self,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test that backup uses DATA_DIR/datasets by default."""
_skip_if_pg_dump_missing()
# Create test data so tables exist
Organization.objects.create(twitch_id="test004", name="Test Org")
monkeypatch.setattr(settings, "DATA_DIR", tmp_path)
datasets_dir = tmp_path / "datasets"
datasets_dir.mkdir(exist_ok=True, parents=True)
call_command("backup_db")
backup_files = list(datasets_dir.glob("ttvdrops-*.sql.zst"))
assert len(backup_files) >= 1
def test_backup_creates_json_file(self, tmp_path: Path) -> None:
"""Test that backup command creates a JSON file alongside the SQL dump."""
_skip_if_pg_dump_missing()
Organization.objects.create(twitch_id="test_json", name="Test Org JSON")
output_dir: Path = tmp_path / "backups"
output_dir.mkdir()
call_command("backup_db", output_dir=str(output_dir), prefix="test")
json_files: list[Path] = list(output_dir.glob("test-*.json.zst"))
assert len(json_files) == 1
with (
json_files[0].open("rb") as raw_handle,
zstd.open(raw_handle, "r") as compressed,
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
):
data = json.load(handle)
assert isinstance(data, dict)
assert "twitch_organization" in data
assert any(
row.get("name") == "Test Org JSON" for row in data["twitch_organization"]
)
def test_backup_creates_csv_files(self, tmp_path: Path) -> None:
"""Test that backup command creates per-table CSV files alongside the SQL dump."""
_skip_if_pg_dump_missing()
Organization.objects.create(twitch_id="test_csv", name="Test Org CSV")
output_dir: Path = tmp_path / "backups"
output_dir.mkdir()
call_command("backup_db", output_dir=str(output_dir), prefix="test")
org_csv_files: list[Path] = list(
output_dir.glob("test-*-twitch_organization.csv.zst"),
)
assert len(org_csv_files) == 1
with (
org_csv_files[0].open("rb") as raw_handle,
zstd.open(raw_handle, "r") as compressed,
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
):
reader: Reader = csv.reader(handle)
rows: list[list[str]] = list(reader)
assert len(rows) >= 2 # header + at least one data row
assert "name" in rows[0]
assert any("Test Org CSV" in row for row in rows[1:])
@pytest.mark.django_db
class TestBackupHelperFunctions:
"""Tests for backup command helper functions."""
def test_get_allowed_tables_filters_by_prefix(self) -> None:
"""Test that _get_allowed_tables returns only matching tables."""
# Use Django's connection to access the test database
tables = _get_allowed_tables("twitch_")
assert len(tables) > 0
assert all(table.startswith("twitch_") for table in tables)
assert "twitch_organization" in tables
assert "twitch_game" in tables
def test_get_allowed_tables_excludes_non_matching(self) -> None:
"""Test that _get_allowed_tables excludes non-matching tables."""
# Use Django's connection to access the test database
tables = _get_allowed_tables("twitch_")
# Should not include django, silk, or debug toolbar tables
assert not any(table.startswith("django_") for table in tables)
assert not any(table.startswith("silk_") for table in tables)
assert not any(table.startswith("debug_toolbar_") for table in tables)
def test_sql_literal_handles_none(self) -> None:
"""Test _sql_literal converts None to NULL."""
assert _sql_literal(None) == "NULL"
def test_sql_literal_handles_booleans(self) -> None:
"""Test _sql_literal converts booleans to 1/0."""
assert _sql_literal(True) == "1"
assert _sql_literal(False) == "0"
def test_sql_literal_handles_numbers(self) -> None:
"""Test _sql_literal handles int and float."""
assert _sql_literal(42) == "42"
assert _sql_literal(math.pi) == str(math.pi)
def test_sql_literal_handles_strings(self) -> None:
"""Test _sql_literal quotes and escapes strings."""
assert _sql_literal("test") == "'test'"
assert _sql_literal("o'reilly") == "'o''reilly'"
assert _sql_literal("test\nline") == "'test\nline'"
def test_sql_literal_handles_bytes(self) -> None:
"""Test _sql_literal converts bytes to hex notation."""
assert _sql_literal(b"\x00\x01\x02") == "X'000102'"
assert _sql_literal(b"hello") == "X'68656c6c6f'"
def test_write_dump_includes_schema_and_data(self, tmp_path: Path) -> None:
"""Test _write_dump writes complete SQL dump."""
# Create test data
Organization.objects.create(twitch_id="test789", name="Write Test Org")
tables = _get_allowed_tables("twitch_")
if connection.vendor == "postgresql":
if not shutil.which("pg_dump"):
pytest.skip("pg_dump is not available")
output_path = tmp_path / "backup.sql.zst"
_write_postgres_dump(output_path, tables)
with (
output_path.open("rb") as raw_handle,
zstd.open(raw_handle, "r") as compressed,
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
):
content = handle.read()
assert "CREATE TABLE" in content
assert "INSERT INTO" in content
assert "twitch_organization" in content
assert "Write Test Org" in content
else:
db_connection = connection.connection
output = io.StringIO()
_write_sqlite_dump(output, db_connection, tables)
content = output.getvalue()
assert "PRAGMA foreign_keys=OFF;" in content
assert "BEGIN TRANSACTION;" in content
assert "COMMIT;" in content
assert "PRAGMA foreign_keys=ON;" in content
assert "CREATE TABLE" in content
assert "twitch_organization" in content
assert "INSERT INTO" in content
assert "Write Test Org" in content
def test_write_json_dump_creates_valid_json(self, tmp_path: Path) -> None:
"""Test _write_json_dump creates valid compressed JSON with all tables."""
Organization.objects.create(
twitch_id="test_json_helper",
name="JSON Helper Org",
)
tables: list[str] = _get_allowed_tables("twitch_")
output_path: Path = tmp_path / "backup.json.zst"
_write_json_dump(output_path, tables)
with (
output_path.open("rb") as raw_handle,
zstd.open(raw_handle, "r") as compressed,
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
):
data = json.load(handle)
assert isinstance(data, dict)
assert "twitch_organization" in data
assert all(table in data for table in tables)
assert any(
row.get("name") == "JSON Helper Org" for row in data["twitch_organization"]
)
def test_write_csv_dumps_creates_per_table_files(self, tmp_path: Path) -> None:
"""Test _write_csv_dumps creates one compressed CSV file per table."""
Organization.objects.create(twitch_id="test_csv_helper", name="CSV Helper Org")
tables: list[str] = _get_allowed_tables("twitch_")
paths: list[Path] = _write_csv_dumps(
tmp_path,
"test",
"20260317-120000",
tables,
)
assert len(paths) == len(tables)
assert all(p.exists() for p in paths)
org_csv: Path = tmp_path / "test-20260317-120000-twitch_organization.csv.zst"
assert org_csv.exists()
with (
org_csv.open("rb") as raw_handle,
zstd.open(raw_handle, "r") as compressed,
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
):
reader: Reader = csv.reader(handle)
rows: list[list[str]] = list(reader)
assert len(rows) >= 2 # header + at least one data row
assert "name" in rows[0]
assert any("CSV Helper Org" in row for row in rows[1:])
def test_json_default_handles_bytes(self) -> None:
"""Test _json_default converts bytes to hex string."""
assert _json_default(b"\x00\x01") == "0001"
assert _json_default(b"hello") == "68656c6c6f"
def test_json_default_handles_other_types(self) -> None:
"""Test _json_default falls back to str() for other types."""
value: datetime = dt(2026, 3, 17, 12, 0, 0, tzinfo=dt.now().astimezone().tzinfo)
assert _json_default(value) == str(value)
@pytest.mark.django_db
class TestDatasetBackupViews:
"""Tests for dataset backup list and download views."""
@pytest.fixture
def datasets_dir(self, tmp_path: Path) -> Path:
"""Create a temporary datasets directory.
Returns:
Path to the created datasets directory.
"""
datasets_dir = tmp_path / "datasets"
datasets_dir.mkdir()
return datasets_dir
@pytest.fixture
def sample_backup(self, datasets_dir: Path) -> Path:
"""Create a sample backup file.
Returns:
Path to the created backup file.
"""
backup_file = datasets_dir / "ttvdrops-20260210-120000.sql.zst"
with (
backup_file.open("wb") as raw_handle,
zstd.open(raw_handle, "w") as compressed,
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
):
handle.write("-- Sample backup content\n")
return backup_file
def test_dataset_list_view_shows_backups(
self,
client: Client,
datasets_dir: Path,
sample_backup: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test that dataset list view displays backup files."""
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
response: _MonkeyPatchedWSGIResponse = client.get(
reverse("core:dataset_backups"),
)
assert response.status_code == 200
assert b"ttvdrops-20260210-120000.sql.zst" in response.content
assert b"1 datasets" in response.content or b"1 dataset" in response.content
def test_dataset_list_view_empty_directory(
self,
client: Client,
datasets_dir: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test dataset list view with empty directory."""
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
response: _MonkeyPatchedWSGIResponse = client.get(
reverse("core:dataset_backups"),
)
assert response.status_code == 200
assert b"No dataset backups found" in response.content
def test_dataset_list_view_sorts_by_date(
self,
client: Client,
datasets_dir: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test that backups are sorted by modification time."""
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
# Create multiple backup files with different timestamps
older_backup = datasets_dir / "ttvdrops-20260210-100000.sql.zst"
newer_backup = datasets_dir / "ttvdrops-20260210-140000.sql.zst"
for backup in [older_backup, newer_backup]:
with (
backup.open("wb") as raw_handle,
zstd.open(raw_handle, "w") as compressed,
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
):
handle.write("-- Test\n")
# Set explicit modification times to ensure proper sorting
older_time = 1707561600 # 2024-02-10 10:00:00 UTC
newer_time = 1707575400 # 2024-02-10 14:00:00 UTC
os.utime(older_backup, (older_time, older_time))
os.utime(newer_backup, (newer_time, newer_time))
response: _MonkeyPatchedWSGIResponse = client.get(
reverse("core:dataset_backups"),
)
content = response.content.decode()
newer_pos = content.find("20260210-140000")
older_pos = content.find("20260210-100000")
# Newer backup should appear first (sorted descending)
assert 0 < newer_pos < older_pos
def test_dataset_download_view_success(
self,
client: Client,
datasets_dir: Path,
sample_backup: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test successful backup download."""
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
response: _MonkeyPatchedWSGIResponse = client.get(
reverse(
"core:dataset_backup_download",
args=["ttvdrops-20260210-120000.sql.zst"],
),
)
assert response.status_code == 200
# FileResponse may use application/x-compressed for .zst files
assert "attachment" in response["Content-Disposition"]
assert "ttvdrops-20260210-120000.sql.zst" in response["Content-Disposition"]
def test_dataset_download_prevents_path_traversal(
self,
client: Client,
datasets_dir: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test that path traversal attempts are blocked."""
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
# Attempt path traversal
response = client.get(
reverse("core:dataset_backup_download", args=["../../../etc/passwd"]),
)
assert response.status_code == 404
def test_dataset_download_rejects_invalid_extensions(
self,
client: Client,
datasets_dir: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test that files with invalid extensions cannot be downloaded."""
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
# Create a file with invalid extension
invalid_file = datasets_dir / "malicious.exe"
invalid_file.write_text("not a backup")
response = client.get(
reverse("core:dataset_backup_download", args=["malicious.exe"]),
)
assert response.status_code == 404
def test_dataset_download_file_not_found(
self,
client: Client,
datasets_dir: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test download returns 404 for non-existent file."""
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
response = client.get(
reverse("core:dataset_backup_download", args=["nonexistent.sql.zst"]),
)
assert response.status_code == 404
def test_dataset_list_view_shows_file_sizes(
self,
client: Client,
datasets_dir: Path,
sample_backup: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test that file sizes are displayed in human-readable format."""
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
response: _MonkeyPatchedWSGIResponse = client.get(
reverse("core:dataset_backups"),
)
assert response.status_code == 200
# Should contain size information (bytes, KB, MB, or GB)
content = response.content.decode()
assert any(unit in content for unit in ["bytes", "KB", "MB", "GB"])
def test_dataset_list_ignores_non_zst_files(
self,
client: Client,
datasets_dir: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test that non-zst files are ignored in listing."""
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
# Create various file types
(datasets_dir / "backup.sql.zst").write_bytes(b"valid")
(datasets_dir / "readme.txt").write_text("should be ignored")
(datasets_dir / "old_backup.gz").write_bytes(b"should be ignored")
response: _MonkeyPatchedWSGIResponse = client.get(
reverse("core:dataset_backups"),
)
content = response.content.decode()
assert "backup.sql.zst" in content
assert "readme.txt" not in content
assert "old_backup.gz" not in content
def test_dataset_download_view_handles_subdirectories(
self,
client: Client,
datasets_dir: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test download works with files in subdirectories."""
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
# Create subdirectory with backup
subdir = datasets_dir / "2026" / "02"
subdir.mkdir(parents=True)
backup_file = subdir / "backup.sql.zst"
with (
backup_file.open("wb") as raw_handle,
zstd.open(raw_handle, "w") as compressed,
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
):
handle.write("-- Test\n")
response: _MonkeyPatchedWSGIResponse = client.get(
reverse("core:dataset_backup_download", args=["2026/02/backup.sql.zst"]),
)
assert response.status_code == 200
assert "attachment" in response["Content-Disposition"]