Add tldextract for improved domain extraction and add new tests for extract_domain function
This commit is contained in:
@ -7,6 +7,7 @@ import re
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
|
||||
import tldextract
|
||||
from discord_webhook import DiscordEmbed, DiscordWebhook
|
||||
from fastapi import HTTPException
|
||||
from reader import Entry, EntryNotFoundError, Feed, FeedExistsError, Reader, ReaderError, StorageError, TagNotFoundError
|
||||
@ -70,12 +71,10 @@ def extract_domain(url: str) -> str: # noqa: PLR0911
|
||||
if domain in domain_mapping:
|
||||
return domain_mapping[domain]
|
||||
|
||||
# For other domains, capitalize the first part before the TLD
|
||||
parts: list[str] = domain.split(".")
|
||||
min_domain_parts = 2
|
||||
if len(parts) >= min_domain_parts:
|
||||
return parts[0].capitalize()
|
||||
|
||||
# Use tldextract to get the domain (SLD)
|
||||
ext = tldextract.extract(url)
|
||||
if ext.domain:
|
||||
return ext.domain.capitalize()
|
||||
return domain.capitalize()
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning("Error extracting domain from %s: %s", url, e)
|
||||
|
@ -17,6 +17,7 @@ dependencies = [
|
||||
"python-multipart",
|
||||
"reader",
|
||||
"sentry-sdk[fastapi]",
|
||||
"tldextract",
|
||||
"uvicorn",
|
||||
]
|
||||
|
||||
|
@ -257,3 +257,22 @@ def test_extract_domain_special_characters() -> None:
|
||||
"""Test extract_domain for URLs with special characters."""
|
||||
url: str = "https://www.ex-ample.com/feed"
|
||||
assert extract_domain(url) == "Ex-ample", "Domains with special characters should return the capitalized domain."
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
argnames=("url", "expected"),
|
||||
argvalues=[
|
||||
("https://blog.something.com", "Something"),
|
||||
("https://www.something.com", "Something"),
|
||||
("https://subdomain.example.co.uk", "Example"),
|
||||
("https://github.com/user/repo", "GitHub"),
|
||||
("https://youtube.com/feeds/videos.xml?channel_id=abc", "YouTube"),
|
||||
("https://reddit.com/r/python/.rss", "Reddit"),
|
||||
("", "Other"),
|
||||
("not a url", "Other"),
|
||||
("https://www.example.com", "Example"),
|
||||
("https://foo.bar.baz.com", "Baz"),
|
||||
],
|
||||
)
|
||||
def test_extract_domain(url: str, expected: str) -> None:
|
||||
assert extract_domain(url) == expected
|
||||
|
Reference in New Issue
Block a user