Add tldextract for improved domain extraction and add new tests for extract_domain function

This commit is contained in:
2025-04-16 13:32:31 +02:00
parent 8b50003eda
commit cd0f63d59a
3 changed files with 25 additions and 6 deletions

View File

@ -7,6 +7,7 @@ import re
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from urllib.parse import ParseResult, urlparse from urllib.parse import ParseResult, urlparse
import tldextract
from discord_webhook import DiscordEmbed, DiscordWebhook from discord_webhook import DiscordEmbed, DiscordWebhook
from fastapi import HTTPException from fastapi import HTTPException
from reader import Entry, EntryNotFoundError, Feed, FeedExistsError, Reader, ReaderError, StorageError, TagNotFoundError from reader import Entry, EntryNotFoundError, Feed, FeedExistsError, Reader, ReaderError, StorageError, TagNotFoundError
@ -70,12 +71,10 @@ def extract_domain(url: str) -> str: # noqa: PLR0911
if domain in domain_mapping: if domain in domain_mapping:
return domain_mapping[domain] return domain_mapping[domain]
# For other domains, capitalize the first part before the TLD # Use tldextract to get the domain (SLD)
parts: list[str] = domain.split(".") ext = tldextract.extract(url)
min_domain_parts = 2 if ext.domain:
if len(parts) >= min_domain_parts: return ext.domain.capitalize()
return parts[0].capitalize()
return domain.capitalize() return domain.capitalize()
except (ValueError, AttributeError, TypeError) as e: except (ValueError, AttributeError, TypeError) as e:
logger.warning("Error extracting domain from %s: %s", url, e) logger.warning("Error extracting domain from %s: %s", url, e)

View File

@ -17,6 +17,7 @@ dependencies = [
"python-multipart", "python-multipart",
"reader", "reader",
"sentry-sdk[fastapi]", "sentry-sdk[fastapi]",
"tldextract",
"uvicorn", "uvicorn",
] ]

View File

@ -257,3 +257,22 @@ def test_extract_domain_special_characters() -> None:
"""Test extract_domain for URLs with special characters.""" """Test extract_domain for URLs with special characters."""
url: str = "https://www.ex-ample.com/feed" url: str = "https://www.ex-ample.com/feed"
assert extract_domain(url) == "Ex-ample", "Domains with special characters should return the capitalized domain." assert extract_domain(url) == "Ex-ample", "Domains with special characters should return the capitalized domain."
@pytest.mark.parametrize(
argnames=("url", "expected"),
argvalues=[
("https://blog.something.com", "Something"),
("https://www.something.com", "Something"),
("https://subdomain.example.co.uk", "Example"),
("https://github.com/user/repo", "GitHub"),
("https://youtube.com/feeds/videos.xml?channel_id=abc", "YouTube"),
("https://reddit.com/r/python/.rss", "Reddit"),
("", "Other"),
("not a url", "Other"),
("https://www.example.com", "Example"),
("https://foo.bar.baz.com", "Baz"),
],
)
def test_extract_domain(url: str, expected: str) -> None:
assert extract_domain(url) == expected