Refactor reader dependency injection in FastAPI routes and tests

This commit is contained in:
Joakim Hellsén 2026-03-15 15:39:05 +01:00
commit 727057439e
Signed by: Joakim Hellsén
SSH key fingerprint: SHA256:/9h/CsExpFp+PRhsfA0xznFx2CGfTT5R/kpuFfUgEQk
3 changed files with 193 additions and 38 deletions

View file

@ -19,6 +19,7 @@ import httpx
import sentry_sdk import sentry_sdk
import uvicorn import uvicorn
from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.asyncio import AsyncIOScheduler
from fastapi import Depends
from fastapi import FastAPI from fastapi import FastAPI
from fastapi import Form from fastapi import Form
from fastapi import HTTPException from fastapi import HTTPException
@ -100,7 +101,16 @@ LOGGING_CONFIG: dict[str, Any] = {
logging.config.dictConfig(LOGGING_CONFIG) logging.config.dictConfig(LOGGING_CONFIG)
logger: logging.Logger = logging.getLogger(__name__) logger: logging.Logger = logging.getLogger(__name__)
reader: Reader = get_reader()
def get_reader_dependency() -> Reader:
"""Provide the app Reader instance as a FastAPI dependency.
Returns:
Reader: The shared Reader instance.
"""
return get_reader()
# Time constants for relative time formatting # Time constants for relative time formatting
SECONDS_PER_MINUTE = 60 SECONDS_PER_MINUTE = 60
@ -146,6 +156,7 @@ def relative_time(dt: datetime | None) -> str:
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None]: async def lifespan(app: FastAPI) -> AsyncGenerator[None]:
"""Lifespan function for the FastAPI app.""" """Lifespan function for the FastAPI app."""
reader: Reader = get_reader()
add_missing_tags(reader) add_missing_tags(reader)
scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone=UTC) scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone=UTC)
scheduler.add_job( scheduler.add_job(
@ -181,12 +192,14 @@ templates.env.globals["get_backup_path"] = get_backup_path
async def post_add_webhook( async def post_add_webhook(
webhook_name: Annotated[str, Form()], webhook_name: Annotated[str, Form()],
webhook_url: Annotated[str, Form()], webhook_url: Annotated[str, Form()],
reader: Annotated[Reader, Depends(get_reader_dependency)],
) -> RedirectResponse: ) -> RedirectResponse:
"""Add a feed to the database. """Add a feed to the database.
Args: Args:
webhook_name: The name of the webhook. webhook_name: The name of the webhook.
webhook_url: The url of the webhook. webhook_url: The url of the webhook.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the index page. RedirectResponse: Redirect to the index page.
@ -219,11 +232,15 @@ async def post_add_webhook(
@app.post("/delete_webhook") @app.post("/delete_webhook")
async def post_delete_webhook(webhook_url: Annotated[str, Form()]) -> RedirectResponse: async def post_delete_webhook(
webhook_url: Annotated[str, Form()],
reader: Annotated[Reader, Depends(get_reader_dependency)],
) -> RedirectResponse:
"""Delete a webhook from the database. """Delete a webhook from the database.
Args: Args:
webhook_url: The url of the webhook. webhook_url: The url of the webhook.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the index page. RedirectResponse: Redirect to the index page.
@ -266,12 +283,14 @@ async def post_delete_webhook(webhook_url: Annotated[str, Form()]) -> RedirectRe
async def post_create_feed( async def post_create_feed(
feed_url: Annotated[str, Form()], feed_url: Annotated[str, Form()],
webhook_dropdown: Annotated[str, Form()], webhook_dropdown: Annotated[str, Form()],
reader: Annotated[Reader, Depends(get_reader_dependency)],
) -> RedirectResponse: ) -> RedirectResponse:
"""Add a feed to the database. """Add a feed to the database.
Args: Args:
feed_url: The feed to add. feed_url: The feed to add.
webhook_dropdown: The webhook to use. webhook_dropdown: The webhook to use.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the feed page. RedirectResponse: Redirect to the feed page.
@ -283,11 +302,15 @@ async def post_create_feed(
@app.post("/pause") @app.post("/pause")
async def post_pause_feed(feed_url: Annotated[str, Form()]) -> RedirectResponse: async def post_pause_feed(
feed_url: Annotated[str, Form()],
reader: Annotated[Reader, Depends(get_reader_dependency)],
) -> RedirectResponse:
"""Pause a feed. """Pause a feed.
Args: Args:
feed_url: The feed to pause. feed_url: The feed to pause.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the feed page. RedirectResponse: Redirect to the feed page.
@ -298,11 +321,15 @@ async def post_pause_feed(feed_url: Annotated[str, Form()]) -> RedirectResponse:
@app.post("/unpause") @app.post("/unpause")
async def post_unpause_feed(feed_url: Annotated[str, Form()]) -> RedirectResponse: async def post_unpause_feed(
feed_url: Annotated[str, Form()],
reader: Annotated[Reader, Depends(get_reader_dependency)],
) -> RedirectResponse:
"""Unpause a feed. """Unpause a feed.
Args: Args:
feed_url: The Feed to unpause. feed_url: The Feed to unpause.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the feed page. RedirectResponse: Redirect to the feed page.
@ -314,6 +341,7 @@ async def post_unpause_feed(feed_url: Annotated[str, Form()]) -> RedirectRespons
@app.post("/whitelist") @app.post("/whitelist")
async def post_set_whitelist( async def post_set_whitelist(
reader: Annotated[Reader, Depends(get_reader_dependency)],
whitelist_title: Annotated[str, Form()] = "", whitelist_title: Annotated[str, Form()] = "",
whitelist_summary: Annotated[str, Form()] = "", whitelist_summary: Annotated[str, Form()] = "",
whitelist_content: Annotated[str, Form()] = "", whitelist_content: Annotated[str, Form()] = "",
@ -336,6 +364,7 @@ async def post_set_whitelist(
regex_whitelist_content: Whitelisted regex for when checking the content. regex_whitelist_content: Whitelisted regex for when checking the content.
regex_whitelist_author: Whitelisted regex for when checking the author. regex_whitelist_author: Whitelisted regex for when checking the author.
feed_url: The feed we should set the whitelist for. feed_url: The feed we should set the whitelist for.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the feed page. RedirectResponse: Redirect to the feed page.
@ -356,12 +385,17 @@ async def post_set_whitelist(
@app.get("/whitelist", response_class=HTMLResponse) @app.get("/whitelist", response_class=HTMLResponse)
async def get_whitelist(feed_url: str, request: Request): async def get_whitelist(
feed_url: str,
request: Request,
reader: Annotated[Reader, Depends(get_reader_dependency)],
):
"""Get the whitelist. """Get the whitelist.
Args: Args:
feed_url: What feed we should get the whitelist for. feed_url: What feed we should get the whitelist for.
request: The request object. request: The request object.
reader: The Reader instance.
Returns: Returns:
HTMLResponse: The whitelist page. HTMLResponse: The whitelist page.
@ -395,6 +429,7 @@ async def get_whitelist(feed_url: str, request: Request):
@app.post("/blacklist") @app.post("/blacklist")
async def post_set_blacklist( async def post_set_blacklist(
reader: Annotated[Reader, Depends(get_reader_dependency)],
blacklist_title: Annotated[str, Form()] = "", blacklist_title: Annotated[str, Form()] = "",
blacklist_summary: Annotated[str, Form()] = "", blacklist_summary: Annotated[str, Form()] = "",
blacklist_content: Annotated[str, Form()] = "", blacklist_content: Annotated[str, Form()] = "",
@ -420,6 +455,7 @@ async def post_set_blacklist(
regex_blacklist_content: Blacklisted regex for when checking the content. regex_blacklist_content: Blacklisted regex for when checking the content.
regex_blacklist_author: Blacklisted regex for when checking the author. regex_blacklist_author: Blacklisted regex for when checking the author.
feed_url: What feed we should set the blacklist for. feed_url: What feed we should set the blacklist for.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the feed page. RedirectResponse: Redirect to the feed page.
@ -438,12 +474,17 @@ async def post_set_blacklist(
@app.get("/blacklist", response_class=HTMLResponse) @app.get("/blacklist", response_class=HTMLResponse)
async def get_blacklist(feed_url: str, request: Request): async def get_blacklist(
feed_url: str,
request: Request,
reader: Annotated[Reader, Depends(get_reader_dependency)],
):
"""Get the blacklist. """Get the blacklist.
Args: Args:
feed_url: What feed we should get the blacklist for. feed_url: What feed we should get the blacklist for.
request: The request object. request: The request object.
reader: The Reader instance.
Returns: Returns:
HTMLResponse: The blacklist page. HTMLResponse: The blacklist page.
@ -477,6 +518,7 @@ async def get_blacklist(feed_url: str, request: Request):
@app.post("/custom") @app.post("/custom")
async def post_set_custom( async def post_set_custom(
feed_url: Annotated[str, Form()], feed_url: Annotated[str, Form()],
reader: Annotated[Reader, Depends(get_reader_dependency)],
custom_message: Annotated[str, Form()] = "", custom_message: Annotated[str, Form()] = "",
) -> RedirectResponse: ) -> RedirectResponse:
"""Set the custom message, this is used when sending the message. """Set the custom message, this is used when sending the message.
@ -484,6 +526,7 @@ async def post_set_custom(
Args: Args:
custom_message: The custom message. custom_message: The custom message.
feed_url: The feed we should set the custom message for. feed_url: The feed we should set the custom message for.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the feed page. RedirectResponse: Redirect to the feed page.
@ -505,12 +548,17 @@ async def post_set_custom(
@app.get("/custom", response_class=HTMLResponse) @app.get("/custom", response_class=HTMLResponse)
async def get_custom(feed_url: str, request: Request): async def get_custom(
feed_url: str,
request: Request,
reader: Annotated[Reader, Depends(get_reader_dependency)],
):
"""Get the custom message. This is used when sending the message to Discord. """Get the custom message. This is used when sending the message to Discord.
Args: Args:
feed_url: What feed we should get the custom message for. feed_url: What feed we should get the custom message for.
request: The request object. request: The request object.
reader: The Reader instance.
Returns: Returns:
HTMLResponse: The custom message page. HTMLResponse: The custom message page.
@ -531,12 +579,17 @@ async def get_custom(feed_url: str, request: Request):
@app.get("/embed", response_class=HTMLResponse) @app.get("/embed", response_class=HTMLResponse)
async def get_embed_page(feed_url: str, request: Request): async def get_embed_page(
feed_url: str,
request: Request,
reader: Annotated[Reader, Depends(get_reader_dependency)],
):
"""Get the custom message. This is used when sending the message to Discord. """Get the custom message. This is used when sending the message to Discord.
Args: Args:
feed_url: What feed we should get the custom message for. feed_url: What feed we should get the custom message for.
request: The request object. request: The request object.
reader: The Reader instance.
Returns: Returns:
HTMLResponse: The embed page. HTMLResponse: The embed page.
@ -572,6 +625,7 @@ async def get_embed_page(feed_url: str, request: Request):
@app.post("/embed", response_class=HTMLResponse) @app.post("/embed", response_class=HTMLResponse)
async def post_embed( async def post_embed(
feed_url: Annotated[str, Form()], feed_url: Annotated[str, Form()],
reader: Annotated[Reader, Depends(get_reader_dependency)],
title: Annotated[str, Form()] = "", title: Annotated[str, Form()] = "",
description: Annotated[str, Form()] = "", description: Annotated[str, Form()] = "",
color: Annotated[str, Form()] = "", color: Annotated[str, Form()] = "",
@ -597,7 +651,7 @@ async def post_embed(
author_icon_url: The author icon url of the embed. author_icon_url: The author icon url of the embed.
footer_text: The footer text of the embed. footer_text: The footer text of the embed.
footer_icon_url: The footer icon url of the embed. footer_icon_url: The footer icon url of the embed.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the embed page. RedirectResponse: Redirect to the embed page.
@ -625,11 +679,15 @@ async def post_embed(
@app.post("/use_embed") @app.post("/use_embed")
async def post_use_embed(feed_url: Annotated[str, Form()]) -> RedirectResponse: async def post_use_embed(
feed_url: Annotated[str, Form()],
reader: Annotated[Reader, Depends(get_reader_dependency)],
) -> RedirectResponse:
"""Use embed instead of text. """Use embed instead of text.
Args: Args:
feed_url: The feed to change. feed_url: The feed to change.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the feed page. RedirectResponse: Redirect to the feed page.
@ -641,11 +699,15 @@ async def post_use_embed(feed_url: Annotated[str, Form()]) -> RedirectResponse:
@app.post("/use_text") @app.post("/use_text")
async def post_use_text(feed_url: Annotated[str, Form()]) -> RedirectResponse: async def post_use_text(
feed_url: Annotated[str, Form()],
reader: Annotated[Reader, Depends(get_reader_dependency)],
) -> RedirectResponse:
"""Use text instead of embed. """Use text instead of embed.
Args: Args:
feed_url: The feed to change. feed_url: The feed to change.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the feed page. RedirectResponse: Redirect to the feed page.
@ -659,6 +721,7 @@ async def post_use_text(feed_url: Annotated[str, Form()]) -> RedirectResponse:
@app.post("/set_update_interval") @app.post("/set_update_interval")
async def post_set_update_interval( async def post_set_update_interval(
feed_url: Annotated[str, Form()], feed_url: Annotated[str, Form()],
reader: Annotated[Reader, Depends(get_reader_dependency)],
interval_minutes: Annotated[int | None, Form()] = None, interval_minutes: Annotated[int | None, Form()] = None,
redirect_to: Annotated[str, Form()] = "", redirect_to: Annotated[str, Form()] = "",
) -> RedirectResponse: ) -> RedirectResponse:
@ -668,6 +731,7 @@ async def post_set_update_interval(
feed_url: The feed to change. feed_url: The feed to change.
interval_minutes: The update interval in minutes (None to reset to global default). interval_minutes: The update interval in minutes (None to reset to global default).
redirect_to: Optional redirect URL (defaults to feed page). redirect_to: Optional redirect URL (defaults to feed page).
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the specified page or feed page. RedirectResponse: Redirect to the specified page or feed page.
@ -703,12 +767,14 @@ async def post_set_update_interval(
async def post_change_feed_url( async def post_change_feed_url(
old_feed_url: Annotated[str, Form()], old_feed_url: Annotated[str, Form()],
new_feed_url: Annotated[str, Form()], new_feed_url: Annotated[str, Form()],
reader: Annotated[Reader, Depends(get_reader_dependency)],
) -> RedirectResponse: ) -> RedirectResponse:
"""Change the URL for an existing feed. """Change the URL for an existing feed.
Args: Args:
old_feed_url: Current feed URL. old_feed_url: Current feed URL.
new_feed_url: New feed URL to change to. new_feed_url: New feed URL to change to.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the feed page for the resulting URL. RedirectResponse: Redirect to the feed page for the resulting URL.
@ -754,6 +820,7 @@ async def post_change_feed_url(
@app.post("/reset_update_interval") @app.post("/reset_update_interval")
async def post_reset_update_interval( async def post_reset_update_interval(
feed_url: Annotated[str, Form()], feed_url: Annotated[str, Form()],
reader: Annotated[Reader, Depends(get_reader_dependency)],
redirect_to: Annotated[str, Form()] = "", redirect_to: Annotated[str, Form()] = "",
) -> RedirectResponse: ) -> RedirectResponse:
"""Reset the update interval for a feed to use the global default. """Reset the update interval for a feed to use the global default.
@ -761,6 +828,7 @@ async def post_reset_update_interval(
Args: Args:
feed_url: The feed to change. feed_url: The feed to change.
redirect_to: Optional redirect URL (defaults to feed page). redirect_to: Optional redirect URL (defaults to feed page).
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the specified page or feed page. RedirectResponse: Redirect to the specified page or feed page.
@ -787,11 +855,15 @@ async def post_reset_update_interval(
@app.post("/set_global_update_interval") @app.post("/set_global_update_interval")
async def post_set_global_update_interval(interval_minutes: Annotated[int, Form()]) -> RedirectResponse: async def post_set_global_update_interval(
interval_minutes: Annotated[int, Form()],
reader: Annotated[Reader, Depends(get_reader_dependency)],
) -> RedirectResponse:
"""Set the global default update interval. """Set the global default update interval.
Args: Args:
interval_minutes: The update interval in minutes. interval_minutes: The update interval in minutes.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the settings page. RedirectResponse: Redirect to the settings page.
@ -805,11 +877,15 @@ async def post_set_global_update_interval(interval_minutes: Annotated[int, Form(
@app.get("/add", response_class=HTMLResponse) @app.get("/add", response_class=HTMLResponse)
def get_add(request: Request): def get_add(
request: Request,
reader: Annotated[Reader, Depends(get_reader_dependency)],
):
"""Page for adding a new feed. """Page for adding a new feed.
Args: Args:
request: The request object. request: The request object.
reader: The Reader instance.
Returns: Returns:
HTMLResponse: The add feed page. HTMLResponse: The add feed page.
@ -822,13 +898,19 @@ def get_add(request: Request):
@app.get("/feed", response_class=HTMLResponse) @app.get("/feed", response_class=HTMLResponse)
async def get_feed(feed_url: str, request: Request, starting_after: str = ""): # noqa: C901, PLR0912, PLR0914, PLR0915 async def get_feed( # noqa: C901, PLR0912, PLR0914, PLR0915
feed_url: str,
request: Request,
reader: Annotated[Reader, Depends(get_reader_dependency)],
starting_after: str = "",
):
"""Get a feed by URL. """Get a feed by URL.
Args: Args:
feed_url: The feed to add. feed_url: The feed to add.
request: The request object. request: The request object.
starting_after: The entry to start after. Used for pagination. starting_after: The entry to start after. Used for pagination.
reader: The Reader instance.
Returns: Returns:
HTMLResponse: The feed page. HTMLResponse: The feed page.
@ -1083,6 +1165,7 @@ def get_data_from_hook_url(hook_name: str, hook_url: str) -> WebhookInfo:
hook_name (str): The webhook name. hook_name (str): The webhook name.
hook_url (str): The webhook URL. hook_url (str): The webhook URL.
Returns: Returns:
WebhookInfo: The webhook username, avatar, guild id, etc. WebhookInfo: The webhook username, avatar, guild id, etc.
""" """
@ -1104,11 +1187,15 @@ def get_data_from_hook_url(hook_name: str, hook_url: str) -> WebhookInfo:
@app.get("/settings", response_class=HTMLResponse) @app.get("/settings", response_class=HTMLResponse)
async def get_settings(request: Request): async def get_settings(
request: Request,
reader: Annotated[Reader, Depends(get_reader_dependency)],
):
"""Settings page. """Settings page.
Args: Args:
request: The request object. request: The request object.
reader: The Reader instance.
Returns: Returns:
HTMLResponse: The settings page. HTMLResponse: The settings page.
@ -1154,11 +1241,15 @@ async def get_settings(request: Request):
@app.get("/webhooks", response_class=HTMLResponse) @app.get("/webhooks", response_class=HTMLResponse)
async def get_webhooks(request: Request): async def get_webhooks(
request: Request,
reader: Annotated[Reader, Depends(get_reader_dependency)],
):
"""Page for adding a new webhook. """Page for adding a new webhook.
Args: Args:
request: The request object. request: The request object.
reader: The Reader instance.
Returns: Returns:
HTMLResponse: The add webhook page. HTMLResponse: The add webhook page.
@ -1179,54 +1270,65 @@ async def get_webhooks(request: Request):
@app.get("/", response_class=HTMLResponse) @app.get("/", response_class=HTMLResponse)
def get_index(request: Request, message: str = ""): def get_index(
request: Request,
reader: Annotated[Reader, Depends(get_reader_dependency)],
message: str = "",
):
"""This is the root of the website. """This is the root of the website.
Args: Args:
request: The request object. request: The request object.
message: Optional message to display to the user. message: Optional message to display to the user.
reader: The Reader instance.
Returns: Returns:
HTMLResponse: The index page. HTMLResponse: The index page.
""" """
return templates.TemplateResponse(request=request, name="index.html", context=make_context_index(request, message)) return templates.TemplateResponse(
request=request,
name="index.html",
context=make_context_index(request, message, reader),
)
def make_context_index(request: Request, message: str = ""): def make_context_index(request: Request, message: str = "", reader: Reader | None = None):
"""Create the needed context for the index page. """Create the needed context for the index page.
Args: Args:
request: The request object. request: The request object.
message: Optional message to display to the user. message: Optional message to display to the user.
reader: The Reader instance.
Returns: Returns:
dict: The context for the index page. dict: The context for the index page.
""" """
hooks: list[dict[str, str]] = cast("list[dict[str, str]]", list(reader.get_tag((), "webhooks", []))) effective_reader: Reader = reader or get_reader_dependency()
hooks: list[dict[str, str]] = cast("list[dict[str, str]]", list(effective_reader.get_tag((), "webhooks", [])))
feed_list = [] feed_list: list[dict[str, JSONType | Feed | str]] = []
broken_feeds = [] broken_feeds: list[Feed] = []
feeds_without_attached_webhook = [] feeds_without_attached_webhook: list[Feed] = []
# Get all feeds and organize them # Get all feeds and organize them
feeds: Iterable[Feed] = reader.get_feeds() feeds: Iterable[Feed] = effective_reader.get_feeds()
for feed in feeds: for feed in feeds:
try: try:
webhook = reader.get_tag(feed.url, "webhook") webhook: JSONType = effective_reader.get_tag(feed.url, "webhook")
feed_list.append({"feed": feed, "webhook": webhook, "domain": extract_domain(feed.url)}) feed_list.append({"feed": feed, "webhook": webhook, "domain": extract_domain(feed.url)})
except TagNotFoundError: except TagNotFoundError:
broken_feeds.append(feed) broken_feeds.append(feed)
continue continue
webhook_list = [hook["url"] for hook in hooks] webhook_list: list[str] = [hook["url"] for hook in hooks]
if webhook not in webhook_list: if webhook not in webhook_list:
feeds_without_attached_webhook.append(feed) feeds_without_attached_webhook.append(feed)
return { return {
"request": request, "request": request,
"feeds": feed_list, "feeds": feed_list,
"feed_count": reader.get_feed_counts(), "feed_count": effective_reader.get_feed_counts(),
"entry_count": reader.get_entry_counts(), "entry_count": effective_reader.get_entry_counts(),
"webhooks": hooks, "webhooks": hooks,
"broken_feeds": broken_feeds, "broken_feeds": broken_feeds,
"feeds_without_attached_webhook": feeds_without_attached_webhook, "feeds_without_attached_webhook": feeds_without_attached_webhook,
@ -1235,12 +1337,15 @@ def make_context_index(request: Request, message: str = ""):
@app.post("/remove", response_class=HTMLResponse) @app.post("/remove", response_class=HTMLResponse)
async def remove_feed(feed_url: Annotated[str, Form()]): async def remove_feed(
feed_url: Annotated[str, Form()],
reader: Annotated[Reader, Depends(get_reader_dependency)],
):
"""Get a feed by URL. """Get a feed by URL.
Args: Args:
feed_url: The feed to add. feed_url: The feed to add.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the index page. RedirectResponse: Redirect to the index page.
@ -1259,13 +1364,17 @@ async def remove_feed(feed_url: Annotated[str, Form()]):
@app.get("/update", response_class=HTMLResponse) @app.get("/update", response_class=HTMLResponse)
async def update_feed(request: Request, feed_url: str): async def update_feed(
request: Request,
feed_url: str,
reader: Annotated[Reader, Depends(get_reader_dependency)],
):
"""Update a feed. """Update a feed.
Args: Args:
request: The request object. request: The request object.
feed_url: The feed URL to update. feed_url: The feed URL to update.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the feed page. RedirectResponse: Redirect to the feed page.
@ -1283,11 +1392,15 @@ async def update_feed(request: Request, feed_url: str):
@app.post("/backup") @app.post("/backup")
async def manual_backup(request: Request) -> RedirectResponse: async def manual_backup(
request: Request,
reader: Annotated[Reader, Depends(get_reader_dependency)],
) -> RedirectResponse:
"""Manually trigger a git backup of the current state. """Manually trigger a git backup of the current state.
Args: Args:
request: The request object. request: The request object.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the index page with a success or error message. RedirectResponse: Redirect to the index page with a success or error message.
@ -1310,12 +1423,17 @@ async def manual_backup(request: Request) -> RedirectResponse:
@app.get("/search", response_class=HTMLResponse) @app.get("/search", response_class=HTMLResponse)
async def search(request: Request, query: str): async def search(
request: Request,
query: str,
reader: Annotated[Reader, Depends(get_reader_dependency)],
):
"""Get entries matching a full-text search query. """Get entries matching a full-text search query.
Args: Args:
query: The query to search for. query: The query to search for.
request: The request object. request: The request object.
reader: The Reader instance.
Returns: Returns:
HTMLResponse: The search page. HTMLResponse: The search page.
@ -1326,11 +1444,15 @@ async def search(request: Request, query: str):
@app.get("/post_entry", response_class=HTMLResponse) @app.get("/post_entry", response_class=HTMLResponse)
async def post_entry(entry_id: str): async def post_entry(
entry_id: str,
reader: Annotated[Reader, Depends(get_reader_dependency)],
):
"""Send single entry to Discord. """Send single entry to Discord.
Args: Args:
entry_id: The entry to send. entry_id: The entry to send.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the feed page. RedirectResponse: Redirect to the feed page.
@ -1349,12 +1471,17 @@ async def post_entry(entry_id: str):
@app.post("/modify_webhook", response_class=HTMLResponse) @app.post("/modify_webhook", response_class=HTMLResponse)
def modify_webhook(old_hook: Annotated[str, Form()], new_hook: Annotated[str, Form()]): def modify_webhook(
old_hook: Annotated[str, Form()],
new_hook: Annotated[str, Form()],
reader: Annotated[Reader, Depends(get_reader_dependency)],
):
"""Modify a webhook. """Modify a webhook.
Args: Args:
old_hook: The webhook to modify. old_hook: The webhook to modify.
new_hook: The new webhook. new_hook: The new webhook.
reader: The Reader instance.
Returns: Returns:
RedirectResponse: Redirect to the webhook page. RedirectResponse: Redirect to the webhook page.
@ -1424,6 +1551,7 @@ def extract_youtube_video_id(url: str) -> str | None:
async def get_webhook_entries( # noqa: C901, PLR0912, PLR0914 async def get_webhook_entries( # noqa: C901, PLR0912, PLR0914
webhook_url: str, webhook_url: str,
request: Request, request: Request,
reader: Annotated[Reader, Depends(get_reader_dependency)],
starting_after: str = "", starting_after: str = "",
) -> HTMLResponse: ) -> HTMLResponse:
"""Get all latest entries from all feeds for a specific webhook. """Get all latest entries from all feeds for a specific webhook.
@ -1432,6 +1560,7 @@ async def get_webhook_entries( # noqa: C901, PLR0912, PLR0914
webhook_url: The webhook URL to get entries for. webhook_url: The webhook URL to get entries for.
request: The request object. request: The request object.
starting_after: The entry to start after. Used for pagination. starting_after: The entry to start after. Used for pagination.
reader: The Reader instance.
Returns: Returns:
HTMLResponse: The webhook entries page. HTMLResponse: The webhook entries page.

View file

@ -57,7 +57,7 @@ def pytest_sessionstart(session: pytest.Session) -> None:
current_reader.close() current_reader.close()
get_reader: Any = getattr(settings_module, "get_reader", None) get_reader: Any = getattr(settings_module, "get_reader", None)
if callable(get_reader): if callable(get_reader):
main_module.reader = get_reader() get_reader()
def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None: def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None:

View file

@ -14,6 +14,7 @@ from fastapi.testclient import TestClient
import discord_rss_bot.main as main_module import discord_rss_bot.main as main_module
from discord_rss_bot.main import app from discord_rss_bot.main import app
from discord_rss_bot.main import create_html_for_feed from discord_rss_bot.main import create_html_for_feed
from discord_rss_bot.main import get_reader_dependency
if TYPE_CHECKING: if TYPE_CHECKING:
from pathlib import Path from pathlib import Path
@ -324,7 +325,7 @@ def test_change_feed_url_marks_entries_as_read() -> None:
mock_entry_b = MagicMock() mock_entry_b = MagicMock()
mock_entry_b.id = "entry-b" mock_entry_b.id = "entry-b"
real_reader = main_module.reader real_reader = main_module.get_reader_dependency()
# Use a no-redirect client so the POST response is inspected directly; the # Use a no-redirect client so the POST response is inspected directly; the
# redirect target (/feed?feed_url=…) would 404 because change_feed_url is mocked. # redirect target (/feed?feed_url=…) would 404 because change_feed_url is mocked.
@ -927,3 +928,28 @@ def test_webhook_entries_url_encoding() -> None:
# Clean up # Clean up
client.post(url="/remove", data={"feed_url": feed_url}) client.post(url="/remove", data={"feed_url": feed_url})
def test_reader_dependency_override_is_used() -> None:
"""Reader should be injectable and overridable via FastAPI dependency overrides."""
class StubReader:
def get_tag(self, _resource: str, _key: str, default: str | None = None) -> str | None:
"""Stub get_tag that always returns the default value.
Args:
_resource: Ignored.
_key: Ignored.
default: The value to return.
Returns:
The default value, simulating a missing tag.
"""
return default
app.dependency_overrides[get_reader_dependency] = StubReader
try:
response: Response = client.get(url="/add")
assert response.status_code == 200, f"Expected /add to render with overridden reader: {response.text}"
finally:
app.dependency_overrides = {}