diff --git a/discord_rss_bot/feeds.py b/discord_rss_bot/feeds.py index 46c6e50..83ac2fd 100644 --- a/discord_rss_bot/feeds.py +++ b/discord_rss_bot/feeds.py @@ -3,7 +3,9 @@ from __future__ import annotations import datetime import logging import pprint +import re from typing import TYPE_CHECKING +from urllib.parse import ParseResult, urlparse from discord_webhook import DiscordEmbed, DiscordWebhook from fastapi import HTTPException @@ -29,6 +31,57 @@ if TYPE_CHECKING: logger: logging.Logger = logging.getLogger(__name__) +def extract_domain(url: str) -> str: # noqa: PLR0911 + """Extract the domain name from a URL. + + Args: + url: The URL to extract the domain from. + + Returns: + str: The domain name, formatted for display. + """ + # Check for empty URL first + if not url: + return "Other" + + try: + # Special handling for YouTube feeds + if "youtube.com/feeds/videos.xml" in url: + return "YouTube" + + # Special handling for Reddit feeds + if "reddit.com" in url or (".rss" in url and "r/" in url): + return "Reddit" + + # Parse the URL and extract the domain + parsed_url: ParseResult = urlparse(url) + domain: str = parsed_url.netloc + + # If we couldn't extract a domain, return "Other" + if not domain: + return "Other" + + # Remove www. prefix if present + domain = re.sub(r"^www\.", "", domain) + + # Special handling for common domains + domain_mapping: dict[str, str] = {"github.com": "GitHub"} + + if domain in domain_mapping: + return domain_mapping[domain] + + # For other domains, capitalize the first part before the TLD + parts: list[str] = domain.split(".") + min_domain_parts = 2 + if len(parts) >= min_domain_parts: + return parts[0].capitalize() + + return domain.capitalize() + except (ValueError, AttributeError, TypeError) as e: + logger.warning("Error extracting domain from %s: %s", url, e) + return "Other" + + def send_entry_to_discord(entry: Entry, custom_reader: Reader | None = None) -> str | None: """Send a single entry to Discord. diff --git a/discord_rss_bot/main.py b/discord_rss_bot/main.py index 00349ac..7ae706f 100644 --- a/discord_rss_bot/main.py +++ b/discord_rss_bot/main.py @@ -37,7 +37,7 @@ from discord_rss_bot.custom_message import ( replace_tags_in_text_message, save_embed, ) -from discord_rss_bot.feeds import create_feed, send_entry_to_discord, send_to_discord +from discord_rss_bot.feeds import create_feed, extract_domain, send_entry_to_discord, send_to_discord from discord_rss_bot.missing_tags import add_missing_tags from discord_rss_bot.search import create_html_for_search_results from discord_rss_bot.settings import get_reader @@ -875,11 +875,12 @@ def make_context_index(request: Request): broken_feeds = [] feeds_without_attached_webhook = [] + # Get all feeds and organize them feeds: Iterable[Feed] = reader.get_feeds() for feed in feeds: try: webhook = reader.get_tag(feed.url, "webhook") - feed_list.append({"feed": feed, "webhook": webhook}) + feed_list.append({"feed": feed, "webhook": webhook, "domain": extract_domain(feed.url)}) except TagNotFoundError: broken_feeds.append(feed) continue diff --git a/discord_rss_bot/templates/index.html b/discord_rss_bot/templates/index.html index 3db4a50..f9dfc0d 100644 --- a/discord_rss_bot/templates/index.html +++ b/discord_rss_bot/templates/index.html @@ -28,45 +28,66 @@ {{ entry_count.averages[2]|round(1) }})

- + + {% for hook_from_context in webhooks %} -
-

+
+

{{ hook_from_context.name }}

- +
+ +
+
+ {% endfor %} + {% else %} +

No feeds associated with this webhook.

+ {% endif %}

{% endfor %} {% else %}

Hello there!
+
You need to add a webhook here to get started. After that, you can add feeds here. You can find both of these links in the navigation bar above. @@ -79,6 +100,7 @@ Thanks!

