Fix all the bugs
This commit is contained in:
@ -1,20 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
import urllib.parse
|
||||
from collections.abc import Iterable
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from functools import lru_cache
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from fastapi import FastAPI, Form, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from httpx import Response
|
||||
from reader import Entry, Feed, FeedNotFoundError, Reader, TagNotFoundError
|
||||
from reader.types import JSONType
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
from discord_rss_bot import settings
|
||||
@ -38,11 +42,32 @@ from discord_rss_bot.search import create_html_for_search_results
|
||||
from discord_rss_bot.settings import get_reader
|
||||
from discord_rss_bot.webhook import add_webhook, remove_webhook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
|
||||
reader: Reader = get_reader()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> typing.AsyncGenerator[None, None]:
|
||||
"""This is needed for the ASGI server to run."""
|
||||
add_missing_tags(reader=reader)
|
||||
scheduler: AsyncIOScheduler = AsyncIOScheduler()
|
||||
|
||||
# Update all feeds every 15 minutes.
|
||||
# TODO(TheLovinator): Make this configurable.
|
||||
scheduler.add_job(send_to_discord, "interval", minutes=15, next_run_time=datetime.now(tz=timezone.utc))
|
||||
scheduler.start()
|
||||
yield
|
||||
reader.close()
|
||||
scheduler.shutdown(wait=True)
|
||||
|
||||
|
||||
app: FastAPI = FastAPI()
|
||||
app.mount("/static", StaticFiles(directory="discord_rss_bot/static"), name="static")
|
||||
templates: Jinja2Templates = Jinja2Templates(directory="discord_rss_bot/templates")
|
||||
|
||||
reader: Reader = get_reader()
|
||||
|
||||
# Add the filters to the Jinja2 environment so they can be used in html templates.
|
||||
templates.env.filters["encode_url"] = encode_url
|
||||
@ -70,7 +95,7 @@ async def post_delete_webhook(webhook_url: str = Form()) -> RedirectResponse:
|
||||
Args:
|
||||
webhook_url: The url of the webhook.
|
||||
"""
|
||||
# TODO: Check if the webhook is in use by any feeds before deleting it.
|
||||
# TODO(TheLovinator): Check if the webhook is in use by any feeds before deleting it.
|
||||
remove_webhook(reader, webhook_url)
|
||||
return RedirectResponse(url="/", status_code=303)
|
||||
|
||||
@ -131,19 +156,19 @@ async def post_set_whitelist(
|
||||
"""
|
||||
clean_feed_url: str = feed_url.strip()
|
||||
if whitelist_title:
|
||||
reader.set_tag(clean_feed_url, "whitelist_title", whitelist_title) # type: ignore
|
||||
reader.set_tag(clean_feed_url, "whitelist_title", whitelist_title) # type: ignore[call-overload]
|
||||
if whitelist_summary:
|
||||
reader.set_tag(clean_feed_url, "whitelist_summary", whitelist_summary) # type: ignore
|
||||
reader.set_tag(clean_feed_url, "whitelist_summary", whitelist_summary) # type: ignore[call-overload]
|
||||
if whitelist_content:
|
||||
reader.set_tag(clean_feed_url, "whitelist_content", whitelist_content) # type: ignore
|
||||
reader.set_tag(clean_feed_url, "whitelist_content", whitelist_content) # type: ignore[call-overload]
|
||||
if whitelist_author:
|
||||
reader.set_tag(clean_feed_url, "whitelist_author", whitelist_author) # type: ignore
|
||||
reader.set_tag(clean_feed_url, "whitelist_author", whitelist_author) # type: ignore[call-overload]
|
||||
|
||||
return RedirectResponse(url=f"/feed/?feed_url={urllib.parse.quote(clean_feed_url)}", status_code=303)
|
||||
|
||||
|
||||
@app.get("/whitelist", response_class=HTMLResponse)
|
||||
async def get_whitelist(feed_url: str, request: Request): # noqa: ANN201
|
||||
async def get_whitelist(feed_url: str, request: Request):
|
||||
"""Get the whitelist.
|
||||
|
||||
Args:
|
||||
@ -167,7 +192,7 @@ async def get_whitelist(feed_url: str, request: Request): # noqa: ANN201
|
||||
"whitelist_content": whitelist_content,
|
||||
"whitelist_author": whitelist_author,
|
||||
}
|
||||
return templates.TemplateResponse("whitelist.html", context)
|
||||
return templates.TemplateResponse(request=request, name="whitelist.html", context=context)
|
||||
|
||||
|
||||
@app.post("/blacklist")
|
||||
@ -192,19 +217,28 @@ async def post_set_blacklist(
|
||||
"""
|
||||
clean_feed_url: str = feed_url.strip()
|
||||
if blacklist_title:
|
||||
reader.set_tag(clean_feed_url, "blacklist_title", blacklist_title) # type: ignore
|
||||
reader.set_tag(clean_feed_url, "blacklist_title", blacklist_title) # type: ignore[call-overload]
|
||||
if blacklist_summary:
|
||||
reader.set_tag(clean_feed_url, "blacklist_summary", blacklist_summary) # type: ignore
|
||||
reader.set_tag(clean_feed_url, "blacklist_summary", blacklist_summary) # type: ignore[call-overload]
|
||||
if blacklist_content:
|
||||
reader.set_tag(clean_feed_url, "blacklist_content", blacklist_content) # type: ignore
|
||||
reader.set_tag(clean_feed_url, "blacklist_content", blacklist_content) # type: ignore[call-overload]
|
||||
if blacklist_author:
|
||||
reader.set_tag(clean_feed_url, "blacklist_author", blacklist_author) # type: ignore
|
||||
reader.set_tag(clean_feed_url, "blacklist_author", blacklist_author) # type: ignore[call-overload]
|
||||
|
||||
return RedirectResponse(url=f"/feed/?feed_url={urllib.parse.quote(clean_feed_url)}", status_code=303)
|
||||
|
||||
|
||||
@app.get("/blacklist", response_class=HTMLResponse)
|
||||
async def get_blacklist(feed_url: str, request: Request): # noqa: ANN201
|
||||
async def get_blacklist(feed_url: str, request: Request):
|
||||
"""Get the blacklist.
|
||||
|
||||
Args:
|
||||
feed_url: What feed we should get the blacklist for.
|
||||
request: The request object.
|
||||
|
||||
Returns:
|
||||
HTMLResponse: The blacklist page.
|
||||
"""
|
||||
feed: Feed = reader.get_feed(urllib.parse.unquote(feed_url))
|
||||
|
||||
# Get previous data, this is used when creating the form.
|
||||
@ -221,7 +255,7 @@ async def get_blacklist(feed_url: str, request: Request): # noqa: ANN201
|
||||
"blacklist_content": blacklist_content,
|
||||
"blacklist_author": blacklist_author,
|
||||
}
|
||||
return templates.TemplateResponse("blacklist.html", context)
|
||||
return templates.TemplateResponse(request=request, name="blacklist.html", context=context)
|
||||
|
||||
|
||||
@app.post("/custom")
|
||||
@ -232,17 +266,23 @@ async def post_set_custom(custom_message: str = Form(""), feed_url: str = Form()
|
||||
custom_message: The custom message.
|
||||
feed_url: The feed we should set the custom message for.
|
||||
"""
|
||||
if custom_message:
|
||||
reader.set_tag(feed_url, "custom_message", custom_message.strip()) # type: ignore
|
||||
our_custom_message: JSONType | str = custom_message.strip()
|
||||
our_custom_message = typing.cast(JSONType, our_custom_message)
|
||||
|
||||
default_custom_message: JSONType | str = settings.default_custom_message
|
||||
default_custom_message = typing.cast(JSONType, default_custom_message)
|
||||
|
||||
if our_custom_message:
|
||||
reader.set_tag(feed_url, "custom_message", our_custom_message)
|
||||
else:
|
||||
reader.set_tag(feed_url, "custom_message", settings.default_custom_message) # type: ignore
|
||||
reader.set_tag(feed_url, "custom_message", default_custom_message)
|
||||
|
||||
clean_feed_url: str = feed_url.strip()
|
||||
return RedirectResponse(url=f"/feed/?feed_url={urllib.parse.quote(clean_feed_url)}", status_code=303)
|
||||
|
||||
|
||||
@app.get("/custom", response_class=HTMLResponse)
|
||||
async def get_custom(feed_url: str, request: Request): # noqa: ANN201
|
||||
async def get_custom(feed_url: str, request: Request):
|
||||
"""Get the custom message. This is used when sending the message to Discord.
|
||||
|
||||
Args:
|
||||
@ -261,11 +301,11 @@ async def get_custom(feed_url: str, request: Request): # noqa: ANN201
|
||||
for entry in reader.get_entries(feed=feed, limit=1):
|
||||
context["entry"] = entry
|
||||
|
||||
return templates.TemplateResponse("custom.html", context)
|
||||
return templates.TemplateResponse(request=request, name="custom.html", context=context)
|
||||
|
||||
|
||||
@app.get("/embed", response_class=HTMLResponse)
|
||||
async def get_embed_page(feed_url: str, request: Request): # noqa: ANN201
|
||||
async def get_embed_page(feed_url: str, request: Request):
|
||||
"""Get the custom message. This is used when sending the message to Discord.
|
||||
|
||||
Args:
|
||||
@ -297,11 +337,11 @@ async def get_embed_page(feed_url: str, request: Request): # noqa: ANN201
|
||||
for entry in reader.get_entries(feed=feed, limit=1):
|
||||
# Append to context.
|
||||
context["entry"] = entry
|
||||
return templates.TemplateResponse("embed.html", context)
|
||||
return templates.TemplateResponse(request=request, name="embed.html", context=context)
|
||||
|
||||
|
||||
@app.post("/embed", response_class=HTMLResponse)
|
||||
async def post_embed( # noqa: PLR0913
|
||||
async def post_embed( # noqa: PLR0913, PLR0917
|
||||
feed_url: str = Form(),
|
||||
title: str = Form(""),
|
||||
description: str = Form(""),
|
||||
@ -385,22 +425,23 @@ async def post_use_text(feed_url: str = Form()) -> RedirectResponse:
|
||||
|
||||
|
||||
@app.get("/add", response_class=HTMLResponse)
|
||||
def get_add(request: Request): # noqa: ANN201
|
||||
def get_add(request: Request):
|
||||
"""Page for adding a new feed."""
|
||||
context = {
|
||||
"request": request,
|
||||
"webhooks": reader.get_tag((), "webhooks", []),
|
||||
}
|
||||
return templates.TemplateResponse("add.html", context)
|
||||
return templates.TemplateResponse(request=request, name="add.html", context=context)
|
||||
|
||||
|
||||
@app.get("/feed", response_class=HTMLResponse)
|
||||
async def get_feed(feed_url: str, request: Request): # noqa: ANN201
|
||||
async def get_feed(feed_url: str, request: Request, starting_after: str | None = None):
|
||||
"""Get a feed by URL.
|
||||
|
||||
Args:
|
||||
feed_url: The feed to add.
|
||||
request: The request object.
|
||||
starting_after: The entry to start after. Used for pagination.
|
||||
|
||||
Returns:
|
||||
HTMLResponse: The feed page.
|
||||
@ -410,7 +451,7 @@ async def get_feed(feed_url: str, request: Request): # noqa: ANN201
|
||||
feed: Feed = reader.get_feed(clean_feed_url)
|
||||
|
||||
# Get entries from the feed.
|
||||
entries: Iterable[Entry] = reader.get_entries(feed=clean_feed_url)
|
||||
entries: typing.Iterable[Entry] = reader.get_entries(feed=clean_feed_url, limit=10)
|
||||
|
||||
# Create the html for the entries.
|
||||
html: str = create_html_for_feed(entries)
|
||||
@ -428,8 +469,49 @@ async def get_feed(feed_url: str, request: Request): # noqa: ANN201
|
||||
"feed_counts": reader.get_feed_counts(feed=clean_feed_url),
|
||||
"html": html,
|
||||
"should_send_embed": should_send_embed,
|
||||
"show_more_button": True,
|
||||
}
|
||||
return templates.TemplateResponse("feed.html", context)
|
||||
return templates.TemplateResponse(request=request, name="feed.html", context=context)
|
||||
|
||||
|
||||
@app.get("/feed_more", response_class=HTMLResponse)
|
||||
async def get_all_entries(feed_url: str, request: Request):
|
||||
"""Get a feed by URL and show more entries.
|
||||
|
||||
Args:
|
||||
feed_url: The feed to add.
|
||||
request: The request object.
|
||||
starting_after: The entry to start after. Used for pagination.
|
||||
|
||||
Returns:
|
||||
HTMLResponse: The feed page.
|
||||
"""
|
||||
clean_feed_url: str = urllib.parse.unquote(feed_url.strip())
|
||||
|
||||
feed: Feed = reader.get_feed(clean_feed_url)
|
||||
|
||||
# Get entries from the feed.
|
||||
entries: typing.Iterable[Entry] = reader.get_entries(feed=clean_feed_url, limit=200)
|
||||
|
||||
# Create the html for the entries.
|
||||
html: str = create_html_for_feed(entries)
|
||||
|
||||
try:
|
||||
should_send_embed: bool = bool(reader.get_tag(feed, "should_send_embed"))
|
||||
except TagNotFoundError:
|
||||
add_missing_tags(reader)
|
||||
should_send_embed: bool = bool(reader.get_tag(feed, "should_send_embed"))
|
||||
|
||||
context = {
|
||||
"request": request,
|
||||
"feed": feed,
|
||||
"entries": entries,
|
||||
"feed_counts": reader.get_feed_counts(feed=clean_feed_url),
|
||||
"html": html,
|
||||
"should_send_embed": should_send_embed,
|
||||
"show_more_button": False,
|
||||
}
|
||||
return templates.TemplateResponse(request=request, name="feed.html", context=context)
|
||||
|
||||
|
||||
def create_html_for_feed(entries: Iterable[Entry]) -> str:
|
||||
@ -468,7 +550,7 @@ def create_html_for_feed(entries: Iterable[Entry]) -> str:
|
||||
|
||||
html += f"""<div class="p-2 mb-2 border border-dark">
|
||||
{blacklisted}{whitelisted}<a class="text-muted text-decoration-none" href="{entry.link}"><h2>{entry.title}</h2></a>
|
||||
{f"By { entry.author } @" if entry.author else ""}{published} - {to_discord_html}
|
||||
{f"By {entry.author} @" if entry.author else ""}{published} - {to_discord_html}
|
||||
|
||||
{text}
|
||||
{image_html}
|
||||
@ -478,7 +560,7 @@ def create_html_for_feed(entries: Iterable[Entry]) -> str:
|
||||
|
||||
|
||||
@app.get("/add_webhook", response_class=HTMLResponse)
|
||||
async def get_add_webhook(request: Request): # noqa: ANN201
|
||||
async def get_add_webhook(request: Request):
|
||||
"""Page for adding a new webhook.
|
||||
|
||||
Args:
|
||||
@ -487,7 +569,7 @@ async def get_add_webhook(request: Request): # noqa: ANN201
|
||||
Returns:
|
||||
HTMLResponse: The add webhook page.
|
||||
"""
|
||||
return templates.TemplateResponse("add_webhook.html", {"request": request})
|
||||
return templates.TemplateResponse(request=request, name="add_webhook.html", context={"request": request})
|
||||
|
||||
|
||||
@dataclass()
|
||||
@ -533,7 +615,7 @@ def get_data_from_hook_url(hook_name: str, hook_url: str) -> WebhookInfo:
|
||||
|
||||
|
||||
@app.get("/webhooks", response_class=HTMLResponse)
|
||||
async def get_webhooks(request: Request): # noqa: ANN201
|
||||
async def get_webhooks(request: Request):
|
||||
"""Page for adding a new webhook.
|
||||
|
||||
Args:
|
||||
@ -549,11 +631,11 @@ async def get_webhooks(request: Request): # noqa: ANN201
|
||||
hooks_with_data.append(our_hook)
|
||||
|
||||
context = {"request": request, "hooks_with_data": hooks_with_data}
|
||||
return templates.TemplateResponse("webhooks.html", context)
|
||||
return templates.TemplateResponse(request=request, name="webhooks.html", context=context)
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
def get_index(request: Request): # noqa: ANN201
|
||||
def get_index(request: Request):
|
||||
"""This is the root of the website.
|
||||
|
||||
Args:
|
||||
@ -562,10 +644,10 @@ def get_index(request: Request): # noqa: ANN201
|
||||
Returns:
|
||||
HTMLResponse: The index page.
|
||||
"""
|
||||
return templates.TemplateResponse("index.html", make_context_index(request))
|
||||
return templates.TemplateResponse(request=request, name="index.html", context=make_context_index(request))
|
||||
|
||||
|
||||
def make_context_index(request: Request): # noqa: ANN201
|
||||
def make_context_index(request: Request):
|
||||
"""Create the needed context for the index page.
|
||||
|
||||
Args:
|
||||
@ -605,7 +687,7 @@ def make_context_index(request: Request): # noqa: ANN201
|
||||
|
||||
|
||||
@app.post("/remove", response_class=HTMLResponse)
|
||||
async def remove_feed(feed_url: str = Form()): # noqa: ANN201
|
||||
async def remove_feed(feed_url: str = Form()):
|
||||
"""Get a feed by URL.
|
||||
|
||||
Args:
|
||||
@ -623,7 +705,7 @@ async def remove_feed(feed_url: str = Form()): # noqa: ANN201
|
||||
|
||||
|
||||
@app.get("/search", response_class=HTMLResponse)
|
||||
async def search(request: Request, query: str): # noqa: ANN201
|
||||
async def search(request: Request, query: str):
|
||||
"""Get entries matching a full-text search query.
|
||||
|
||||
Args:
|
||||
@ -641,11 +723,11 @@ async def search(request: Request, query: str): # noqa: ANN201
|
||||
"query": query,
|
||||
"search_amount": reader.search_entry_counts(query),
|
||||
}
|
||||
return templates.TemplateResponse("search.html", context)
|
||||
return templates.TemplateResponse(request=request, name="search.html", context=context)
|
||||
|
||||
|
||||
@app.get("/post_entry", response_class=HTMLResponse)
|
||||
async def post_entry(entry_id: str): # noqa: ANN201
|
||||
async def post_entry(entry_id: str):
|
||||
"""Send single entry to Discord.
|
||||
|
||||
Args:
|
||||
@ -668,7 +750,7 @@ async def post_entry(entry_id: str): # noqa: ANN201
|
||||
|
||||
|
||||
@app.post("/modify_webhook", response_class=HTMLResponse)
|
||||
def modify_webhook(old_hook: str = Form(), new_hook: str = Form()): # noqa: ANN201
|
||||
def modify_webhook(old_hook: str = Form(), new_hook: str = Form()):
|
||||
"""Modify a webhook.
|
||||
|
||||
Args:
|
||||
@ -682,7 +764,7 @@ def modify_webhook(old_hook: str = Form(), new_hook: str = Form()): # noqa: ANN
|
||||
webhooks = list(reader.get_tag((), "webhooks", []))
|
||||
|
||||
# Webhooks are stored as a list of dictionaries.
|
||||
# Example: [{"name": "webhook_name", "url": "webhook_url"}] # noqa: ERA001
|
||||
# Example: [{"name": "webhook_name", "url": "webhook_url"}]
|
||||
webhooks = cast(list[dict[str, str]], webhooks)
|
||||
|
||||
for hook in webhooks:
|
||||
@ -712,24 +794,8 @@ def modify_webhook(old_hook: str = Form(), new_hook: str = Form()): # noqa: ANN
|
||||
return RedirectResponse(url="/webhooks", status_code=303)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup() -> None:
|
||||
"""This is called when the server starts.
|
||||
|
||||
It adds missing tags and starts the scheduler.
|
||||
"""
|
||||
add_missing_tags(reader=reader)
|
||||
|
||||
scheduler: BackgroundScheduler = BackgroundScheduler()
|
||||
|
||||
# Update all feeds every 15 minutes.
|
||||
# TODO: Make this configurable.
|
||||
scheduler.add_job(send_to_discord, "interval", minutes=15, next_run_time=datetime.now(tz=timezone.utc))
|
||||
scheduler.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO: Make this configurable.
|
||||
# TODO(TheLovinator): Make this configurable.
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
log_level="info",
|
||||
|
Reference in New Issue
Block a user