Improve performance and add type hints

This commit is contained in:
Joakim Hellsén 2026-04-11 00:44:16 +02:00
commit b7e10e766e
Signed by: Joakim Hellsén
SSH key fingerprint: SHA256:/9h/CsExpFp+PRhsfA0xznFx2CGfTT5R/kpuFfUgEQk
23 changed files with 745 additions and 178 deletions

View file

@ -5,7 +5,7 @@ import logging
from celery import shared_task from celery import shared_task
from django.core.management import call_command from django.core.management import call_command
logger = logging.getLogger("ttvdrops.tasks") logger: logging.Logger = logging.getLogger("ttvdrops.tasks")
@shared_task(bind=True, queue="imports", max_retries=3, default_retry_delay=60) @shared_task(bind=True, queue="imports", max_retries=3, default_retry_delay=60)

View file

@ -136,7 +136,7 @@ class ImportChzzkCampaignRangeCommandTest(TestCase):
stdout = StringIO() stdout = StringIO()
stderr = StringIO() stderr = StringIO()
def side_effect(command: str, *args: str, **kwargs: object) -> None: def side_effect(command: str, *args: str, **kwargs: StringIO) -> None:
if "4" in args: if "4" in args:
msg = "Campaign 4 not found" msg = "Campaign 4 not found"
raise CommandError(msg) raise CommandError(msg)

View file

@ -1,7 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from django.db.models import Q from django.db.models import Q
from django.db.models.query import QuerySet
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.shortcuts import render from django.shortcuts import render
from django.urls import reverse from django.urls import reverse
@ -16,9 +15,9 @@ from twitch.feeds import TTVDropsBaseFeed
if TYPE_CHECKING: if TYPE_CHECKING:
import datetime import datetime
from django.db.models.query import QuerySet
from django.http import HttpResponse from django.http import HttpResponse
from django.http.request import HttpRequest from django.http.request import HttpRequest
from pytest_django.asserts import QuerySet
def dashboard_view(request: HttpRequest) -> HttpResponse: def dashboard_view(request: HttpRequest) -> HttpResponse:

View file

@ -224,7 +224,7 @@ DATABASES: dict[str, dict[str, Any]] = configure_databases(
base_dir=BASE_DIR, base_dir=BASE_DIR,
) )
if DEBUG: if DEBUG or TESTING:
INSTALLED_APPS.append("zeal") INSTALLED_APPS.append("zeal")
MIDDLEWARE.append("zeal.middleware.zeal_middleware") MIDDLEWARE.append("zeal.middleware.zeal_middleware")

25
conftest.py Normal file
View file

@ -0,0 +1,25 @@
from typing import TYPE_CHECKING
import pytest
from zeal import zeal_context
if TYPE_CHECKING:
from collections.abc import Generator
@pytest.fixture(autouse=True)
def use_zeal(request: pytest.FixtureRequest) -> Generator[None]:
"""Enable Zeal N+1 detection context for each pytest test.
Use @pytest.mark.no_zeal for tests that intentionally exercise import paths
where Zeal's strict get() heuristics are too noisy.
Yields:
None: Control back to pytest for test execution.
"""
if request.node.get_closest_marker("no_zeal") is not None:
yield
return
with zeal_context():
yield

View file

@ -69,7 +69,7 @@ class _TTVDropsSite:
domain: str domain: str
def get_current_site(request: object) -> _TTVDropsSite: def get_current_site(request: HttpRequest | None) -> _TTVDropsSite:
"""Return a site-like object with domain derived from BASE_URL.""" """Return a site-like object with domain derived from BASE_URL."""
base_url: str = _get_base_url() base_url: str = _get_base_url()
parts: SplitResult = urlsplit(base_url) parts: SplitResult = urlsplit(base_url)

View file

@ -5,7 +5,7 @@ import logging
from celery import shared_task from celery import shared_task
from django.core.management import call_command from django.core.management import call_command
logger = logging.getLogger("ttvdrops.tasks") logger: logging.Logger = logging.getLogger("ttvdrops.tasks")
@shared_task(bind=True, queue="default", max_retries=3, default_retry_delay=300) @shared_task(bind=True, queue="default", max_retries=3, default_retry_delay=300)

View file

@ -5,6 +5,7 @@ from django.urls import reverse
if TYPE_CHECKING: if TYPE_CHECKING:
from django.test.client import Client from django.test.client import Client
from pytest_django.fixtures import SettingsWrapper
def _extract_locs(xml_bytes: bytes) -> list[str]: def _extract_locs(xml_bytes: bytes) -> list[str]:
@ -15,7 +16,7 @@ def _extract_locs(xml_bytes: bytes) -> list[str]:
def test_sitemap_static_contains_expected_links( def test_sitemap_static_contains_expected_links(
client: Client, client: Client,
settings: object, settings: SettingsWrapper,
) -> None: ) -> None:
"""Ensure the static sitemap contains the main site links across apps. """Ensure the static sitemap contains the main site links across apps.

View file

@ -15,11 +15,9 @@ from django.db.models import Max
from django.db.models import OuterRef from django.db.models import OuterRef
from django.db.models import Prefetch from django.db.models import Prefetch
from django.db.models import Q from django.db.models import Q
from django.db.models import QuerySet
from django.db.models.functions import Trim from django.db.models.functions import Trim
from django.http import FileResponse from django.http import FileResponse
from django.http import Http404 from django.http import Http404
from django.http import HttpRequest
from django.http import HttpResponse from django.http import HttpResponse
from django.shortcuts import render from django.shortcuts import render
from django.template.defaultfilters import filesizeformat from django.template.defaultfilters import filesizeformat

View file

@ -206,8 +206,8 @@ class KickOrganizationFeed(TTVDropsBaseFeed):
def __call__( def __call__(
self, self,
request: HttpRequest, request: HttpRequest,
*args: object, *args: str | int,
**kwargs: object, **kwargs: str | int,
) -> HttpResponse: ) -> HttpResponse:
"""Capture optional ?limit query parameter. """Capture optional ?limit query parameter.
@ -283,8 +283,8 @@ class KickCategoryFeed(TTVDropsBaseFeed):
def __call__( def __call__(
self, self,
request: HttpRequest, request: HttpRequest,
*args: object, *args: str | int,
**kwargs: object, **kwargs: str | int,
) -> HttpResponse: ) -> HttpResponse:
"""Capture optional ?limit query parameter. """Capture optional ?limit query parameter.
@ -372,8 +372,8 @@ class KickCampaignFeed(TTVDropsBaseFeed):
def __call__( def __call__(
self, self,
request: HttpRequest, request: HttpRequest,
*args: object, *args: str | int,
**kwargs: object, **kwargs: str | int,
) -> HttpResponse: ) -> HttpResponse:
"""Capture optional ?limit query parameter. """Capture optional ?limit query parameter.
@ -481,8 +481,8 @@ class KickCategoryCampaignFeed(TTVDropsBaseFeed):
def __call__( def __call__(
self, self,
request: HttpRequest, request: HttpRequest,
*args: object, *args: str | int,
**kwargs: object, **kwargs: str | int,
) -> HttpResponse: ) -> HttpResponse:
"""Capture optional ?limit query parameter. """Capture optional ?limit query parameter.

View file