{% endif %} + {% if broken_feeds %}
@@ -103,6 +125,7 @@
{% endif %} + {% if feeds_without_attached_webhook %}
diff --git a/tests/test_feeds.py b/tests/test_feeds.py index 037711b..8fa6c4b 100644 --- a/tests/test_feeds.py +++ b/tests/test_feeds.py @@ -10,6 +10,7 @@ import pytest from reader import Feed, Reader, make_reader from discord_rss_bot.feeds import ( + extract_domain, is_youtube_feed, send_entry_to_discord, send_to_discord, @@ -202,3 +203,57 @@ def test_send_entry_to_discord_youtube_feed( # Verify execute_webhook was called mock_execute_webhook.assert_called_once_with(mock_webhook, mock_entry) + + +def test_extract_domain_youtube_feed() -> None: + """Test extract_domain for YouTube feeds.""" + url: str = "https://www.youtube.com/feeds/videos.xml?channel_id=123456" + assert extract_domain(url) == "YouTube", "YouTube feeds should return 'YouTube' as the domain." + + +def test_extract_domain_reddit_feed() -> None: + """Test extract_domain for Reddit feeds.""" + url: str = "https://www.reddit.com/r/Python/.rss" + assert extract_domain(url) == "Reddit", "Reddit feeds should return 'Reddit' as the domain." + + +def test_extract_domain_github_feed() -> None: + """Test extract_domain for GitHub feeds.""" + url: str = "https://www.github.com/user/repo" + assert extract_domain(url) == "GitHub", "GitHub feeds should return 'GitHub' as the domain." + + +def test_extract_domain_custom_domain() -> None: + """Test extract_domain for custom domains.""" + url: str = "https://www.example.com/feed" + assert extract_domain(url) == "Example", "Custom domains should return the capitalized first part of the domain." + + +def test_extract_domain_no_www_prefix() -> None: + """Test extract_domain removes 'www.' prefix.""" + url: str = "https://www.example.com/feed" + assert extract_domain(url) == "Example", "The 'www.' prefix should be removed from the domain." + + +def test_extract_domain_no_tld() -> None: + """Test extract_domain for domains without a TLD.""" + url: str = "https://localhost/feed" + assert extract_domain(url) == "Localhost", "Domains without a TLD should return the capitalized domain." + + +def test_extract_domain_invalid_url() -> None: + """Test extract_domain for invalid URLs.""" + url: str = "not-a-valid-url" + assert extract_domain(url) == "Other", "Invalid URLs should return 'Other' as the domain." + + +def test_extract_domain_empty_url() -> None: + """Test extract_domain for empty URLs.""" + url: str = "" + assert extract_domain(url) == "Other", "Empty URLs should return 'Other' as the domain." + + +def test_extract_domain_special_characters() -> None: + """Test extract_domain for URLs with special characters.""" + url: str = "https://www.ex-ample.com/feed" + assert extract_domain(url) == "Ex-ample", "Domains with special characters should return the capitalized domain." diff --git a/tests/test_main.py b/tests/test_main.py index 59bd109..c86901f 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -45,7 +45,7 @@ def test_search() -> None: # Check that the feed was added. response = client.get(url="/") assert response.status_code == 200, f"Failed to get /: {response.text}" - assert feed_url in response.text, f"Feed not found in /: {response.text}" + assert encoded_feed_url(feed_url) in response.text, f"Feed not found in /: {response.text}" # Search for an entry. response: Response = client.get(url="/search/?query=a") @@ -85,7 +85,7 @@ def test_create_feed() -> None: # Check that the feed was added. response = client.get(url="/") assert response.status_code == 200, f"Failed to get /: {response.text}" - assert feed_url in response.text, f"Feed not found in /: {response.text}" + assert encoded_feed_url(feed_url) in response.text, f"Feed not found in /: {response.text}" def test_get() -> None: @@ -103,7 +103,7 @@ def test_get() -> None: # Check that the feed was added. response = client.get("/") assert response.status_code == 200, f"Failed to get /: {response.text}" - assert feed_url in response.text, f"Feed not found in /: {response.text}" + assert encoded_feed_url(feed_url) in response.text, f"Feed not found in /: {response.text}" response: Response = client.get(url="/add") assert response.status_code == 200, f"/add failed: {response.text}" @@ -157,7 +157,7 @@ def test_pause_feed() -> None: # Check that the feed was paused. response = client.get(url="/") assert response.status_code == 200, f"Failed to get /: {response.text}" - assert feed_url in response.text, f"Feed not found in /: {response.text}" + assert encoded_feed_url(feed_url) in response.text, f"Feed not found in /: {response.text}" def test_unpause_feed() -> None: @@ -184,7 +184,7 @@ def test_unpause_feed() -> None: # Check that the feed was unpaused. response = client.get(url="/") assert response.status_code == 200, f"Failed to get /: {response.text}" - assert feed_url in response.text, f"Feed not found in /: {response.text}" + assert encoded_feed_url(feed_url) in response.text, f"Feed not found in /: {response.text}" def test_remove_feed() -> None: