diff --git a/discord_rss_bot/main.py b/discord_rss_bot/main.py index 050de04..755cfc0 100644 --- a/discord_rss_bot/main.py +++ b/discord_rss_bot/main.py @@ -19,6 +19,7 @@ import httpx import sentry_sdk import uvicorn from apscheduler.schedulers.asyncio import AsyncIOScheduler +from fastapi import Depends from fastapi import FastAPI from fastapi import Form from fastapi import HTTPException @@ -100,7 +101,16 @@ LOGGING_CONFIG: dict[str, Any] = { logging.config.dictConfig(LOGGING_CONFIG) logger: logging.Logger = logging.getLogger(__name__) -reader: Reader = get_reader() + + +def get_reader_dependency() -> Reader: + """Provide the app Reader instance as a FastAPI dependency. + + Returns: + Reader: The shared Reader instance. + """ + return get_reader() + # Time constants for relative time formatting SECONDS_PER_MINUTE = 60 @@ -146,6 +156,7 @@ def relative_time(dt: datetime | None) -> str: @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None]: """Lifespan function for the FastAPI app.""" + reader: Reader = get_reader() add_missing_tags(reader) scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone=UTC) scheduler.add_job( @@ -181,12 +192,14 @@ templates.env.globals["get_backup_path"] = get_backup_path async def post_add_webhook( webhook_name: Annotated[str, Form()], webhook_url: Annotated[str, Form()], + reader: Annotated[Reader, Depends(get_reader_dependency)], ) -> RedirectResponse: """Add a feed to the database. Args: webhook_name: The name of the webhook. webhook_url: The url of the webhook. + reader: The Reader instance. Returns: RedirectResponse: Redirect to the index page. @@ -219,11 +232,15 @@ async def post_add_webhook( @app.post("/delete_webhook") -async def post_delete_webhook(webhook_url: Annotated[str, Form()]) -> RedirectResponse: +async def post_delete_webhook( + webhook_url: Annotated[str, Form()], + reader: Annotated[Reader, Depends(get_reader_dependency)], +) -> RedirectResponse: """Delete a webhook from the database. Args: webhook_url: The url of the webhook. + reader: The Reader instance. Returns: RedirectResponse: Redirect to the index page. @@ -266,12 +283,14 @@ async def post_delete_webhook(webhook_url: Annotated[str, Form()]) -> RedirectRe async def post_create_feed( feed_url: Annotated[str, Form()], webhook_dropdown: Annotated[str, Form()], + reader: Annotated[Reader, Depends(get_reader_dependency)], ) -> RedirectResponse: """Add a feed to the database. Args: feed_url: The feed to add. webhook_dropdown: The webhook to use. + reader: The Reader instance. Returns: RedirectResponse: Redirect to the feed page. @@ -283,11 +302,15 @@ async def post_create_feed( @app.post("/pause") -async def post_pause_feed(feed_url: Annotated[str, Form()]) -> RedirectResponse: +async def post_pause_feed( + feed_url: Annotated[str, Form()], + reader: Annotated[Reader, Depends(get_reader_dependency)], +) -> RedirectResponse: """Pause a feed. Args: feed_url: The feed to pause. + reader: The Reader instance. Returns: RedirectResponse: Redirect to the feed page. @@ -298,11 +321,15 @@ async def post_pause_feed(feed_url: Annotated[str, Form()]) -> RedirectResponse: @app.post("/unpause") -async def post_unpause_feed(feed_url: Annotated[str, Form()]) -> RedirectResponse: +async def post_unpause_feed( + feed_url: Annotated[str, Form()], + reader: Annotated[Reader, Depends(get_reader_dependency)], +) -> RedirectResponse: """Unpause a feed. Args: feed_url: The Feed to unpause. + reader: The Reader instance. Returns: RedirectResponse: Redirect to the feed page. @@ -314,6 +341,7 @@ async def post_unpause_feed(feed_url: Annotated[str, Form()]) -> RedirectRespons @app.post("/whitelist") async def post_set_whitelist( + reader: Annotated[Reader, Depends(get_reader_dependency)], whitelist_title: Annotated[str, Form()] = "", whitelist_summary: Annotated[str, Form()] = "", whitelist_content: Annotated[str, Form()] = "", @@ -336,6 +364,7 @@ async def post_set_whitelist( regex_whitelist_content: Whitelisted regex for when checking the content. regex_whitelist_author: Whitelisted regex for when checking the author. feed_url: The feed we should set the whitelist for. + reader: The Reader instance. Returns: RedirectResponse: Redirect to the feed page. @@ -356,12 +385,17 @@ async def post_set_whitelist( @app.get("/whitelist", response_class=HTMLResponse) -async def get_whitelist(feed_url: str, request: Request): +async def get_whitelist( + feed_url: str, + request: Request, + reader: Annotated[Reader, Depends(get_reader_dependency)], +): """Get the whitelist. Args: feed_url: What feed we should get the whitelist for. request: The request object. + reader: The Reader instance. Returns: HTMLResponse: The whitelist page. @@ -395,6 +429,7 @@ async def get_whitelist(feed_url: str, request: Request): @app.post("/blacklist") async def post_set_blacklist( + reader: Annotated[Reader, Depends(get_reader_dependency)], blacklist_title: Annotated[str, Form()] = "", blacklist_summary: Annotated[str, Form()] = "", blacklist_content: Annotated[str, Form()] = "", @@ -420,6 +455,7 @@ async def post_set_blacklist( regex_blacklist_content: Blacklisted regex for when checking the content. regex_blacklist_author: Blacklisted regex for when checking the author. feed_url: What feed we should set the blacklist for. + reader: The Reader instance. Returns: RedirectResponse: Redirect to the feed page. @@ -438,12 +474,17 @@ async def post_set_blacklist( @app.get("/blacklist", response_class=HTMLResponse) -async def get_blacklist(feed_url: str, request: Request): +async def get_blacklist( + feed_url: str, + request: Request, + reader: Annotated[Reader, Depends(get_reader_dependency)], +): """Get the blacklist. Args: feed_url: What feed we should get the blacklist for. request: The request object. + reader: The Reader instance. Returns: HTMLResponse: The blacklist page. @@ -477,6 +518,7 @@ async def get_blacklist(feed_url: str, request: Request): @app.post("/custom") async def post_set_custom( feed_url: Annotated[str, Form()], + reader: Annotated[Reader, Depends(get_reader_dependency)], custom_message: Annotated[str, Form()] = "", ) -> RedirectResponse: """Set the custom message, this is used when sending the message. @@ -484,6 +526,7 @@ async def post_set_custom( Args: custom_message: The custom message. feed_url: The feed we should set the custom message for. + reader: The Reader instance. Returns: RedirectResponse: Redirect to the feed page. @@ -505,12 +548,17 @@ async def post_set_custom( @app.get("/custom", response_class=HTMLResponse) -async def get_custom(feed_url: str, request: Request): +async def get_custom( + feed_url: str, + request: Request, + reader: Annotated[Reader, Depends(get_reader_dependency)], +): """Get the custom message. This is used when sending the message to Discord. Args: feed_url: What feed we should get the custom message for. request: The request object. + reader: The Reader instance. Returns: HTMLResponse: The custom message page. @@ -531,12 +579,17 @@ async def get_custom(feed_url: str, request: Request): @app.get("/embed", response_class=HTMLResponse) -async def get_embed_page(feed_url: str, request: Request): +async def get_embed_page( + feed_url: str, + request: Request, + reader: Annotated[Reader, Depends(get_reader_dependency)], +): """Get the custom message. This is used when sending the message to Discord. Args: feed_url: What feed we should get the custom message for. request: The request object. + reader: The Reader instance. Returns: HTMLResponse: The embed page. @@ -572,6 +625,7 @@ async def get_embed_page(feed_url: str, request: Request): @app.post("/embed", response_class=HTMLResponse) async def post_embed( feed_url: Annotated[str, Form()], + reader: Annotated[Reader, Depends(get_reader_dependency)], title: Annotated[str, Form()] = "", description: Annotated[str, Form()] = "", color: Annotated[str, Form()] = "", @@ -597,7 +651,7 @@ async def post_embed( author_icon_url: The author icon url of the embed. footer_text: The footer text of the embed. footer_icon_url: The footer icon url of the embed. - + reader: The Reader instance. Returns: RedirectResponse: Redirect to the embed page. @@ -625,11 +679,15 @@ async def post_embed( @app.post("/use_embed") -async def post_use_embed(feed_url: Annotated[str, Form()]) -> RedirectResponse: +async def post_use_embed( + feed_url: Annotated[str, Form()], + reader: Annotated[Reader, Depends(get_reader_dependency)], +) -> RedirectResponse: """Use embed instead of text. Args: feed_url: The feed to change. + reader: The Reader instance. Returns: RedirectResponse: Redirect to the feed page. @@ -641,11 +699,15 @@ async def post_use_embed(feed_url: Annotated[str, Form()]) -> RedirectResponse: @app.post("/use_text") -async def post_use_text(feed_url: Annotated[str, Form()]) -> RedirectResponse: +async def post_use_text( + feed_url: Annotated[str, Form()], + reader: Annotated[Reader, Depends(get_reader_dependency)], +) -> RedirectResponse: """Use text instead of embed. Args: feed_url: The feed to change. + reader: The Reader instance. Returns: RedirectResponse: Redirect to the feed page. @@ -659,6 +721,7 @@ async def post_use_text(feed_url: Annotated[str, Form()]) -> RedirectResponse: @app.post("/set_update_interval") async def post_set_update_interval( feed_url: Annotated[str, Form()], + reader: Annotated[Reader, Depends(get_reader_dependency)], interval_minutes: Annotated[int | None, Form()] = None, redirect_to: Annotated[str, Form()] = "", ) -> RedirectResponse: @@ -668,6 +731,7 @@ async def post_set_update_interval( feed_url: The feed to change. interval_minutes: The update interval in minutes (None to reset to global default). redirect_to: Optional redirect URL (defaults to feed page). + reader: The Reader instance. Returns: RedirectResponse: Redirect to the specified page or feed page. @@ -703,12 +767,14 @@ async def post_set_update_interval( async def post_change_feed_url( old_feed_url: Annotated[str, Form()], new_feed_url: Annotated[str, Form()], + reader: Annotated[Reader, Depends(get_reader_dependency)], ) -> RedirectResponse: """Change the URL for an existing feed. Args: old_feed_url: Current feed URL. new_feed_url: New feed URL to change to. + reader: The Reader instance. Returns: RedirectResponse: Redirect to the feed page for the resulting URL. @@ -754,6 +820,7 @@ async def post_change_feed_url( @app.post("/reset_update_interval") async def post_reset_update_interval( feed_url: Annotated[str, Form()], + reader: Annotated[Reader, Depends(get_reader_dependency)], redirect_to: Annotated[str, Form()] = "", ) -> RedirectResponse: """Reset the update interval for a feed to use the global default. @@ -761,6 +828,7 @@ async def post_reset_update_interval( Args: feed_url: The feed to change. redirect_to: Optional redirect URL (defaults to feed page). + reader: The Reader instance. Returns: RedirectResponse: Redirect to the specified page or feed page. @@ -787,11 +855,15 @@ async def post_reset_update_interval( @app.post("/set_global_update_interval") -async def post_set_global_update_interval(interval_minutes: Annotated[int, Form()]) -> RedirectResponse: +async def post_set_global_update_interval( + interval_minutes: Annotated[int, Form()], + reader: Annotated[Reader, Depends(get_reader_dependency)], +) -> RedirectResponse: """Set the global default update interval. Args: interval_minutes: The update interval in minutes. + reader: The Reader instance. Returns: RedirectResponse: Redirect to the settings page. @@ -805,11 +877,15 @@ async def post_set_global_update_interval(interval_minutes: Annotated[int, Form( @app.get("/add", response_class=HTMLResponse) -def get_add(request: Request): +def get_add( + request: Request, + reader: Annotated[Reader, Depends(get_reader_dependency)], +): """Page for adding a new feed. Args: request: The request object. + reader: The Reader instance. Returns: HTMLResponse: The add feed page. @@ -822,13 +898,19 @@ def get_add(request: Request): @app.get("/feed", response_class=HTMLResponse) -async def get_feed(feed_url: str, request: Request, starting_after: str = ""): # noqa: C901, PLR0912, PLR0914, PLR0915 +async def get_feed( # noqa: C901, PLR0912, PLR0914, PLR0915 + feed_url: str, + request: Request, + reader: Annotated[Reader, Depends(get_reader_dependency)], + starting_after: str = "", +): """Get a feed by URL. Args: feed_url: The feed to add. request: The request object. starting_after: The entry to start after. Used for pagination. + reader: The Reader instance. Returns: HTMLResponse: The feed page. @@ -1083,6 +1165,7 @@ def get_data_from_hook_url(hook_name: str, hook_url: str) -> WebhookInfo: hook_name (str): The webhook name. hook_url (str): The webhook URL. + Returns: WebhookInfo: The webhook username, avatar, guild id, etc. """ @@ -1104,11 +1187,15 @@ def get_data_from_hook_url(hook_name: str, hook_url: str) -> WebhookInfo: @app.get("/settings", response_class=HTMLResponse) -async def get_settings(request: Request): +async def get_settings( + request: Request, + reader: Annotated[Reader, Depends(get_reader_dependency)], +): """Settings page. Args: request: The request object. + reader: The Reader instance. Returns: HTMLResponse: The settings page. @@ -1154,11 +1241,15 @@ async def get_settings(request: Request): @app.get("/webhooks", response_class=HTMLResponse) -async def get_webhooks(request: Request): +async def get_webhooks( + request: Request, + reader: Annotated[Reader, Depends(get_reader_dependency)], +): """Page for adding a new webhook. Args: request: The request object. + reader: The Reader instance. Returns: HTMLResponse: The add webhook page. @@ -1179,54 +1270,65 @@ async def get_webhooks(request: Request): @app.get("/", response_class=HTMLResponse) -def get_index(request: Request, message: str = ""): +def get_index( + request: Request, + reader: Annotated[Reader, Depends(get_reader_dependency)], + message: str = "", +): """This is the root of the website. Args: request: The request object. message: Optional message to display to the user. + reader: The Reader instance. Returns: HTMLResponse: The index page. """ - return templates.TemplateResponse(request=request, name="index.html", context=make_context_index(request, message)) + return templates.TemplateResponse( + request=request, + name="index.html", + context=make_context_index(request, message, reader), + ) -def make_context_index(request: Request, message: str = ""): +def make_context_index(request: Request, message: str = "", reader: Reader | None = None): """Create the needed context for the index page. Args: request: The request object. message: Optional message to display to the user. + reader: The Reader instance. Returns: dict: The context for the index page. """ - hooks: list[dict[str, str]] = cast("list[dict[str, str]]", list(reader.get_tag((), "webhooks", []))) + effective_reader: Reader = reader or get_reader_dependency() + hooks: list[dict[str, str]] = cast("list[dict[str, str]]", list(effective_reader.get_tag((), "webhooks", []))) - feed_list = [] - broken_feeds = [] - feeds_without_attached_webhook = [] + feed_list: list[dict[str, JSONType | Feed | str]] = [] + broken_feeds: list[Feed] = [] + feeds_without_attached_webhook: list[Feed] = [] # Get all feeds and organize them - feeds: Iterable[Feed] = reader.get_feeds() + feeds: Iterable[Feed] = effective_reader.get_feeds() for feed in feeds: try: - webhook = reader.get_tag(feed.url, "webhook") + webhook: JSONType = effective_reader.get_tag(feed.url, "webhook") feed_list.append({"feed": feed, "webhook": webhook, "domain": extract_domain(feed.url)}) except TagNotFoundError: broken_feeds.append(feed) continue - webhook_list = [hook["url"] for hook in hooks] + webhook_list: list[str] = [hook["url"] for hook in hooks] if webhook not in webhook_list: feeds_without_attached_webhook.append(feed) return { "request": request, "feeds": feed_list, - "feed_count": reader.get_feed_counts(), - "entry_count": reader.get_entry_counts(), + "feed_count": effective_reader.get_feed_counts(), + "entry_count": effective_reader.get_entry_counts(), "webhooks": hooks, "broken_feeds": broken_feeds, "feeds_without_attached_webhook": feeds_without_attached_webhook, @@ -1235,12 +1337,15 @@ def make_context_index(request: Request, message: str = ""): @app.post("/remove", response_class=HTMLResponse) -async def remove_feed(feed_url: Annotated[str, Form()]): +async def remove_feed( + feed_url: Annotated[str, Form()], + reader: Annotated[Reader, Depends(get_reader_dependency)], +): """Get a feed by URL. Args: feed_url: The feed to add. - + reader: The Reader instance. Returns: RedirectResponse: Redirect to the index page. @@ -1259,13 +1364,17 @@ async def remove_feed(feed_url: Annotated[str, Form()]): @app.get("/update", response_class=HTMLResponse) -async def update_feed(request: Request, feed_url: str): +async def update_feed( + request: Request, + feed_url: str, + reader: Annotated[Reader, Depends(get_reader_dependency)], +): """Update a feed. Args: request: The request object. feed_url: The feed URL to update. - + reader: The Reader instance. Returns: RedirectResponse: Redirect to the feed page. @@ -1283,11 +1392,15 @@ async def update_feed(request: Request, feed_url: str): @app.post("/backup") -async def manual_backup(request: Request) -> RedirectResponse: +async def manual_backup( + request: Request, + reader: Annotated[Reader, Depends(get_reader_dependency)], +) -> RedirectResponse: """Manually trigger a git backup of the current state. Args: request: The request object. + reader: The Reader instance. Returns: RedirectResponse: Redirect to the index page with a success or error message. @@ -1310,12 +1423,17 @@ async def manual_backup(request: Request) -> RedirectResponse: @app.get("/search", response_class=HTMLResponse) -async def search(request: Request, query: str): +async def search( + request: Request, + query: str, + reader: Annotated[Reader, Depends(get_reader_dependency)], +): """Get entries matching a full-text search query. Args: query: The query to search for. request: The request object. + reader: The Reader instance. Returns: HTMLResponse: The search page. @@ -1326,11 +1444,15 @@ async def search(request: Request, query: str): @app.get("/post_entry", response_class=HTMLResponse) -async def post_entry(entry_id: str): +async def post_entry( + entry_id: str, + reader: Annotated[Reader, Depends(get_reader_dependency)], +): """Send single entry to Discord. Args: entry_id: The entry to send. + reader: The Reader instance. Returns: RedirectResponse: Redirect to the feed page. @@ -1349,12 +1471,17 @@ async def post_entry(entry_id: str): @app.post("/modify_webhook", response_class=HTMLResponse) -def modify_webhook(old_hook: Annotated[str, Form()], new_hook: Annotated[str, Form()]): +def modify_webhook( + old_hook: Annotated[str, Form()], + new_hook: Annotated[str, Form()], + reader: Annotated[Reader, Depends(get_reader_dependency)], +): """Modify a webhook. Args: old_hook: The webhook to modify. new_hook: The new webhook. + reader: The Reader instance. Returns: RedirectResponse: Redirect to the webhook page. @@ -1424,6 +1551,7 @@ def extract_youtube_video_id(url: str) -> str | None: async def get_webhook_entries( # noqa: C901, PLR0912, PLR0914 webhook_url: str, request: Request, + reader: Annotated[Reader, Depends(get_reader_dependency)], starting_after: str = "", ) -> HTMLResponse: """Get all latest entries from all feeds for a specific webhook. @@ -1432,6 +1560,7 @@ async def get_webhook_entries( # noqa: C901, PLR0912, PLR0914 webhook_url: The webhook URL to get entries for. request: The request object. starting_after: The entry to start after. Used for pagination. + reader: The Reader instance. Returns: HTMLResponse: The webhook entries page. diff --git a/tests/conftest.py b/tests/conftest.py index 30c6274..c7d9170 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,7 +57,7 @@ def pytest_sessionstart(session: pytest.Session) -> None: current_reader.close() get_reader: Any = getattr(settings_module, "get_reader", None) if callable(get_reader): - main_module.reader = get_reader() + get_reader() def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None: diff --git a/tests/test_main.py b/tests/test_main.py index eee423e..8ddb0b8 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -14,6 +14,7 @@ from fastapi.testclient import TestClient import discord_rss_bot.main as main_module from discord_rss_bot.main import app from discord_rss_bot.main import create_html_for_feed +from discord_rss_bot.main import get_reader_dependency if TYPE_CHECKING: from pathlib import Path @@ -324,7 +325,7 @@ def test_change_feed_url_marks_entries_as_read() -> None: mock_entry_b = MagicMock() mock_entry_b.id = "entry-b" - real_reader = main_module.reader + real_reader = main_module.get_reader_dependency() # Use a no-redirect client so the POST response is inspected directly; the # redirect target (/feed?feed_url=…) would 404 because change_feed_url is mocked. @@ -927,3 +928,28 @@ def test_webhook_entries_url_encoding() -> None: # Clean up client.post(url="/remove", data={"feed_url": feed_url}) + + +def test_reader_dependency_override_is_used() -> None: + """Reader should be injectable and overridable via FastAPI dependency overrides.""" + + class StubReader: + def get_tag(self, _resource: str, _key: str, default: str | None = None) -> str | None: + """Stub get_tag that always returns the default value. + + Args: + _resource: Ignored. + _key: Ignored. + default: The value to return. + + Returns: + The default value, simulating a missing tag. + """ + return default + + app.dependency_overrides[get_reader_dependency] = StubReader + try: + response: Response = client.get(url="/add") + assert response.status_code == 200, f"Expected /add to render with overridden reader: {response.text}" + finally: + app.dependency_overrides = {}