@ -1,4 +1,7 @@
from __future__ import annotations
import logging import logging
from datetime import datetime
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import httpx import httpx
@ -14,6 +17,8 @@ from kick.models import KickUser
from kick.schemas import KickDropsResponseSchema from kick.schemas import KickDropsResponseSchema
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Mapping
from django.core.management.base import CommandParser from django.core.management.base import CommandParser
from kick.schemas import KickCategorySchema from kick.schemas import KickCategorySchema
@ -23,6 +28,26 @@ if TYPE_CHECKING:
logger: logging.Logger = logging.getLogger("ttvdrops") logger: logging.Logger = logging.getLogger("ttvdrops")
type KickImportModel = (
KickOrganization
| KickCategory
| KickDropCampaign
| KickUser
| KickChannel
| KickReward
)
type KickFieldValue = (
str
| bool
| int
| datetime
| KickOrganization
| KickCategory
| KickDropCampaign
| KickUser
| None
)
KICK_DROPS_API_URL = "https://web.kick.com/api/v1/drops/campaigns" KICK_DROPS_API_URL = "https://web.kick.com/api/v1/drops/campaigns"
# Kick's public API requires a browser-like User-Agent. # Kick's public API requires a browser-like User-Agent.
@ -48,7 +73,26 @@ class Command(BaseCommand):
help="API endpoint to fetch (default: %(default)s).", help="API endpoint to fetch (default: %(default)s).",
) )
def handle(self, *args: object, **options: object) -> None: # noqa: ARG002 @staticmethod
def _save_if_changed(
obj: KickImportModel,
defaults: Mapping[str, KickFieldValue],
) -> None:
"""Persist only changed fields to avoid unnecessary updates."""
changed_fields: list[str] = []
for field, new_value in defaults.items():
if getattr(obj, field, None) != new_value:
setattr(obj, field, new_value)
changed_fields.append(field)
if changed_fields:
obj.save(update_fields=changed_fields)
def handle(
self,
*_args: str,
**options: str | bool | int | None,
) -> None:
"""Main entry point for the command.""" """Main entry point for the command."""
url: str = str(options["url"]) url: str = str(options["url"])
self.stdout.write(f"Fetching Kick drops from {url} ...") self.stdout.write(f"Fetching Kick drops from {url} ...")
@ -99,54 +143,75 @@ class Command(BaseCommand):
self.style.SUCCESS(f"Imported {imported}/{len(campaigns)} campaign(s)."), self.style.SUCCESS(f"Imported {imported}/{len(campaigns)} campaign(s)."),
) )
def _import_campaign(self, data: KickDropCampaignSchema) -> None: def _import_campaign(self, data: KickDropCampaignSchema) -> None: # noqa: PLR0914, PLR0915
"""Import a single campaign and all its related objects.""" """Import a single campaign and all its related objects."""
# Organisation # Organization
org_data: KickOrganizationSchema = data.organization org_data: KickOrganizationSchema = data.organization
org, created = KickOrganization.objects.update_or_create( org_defaults: dict[str, str | bool] = {
"name": org_data.name,
"logo_url": org_data.logo_url,
"url": org_data.url,
"restricted": org_data.restricted,
}
org: KickOrganization | None = KickOrganization.objects.filter(
kick_id=org_data.id, kick_id=org_data.id,
defaults={ ).first()
"name": org_data.name, created: bool = org is None
"logo_url": org_data.logo_url, if org is None:
"url": org_data.url, org = KickOrganization.objects.create(kick_id=org_data.id, **org_defaults)
"restricted": org_data.restricted, else:
}, self._save_if_changed(org, org_defaults)
)
if created: if created:
logger.info("Created new organization: %s", org.kick_id) logger.info("Created new organization: %s", org.kick_id)
# Category # Category
cat_data: KickCategorySchema = data.category cat_data: KickCategorySchema = data.category
category, created = KickCategory.objects.update_or_create( category_defaults: dict[str, KickFieldValue] = {
"name": cat_data.name,
"slug": cat_data.slug,
"image_url": cat_data.image_url,
}
category: KickCategory | None = KickCategory.objects.filter(
kick_id=cat_data.id, kick_id=cat_data.id,
defaults={ ).first()
"name": cat_data.name, created = category is None
"slug": cat_data.slug, if category is None:
"image_url": cat_data.image_url, category = KickCategory.objects.create(
}, kick_id=cat_data.id,
) **category_defaults,
)
else:
self._save_if_changed(category, category_defaults)
if created: if created:
logger.info("Created new category: %s", category.kick_id) logger.info("Created new category: %s", category.kick_id)
# Campaign # Campaign
campaign, created = KickDropCampaign.objects.update_or_create( campaign_defaults: dict[str, KickFieldValue] = {
"name": data.name,
"status": data.status,
"starts_at": data.starts_at,
"ends_at": data.ends_at,
"connect_url": data.connect_url,
"url": data.url,
"rule_id": data.rule.id,
"rule_name": data.rule.name,
"organization": org,
"category": category,
"created_at": data.created_at,
"api_updated_at": data.updated_at,
"is_fully_imported": True,
}
campaign: KickDropCampaign | None = KickDropCampaign.objects.filter(
kick_id=data.id, kick_id=data.id,
defaults={ ).first()
"name": data.name, created = campaign is None
"status": data.status, if campaign is None:
"starts_at": data.starts_at, campaign = KickDropCampaign.objects.create(
"ends_at": data.ends_at, kick_id=data.id,
"connect_url": data.connect_url, **campaign_defaults,
"url": data.url, )
"rule_id": data.rule.id, else:
"rule_name": data.rule.name, self._save_if_changed(campaign, campaign_defaults)
"organization": org,
"category": category,
"created_at": data.created_at,
"api_updated_at": data.updated_at,
"is_fully_imported": True,
},
)
if created: if created:
logger.info("Created new campaign: %s", campaign.kick_id) logger.info("Created new campaign: %s", campaign.kick_id)
@ -154,25 +219,38 @@ class Command(BaseCommand):
channel_objs: list[KickChannel] = [] channel_objs: list[KickChannel] = []
for ch_data in data.channels: for ch_data in data.channels:
user_data: KickUserSchema = ch_data.user user_data: KickUserSchema = ch_data.user
user, created = KickUser.objects.update_or_create( user_defaults: dict[str, KickFieldValue] = {
"username": user_data.username,
"profile_picture": user_data.profile_picture,
}
user: KickUser | None = KickUser.objects.filter(
kick_id=user_data.id, kick_id=user_data.id,
defaults={ ).first()
"username": user_data.username, created = user is None
"profile_picture": user_data.profile_picture, if user is None:
}, user = KickUser.objects.create(kick_id=user_data.id, **user_defaults)
) else:
self._save_if_changed(user, user_defaults)
if created: if created:
logger.info("Created new user: %s", user.kick_id) logger.info("Created new user: %s", user.kick_id)
channel, created = KickChannel.objects.update_or_create( channel_defaults: dict[str, KickFieldValue] = {
"slug": ch_data.slug,
"description": ch_data.description,
"banner_picture_url": ch_data.banner_picture_url,
"user": user,
}
channel: KickChannel | None = KickChannel.objects.filter(
kick_id=ch_data.id, kick_id=ch_data.id,
defaults={ ).first()
"slug": ch_data.slug, created = channel is None
"description": ch_data.description, if channel is None:
"banner_picture_url": ch_data.banner_picture_url, channel = KickChannel.objects.create(
"user": user, kick_id=ch_data.id,
}, **channel_defaults,
) )
else:
self._save_if_changed(channel, channel_defaults)
if created: if created:
logger.info("Created new channel: %s", channel.kick_id) logger.info("Created new channel: %s", channel.kick_id)
@ -184,36 +262,46 @@ class Command(BaseCommand):
# Resolve reward's category (may differ from campaign category) # Resolve reward's category (may differ from campaign category)
reward_category: KickCategory = category reward_category: KickCategory = category
if reward_data.category_id != cat_data.id: if reward_data.category_id != cat_data.id:
reward_category, created = KickCategory.objects.get_or_create( reward_category = KickCategory.objects.filter(
kick_id=reward_data.category_id, kick_id=reward_data.category_id,
defaults={"name": "", "slug": "", "image_url": ""}, ).first() or KickCategory.objects.create(
kick_id=reward_data.category_id,
name="",
slug="",
image_url="",
) )
created = not reward_category.name and not reward_category.slug
if created: if created:
logger.info("Created new category: %s", reward_category.kick_id) logger.info("Created new category: %s", reward_category.kick_id)
# Resolve reward's organization (may differ from campaign org) # Resolve reward's organization (may differ from campaign org)
reward_org: KickOrganization = org reward_org: KickOrganization = org
if reward_data.organization_id != org_data.id: if reward_data.organization_id != org_data.id:
reward_org, created = KickOrganization.objects.get_or_create( reward_org = KickOrganization.objects.filter(
kick_id=reward_data.organization_id, kick_id=reward_data.organization_id,
defaults={ ).first() or KickOrganization.objects.create(
"name": "", kick_id=reward_data.organization_id,
"logo_url": "", name="",
"url": "", logo_url="",
"restricted": False, url="",
}, restricted=False,
) )
created = not reward_org.name and not reward_org.url
if created: if created:
logger.info("Created new organization: %s", reward_org.kick_id) logger.info("Created new organization: %s", reward_org.kick_id)
KickReward.objects.update_or_create( reward_defaults: dict[str, KickFieldValue] = {
"name": reward_data.name,
"image_url": reward_data.image_url,
"required_units": reward_data.required_units,
"campaign": campaign,
"category": reward_category,
"organization": reward_org,
}
reward: KickReward | None = KickReward.objects.filter(
kick_id=reward_data.id, kick_id=reward_data.id,
defaults={ ).first()
"name": reward_data.name, if reward is None:
"image_url": reward_data.image_url, KickReward.objects.create(kick_id=reward_data.id, **reward_defaults)
"required_units": reward_data.required_units, else:
"campaign": campaign, self._save_if_changed(reward, reward_defaults)
"category": reward_category,
"organization": reward_org,
},
)

