Add tldextract for improved domain extraction and add new tests for extract_domain function
This commit is contained in:
@ -7,6 +7,7 @@ import re
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from urllib.parse import ParseResult, urlparse
|
from urllib.parse import ParseResult, urlparse
|
||||||
|
|
||||||
|
import tldextract
|
||||||
from discord_webhook import DiscordEmbed, DiscordWebhook
|
from discord_webhook import DiscordEmbed, DiscordWebhook
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from reader import Entry, EntryNotFoundError, Feed, FeedExistsError, Reader, ReaderError, StorageError, TagNotFoundError
|
from reader import Entry, EntryNotFoundError, Feed, FeedExistsError, Reader, ReaderError, StorageError, TagNotFoundError
|
||||||
@ -70,12 +71,10 @@ def extract_domain(url: str) -> str: # noqa: PLR0911
|
|||||||
if domain in domain_mapping:
|
if domain in domain_mapping:
|
||||||
return domain_mapping[domain]
|
return domain_mapping[domain]
|
||||||
|
|
||||||
# For other domains, capitalize the first part before the TLD
|
# Use tldextract to get the domain (SLD)
|
||||||
parts: list[str] = domain.split(".")
|
ext = tldextract.extract(url)
|
||||||
min_domain_parts = 2
|
if ext.domain:
|
||||||
if len(parts) >= min_domain_parts:
|
return ext.domain.capitalize()
|
||||||
return parts[0].capitalize()
|
|
||||||
|
|
||||||
return domain.capitalize()
|
return domain.capitalize()
|
||||||
except (ValueError, AttributeError, TypeError) as e:
|
except (ValueError, AttributeError, TypeError) as e:
|
||||||
logger.warning("Error extracting domain from %s: %s", url, e)
|
logger.warning("Error extracting domain from %s: %s", url, e)
|
||||||
|
@ -17,6 +17,7 @@ dependencies = [
|
|||||||
"python-multipart",
|
"python-multipart",
|
||||||
"reader",
|
"reader",
|
||||||
"sentry-sdk[fastapi]",
|
"sentry-sdk[fastapi]",
|
||||||
|
"tldextract",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -257,3 +257,22 @@ def test_extract_domain_special_characters() -> None:
|
|||||||
"""Test extract_domain for URLs with special characters."""
|
"""Test extract_domain for URLs with special characters."""
|
||||||
url: str = "https://www.ex-ample.com/feed"
|
url: str = "https://www.ex-ample.com/feed"
|
||||||
assert extract_domain(url) == "Ex-ample", "Domains with special characters should return the capitalized domain."
|
assert extract_domain(url) == "Ex-ample", "Domains with special characters should return the capitalized domain."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
argnames=("url", "expected"),
|
||||||
|
argvalues=[
|
||||||
|
("https://blog.something.com", "Something"),
|
||||||
|
("https://www.something.com", "Something"),
|
||||||
|
("https://subdomain.example.co.uk", "Example"),
|
||||||
|
("https://github.com/user/repo", "GitHub"),
|
||||||
|
("https://youtube.com/feeds/videos.xml?channel_id=abc", "YouTube"),
|
||||||
|
("https://reddit.com/r/python/.rss", "Reddit"),
|
||||||
|
("", "Other"),
|
||||||
|
("not a url", "Other"),
|
||||||
|
("https://www.example.com", "Example"),
|
||||||
|
("https://foo.bar.baz.com", "Baz"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_domain(url: str, expected: str) -> None:
|
||||||
|
assert extract_domain(url) == expected
|
||||||
|
Reference in New Issue
Block a user