Add tldextract for improved domain extraction and add new tests for extract_domain function

This commit is contained in:
2025-04-16 13:32:31 +02:00
parent 8b50003eda
commit cd0f63d59a
3 changed files with 25 additions and 6 deletions

View File

@ -7,6 +7,7 @@ import re
from typing import TYPE_CHECKING
from urllib.parse import ParseResult, urlparse
import tldextract
from discord_webhook import DiscordEmbed, DiscordWebhook
from fastapi import HTTPException
from reader import Entry, EntryNotFoundError, Feed, FeedExistsError, Reader, ReaderError, StorageError, TagNotFoundError
@ -70,12 +71,10 @@ def extract_domain(url: str) -> str: # noqa: PLR0911
if domain in domain_mapping:
return domain_mapping[domain]
# For other domains, capitalize the first part before the TLD
parts: list[str] = domain.split(".")
min_domain_parts = 2
if len(parts) >= min_domain_parts:
return parts[0].capitalize()
# Use tldextract to get the domain (SLD)
ext = tldextract.extract(url)
if ext.domain:
return ext.domain.capitalize()
return domain.capitalize()
except (ValueError, AttributeError, TypeError) as e:
logger.warning("Error extracting domain from %s: %s", url, e)

View File

@ -17,6 +17,7 @@ dependencies = [
"python-multipart",
"reader",
"sentry-sdk[fastapi]",
"tldextract",
"uvicorn",
]

View File

@ -257,3 +257,22 @@ def test_extract_domain_special_characters() -> None:
"""Test extract_domain for URLs with special characters."""
url: str = "https://www.ex-ample.com/feed"
assert extract_domain(url) == "Ex-ample", "Domains with special characters should return the capitalized domain."
@pytest.mark.parametrize(
argnames=("url", "expected"),
argvalues=[
("https://blog.something.com", "Something"),
("https://www.something.com", "Something"),
("https://subdomain.example.co.uk", "Example"),
("https://github.com/user/repo", "GitHub"),
("https://youtube.com/feeds/videos.xml?channel_id=abc", "YouTube"),
("https://reddit.com/r/python/.rss", "Reddit"),
("", "Other"),
("not a url", "Other"),
("https://www.example.com", "Example"),
("https://foo.bar.baz.com", "Baz"),
],
)
def test_extract_domain(url: str, expected: str) -> None:
assert extract_domain(url) == expected