View file

@ -669,20 +669,25 @@ class KickDashboardViewTest(TestCase):
class KickCampaignListViewTest(TestCase): class KickCampaignListViewTest(TestCase):
"""Tests for the kick campaign list view.""" """Tests for the kick campaign list view."""
@classmethod
def setUpTestData(cls) -> None:
"""Set up shared test data for campaign list view tests."""
cls.org: KickOrganization = KickOrganization.objects.create(
kick_id="org-list",
name="List Org",
)
cls.cat: KickCategory = KickCategory.objects.create(
kick_id=300,
name="List Cat",
slug="list-cat",
)
def _make_campaign( def _make_campaign(
self, self,
kick_id: str, kick_id: str,
name: str, name: str,
status: str = "active", status: str = "active",
) -> KickDropCampaign: ) -> KickDropCampaign:
org, _ = KickOrganization.objects.get_or_create(
kick_id="org-list",
defaults={"name": "List Org"},
)
cat, _ = KickCategory.objects.get_or_create(
kick_id=300,
defaults={"name": "List Cat", "slug": "list-cat"},
)
# Set dates so the active/expired filter works correctly # Set dates so the active/expired filter works correctly
if status == "active": if status == "active":
starts_at = dt(2020, 1, 1, tzinfo=UTC) starts_at = dt(2020, 1, 1, tzinfo=UTC)
@ -696,8 +701,8 @@ class KickCampaignListViewTest(TestCase):
status=status, status=status,
starts_at=starts_at, starts_at=starts_at,
ends_at=ends_at, ends_at=ends_at,
organization=org, organization=self.org,
category=cat, category=self.cat,
rule_id=1, rule_id=1,
rule_name="Watch to redeem", rule_name="Watch to redeem",
is_fully_imported=True, is_fully_imported=True,

View file

