diff --git a/twitch/management/commands/backup_db.py b/twitch/management/commands/backup_db.py index 7c4e587..b1b380d 100644 --- a/twitch/management/commands/backup_db.py +++ b/twitch/management/commands/backup_db.py @@ -1,4 +1,6 @@ +import csv import io +import json import os import shutil import subprocess # noqa: S404 @@ -82,6 +84,16 @@ class Command(BaseCommand): msg = f"Unsupported database backend: {django_connection.vendor}" raise CommandError(msg) + json_path: Path = output_dir / f"{prefix}-{timestamp}.json.zst" + _write_json_dump(json_path, allowed_tables) + + csv_paths: list[Path] = _write_csv_dumps( + output_dir, + prefix, + timestamp, + allowed_tables, + ) + created_at: datetime = datetime.fromtimestamp( output_path.stat().st_mtime, tz=timezone.get_current_timezone(), @@ -91,6 +103,10 @@ class Command(BaseCommand): f"Backup created: {output_path} (updated {created_at.isoformat()})", ), ) + self.stdout.write(self.style.SUCCESS(f"JSON backup created: {json_path}")) + self.stdout.write( + self.style.SUCCESS(f"CSV backups created: {len(csv_paths)} files"), + ) self.stdout.write(self.style.SUCCESS(f"Included tables: {len(allowed_tables)}")) @@ -298,3 +314,77 @@ def _sql_literal(value: object) -> str: if isinstance(value, bytes): return "X'" + value.hex() + "'" return "'" + str(value).replace("'", "''") + "'" + + +def _json_default(value: object) -> str: + """Convert non-serializable values to JSON-compatible strings. + + Args: + value: Value to convert. + + Returns: + String representation. + """ + if isinstance(value, bytes): + return value.hex() + return str(value) + + +def _write_json_dump(output_path: Path, tables: list[str]) -> None: + """Write a JSON dump of all tables into a zstd-compressed file. + + Args: + output_path: Destination path for the zstd file. + tables: Table names to include. + """ + data: dict[str, list[dict]] = {} + with django_connection.cursor() as cursor: + for table in tables: + cursor.execute(f'SELECT * FROM "{table}"') # noqa: S608 + columns: list[str] = [col[0] for col in cursor.description] + rows = cursor.fetchall() + data[table] = [dict(zip(columns, row, strict=False)) for row in rows] + + with ( + output_path.open("wb") as raw_handle, + zstd.open(raw_handle, "w") as compressed, + io.TextIOWrapper(compressed, encoding="utf-8") as handle, + ): + json.dump(data, handle, default=_json_default) + + +def _write_csv_dumps( + output_dir: Path, + prefix: str, + timestamp: str, + tables: list[str], +) -> list[Path]: + """Write per-table CSV files into zstd-compressed files. + + Args: + output_dir: Directory where CSV files will be written. + prefix: Filename prefix. + timestamp: Timestamp string for filenames. + tables: Table names to include. + + Returns: + List of created file paths. + """ + paths: list[Path] = [] + with django_connection.cursor() as cursor: + for table in tables: + cursor.execute(f'SELECT * FROM "{table}"') # noqa: S608 + columns: list[str] = [col[0] for col in cursor.description] + rows: list[tuple] = cursor.fetchall() + + output_path: Path = output_dir / f"{prefix}-{timestamp}-{table}.csv.zst" + with ( + output_path.open("wb") as raw_handle, + zstd.open(raw_handle, "w") as compressed, + io.TextIOWrapper(compressed, encoding="utf-8") as handle, + ): + writer: csv.Writer = csv.writer(handle) + writer.writerow(columns) + writer.writerows(rows) + paths.append(output_path) + return paths diff --git a/twitch/tests/test_backup.py b/twitch/tests/test_backup.py index 6bc0b0d..40e88ca 100644 --- a/twitch/tests/test_backup.py +++ b/twitch/tests/test_backup.py @@ -1,8 +1,11 @@ +import csv import io +import json import math import os import shutil from compression import zstd +from datetime import datetime as dt from typing import TYPE_CHECKING import pytest @@ -12,13 +15,18 @@ from django.db import connection from django.urls import reverse from twitch.management.commands.backup_db import _get_allowed_tables +from twitch.management.commands.backup_db import _json_default from twitch.management.commands.backup_db import _sql_literal +from twitch.management.commands.backup_db import _write_csv_dumps +from twitch.management.commands.backup_db import _write_json_dump from twitch.management.commands.backup_db import _write_postgres_dump from twitch.management.commands.backup_db import _write_sqlite_dump from twitch.models import Game from twitch.models import Organization if TYPE_CHECKING: + from csv import Reader + from datetime import datetime from pathlib import Path from django.test import Client @@ -164,6 +172,59 @@ class TestBackupCommand: backup_files = list(datasets_dir.glob("ttvdrops-*.sql.zst")) assert len(backup_files) >= 1 + def test_backup_creates_json_file(self, tmp_path: Path) -> None: + """Test that backup command creates a JSON file alongside the SQL dump.""" + _skip_if_pg_dump_missing() + Organization.objects.create(twitch_id="test_json", name="Test Org JSON") + + output_dir: Path = tmp_path / "backups" + output_dir.mkdir() + + call_command("backup_db", output_dir=str(output_dir), prefix="test") + + json_files: list[Path] = list(output_dir.glob("test-*.json.zst")) + assert len(json_files) == 1 + + with ( + json_files[0].open("rb") as raw_handle, + zstd.open(raw_handle, "r") as compressed, + io.TextIOWrapper(compressed, encoding="utf-8") as handle, + ): + data = json.load(handle) + + assert isinstance(data, dict) + assert "twitch_organization" in data + assert any( + row.get("name") == "Test Org JSON" for row in data["twitch_organization"] + ) + + def test_backup_creates_csv_files(self, tmp_path: Path) -> None: + """Test that backup command creates per-table CSV files alongside the SQL dump.""" + _skip_if_pg_dump_missing() + Organization.objects.create(twitch_id="test_csv", name="Test Org CSV") + + output_dir: Path = tmp_path / "backups" + output_dir.mkdir() + + call_command("backup_db", output_dir=str(output_dir), prefix="test") + + org_csv_files: list[Path] = list( + output_dir.glob("test-*-twitch_organization.csv.zst"), + ) + assert len(org_csv_files) == 1 + + with ( + org_csv_files[0].open("rb") as raw_handle, + zstd.open(raw_handle, "r") as compressed, + io.TextIOWrapper(compressed, encoding="utf-8") as handle, + ): + reader: Reader = csv.reader(handle) + rows: list[list[str]] = list(reader) + + assert len(rows) >= 2 # header + at least one data row + assert "name" in rows[0] + assert any("Test Org CSV" in row for row in rows[1:]) + @pytest.mark.django_db class TestBackupHelperFunctions: @@ -250,6 +311,71 @@ class TestBackupHelperFunctions: assert "INSERT INTO" in content assert "Write Test Org" in content + def test_write_json_dump_creates_valid_json(self, tmp_path: Path) -> None: + """Test _write_json_dump creates valid compressed JSON with all tables.""" + Organization.objects.create( + twitch_id="test_json_helper", + name="JSON Helper Org", + ) + + tables: list[str] = _get_allowed_tables("twitch_") + output_path: Path = tmp_path / "backup.json.zst" + _write_json_dump(output_path, tables) + + with ( + output_path.open("rb") as raw_handle, + zstd.open(raw_handle, "r") as compressed, + io.TextIOWrapper(compressed, encoding="utf-8") as handle, + ): + data = json.load(handle) + + assert isinstance(data, dict) + assert "twitch_organization" in data + assert all(table in data for table in tables) + assert any( + row.get("name") == "JSON Helper Org" for row in data["twitch_organization"] + ) + + def test_write_csv_dumps_creates_per_table_files(self, tmp_path: Path) -> None: + """Test _write_csv_dumps creates one compressed CSV file per table.""" + Organization.objects.create(twitch_id="test_csv_helper", name="CSV Helper Org") + + tables: list[str] = _get_allowed_tables("twitch_") + paths: list[Path] = _write_csv_dumps( + tmp_path, + "test", + "20260317-120000", + tables, + ) + + assert len(paths) == len(tables) + assert all(p.exists() for p in paths) + + org_csv: Path = tmp_path / "test-20260317-120000-twitch_organization.csv.zst" + assert org_csv.exists() + + with ( + org_csv.open("rb") as raw_handle, + zstd.open(raw_handle, "r") as compressed, + io.TextIOWrapper(compressed, encoding="utf-8") as handle, + ): + reader: Reader = csv.reader(handle) + rows: list[list[str]] = list(reader) + + assert len(rows) >= 2 # header + at least one data row + assert "name" in rows[0] + assert any("CSV Helper Org" in row for row in rows[1:]) + + def test_json_default_handles_bytes(self) -> None: + """Test _json_default converts bytes to hex string.""" + assert _json_default(b"\x00\x01") == "0001" + assert _json_default(b"hello") == "68656c6c6f" + + def test_json_default_handles_other_types(self) -> None: + """Test _json_default falls back to str() for other types.""" + value: datetime = dt(2026, 3, 17, 12, 0, 0, tzinfo=dt.now().astimezone().tzinfo) + assert _json_default(value) == str(value) + @pytest.mark.django_db class TestDatasetBackupViews: