From cd0f63d59a99224a915c23112b7bcf777011cfb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Hells=C3=A9n?= Date: Wed, 16 Apr 2025 13:32:31 +0200 Subject: [PATCH] Add tldextract for improved domain extraction and add new tests for extract_domain function --- discord_rss_bot/feeds.py | 11 +++++------ pyproject.toml | 1 + tests/test_feeds.py | 19 +++++++++++++++++++ 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/discord_rss_bot/feeds.py b/discord_rss_bot/feeds.py index 83ac2fd..7852b0d 100644 --- a/discord_rss_bot/feeds.py +++ b/discord_rss_bot/feeds.py @@ -7,6 +7,7 @@ import re from typing import TYPE_CHECKING from urllib.parse import ParseResult, urlparse +import tldextract from discord_webhook import DiscordEmbed, DiscordWebhook from fastapi import HTTPException from reader import Entry, EntryNotFoundError, Feed, FeedExistsError, Reader, ReaderError, StorageError, TagNotFoundError @@ -70,12 +71,10 @@ def extract_domain(url: str) -> str: # noqa: PLR0911 if domain in domain_mapping: return domain_mapping[domain] - # For other domains, capitalize the first part before the TLD - parts: list[str] = domain.split(".") - min_domain_parts = 2 - if len(parts) >= min_domain_parts: - return parts[0].capitalize() - + # Use tldextract to get the domain (SLD) + ext = tldextract.extract(url) + if ext.domain: + return ext.domain.capitalize() return domain.capitalize() except (ValueError, AttributeError, TypeError) as e: logger.warning("Error extracting domain from %s: %s", url, e) diff --git a/pyproject.toml b/pyproject.toml index 21ab35a..f5758e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "python-multipart", "reader", "sentry-sdk[fastapi]", + "tldextract", "uvicorn", ] diff --git a/tests/test_feeds.py b/tests/test_feeds.py index 8fa6c4b..2b3a2b4 100644 --- a/tests/test_feeds.py +++ b/tests/test_feeds.py @@ -257,3 +257,22 @@ def test_extract_domain_special_characters() -> None: """Test extract_domain for URLs with special characters.""" url: str = "https://www.ex-ample.com/feed" assert extract_domain(url) == "Ex-ample", "Domains with special characters should return the capitalized domain." + + +@pytest.mark.parametrize( + argnames=("url", "expected"), + argvalues=[ + ("https://blog.something.com", "Something"), + ("https://www.something.com", "Something"), + ("https://subdomain.example.co.uk", "Example"), + ("https://github.com/user/repo", "GitHub"), + ("https://youtube.com/feeds/videos.xml?channel_id=abc", "YouTube"), + ("https://reddit.com/r/python/.rss", "Reddit"), + ("", "Other"), + ("not a url", "Other"), + ("https://www.example.com", "Example"), + ("https://foo.bar.baz.com", "Baz"), + ], +) +def test_extract_domain(url: str, expected: str) -> None: + assert extract_domain(url) == expected