@ -52,6 +52,7 @@ dev = [
DJANGO_SETTINGS_MODULE = "config.settings" DJANGO_SETTINGS_MODULE = "config.settings"
python_files = ["test_*.py", "*_test.py"] python_files = ["test_*.py", "*_test.py"]
addopts = "--tb=short -n auto --cov" addopts = "--tb=short -n auto --cov"
markers = ["no_zeal: run test without zeal_context N+1 checks"]
filterwarnings = [ filterwarnings = [
"ignore:Parsing dates involving a day of month without a year specified is ambiguous:DeprecationWarning", "ignore:Parsing dates involving a day of month without a year specified is ambiguous:DeprecationWarning",
] ]

View file

@ -69,7 +69,7 @@
<div> <div>
<a href="{% url 'twitch:campaign_detail' campaign_data.campaign.twitch_id %}"> <a href="{% url 'twitch:campaign_detail' campaign_data.campaign.twitch_id %}">
{% picture campaign_data.image_url alt="Image for "|add:campaign_data.campaign.name width=120 %} {% picture campaign_data.image_url alt="Image for "|add:campaign_data.campaign.name width=120 %}
<h4 style="margin: 0.5rem 0; text-align: left;">{{ campaign_data.campaign.clean_name }}</h4> <h4 style="margin: 0.5rem 0; text-align: left;">{{ campaign_data.clean_name }}</h4>
</a> </a>
<!-- End time --> <!-- End time -->
<time datetime="{{ campaign_data.campaign.end_at|date:'c' }}" <time datetime="{{ campaign_data.campaign.end_at|date:'c' }}"
@ -114,11 +114,11 @@
{% endfor %} {% endfor %}
{% else %} {% else %}
<!-- No allowed channels means drops are available in any stream of the game's category --> <!-- No allowed channels means drops are available in any stream of the game's category -->
{% if campaign.game.twitch_directory_url %} {% if campaign_data.game_twitch_directory_url %}
<li> <li>
<a href="{{ campaign.game.twitch_directory_url }}" <a href="{{ campaign_data.game_twitch_directory_url }}"
title="Open Twitch category page for {{ campaign_data.campaign.game.display_name }} with Drops filter"> title="Open Twitch category page for {{ campaign_data.game_display_name }} with Drops filter">
Browse {{ campaign_data.campaign.game.display_name }} category Browse {{ campaign_data.game_display_name }} category
</a> </a>
</li> </li>
{% else %} {% else %}
@ -131,10 +131,10 @@
</li> </li>
{% endif %} {% endif %}
{% else %} {% else %}
{% if campaign_data.campaign.game.twitch_directory_url %} {% if campaign_data.game_twitch_directory_url %}
<li> <li>
<a href="{{ campaign_data.campaign.game.twitch_directory_url }}" <a href="{{ campaign_data.game_twitch_directory_url }}"
title="Find streamers playing {{ campaign_data.campaign.game.display_name }} with drops enabled"> title="Find streamers playing {{ campaign_data.game_display_name }} with drops enabled">
Go to a participating live channel Go to a participating live channel
</a> </a>
</li> </li>

View file

@ -161,7 +161,7 @@ class TTVDropsBaseFeed(Feed):
response.content = content.encode(encoding) response.content = content.encode(encoding)
def get_feed(self, obj: object, request: HttpRequest) -> SyndicationFeed: def get_feed(self, obj: Model | None, request: HttpRequest) -> SyndicationFeed:
"""Use deterministic BASE_URL handling for syndication feed generation. """Use deterministic BASE_URL handling for syndication feed generation.
Returns: Returns:
@ -199,8 +199,8 @@ class TTVDropsBaseFeed(Feed):
def __call__( def __call__(
self, self,
request: HttpRequest, request: HttpRequest,
*args: object, *args: str | int,
**kwargs: object, **kwargs: str | int,
) -> HttpResponse: ) -> HttpResponse:
"""Return feed response with inline content disposition for browser display.""" """Return feed response with inline content disposition for browser display."""
original_stylesheets: list[str] = self.stylesheets original_stylesheets: list[str] = self.stylesheets
@ -745,8 +745,8 @@ class OrganizationRSSFeed(TTVDropsBaseFeed):
def __call__( def __call__(
self, self,
request: HttpRequest, request: HttpRequest,
*args: object, *args: str | int,
**kwargs: object, **kwargs: str | int,
) -> HttpResponse: ) -> HttpResponse:
"""Override to capture limit parameter from request. """Override to capture limit parameter from request.
@ -822,8 +822,8 @@ class GameFeed(TTVDropsBaseFeed):
def __call__( def __call__(
self, self,
request: HttpRequest, request: HttpRequest,
*args: object, *args: str | int,
**kwargs: object, **kwargs: str | int,
) -> HttpResponse: ) -> HttpResponse:
"""Override to capture limit parameter from request. """Override to capture limit parameter from request.
@ -975,8 +975,8 @@ class DropCampaignFeed(TTVDropsBaseFeed):
def __call__( def __call__(
self, self,
request: HttpRequest, request: HttpRequest,
*args: object, *args: str | int,
**kwargs: object, **kwargs: str | int,
) -> HttpResponse: ) -> HttpResponse:
"""Override to capture limit parameter from request. """Override to capture limit parameter from request.
@ -1114,8 +1114,8 @@ class GameCampaignFeed(TTVDropsBaseFeed):
def __call__( def __call__(
self, self,
request: HttpRequest, request: HttpRequest,
*args: object, *args: str | int,
**kwargs: object, **kwargs: str | int,
) -> HttpResponse: ) -> HttpResponse:
"""Override to capture limit parameter from request. """Override to capture limit parameter from request.
@ -1293,8 +1293,8 @@ class RewardCampaignFeed(TTVDropsBaseFeed):
def __call__( def __call__(
self, self,
request: HttpRequest, request: HttpRequest,
*args: object, *args: str | int,
**kwargs: object, **kwargs: str | int,
) -> HttpResponse: ) -> HttpResponse:
"""Override to capture limit parameter from request. """Override to capture limit parameter from request.

View file

@ -7,6 +7,7 @@ from compression import zstd
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Protocol
from django.conf import settings from django.conf import settings
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
@ -19,6 +20,15 @@ if TYPE_CHECKING:
from argparse import ArgumentParser from argparse import ArgumentParser
class SupportsStr(Protocol):
"""Protocol for values that provide a string representation."""
def __str__(self) -> str: ...
type SqlSerializable = bool | int | float | bytes | SupportsStr | None
class Command(BaseCommand): class Command(BaseCommand):
"""Create a compressed SQL dump of the Twitch and Kick dataset tables.""" """Create a compressed SQL dump of the Twitch and Kick dataset tables."""
@ -285,7 +295,7 @@ def _write_postgres_dump(output_path: Path, tables: list[str]) -> None:
raise CommandError(msg) raise CommandError(msg)
def _sql_literal(value: object) -> str: def _sql_literal(value: SqlSerializable) -> str:
"""Convert a Python value to a SQL literal. """Convert a Python value to a SQL literal.
Args: Args:
@ -305,7 +315,7 @@ def _sql_literal(value: object) -> str:
return "'" + str(value).replace("'", "''") + "'" return "'" + str(value).replace("'", "''") + "'"
def _json_default(value: object) -> str: def _json_default(value: bytes | SupportsStr) -> str:
"""Convert non-serializable values to JSON-compatible strings. """Convert non-serializable values to JSON-compatible strings.
Args: Args:

View file

@ -583,16 +583,31 @@ class Command(BaseCommand):
Returns: Returns:
Organization instance. Organization instance.
""" """
org_obj, created = Organization.objects.get_or_create( cache: dict[str, Organization] = getattr(self, "_org_cache", {})
if not hasattr(self, "_org_cache"):
self._org_cache = cache
cached_org: Organization | None = cache.get(org_data.twitch_id)
if cached_org is not None:
self._save_if_changed(cached_org, {"name": org_data.name})
return cached_org
org_obj: Organization | None = Organization.objects.filter(
twitch_id=org_data.twitch_id, twitch_id=org_data.twitch_id,
defaults={"name": org_data.name}, ).first()
) _created: bool = org_obj is None
if not created: if org_obj is None:
self._save_if_changed(org_obj, {"name": org_data.name}) org_obj = Organization.objects.create(
else: twitch_id=org_data.twitch_id,
name=org_data.name,
)
tqdm.write( tqdm.write(
f"{Fore.GREEN}{Style.RESET_ALL} Created new organization: {org_data.name}", f"{Fore.GREEN}{Style.RESET_ALL} Created new organization: {org_data.name}",
) )
else:
self._save_if_changed(org_obj, {"name": org_data.name})
cache[org_data.twitch_id] = org_obj
return org_obj return org_obj
@ -621,6 +636,10 @@ class Command(BaseCommand):
if campaign_org_obj: if campaign_org_obj:
owner_orgs.add(campaign_org_obj) owner_orgs.add(campaign_org_obj)
cache: dict[str, Game] = getattr(self, "_game_cache", {})
if not hasattr(self, "_game_cache"):
self._game_cache = cache
defaults: dict[str, object] = { defaults: dict[str, object] = {
"display_name": game_data.display_name or (game_data.name or ""), "display_name": game_data.display_name or (game_data.name or ""),
"name": game_data.name or "", "name": game_data.name or "",
@ -628,10 +647,22 @@ class Command(BaseCommand):
"box_art": game_data.box_art_url or "", "box_art": game_data.box_art_url or "",
} }
game_obj, created = Game.objects.get_or_create( cached_game: Game | None = cache.get(game_data.twitch_id)
if cached_game is not None:
if owner_orgs:
cached_game.owners.add(*owner_orgs)
self._save_if_changed(cached_game, defaults)
return cached_game
game_obj: Game | None = Game.objects.filter(
twitch_id=game_data.twitch_id, twitch_id=game_data.twitch_id,
defaults=defaults, ).first()
) created: bool = game_obj is None
if game_obj is None:
game_obj = Game.objects.create(
twitch_id=game_data.twitch_id,
**defaults,
)
# Set owners (ManyToMany) # Set owners (ManyToMany)
if created or owner_orgs: if created or owner_orgs:
game_obj.owners.add(*owner_orgs) game_obj.owners.add(*owner_orgs)
@ -642,6 +673,7 @@ class Command(BaseCommand):
f"{Fore.GREEN}{Style.RESET_ALL} Created new game: {game_data.display_name}", f"{Fore.GREEN}{Style.RESET_ALL} Created new game: {game_data.display_name}",
) )
self._download_game_box_art(game_obj, game_obj.box_art) self._download_game_box_art(game_obj, game_obj.box_art)
cache[game_data.twitch_id] = game_obj
return game_obj return game_obj
def _download_game_box_art(self, game_obj: Game, box_art_url: str | None) -> None: def _download_game_box_art(self, game_obj: Game, box_art_url: str | None) -> None:
@ -701,7 +733,7 @@ class Command(BaseCommand):
return channel_obj return channel_obj
def process_responses( def process_responses( # noqa: PLR0915
self, self,
responses: list[dict[str, Any]], responses: list[dict[str, Any]],
file_path: Path, file_path: Path,
@ -792,13 +824,18 @@ class Command(BaseCommand):
"account_link_url": drop_campaign.account_link_url, "account_link_url": drop_campaign.account_link_url,
} }
campaign_obj, created = DropCampaign.objects.get_or_create( campaign_obj: DropCampaign | None = DropCampaign.objects.filter(
twitch_id=drop_campaign.twitch_id, twitch_id=drop_campaign.twitch_id,
defaults=defaults, ).first()
) created: bool = campaign_obj is None
if not created: if campaign_obj is None:
self._save_if_changed(campaign_obj, defaults) campaign_obj = DropCampaign.objects.create(
twitch_id=drop_campaign.twitch_id,
**defaults,
)
else: else:
self._save_if_changed(campaign_obj, defaults)
if created:
tqdm.write( tqdm.write(
f"{Fore.GREEN}{Style.RESET_ALL} Created new campaign: {drop_campaign.name}", f"{Fore.GREEN}{Style.RESET_ALL} Created new campaign: {drop_campaign.name}",
) )
@ -882,13 +919,18 @@ class Command(BaseCommand):
if end_at_dt is not None: if end_at_dt is not None:
drop_defaults["end_at"] = end_at_dt drop_defaults["end_at"] = end_at_dt
drop_obj, created = TimeBasedDrop.objects.get_or_create( drop_obj: TimeBasedDrop | None = TimeBasedDrop.objects.filter(
twitch_id=drop_schema.twitch_id, twitch_id=drop_schema.twitch_id,
defaults=drop_defaults, ).first()
) created: bool = drop_obj is None
if not created: if drop_obj is None:
self._save_if_changed(drop_obj, drop_defaults) drop_obj = TimeBasedDrop.objects.create(
twitch_id=drop_schema.twitch_id,
**drop_defaults,
)
else: else:
self._save_if_changed(drop_obj, drop_defaults)
if created:
tqdm.write( tqdm.write(
f"{Fore.GREEN}{Style.RESET_ALL} Created TimeBasedDrop: {drop_schema.name}", f"{Fore.GREEN}{Style.RESET_ALL} Created TimeBasedDrop: {drop_schema.name}",
) )
@ -900,6 +942,10 @@ class Command(BaseCommand):
def _get_or_update_benefit(self, benefit_schema: DropBenefitSchema) -> DropBenefit: def _get_or_update_benefit(self, benefit_schema: DropBenefitSchema) -> DropBenefit:
"""Return a DropBenefit, creating or updating as needed.""" """Return a DropBenefit, creating or updating as needed."""
cache: dict[str, DropBenefit] = getattr(self, "_benefit_cache", {})
if not hasattr(self, "_benefit_cache"):
self._benefit_cache = cache
distribution_type: str = (benefit_schema.distribution_type or "").strip() distribution_type: str = (benefit_schema.distribution_type or "").strip()
benefit_defaults: dict[str, str | int | datetime | bool | None] = { benefit_defaults: dict[str, str | int | datetime | bool | None] = {
"name": benefit_schema.name, "name": benefit_schema.name,
@ -914,10 +960,20 @@ class Command(BaseCommand):
if created_at_dt: if created_at_dt:
benefit_defaults["created_at"] = created_at_dt benefit_defaults["created_at"] = created_at_dt
benefit_obj, created = DropBenefit.objects.get_or_create( cached_benefit: DropBenefit | None = cache.get(benefit_schema.twitch_id)
if cached_benefit is not None:
self._save_if_changed(cached_benefit, benefit_defaults)
return cached_benefit
benefit_obj: DropBenefit | None = DropBenefit.objects.filter(
twitch_id=benefit_schema.twitch_id, twitch_id=benefit_schema.twitch_id,
defaults=benefit_defaults, ).first()
) created: bool = benefit_obj is None
if benefit_obj is None:
benefit_obj = DropBenefit.objects.create(
twitch_id=benefit_schema.twitch_id,
**benefit_defaults,
)
if not created: if not created:
self._save_if_changed(benefit_obj, benefit_defaults) self._save_if_changed(benefit_obj, benefit_defaults)
else: else:
@ -925,6 +981,8 @@ class Command(BaseCommand):
f"{Fore.GREEN}{Style.RESET_ALL} Created DropBenefit: {benefit_schema.name}", f"{Fore.GREEN}{Style.RESET_ALL} Created DropBenefit: {benefit_schema.name}",
) )
cache[benefit_schema.twitch_id] = benefit_obj
return benefit_obj return benefit_obj
def _process_benefit_edges( def _process_benefit_edges(
@ -946,11 +1004,17 @@ class Command(BaseCommand):
) )
defaults = {"entitlement_limit": edge_schema.entitlement_limit} defaults = {"entitlement_limit": edge_schema.entitlement_limit}
edge_obj, created = DropBenefitEdge.objects.get_or_create( edge_obj: DropBenefitEdge | None = DropBenefitEdge.objects.filter(
drop=drop_obj, drop=drop_obj,
benefit=benefit_obj, benefit=benefit_obj,
defaults=defaults, ).first()
) created: bool = edge_obj is None
if edge_obj is None:
edge_obj = DropBenefitEdge.objects.create(
drop=drop_obj,
benefit=benefit_obj,
**defaults,
)
if not created: if not created:
self._save_if_changed(edge_obj, defaults) self._save_if_changed(edge_obj, defaults)
else: else:

View file

@ -39,9 +39,13 @@ class Command(BaseCommand):
help="Re-download even if a local box art file already exists.", help="Re-download even if a local box art file already exists.",
) )
def handle(self, *_args: object, **options: object) -> None: # noqa: PLR0914, PLR0915 def handle( # noqa: PLR0914, PLR0915
self,
*_args: str,
**options: str | bool | int | None,
) -> None:
"""Download Twitch box art images for all games.""" """Download Twitch box art images for all games."""
limit_value: object | None = options.get("limit") limit_value: str | bool | int | None = options.get("limit")
limit: int | None = limit_value if isinstance(limit_value, int) else None limit: int | None = limit_value if isinstance(limit_value, int) else None
force: bool = bool(options.get("force")) force: bool = bool(options.get("force"))

View file

@ -50,10 +50,14 @@ class Command(BaseCommand):
help="Re-download even if a local image file already exists.", help="Re-download even if a local image file already exists.",
) )
def handle(self, *_args: object, **options: object) -> None: def handle(
self,
*_args: str,
**options: str | bool | int | None,
) -> None:
"""Download images for campaigns, benefits, and/or rewards.""" """Download images for campaigns, benefits, and/or rewards."""
model_choice: str = str(options.get("model", "all")) model_choice: str = str(options.get("model", "all"))
limit_value: object | None = options.get("limit") limit_value: str | bool | int | None = options.get("limit")
limit: int | None = limit_value if isinstance(limit_value, int) else None limit: int | None = limit_value if isinstance(limit_value, int) else None
force: bool = bool(options.get("force")) force: bool = bool(options.get("force"))

View file

