444 lines
16 KiB
Python
444 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import io
|
|
import math
|
|
import os
|
|
from compression import zstd
|
|
from typing import TYPE_CHECKING
|
|
|
|
import pytest
|
|
from django.conf import settings
|
|
from django.core.management import call_command
|
|
from django.db import connection
|
|
from django.urls import reverse
|
|
|
|
from twitch.management.commands.backup_db import _get_allowed_tables
|
|
from twitch.management.commands.backup_db import _sql_literal
|
|
from twitch.management.commands.backup_db import _write_dump
|
|
from twitch.models import Game
|
|
from twitch.models import Organization
|
|
|
|
if TYPE_CHECKING:
|
|
from pathlib import Path
|
|
|
|
from django.test import Client
|
|
from django.test.client import _MonkeyPatchedWSGIResponse
|
|
|
|
|
|
@pytest.mark.django_db
|
|
class TestBackupCommand:
|
|
"""Tests for the backup_db management command."""
|
|
|
|
def test_backup_creates_file(self, tmp_path: Path) -> None:
|
|
"""Test that backup command creates a zstd compressed file."""
|
|
# Create test data so tables exist
|
|
Organization.objects.create(twitch_id="test000", name="Test Org")
|
|
|
|
output_dir = tmp_path / "backups"
|
|
output_dir.mkdir()
|
|
|
|
call_command("backup_db", output_dir=str(output_dir), prefix="test")
|
|
|
|
backup_files = list(output_dir.glob("test-*.sql.zst"))
|
|
assert len(backup_files) == 1
|
|
assert backup_files[0].exists()
|
|
assert backup_files[0].stat().st_size > 0
|
|
|
|
def test_backup_contains_sql_content(self, tmp_path: Path) -> None:
|
|
"""Test that backup file contains valid SQL content."""
|
|
output_dir = tmp_path / "backups"
|
|
output_dir.mkdir()
|
|
|
|
# Create some test data
|
|
org = Organization.objects.create(twitch_id="test123", name="Test Org")
|
|
game = Game.objects.create(twitch_id="game456", display_name="Test Game")
|
|
game.owners.add(org)
|
|
|
|
call_command("backup_db", output_dir=str(output_dir), prefix="test")
|
|
|
|
backup_file = next(iter(output_dir.glob("test-*.sql.zst")))
|
|
|
|
# Decompress and read content
|
|
with (
|
|
backup_file.open("rb") as raw_handle,
|
|
zstd.open(raw_handle, "r") as compressed,
|
|
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
|
|
):
|
|
content = handle.read()
|
|
|
|
assert "PRAGMA foreign_keys=OFF;" in content
|
|
assert "BEGIN TRANSACTION;" in content
|
|
assert "COMMIT;" in content
|
|
assert "twitch_organization" in content
|
|
assert "twitch_game" in content
|
|
assert "Test Org" in content
|
|
|
|
def test_backup_excludes_non_twitch_tables(self, tmp_path: Path) -> None:
|
|
"""Test that backup only includes twitch_ prefixed tables."""
|
|
# Create test data so tables exist
|
|
Organization.objects.create(twitch_id="test001", name="Test Org")
|
|
|
|
output_dir = tmp_path / "backups"
|
|
output_dir.mkdir()
|
|
|
|
call_command("backup_db", output_dir=str(output_dir), prefix="test")
|
|
|
|
backup_file = next(iter(output_dir.glob("test-*.sql.zst")))
|
|
|
|
with (
|
|
backup_file.open("rb") as raw_handle,
|
|
zstd.open(raw_handle, "r") as compressed,
|
|
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
|
|
):
|
|
content = handle.read()
|
|
|
|
# Should NOT contain django admin, silk, or debug toolbar tables
|
|
assert "django_session" not in content
|
|
assert "silk_" not in content
|
|
assert "debug_toolbar_" not in content
|
|
assert "django_admin_log" not in content
|
|
|
|
# Should contain twitch tables
|
|
assert "twitch_" in content
|
|
|
|
def test_backup_with_custom_prefix(self, tmp_path: Path) -> None:
|
|
"""Test that custom prefix is used in filename."""
|
|
# Create test data so tables exist
|
|
Organization.objects.create(twitch_id="test002", name="Test Org")
|
|
|
|
output_dir = tmp_path / "backups"
|
|
output_dir.mkdir()
|
|
|
|
call_command("backup_db", output_dir=str(output_dir), prefix="custom")
|
|
|
|
backup_files = list(output_dir.glob("custom-*.sql.zst"))
|
|
assert len(backup_files) == 1
|
|
|
|
def test_backup_creates_output_directory(self, tmp_path: Path) -> None:
|
|
"""Test that backup command creates output directory if missing."""
|
|
# Create test data so tables exist
|
|
Organization.objects.create(twitch_id="test003", name="Test Org")
|
|
|
|
output_dir = tmp_path / "nonexistent" / "backups"
|
|
|
|
call_command("backup_db", output_dir=str(output_dir), prefix="test")
|
|
|
|
assert output_dir.exists()
|
|
assert len(list(output_dir.glob("test-*.sql.zst"))) == 1
|
|
|
|
def test_backup_uses_default_directory(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Test that backup uses DATA_DIR/datasets by default."""
|
|
# Create test data so tables exist
|
|
Organization.objects.create(twitch_id="test004", name="Test Org")
|
|
|
|
monkeypatch.setattr(settings, "DATA_DIR", tmp_path)
|
|
datasets_dir = tmp_path / "datasets"
|
|
datasets_dir.mkdir(exist_ok=True, parents=True)
|
|
|
|
call_command("backup_db")
|
|
|
|
backup_files = list(datasets_dir.glob("ttvdrops-*.sql.zst"))
|
|
assert len(backup_files) >= 1
|
|
|
|
|
|
@pytest.mark.django_db
|
|
class TestBackupHelperFunctions:
|
|
"""Tests for backup command helper functions."""
|
|
|
|
def test_get_allowed_tables_filters_by_prefix(self) -> None:
|
|
"""Test that _get_allowed_tables returns only matching tables."""
|
|
# Use Django's connection to access the test database
|
|
db_connection = connection.connection
|
|
tables = _get_allowed_tables(db_connection, "twitch_")
|
|
|
|
assert len(tables) > 0
|
|
assert all(table.startswith("twitch_") for table in tables)
|
|
assert "twitch_organization" in tables
|
|
assert "twitch_game" in tables
|
|
|
|
def test_get_allowed_tables_excludes_non_matching(self) -> None:
|
|
"""Test that _get_allowed_tables excludes non-matching tables."""
|
|
# Use Django's connection to access the test database
|
|
db_connection = connection.connection
|
|
tables = _get_allowed_tables(db_connection, "twitch_")
|
|
|
|
# Should not include django, silk, or debug toolbar tables
|
|
assert not any(table.startswith("django_") for table in tables)
|
|
assert not any(table.startswith("silk_") for table in tables)
|
|
assert not any(table.startswith("debug_toolbar_") for table in tables)
|
|
|
|
def test_sql_literal_handles_none(self) -> None:
|
|
"""Test _sql_literal converts None to NULL."""
|
|
assert _sql_literal(None) == "NULL"
|
|
|
|
def test_sql_literal_handles_booleans(self) -> None:
|
|
"""Test _sql_literal converts booleans to 1/0."""
|
|
assert _sql_literal(True) == "1"
|
|
assert _sql_literal(False) == "0"
|
|
|
|
def test_sql_literal_handles_numbers(self) -> None:
|
|
"""Test _sql_literal handles int and float."""
|
|
assert _sql_literal(42) == "42"
|
|
assert _sql_literal(math.pi) == str(math.pi)
|
|
|
|
def test_sql_literal_handles_strings(self) -> None:
|
|
"""Test _sql_literal quotes and escapes strings."""
|
|
assert _sql_literal("test") == "'test'"
|
|
assert _sql_literal("o'reilly") == "'o''reilly'"
|
|
assert _sql_literal("test\nline") == "'test\nline'"
|
|
|
|
def test_sql_literal_handles_bytes(self) -> None:
|
|
"""Test _sql_literal converts bytes to hex notation."""
|
|
assert _sql_literal(b"\x00\x01\x02") == "X'000102'"
|
|
assert _sql_literal(b"hello") == "X'68656c6c6f'"
|
|
|
|
def test_write_dump_includes_schema_and_data(self) -> None:
|
|
"""Test _write_dump writes complete SQL dump."""
|
|
# Create test data
|
|
Organization.objects.create(twitch_id="test789", name="Write Test Org")
|
|
|
|
# Use Django's connection to access the test database
|
|
db_connection = connection.connection
|
|
output = io.StringIO()
|
|
|
|
tables = _get_allowed_tables(db_connection, "twitch_")
|
|
_write_dump(output, db_connection, tables)
|
|
|
|
content = output.getvalue()
|
|
|
|
# Check for SQL structure
|
|
assert "PRAGMA foreign_keys=OFF;" in content
|
|
assert "BEGIN TRANSACTION;" in content
|
|
assert "COMMIT;" in content
|
|
assert "PRAGMA foreign_keys=ON;" in content
|
|
|
|
# Check for schema
|
|
assert "CREATE TABLE" in content
|
|
assert "twitch_organization" in content
|
|
|
|
# Check for data
|
|
assert "INSERT INTO" in content
|
|
assert "Write Test Org" in content
|
|
|
|
|
|
@pytest.mark.django_db
|
|
class TestDatasetBackupViews:
|
|
"""Tests for dataset backup list and download views."""
|
|
|
|
@pytest.fixture
|
|
def datasets_dir(self, tmp_path: Path) -> Path:
|
|
"""Create a temporary datasets directory.
|
|
|
|
Returns:
|
|
Path to the created datasets directory.
|
|
"""
|
|
datasets_dir = tmp_path / "datasets"
|
|
datasets_dir.mkdir()
|
|
return datasets_dir
|
|
|
|
@pytest.fixture
|
|
def sample_backup(self, datasets_dir: Path) -> Path:
|
|
"""Create a sample backup file.
|
|
|
|
Returns:
|
|
Path to the created backup file.
|
|
"""
|
|
backup_file = datasets_dir / "ttvdrops-20260210-120000.sql.zst"
|
|
with (
|
|
backup_file.open("wb") as raw_handle,
|
|
zstd.open(raw_handle, "w") as compressed,
|
|
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
|
|
):
|
|
handle.write("-- Sample backup content\n")
|
|
return backup_file
|
|
|
|
def test_dataset_list_view_shows_backups(
|
|
self,
|
|
client: Client,
|
|
datasets_dir: Path,
|
|
sample_backup: Path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
"""Test that dataset list view displays backup files."""
|
|
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
|
|
|
|
response: _MonkeyPatchedWSGIResponse = client.get(reverse("twitch:dataset_backups"))
|
|
|
|
assert response.status_code == 200
|
|
assert b"ttvdrops-20260210-120000.sql.zst" in response.content
|
|
assert b"1 datasets" in response.content or b"1 dataset" in response.content
|
|
|
|
def test_dataset_list_view_empty_directory(
|
|
self,
|
|
client: Client,
|
|
datasets_dir: Path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
"""Test dataset list view with empty directory."""
|
|
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
|
|
|
|
response: _MonkeyPatchedWSGIResponse = client.get(reverse("twitch:dataset_backups"))
|
|
|
|
assert response.status_code == 200
|
|
assert b"No dataset backups found" in response.content
|
|
|
|
def test_dataset_list_view_sorts_by_date(
|
|
self,
|
|
client: Client,
|
|
datasets_dir: Path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
"""Test that backups are sorted by modification time."""
|
|
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
|
|
|
|
# Create multiple backup files with different timestamps
|
|
older_backup = datasets_dir / "ttvdrops-20260210-100000.sql.zst"
|
|
newer_backup = datasets_dir / "ttvdrops-20260210-140000.sql.zst"
|
|
|
|
for backup in [older_backup, newer_backup]:
|
|
with (
|
|
backup.open("wb") as raw_handle,
|
|
zstd.open(raw_handle, "w") as compressed,
|
|
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
|
|
):
|
|
handle.write("-- Test\n")
|
|
|
|
# Set explicit modification times to ensure proper sorting
|
|
older_time = 1707561600 # 2024-02-10 10:00:00 UTC
|
|
newer_time = 1707575400 # 2024-02-10 14:00:00 UTC
|
|
os.utime(older_backup, (older_time, older_time))
|
|
os.utime(newer_backup, (newer_time, newer_time))
|
|
|
|
response: _MonkeyPatchedWSGIResponse = client.get(reverse("twitch:dataset_backups"))
|
|
|
|
content = response.content.decode()
|
|
newer_pos = content.find("20260210-140000")
|
|
older_pos = content.find("20260210-100000")
|
|
|
|
# Newer backup should appear first (sorted descending)
|
|
assert 0 < newer_pos < older_pos
|
|
|
|
def test_dataset_download_view_success(
|
|
self,
|
|
client: Client,
|
|
datasets_dir: Path,
|
|
sample_backup: Path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
"""Test successful backup download."""
|
|
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
|
|
|
|
response: _MonkeyPatchedWSGIResponse = client.get(
|
|
reverse("twitch:dataset_backup_download", args=["ttvdrops-20260210-120000.sql.zst"]),
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
# FileResponse may use application/x-compressed for .zst files
|
|
assert "attachment" in response["Content-Disposition"]
|
|
assert "ttvdrops-20260210-120000.sql.zst" in response["Content-Disposition"]
|
|
|
|
def test_dataset_download_prevents_path_traversal(
|
|
self,
|
|
client: Client,
|
|
datasets_dir: Path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
"""Test that path traversal attempts are blocked."""
|
|
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
|
|
|
|
# Attempt path traversal
|
|
response = client.get(reverse("twitch:dataset_backup_download", args=["../../../etc/passwd"]))
|
|
assert response.status_code == 404
|
|
|
|
def test_dataset_download_rejects_invalid_extensions(
|
|
self,
|
|
client: Client,
|
|
datasets_dir: Path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
"""Test that files with invalid extensions cannot be downloaded."""
|
|
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
|
|
|
|
# Create a file with invalid extension
|
|
invalid_file = datasets_dir / "malicious.exe"
|
|
invalid_file.write_text("not a backup")
|
|
|
|
response = client.get(reverse("twitch:dataset_backup_download", args=["malicious.exe"]))
|
|
assert response.status_code == 404
|
|
|
|
def test_dataset_download_file_not_found(
|
|
self,
|
|
client: Client,
|
|
datasets_dir: Path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
"""Test download returns 404 for non-existent file."""
|
|
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
|
|
|
|
response = client.get(reverse("twitch:dataset_backup_download", args=["nonexistent.sql.zst"]))
|
|
assert response.status_code == 404
|
|
|
|
def test_dataset_list_view_shows_file_sizes(
|
|
self,
|
|
client: Client,
|
|
datasets_dir: Path,
|
|
sample_backup: Path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
"""Test that file sizes are displayed in human-readable format."""
|
|
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
|
|
|
|
response: _MonkeyPatchedWSGIResponse = client.get(reverse("twitch:dataset_backups"))
|
|
|
|
assert response.status_code == 200
|
|
# Should contain size information (bytes, KB, MB, or GB)
|
|
content = response.content.decode()
|
|
assert any(unit in content for unit in ["bytes", "KB", "MB", "GB"])
|
|
|
|
def test_dataset_list_ignores_non_zst_files(
|
|
self,
|
|
client: Client,
|
|
datasets_dir: Path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
"""Test that non-zst files are ignored in listing."""
|
|
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
|
|
|
|
# Create various file types
|
|
(datasets_dir / "backup.sql.zst").write_bytes(b"valid")
|
|
(datasets_dir / "readme.txt").write_text("should be ignored")
|
|
(datasets_dir / "old_backup.gz").write_bytes(b"should be ignored")
|
|
|
|
response: _MonkeyPatchedWSGIResponse = client.get(reverse("twitch:dataset_backups"))
|
|
|
|
content = response.content.decode()
|
|
assert "backup.sql.zst" in content
|
|
assert "readme.txt" not in content
|
|
assert "old_backup.gz" not in content
|
|
|
|
def test_dataset_download_view_handles_subdirectories(
|
|
self,
|
|
client: Client,
|
|
datasets_dir: Path,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
"""Test download works with files in subdirectories."""
|
|
monkeypatch.setattr(settings, "DATA_DIR", datasets_dir.parent)
|
|
|
|
# Create subdirectory with backup
|
|
subdir = datasets_dir / "2026" / "02"
|
|
subdir.mkdir(parents=True)
|
|
backup_file = subdir / "backup.sql.zst"
|
|
with (
|
|
backup_file.open("wb") as raw_handle,
|
|
zstd.open(raw_handle, "w") as compressed,
|
|
io.TextIOWrapper(compressed, encoding="utf-8") as handle,
|
|
):
|
|
handle.write("-- Test\n")
|
|
|
|
response: _MonkeyPatchedWSGIResponse = client.get(
|
|
reverse("twitch:dataset_backup_download", args=["2026/02/backup.sql.zst"]),
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert "attachment" in response["Content-Disposition"]
|