Fix all the bugs

This commit is contained in:
2024-05-18 04:05:05 +02:00
parent 9c63916716
commit 73cf7c489c
33 changed files with 831 additions and 396 deletions

View File

@ -1,20 +1,24 @@
from __future__ import annotations
import json
import typing
import urllib.parse
from collections.abc import Iterable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import datetime, timezone
from functools import lru_cache
from typing import cast
from typing import TYPE_CHECKING, cast
import httpx
import uvicorn
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from fastapi import FastAPI, Form, HTTPException, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from httpx import Response
from reader import Entry, Feed, FeedNotFoundError, Reader, TagNotFoundError
from reader.types import JSONType
from starlette.responses import RedirectResponse
from discord_rss_bot import settings
@ -38,11 +42,32 @@ from discord_rss_bot.search import create_html_for_search_results
from discord_rss_bot.settings import get_reader
from discord_rss_bot.webhook import add_webhook, remove_webhook
if TYPE_CHECKING:
from collections.abc import Iterable
reader: Reader = get_reader()
@asynccontextmanager
async def lifespan(app: FastAPI) -> typing.AsyncGenerator[None, None]:
"""This is needed for the ASGI server to run."""
add_missing_tags(reader=reader)
scheduler: AsyncIOScheduler = AsyncIOScheduler()
# Update all feeds every 15 minutes.
# TODO(TheLovinator): Make this configurable.
scheduler.add_job(send_to_discord, "interval", minutes=15, next_run_time=datetime.now(tz=timezone.utc))
scheduler.start()
yield
reader.close()
scheduler.shutdown(wait=True)
app: FastAPI = FastAPI()
app.mount("/static", StaticFiles(directory="discord_rss_bot/static"), name="static")
templates: Jinja2Templates = Jinja2Templates(directory="discord_rss_bot/templates")
reader: Reader = get_reader()
# Add the filters to the Jinja2 environment so they can be used in html templates.
templates.env.filters["encode_url"] = encode_url
@ -70,7 +95,7 @@ async def post_delete_webhook(webhook_url: str = Form()) -> RedirectResponse:
Args:
webhook_url: The url of the webhook.
"""
# TODO: Check if the webhook is in use by any feeds before deleting it.
# TODO(TheLovinator): Check if the webhook is in use by any feeds before deleting it.
remove_webhook(reader, webhook_url)
return RedirectResponse(url="/", status_code=303)
@ -131,19 +156,19 @@ async def post_set_whitelist(
"""
clean_feed_url: str = feed_url.strip()
if whitelist_title:
reader.set_tag(clean_feed_url, "whitelist_title", whitelist_title) # type: ignore
reader.set_tag(clean_feed_url, "whitelist_title", whitelist_title) # type: ignore[call-overload]
if whitelist_summary:
reader.set_tag(clean_feed_url, "whitelist_summary", whitelist_summary) # type: ignore
reader.set_tag(clean_feed_url, "whitelist_summary", whitelist_summary) # type: ignore[call-overload]
if whitelist_content:
reader.set_tag(clean_feed_url, "whitelist_content", whitelist_content) # type: ignore
reader.set_tag(clean_feed_url, "whitelist_content", whitelist_content) # type: ignore[call-overload]
if whitelist_author:
reader.set_tag(clean_feed_url, "whitelist_author", whitelist_author) # type: ignore
reader.set_tag(clean_feed_url, "whitelist_author", whitelist_author) # type: ignore[call-overload]
return RedirectResponse(url=f"/feed/?feed_url={urllib.parse.quote(clean_feed_url)}", status_code=303)
@app.get("/whitelist", response_class=HTMLResponse)
async def get_whitelist(feed_url: str, request: Request): # noqa: ANN201
async def get_whitelist(feed_url: str, request: Request):
"""Get the whitelist.
Args:
@ -167,7 +192,7 @@ async def get_whitelist(feed_url: str, request: Request): # noqa: ANN201
"whitelist_content": whitelist_content,
"whitelist_author": whitelist_author,
}
return templates.TemplateResponse("whitelist.html", context)
return templates.TemplateResponse(request=request, name="whitelist.html", context=context)
@app.post("/blacklist")
@ -192,19 +217,28 @@ async def post_set_blacklist(
"""
clean_feed_url: str = feed_url.strip()
if blacklist_title:
reader.set_tag(clean_feed_url, "blacklist_title", blacklist_title) # type: ignore
reader.set_tag(clean_feed_url, "blacklist_title", blacklist_title) # type: ignore[call-overload]
if blacklist_summary:
reader.set_tag(clean_feed_url, "blacklist_summary", blacklist_summary) # type: ignore
reader.set_tag(clean_feed_url, "blacklist_summary", blacklist_summary) # type: ignore[call-overload]
if blacklist_content:
reader.set_tag(clean_feed_url, "blacklist_content", blacklist_content) # type: ignore
reader.set_tag(clean_feed_url, "blacklist_content", blacklist_content) # type: ignore[call-overload]
if blacklist_author:
reader.set_tag(clean_feed_url, "blacklist_author", blacklist_author) # type: ignore
reader.set_tag(clean_feed_url, "blacklist_author", blacklist_author) # type: ignore[call-overload]
return RedirectResponse(url=f"/feed/?feed_url={urllib.parse.quote(clean_feed_url)}", status_code=303)
@app.get("/blacklist", response_class=HTMLResponse)
async def get_blacklist(feed_url: str, request: Request): # noqa: ANN201
async def get_blacklist(feed_url: str, request: Request):
"""Get the blacklist.
Args:
feed_url: What feed we should get the blacklist for.
request: The request object.
Returns:
HTMLResponse: The blacklist page.
"""
feed: Feed = reader.get_feed(urllib.parse.unquote(feed_url))
# Get previous data, this is used when creating the form.
@ -221,7 +255,7 @@ async def get_blacklist(feed_url: str, request: Request): # noqa: ANN201
"blacklist_content": blacklist_content,
"blacklist_author": blacklist_author,
}
return templates.TemplateResponse("blacklist.html", context)
return templates.TemplateResponse(request=request, name="blacklist.html", context=context)
@app.post("/custom")
@ -232,17 +266,23 @@ async def post_set_custom(custom_message: str = Form(""), feed_url: str = Form()
custom_message: The custom message.
feed_url: The feed we should set the custom message for.
"""
if custom_message:
reader.set_tag(feed_url, "custom_message", custom_message.strip()) # type: ignore
our_custom_message: JSONType | str = custom_message.strip()
our_custom_message = typing.cast(JSONType, our_custom_message)
default_custom_message: JSONType | str = settings.default_custom_message
default_custom_message = typing.cast(JSONType, default_custom_message)
if our_custom_message:
reader.set_tag(feed_url, "custom_message", our_custom_message)
else:
reader.set_tag(feed_url, "custom_message", settings.default_custom_message) # type: ignore
reader.set_tag(feed_url, "custom_message", default_custom_message)
clean_feed_url: str = feed_url.strip()
return RedirectResponse(url=f"/feed/?feed_url={urllib.parse.quote(clean_feed_url)}", status_code=303)
@app.get("/custom", response_class=HTMLResponse)
async def get_custom(feed_url: str, request: Request): # noqa: ANN201
async def get_custom(feed_url: str, request: Request):
"""Get the custom message. This is used when sending the message to Discord.
Args:
@ -261,11 +301,11 @@ async def get_custom(feed_url: str, request: Request): # noqa: ANN201
for entry in reader.get_entries(feed=feed, limit=1):
context["entry"] = entry
return templates.TemplateResponse("custom.html", context)
return templates.TemplateResponse(request=request, name="custom.html", context=context)
@app.get("/embed", response_class=HTMLResponse)
async def get_embed_page(feed_url: str, request: Request): # noqa: ANN201
async def get_embed_page(feed_url: str, request: Request):
"""Get the custom message. This is used when sending the message to Discord.
Args:
@ -297,11 +337,11 @@ async def get_embed_page(feed_url: str, request: Request): # noqa: ANN201
for entry in reader.get_entries(feed=feed, limit=1):
# Append to context.
context["entry"] = entry
return templates.TemplateResponse("embed.html", context)
return templates.TemplateResponse(request=request, name="embed.html", context=context)
@app.post("/embed", response_class=HTMLResponse)
async def post_embed( # noqa: PLR0913
async def post_embed( # noqa: PLR0913, PLR0917
feed_url: str = Form(),
title: str = Form(""),
description: str = Form(""),
@ -385,22 +425,23 @@ async def post_use_text(feed_url: str = Form()) -> RedirectResponse:
@app.get("/add", response_class=HTMLResponse)
def get_add(request: Request): # noqa: ANN201
def get_add(request: Request):
"""Page for adding a new feed."""
context = {
"request": request,
"webhooks": reader.get_tag((), "webhooks", []),
}
return templates.TemplateResponse("add.html", context)
return templates.TemplateResponse(request=request, name="add.html", context=context)
@app.get("/feed", response_class=HTMLResponse)
async def get_feed(feed_url: str, request: Request): # noqa: ANN201
async def get_feed(feed_url: str, request: Request, starting_after: str | None = None):
"""Get a feed by URL.
Args:
feed_url: The feed to add.
request: The request object.
starting_after: The entry to start after. Used for pagination.
Returns:
HTMLResponse: The feed page.
@ -410,7 +451,7 @@ async def get_feed(feed_url: str, request: Request): # noqa: ANN201
feed: Feed = reader.get_feed(clean_feed_url)
# Get entries from the feed.
entries: Iterable[Entry] = reader.get_entries(feed=clean_feed_url)
entries: typing.Iterable[Entry] = reader.get_entries(feed=clean_feed_url, limit=10)
# Create the html for the entries.
html: str = create_html_for_feed(entries)
@ -428,8 +469,49 @@ async def get_feed(feed_url: str, request: Request): # noqa: ANN201
"feed_counts": reader.get_feed_counts(feed=clean_feed_url),
"html": html,
"should_send_embed": should_send_embed,
"show_more_button": True,
}
return templates.TemplateResponse("feed.html", context)
return templates.TemplateResponse(request=request, name="feed.html", context=context)
@app.get("/feed_more", response_class=HTMLResponse)
async def get_all_entries(feed_url: str, request: Request):
"""Get a feed by URL and show more entries.
Args:
feed_url: The feed to add.
request: The request object.
starting_after: The entry to start after. Used for pagination.
Returns:
HTMLResponse: The feed page.
"""
clean_feed_url: str = urllib.parse.unquote(feed_url.strip())
feed: Feed = reader.get_feed(clean_feed_url)
# Get entries from the feed.
entries: typing.Iterable[Entry] = reader.get_entries(feed=clean_feed_url, limit=200)
# Create the html for the entries.
html: str = create_html_for_feed(entries)
try:
should_send_embed: bool = bool(reader.get_tag(feed, "should_send_embed"))
except TagNotFoundError:
add_missing_tags(reader)
should_send_embed: bool = bool(reader.get_tag(feed, "should_send_embed"))
context = {
"request": request,
"feed": feed,
"entries": entries,
"feed_counts": reader.get_feed_counts(feed=clean_feed_url),
"html": html,
"should_send_embed": should_send_embed,
"show_more_button": False,
}
return templates.TemplateResponse(request=request, name="feed.html", context=context)
def create_html_for_feed(entries: Iterable[Entry]) -> str:
@ -468,7 +550,7 @@ def create_html_for_feed(entries: Iterable[Entry]) -> str:
html += f"""<div class="p-2 mb-2 border border-dark">
{blacklisted}{whitelisted}<a class="text-muted text-decoration-none" href="{entry.link}"><h2>{entry.title}</h2></a>
{f"By { entry.author } @" if entry.author else ""}{published} - {to_discord_html}
{f"By {entry.author} @" if entry.author else ""}{published} - {to_discord_html}
{text}
{image_html}
@ -478,7 +560,7 @@ def create_html_for_feed(entries: Iterable[Entry]) -> str:
@app.get("/add_webhook", response_class=HTMLResponse)
async def get_add_webhook(request: Request): # noqa: ANN201
async def get_add_webhook(request: Request):
"""Page for adding a new webhook.
Args:
@ -487,7 +569,7 @@ async def get_add_webhook(request: Request): # noqa: ANN201
Returns:
HTMLResponse: The add webhook page.
"""
return templates.TemplateResponse("add_webhook.html", {"request": request})
return templates.TemplateResponse(request=request, name="add_webhook.html", context={"request": request})
@dataclass()
@ -533,7 +615,7 @@ def get_data_from_hook_url(hook_name: str, hook_url: str) -> WebhookInfo:
@app.get("/webhooks", response_class=HTMLResponse)
async def get_webhooks(request: Request): # noqa: ANN201
async def get_webhooks(request: Request):
"""Page for adding a new webhook.
Args:
@ -549,11 +631,11 @@ async def get_webhooks(request: Request): # noqa: ANN201
hooks_with_data.append(our_hook)
context = {"request": request, "hooks_with_data": hooks_with_data}
return templates.TemplateResponse("webhooks.html", context)
return templates.TemplateResponse(request=request, name="webhooks.html", context=context)
@app.get("/", response_class=HTMLResponse)
def get_index(request: Request): # noqa: ANN201
def get_index(request: Request):
"""This is the root of the website.
Args:
@ -562,10 +644,10 @@ def get_index(request: Request): # noqa: ANN201
Returns:
HTMLResponse: The index page.
"""
return templates.TemplateResponse("index.html", make_context_index(request))
return templates.TemplateResponse(request=request, name="index.html", context=make_context_index(request))
def make_context_index(request: Request): # noqa: ANN201
def make_context_index(request: Request):
"""Create the needed context for the index page.
Args:
@ -605,7 +687,7 @@ def make_context_index(request: Request): # noqa: ANN201
@app.post("/remove", response_class=HTMLResponse)
async def remove_feed(feed_url: str = Form()): # noqa: ANN201
async def remove_feed(feed_url: str = Form()):
"""Get a feed by URL.
Args:
@ -623,7 +705,7 @@ async def remove_feed(feed_url: str = Form()): # noqa: ANN201
@app.get("/search", response_class=HTMLResponse)
async def search(request: Request, query: str): # noqa: ANN201
async def search(request: Request, query: str):
"""Get entries matching a full-text search query.
Args:
@ -641,11 +723,11 @@ async def search(request: Request, query: str): # noqa: ANN201
"query": query,
"search_amount": reader.search_entry_counts(query),
}
return templates.TemplateResponse("search.html", context)
return templates.TemplateResponse(request=request, name="search.html", context=context)
@app.get("/post_entry", response_class=HTMLResponse)
async def post_entry(entry_id: str): # noqa: ANN201
async def post_entry(entry_id: str):
"""Send single entry to Discord.
Args:
@ -668,7 +750,7 @@ async def post_entry(entry_id: str): # noqa: ANN201
@app.post("/modify_webhook", response_class=HTMLResponse)
def modify_webhook(old_hook: str = Form(), new_hook: str = Form()): # noqa: ANN201
def modify_webhook(old_hook: str = Form(), new_hook: str = Form()):
"""Modify a webhook.
Args:
@ -682,7 +764,7 @@ def modify_webhook(old_hook: str = Form(), new_hook: str = Form()): # noqa: ANN
webhooks = list(reader.get_tag((), "webhooks", []))
# Webhooks are stored as a list of dictionaries.
# Example: [{"name": "webhook_name", "url": "webhook_url"}] # noqa: ERA001
# Example: [{"name": "webhook_name", "url": "webhook_url"}]
webhooks = cast(list[dict[str, str]], webhooks)
for hook in webhooks:
@ -712,24 +794,8 @@ def modify_webhook(old_hook: str = Form(), new_hook: str = Form()): # noqa: ANN
return RedirectResponse(url="/webhooks", status_code=303)
@app.on_event("startup")
def startup() -> None:
"""This is called when the server starts.
It adds missing tags and starts the scheduler.
"""
add_missing_tags(reader=reader)
scheduler: BackgroundScheduler = BackgroundScheduler()
# Update all feeds every 15 minutes.
# TODO: Make this configurable.
scheduler.add_job(send_to_discord, "interval", minutes=15, next_run_time=datetime.now(tz=timezone.utc))
scheduler.start()
if __name__ == "__main__":
# TODO: Make this configurable.
# TODO(TheLovinator): Make this configurable.
uvicorn.run(
"main:app",
log_level="info",