@ -196,9 +196,12 @@ class Command(BaseCommand):
Returns: Returns:
Tuple of (ChatBadgeSet instance, created flag) Tuple of (ChatBadgeSet instance, created flag)
""" """
badge_set_obj, created = ChatBadgeSet.objects.get_or_create( badge_set_obj: ChatBadgeSet | None = ChatBadgeSet.objects.filter(
set_id=badge_set_schema.set_id, set_id=badge_set_schema.set_id,
) ).first()
created: bool = badge_set_obj is None
if badge_set_obj is None:
badge_set_obj = ChatBadgeSet.objects.create(set_id=badge_set_schema.set_id)
if created: if created:
self.stdout.write( self.stdout.write(
@ -258,11 +261,25 @@ class Command(BaseCommand):
"click_url": version_schema.click_url, "click_url": version_schema.click_url,
} }
_badge_obj, created = ChatBadge.objects.update_or_create( badge_obj: ChatBadge | None = ChatBadge.objects.filter(
badge_set=badge_set_obj, badge_set=badge_set_obj,
badge_id=version_schema.badge_id, badge_id=version_schema.badge_id,
defaults=defaults, ).first()
) created: bool = badge_obj is None
if badge_obj is None:
badge_obj = ChatBadge.objects.create(
badge_set=badge_set_obj,
badge_id=version_schema.badge_id,
**defaults,
)
else:
changed_fields: list[str] = []
for field, value in defaults.items():
if getattr(badge_obj, field) != value:
setattr(badge_obj, field, value)
changed_fields.append(field)
if changed_fields:
badge_obj.save(update_fields=changed_fields)
if created: if created:
msg: str = ( msg: str = (

View file

@ -2,6 +2,7 @@ import logging
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Any from typing import Any
from typing import cast
import auto_prefetch import auto_prefetch
from django.conf import settings from django.conf import settings
@ -531,6 +532,8 @@ class DropCampaign(auto_prefetch.Model):
"name", "name",
"image_url", "image_url",
"image_file", "image_file",
"image_width",
"image_height",
"start_at", "start_at",
"end_at", "end_at",
"allow_is_enabled", "allow_is_enabled",
@ -540,6 +543,8 @@ class DropCampaign(auto_prefetch.Model):
"game__slug", "game__slug",
"game__box_art", "game__box_art",
"game__box_art_file", "game__box_art_file",
"game__box_art_width",
"game__box_art_height",
) )
.select_related("game") .select_related("game")
.prefetch_related( .prefetch_related(
@ -577,26 +582,90 @@ class DropCampaign(auto_prefetch.Model):
""" """
campaigns_by_game: OrderedDict[str, dict[str, Any]] = OrderedDict() campaigns_by_game: OrderedDict[str, dict[str, Any]] = OrderedDict()
for campaign in campaigns: campaigns_list: list[DropCampaign] = list(campaigns)
game: Game = campaign.game game_pks: list[int] = sorted({
game_id: str = game.twitch_id cast("Any", campaign).game_id for campaign in campaigns_list
})
games: models.QuerySet[Game, Game] = (
Game.objects
.filter(pk__in=game_pks)
.only(
"pk",
"twitch_id",
"display_name",
"slug",
"box_art",
"box_art_file",
"box_art_width",
"box_art_height",
)
.prefetch_related(
models.Prefetch(
"owners",
queryset=Organization.objects.only("twitch_id", "name"),
),
)
)
games_by_pk: dict[int, Game] = {game.pk: game for game in games}
def _clean_name(campaign_name: str, game_display_name: str) -> str:
if not game_display_name:
return campaign_name
game_variations: list[str] = [game_display_name]
if "&" in game_display_name:
game_variations.append(game_display_name.replace("&", "and"))
if "and" in game_display_name:
game_variations.append(game_display_name.replace("and", "&"))
for game_name in game_variations:
for separator in [" - ", " | ", " "]:
prefix_to_check: str = game_name + separator
if campaign_name.startswith(prefix_to_check):
return campaign_name.removeprefix(prefix_to_check).strip()
return campaign_name
for campaign in campaigns_list:
game_pk: int = cast("Any", campaign).game_id
game: Game | None = games_by_pk.get(game_pk)
game_id: str = game.twitch_id if game else ""
game_display_name: str = game.display_name if game else ""
if game_id not in campaigns_by_game: if game_id not in campaigns_by_game:
campaigns_by_game[game_id] = { campaigns_by_game[game_id] = {
"name": game.display_name, "name": game_display_name,
"box_art": game.box_art_best_url, "box_art": game.box_art_best_url if game else "",
"owners": list(game.owners.all()), "owners": list(game.owners.all()) if game else [],
"campaigns": [], "campaigns": [],
} }
campaigns_by_game[game_id]["campaigns"].append({ campaigns_by_game[game_id]["campaigns"].append({
"campaign": campaign, "campaign": campaign,
"clean_name": _clean_name(campaign.name, game_display_name),
"image_url": campaign.listing_image_url, "image_url": campaign.listing_image_url,
"allowed_channels": getattr(campaign, "channels_ordered", []), "allowed_channels": getattr(campaign, "channels_ordered", []),
"game_display_name": game_display_name,
"game_twitch_directory_url": game.twitch_directory_url if game else "",
}) })
return campaigns_by_game return campaigns_by_game
@classmethod
def campaigns_by_game_for_dashboard(
cls,
now: datetime.datetime,
) -> OrderedDict[str, dict[str, Any]]:
"""Return active campaigns grouped by game for dashboard rendering.
Args:
now: Current timestamp used to evaluate active campaigns.
Returns:
Ordered mapping keyed by game twitch_id.
"""
return cls.grouped_by_game(cls.active_for_dashboard(now))
@property @property
def is_active(self) -> bool: def is_active(self) -> bool:
"""Check if the campaign is currently active.""" """Check if the campaign is currently active."""

View file

@ -7,6 +7,7 @@ from typing import Any
from typing import Literal from typing import Literal
import pytest import pytest
from django.core.files.base import ContentFile
from django.core.handlers.wsgi import WSGIRequest from django.core.handlers.wsgi import WSGIRequest
from django.core.paginator import Paginator from django.core.paginator import Paginator
from django.db import connection from django.db import connection
@ -40,12 +41,13 @@ if TYPE_CHECKING:
from django.test import Client from django.test import Client
from django.test.client import _MonkeyPatchedWSGIResponse from django.test.client import _MonkeyPatchedWSGIResponse
from django.test.utils import ContextList from django.test.utils import ContextList
from pytest_django.fixtures import SettingsWrapper
from twitch.views import Page from twitch.views import Page
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def apply_base_url_override(settings: object) -> None: def apply_base_url_override(settings: SettingsWrapper) -> None:
"""Ensure BASE_URL is globally overridden for all tests.""" """Ensure BASE_URL is globally overridden for all tests."""
settings.BASE_URL = "https://ttvdrops.lovinator.space" # pyright: ignore[reportAttributeAccessIssue] settings.BASE_URL = "https://ttvdrops.lovinator.space" # pyright: ignore[reportAttributeAccessIssue]
@ -492,10 +494,10 @@ class TestChannelListView:
@pytest.mark.django_db @pytest.mark.django_db
def test_dashboard_view(self, client: Client) -> None: def test_dashboard_view(self, client: Client) -> None:
"""Test dashboard view returns 200 and has active_campaigns in context.""" """Test dashboard view returns 200 and has grouped campaign data in context."""
response: _MonkeyPatchedWSGIResponse = client.get(reverse("twitch:dashboard")) response: _MonkeyPatchedWSGIResponse = client.get(reverse("twitch:dashboard"))
assert response.status_code == 200 assert response.status_code == 200
assert "active_campaigns" in response.context assert "campaigns_by_game" in response.context
@pytest.mark.django_db @pytest.mark.django_db
def test_dashboard_dedupes_campaigns_for_multi_owner_game( def test_dashboard_dedupes_campaigns_for_multi_owner_game(
@ -622,10 +624,7 @@ class TestChannelListView:
now, now,
) )
active_reward_campaigns_qs: QuerySet[RewardCampaign] = ( active_reward_campaigns_qs: QuerySet[RewardCampaign] = (
RewardCampaign.objects RewardCampaign.active_for_dashboard(now)
.filter(starts_at__lte=now, ends_at__gte=now)
.select_related("game")
.order_by("-starts_at")
) )
campaigns_plan: str = active_campaigns_qs.explain() campaigns_plan: str = active_campaigns_qs.explain()
@ -759,6 +758,291 @@ class TestChannelListView:
f"baseline={baseline_select_count}, scaled={scaled_select_count}" f"baseline={baseline_select_count}, scaled={scaled_select_count}"
) )
@pytest.mark.django_db
def test_dashboard_avoids_n_plus_one_game_queries_in_drop_loop(
self,
client: Client,
) -> None:
"""Dashboard should not issue per-campaign Game SELECTs while rendering drops."""
now: datetime.datetime = timezone.now()
org: Organization = Organization.objects.create(
twitch_id="org_no_n_plus_one_game",
name="Org No N+1 Game",
)
game: Game = Game.objects.create(
twitch_id="game_no_n_plus_one_game",
name="game_no_n_plus_one_game",
display_name="Game No N+1 Game",
)
game.owners.add(org)
campaigns: list[DropCampaign] = [
DropCampaign(
twitch_id=f"no_n_plus_one_campaign_{i}",
name=f"No N+1 campaign {i}",
game=game,
operation_names=["DropCampaignDetails"],
start_at=now - timedelta(hours=2),
end_at=now + timedelta(hours=2),
)
for i in range(10)
]
DropCampaign.objects.bulk_create(campaigns)
with CaptureQueriesContext(connection) as queries:
response: _MonkeyPatchedWSGIResponse = client.get(
reverse("twitch:dashboard"),
)
assert response.status_code == 200
context: ContextList | dict[str, Any] = response.context # type: ignore[assignment]
if isinstance(context, list):
context = context[-1]
grouped_campaigns: list[dict[str, Any]] = context["campaigns_by_game"][
game.twitch_id
]["campaigns"]
assert grouped_campaigns
assert all(
"game_display_name" in campaign_data for campaign_data in grouped_campaigns
)
assert all(
"game_twitch_directory_url" in campaign_data
for campaign_data in grouped_campaigns
)
game_select_queries: list[str] = [
query_info["sql"]
for query_info in queries.captured_queries
if query_info["sql"].lstrip().upper().startswith("SELECT")
and "twitch_game" in query_info["sql"].lower()
and "join" not in query_info["sql"].lower()
]
assert len(game_select_queries) <= 1, (
"Expected at most one standalone Game SELECT for dashboard drop grouping; "
f"got {len(game_select_queries)}. Queries: {game_select_queries}"
)
@pytest.mark.django_db
def test_dashboard_avoids_n_plus_one_game_queries_with_multiple_games(
self,
client: Client,
) -> None:
"""Dashboard should keep standalone Game SELECTs bounded with many campaigns and games."""
now: datetime.datetime = timezone.now()
game_ids: list[str] = []
for i in range(5):
org: Organization = Organization.objects.create(
twitch_id=f"org_multi_game_{i}",
name=f"Org Multi Game {i}",
)
game: Game = Game.objects.create(
twitch_id=f"game_multi_game_{i}",
name=f"game_multi_game_{i}",
display_name=f"Game Multi Game {i}",
)
game.owners.add(org)
game_ids.append(game.twitch_id)
campaigns: list[DropCampaign] = [
DropCampaign(
twitch_id=f"multi_game_campaign_{i}_{j}",
name=f"Multi game campaign {i}-{j}",
game=game,
operation_names=["DropCampaignDetails"],
start_at=now - timedelta(hours=2),
end_at=now + timedelta(hours=2),
)
for j in range(20)
]
DropCampaign.objects.bulk_create(campaigns)
with CaptureQueriesContext(connection) as queries:
response: _MonkeyPatchedWSGIResponse = client.get(
reverse("twitch:dashboard"),
)
assert response.status_code == 200
context: ContextList | dict[str, Any] = response.context # type: ignore[assignment]
if isinstance(context, list):
context = context[-1]
campaigns_by_game: dict[str, Any] = context["campaigns_by_game"]
for game_id in game_ids:
assert game_id in campaigns_by_game
grouped_campaigns: list[dict[str, Any]] = campaigns_by_game[game_id][
"campaigns"
]
assert len(grouped_campaigns) == 20
assert all(
"game_display_name" in campaign_data
for campaign_data in grouped_campaigns
)
assert all(
"game_twitch_directory_url" in campaign_data
for campaign_data in grouped_campaigns
)
game_select_queries: list[str] = [
query_info["sql"]
for query_info in queries.captured_queries
if query_info["sql"].lstrip().upper().startswith("SELECT")
and "twitch_game" in query_info["sql"].lower()
and "join" not in query_info["sql"].lower()
]
assert len(game_select_queries) <= 1, (
"Expected a bounded number of standalone Game SELECTs for dashboard grouping; "
f"got {len(game_select_queries)}. Queries: {game_select_queries}"
)
@pytest.mark.django_db
def test_dashboard_does_not_refresh_dropcampaign_rows_for_image_dimensions(
self,
client: Client,
) -> None:
"""Dashboard should not issue per-row DropCampaign refreshes for image dimensions."""
now: datetime.datetime = timezone.now()
org: Organization = Organization.objects.create(
twitch_id="org_image_dimensions",
name="Org Image Dimensions",
)
game: Game = Game.objects.create(
twitch_id="game_image_dimensions",
name="game_image_dimensions",
display_name="Game Image Dimensions",
)
game.owners.add(org)
# 1x1 transparent PNG
png_1x1: bytes = (
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01"
b"\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89"
b"\x00\x00\x00\x0bIDATx\x9cc\x00\x01\x00\x00\x05\x00\x01"
b"\r\n-\xb4\x00\x00\x00\x00IEND\xaeB`\x82"
)
campaigns: list[DropCampaign] = []
for i in range(3):
campaign: DropCampaign = DropCampaign.objects.create(
twitch_id=f"image_dim_campaign_{i}",
name=f"Image dim campaign {i}",
game=game,
operation_names=["DropCampaignDetails"],
start_at=now - timedelta(hours=2),
end_at=now + timedelta(hours=2),
)
assert campaign.image_file is not None
campaign.image_file.save(
f"image_dim_campaign_{i}.png",
ContentFile(png_1x1),
save=True,
)
campaigns.append(campaign)
with CaptureQueriesContext(connection) as queries:
response: _MonkeyPatchedWSGIResponse = client.get(
reverse("twitch:dashboard"),
)
assert response.status_code == 200
context: ContextList | dict[str, Any] = response.context # type: ignore[assignment]
if isinstance(context, list):
context = context[-1]
grouped_campaigns: list[dict[str, Any]] = context["campaigns_by_game"][
game.twitch_id
]["campaigns"]
assert len(grouped_campaigns) == len(campaigns)
per_row_refresh_queries: list[str] = [
query_info["sql"]
for query_info in queries.captured_queries
if query_info["sql"].lstrip().upper().startswith("SELECT")
and 'from "twitch_dropcampaign"' in query_info["sql"].lower()
and 'where "twitch_dropcampaign"."id" =' in query_info["sql"].lower()
]
assert not per_row_refresh_queries, (
"Dashboard unexpectedly refreshed DropCampaign rows one-by-one while "
"resolving image dimensions. Queries: "
f"{per_row_refresh_queries}"
)
@pytest.mark.django_db
def test_dashboard_does_not_refresh_game_rows_for_box_art_dimensions(
self,
client: Client,
) -> None:
"""Dashboard should not issue per-row Game refreshes for box art dimensions."""
now: datetime.datetime = timezone.now()
org: Organization = Organization.objects.create(
twitch_id="org_box_art_dimensions",
name="Org Box Art Dimensions",
)
game: Game = Game.objects.create(
twitch_id="game_box_art_dimensions",
name="game_box_art_dimensions",
display_name="Game Box Art Dimensions",
)
game.owners.add(org)
# 1x1 transparent PNG
png_1x1: bytes = (
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01"
b"\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89"
b"\x00\x00\x00\x0bIDATx\x9cc\x00\x01\x00\x00\x05\x00\x01"
b"\r\n-\xb4\x00\x00\x00\x00IEND\xaeB`\x82"
)
assert game.box_art_file is not None
game.box_art_file.save(
"game_box_art_dimensions.png",
ContentFile(png_1x1),
save=True,
)
DropCampaign.objects.create(
twitch_id="game_box_art_campaign",
name="Game box art campaign",
game=game,
operation_names=["DropCampaignDetails"],
start_at=now - timedelta(hours=2),
end_at=now + timedelta(hours=2),
)
with CaptureQueriesContext(connection) as queries:
response: _MonkeyPatchedWSGIResponse = client.get(
reverse("twitch:dashboard"),
)
assert response.status_code == 200
context: ContextList | dict[str, Any] = response.context # type: ignore[assignment]
if isinstance(context, list):
context = context[-1]
campaigns_by_game: dict[str, Any] = context["campaigns_by_game"]
assert game.twitch_id in campaigns_by_game
per_row_refresh_queries: list[str] = [
query_info["sql"]
for query_info in queries.captured_queries
if query_info["sql"].lstrip().upper().startswith("SELECT")
and 'from "twitch_game"' in query_info["sql"].lower()
and 'where "twitch_game"."id" =' in query_info["sql"].lower()
]
assert not per_row_refresh_queries, (
"Dashboard unexpectedly refreshed Game rows one-by-one while resolving "
"box art dimensions. Queries: "
f"{per_row_refresh_queries}"
)
@pytest.mark.django_db @pytest.mark.django_db
def test_debug_view(self, client: Client) -> None: def test_debug_view(self, client: Client) -> None:
"""Test debug view returns 200 and has games_without_owner in context.""" """Test debug view returns 200 and has games_without_owner in context."""
@ -1079,7 +1363,7 @@ class TestChannelListView:
assert "page=2" in content assert "page=2" in content
@pytest.mark.django_db @pytest.mark.django_db
def test_drop_campaign_detail_view(self, client: Client, db: object) -> None: def test_drop_campaign_detail_view(self, client: Client, db: None) -> None:
"""Test campaign detail view returns 200 and has campaign in context.""" """Test campaign detail view returns 200 and has campaign in context."""
game: Game = Game.objects.create( game: Game = Game.objects.create(
twitch_id="g1", twitch_id="g1",
@ -1164,7 +1448,7 @@ class TestChannelListView:
assert "games" in response.context assert "games" in response.context
@pytest.mark.django_db @pytest.mark.django_db
def test_game_detail_view(self, client: Client, db: object) -> None: def test_game_detail_view(self, client: Client, db: None) -> None:
"""Test game detail view returns 200 and has game in context.""" """Test game detail view returns 200 and has game in context."""
game: Game = Game.objects.create( game: Game = Game.objects.create(
twitch_id="g2", twitch_id="g2",
@ -1177,7 +1461,7 @@ class TestChannelListView:
assert "game" in response.context assert "game" in response.context
@pytest.mark.django_db @pytest.mark.django_db
def test_game_detail_image_aspect_ratio(self, client: Client, db: object) -> None: def test_game_detail_image_aspect_ratio(self, client: Client, db: None) -> None:
"""Box art should render with a width attribute only, preserving aspect ratio.""" """Box art should render with a width attribute only, preserving aspect ratio."""
game: Game = Game.objects.create( game: Game = Game.objects.create(
twitch_id="g3", twitch_id="g3",
@ -1232,7 +1516,7 @@ class TestChannelListView:
assert "orgs" in response.context assert "orgs" in response.context
@pytest.mark.django_db @pytest.mark.django_db
def test_organization_detail_view(self, client: Client, db: object) -> None: def test_organization_detail_view(self, client: Client, db: None) -> None:
"""Test organization detail view returns 200 and has organization in context.""" """Test organization detail view returns 200 and has organization in context."""
org: Organization = Organization.objects.create(twitch_id="o1", name="Org1") org: Organization = Organization.objects.create(twitch_id="o1", name="Org1")
url: str = reverse("twitch:organization_detail", args=[org.twitch_id]) url: str = reverse("twitch:organization_detail", args=[org.twitch_id])
@ -1241,7 +1525,7 @@ class TestChannelListView:
assert "organization" in response.context assert "organization" in response.context
@pytest.mark.django_db @pytest.mark.django_db
def test_channel_detail_view(self, client: Client, db: object) -> None: def test_channel_detail_view(self, client: Client, db: None) -> None:
"""Test channel detail view returns 200 and has channel in context.""" """Test channel detail view returns 200 and has channel in context."""
channel: Channel = Channel.objects.create( channel: Channel = Channel.objects.create(
twitch_id="ch1", twitch_id="ch1",

View file

@ -875,7 +875,7 @@ class GameDetailView(DetailView):
return game return game
def get_context_data(self, **kwargs: object) -> dict[str, Any]: # noqa: PLR0914 def get_context_data(self, **kwargs) -> dict[str, Any]: # noqa: PLR0914
"""Add additional context data. """Add additional context data.
Args: Args:
@ -1071,9 +1071,8 @@ def dashboard(request: HttpRequest) -> HttpResponse:
HttpResponse: The rendered dashboard template. HttpResponse: The rendered dashboard template.
""" """
now: datetime.datetime = timezone.now() now: datetime.datetime = timezone.now()
active_campaigns: QuerySet[DropCampaign] = DropCampaign.active_for_dashboard(now) campaigns_by_game: OrderedDict[str, dict[str, Any]] = (
campaigns_by_game: OrderedDict[str, dict[str, Any]] = DropCampaign.grouped_by_game( DropCampaign.campaigns_by_game_for_dashboard(now)
active_campaigns,
) )
# Get active reward campaigns (Quest rewards) # Get active reward campaigns (Quest rewards)
@ -1112,7 +1111,6 @@ def dashboard(request: HttpRequest) -> HttpResponse:
request, request,
"twitch/dashboard.html", "twitch/dashboard.html",
{ {
"active_campaigns": active_campaigns,
"campaigns_by_game": campaigns_by_game, "campaigns_by_game": campaigns_by_game,
"active_reward_campaigns": active_reward_campaigns, "active_reward_campaigns": active_reward_campaigns,
"now": now, "now": now,
@ -1441,7 +1439,7 @@ class ChannelDetailView(DetailView):
return channel return channel
def get_context_data(self, **kwargs: object) -> dict[str, Any]: # noqa: PLR0914 def get_context_data(self, **kwargs) -> dict[str, Any]: # noqa: PLR0914
"""Add additional context data. """Add additional context data.
Args: Args: