Add type hints
This commit is contained in:
@ -18,7 +18,7 @@ def generate_html_for_videos(url: str, width: int, height: int, screenshot: str,
|
|||||||
Returns:
|
Returns:
|
||||||
Returns HTML for video.
|
Returns HTML for video.
|
||||||
"""
|
"""
|
||||||
video_html = f"""
|
video_html: str = f"""
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html>
|
<html>
|
||||||
<!-- Generated at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} -->
|
<!-- Generated at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} -->
|
||||||
@ -33,13 +33,13 @@ def generate_html_for_videos(url: str, width: int, height: int, screenshot: str,
|
|||||||
</head>
|
</head>
|
||||||
</html>
|
</html>
|
||||||
"""
|
"""
|
||||||
domain = settings.serve_domain
|
domain: str = settings.serve_domain
|
||||||
html_url: str = urljoin(domain, filename)
|
html_url: str = urljoin(domain, filename)
|
||||||
|
|
||||||
# Take the filename and append .html to it.
|
# Take the filename and append .html to it.
|
||||||
filename += ".html"
|
filename += ".html"
|
||||||
|
|
||||||
file_path = os.path.join(settings.upload_folder, filename)
|
file_path: str = os.path.join(settings.upload_folder, filename)
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
f.write(video_html)
|
f.write(video_html)
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ from discord_embed import settings
|
|||||||
from discord_embed.video_file_upload import do_things
|
from discord_embed.video_file_upload import do_things
|
||||||
from discord_embed.webhook import send_webhook
|
from discord_embed.webhook import send_webhook
|
||||||
|
|
||||||
app = FastAPI(
|
app: FastAPI = FastAPI(
|
||||||
title="discord-nice-embed",
|
title="discord-nice-embed",
|
||||||
description=settings.DESCRIPTION,
|
description=settings.DESCRIPTION,
|
||||||
version="0.0.1",
|
version="0.0.1",
|
||||||
@ -26,7 +26,7 @@ app = FastAPI(
|
|||||||
)
|
)
|
||||||
|
|
||||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||||
templates = Jinja2Templates(directory="templates")
|
templates: Jinja2Templates = Jinja2Templates(directory="templates")
|
||||||
|
|
||||||
|
|
||||||
@app.post("/uploadfiles/")
|
@app.post("/uploadfiles/")
|
||||||
@ -46,12 +46,12 @@ async def upload_file(file: UploadFile = File(...)) -> Dict[str, str]:
|
|||||||
if file.content_type.startswith("video/"):
|
if file.content_type.startswith("video/"):
|
||||||
return await do_things(file)
|
return await do_things(file)
|
||||||
|
|
||||||
filename = await remove_illegal_chars(file.filename)
|
filename: str = await remove_illegal_chars(file.filename)
|
||||||
|
|
||||||
with open(f"{settings.upload_folder}/{filename}", "wb+") as f:
|
with open(f"{settings.upload_folder}/{filename}", "wb+") as f:
|
||||||
f.write(file.file.read())
|
f.write(file.file.read())
|
||||||
|
|
||||||
domain_url = urljoin(settings.serve_domain, filename)
|
domain_url: str = urljoin(settings.serve_domain, filename)
|
||||||
send_webhook(f"{domain_url} was uploaded.")
|
send_webhook(f"{domain_url} was uploaded.")
|
||||||
return {"html_url": domain_url}
|
return {"html_url": domain_url}
|
||||||
|
|
||||||
@ -66,8 +66,8 @@ async def remove_illegal_chars(filename: str) -> str:
|
|||||||
Returns a string with the filename without illegal characters.
|
Returns a string with the filename without illegal characters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
filename = filename.replace(" ", ".")
|
filename: str = filename.replace(" ", ".") # type: ignore
|
||||||
illegal_characters = [
|
illegal_characters: list[str] = [
|
||||||
"*",
|
"*",
|
||||||
'"',
|
'"',
|
||||||
"<",
|
"<",
|
||||||
@ -91,7 +91,8 @@ async def remove_illegal_chars(filename: str) -> str:
|
|||||||
",",
|
",",
|
||||||
]
|
]
|
||||||
for character in illegal_characters:
|
for character in illegal_characters:
|
||||||
filename = filename.replace(character, "")
|
filename: str = filename.replace(character, "") # type: ignore
|
||||||
|
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
|
|
||||||
@ -102,7 +103,7 @@ async def main(request: Request):
|
|||||||
You can upload files here.
|
You can upload files here.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
HTMLResponse: Returns HTML for site.
|
TemplateResponse: Returns HTML for site.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return templates.TemplateResponse("index.html", {"request": request})
|
return templates.TemplateResponse("index.html", {"request": request})
|
||||||
|
@ -4,7 +4,7 @@ import sys
|
|||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
DESCRIPTION = (
|
DESCRIPTION: str = (
|
||||||
"Discord will only create embeds for videos and images if they are "
|
"Discord will only create embeds for videos and images if they are "
|
||||||
"smaller than 8 mb. We can 'abuse' this by creating a .html that "
|
"smaller than 8 mb. We can 'abuse' this by creating a .html that "
|
||||||
"contains the 'twitter:player' HTML meta tag linking to the video."
|
"contains the 'twitter:player' HTML meta tag linking to the video."
|
||||||
@ -14,7 +14,7 @@ load_dotenv()
|
|||||||
|
|
||||||
# Check if user has added a domain to the environment.
|
# Check if user has added a domain to the environment.
|
||||||
try:
|
try:
|
||||||
serve_domain = os.environ["SERVE_DOMAIN"]
|
serve_domain: str = os.environ["SERVE_DOMAIN"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
sys.exit("discord-embed: Environment variable 'SERVE_DOMAIN' is missing!")
|
sys.exit("discord-embed: Environment variable 'SERVE_DOMAIN' is missing!")
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ if serve_domain.endswith("/"):
|
|||||||
|
|
||||||
# Check if we have a folder for uploads.
|
# Check if we have a folder for uploads.
|
||||||
try:
|
try:
|
||||||
upload_folder = os.environ["UPLOAD_FOLDER"]
|
upload_folder: str = os.environ["UPLOAD_FOLDER"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
sys.exit("Environment variable 'UPLOAD_FOLDER' is missing!")
|
sys.exit("Environment variable 'UPLOAD_FOLDER' is missing!")
|
||||||
|
|
||||||
@ -37,6 +37,6 @@ if upload_folder.endswith("/"):
|
|||||||
|
|
||||||
# Discord webhook URL
|
# Discord webhook URL
|
||||||
try:
|
try:
|
||||||
webhook_url = os.environ["WEBHOOK_URL"]
|
webhook_url: str = os.environ["WEBHOOK_URL"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
sys.exit("Environment variable 'WEBHOOK_URL' is missing!")
|
sys.exit("Environment variable 'WEBHOOK_URL' is missing!")
|
||||||
|
@ -33,8 +33,8 @@ def video_resolution(path_to_video: str) -> Resolution:
|
|||||||
print("No video stream found", file=sys.stderr)
|
print("No video stream found", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
width = int(video_stream["width"])
|
width: int = int(video_stream["width"])
|
||||||
height = int(video_stream["height"])
|
height: int = int(video_stream["height"])
|
||||||
|
|
||||||
return Resolution(height, width)
|
return Resolution(height, width)
|
||||||
|
|
||||||
|
@ -35,14 +35,14 @@ def save_to_disk(file: UploadFile) -> VideoFile:
|
|||||||
VideoFile object with the filename and location.
|
VideoFile object with the filename and location.
|
||||||
"""
|
"""
|
||||||
# Create the folder where we should save the files
|
# Create the folder where we should save the files
|
||||||
folder_video = os.path.join(settings.upload_folder, "video")
|
folder_video: str = os.path.join(settings.upload_folder, "video")
|
||||||
Path(folder_video).mkdir(parents=True, exist_ok=True)
|
Path(folder_video).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Replace spaces with dots in the filename.
|
# Replace spaces with dots in the filename.
|
||||||
filename = file.filename.replace(" ", ".")
|
filename: str = file.filename.replace(" ", ".")
|
||||||
|
|
||||||
# Save the uploaded file to disk.
|
# Save the uploaded file to disk.
|
||||||
file_location = os.path.join(folder_video, filename)
|
file_location: str = os.path.join(folder_video, filename)
|
||||||
with open(file_location, "wb+") as f:
|
with open(file_location, "wb+") as f:
|
||||||
f.write(file.file.read())
|
f.write(file.file.read())
|
||||||
|
|
||||||
@ -61,10 +61,10 @@ async def do_things(file: UploadFile) -> Dict[str, str]:
|
|||||||
|
|
||||||
video_file: VideoFile = save_to_disk(file)
|
video_file: VideoFile = save_to_disk(file)
|
||||||
|
|
||||||
file_url = f"{settings.serve_domain}/video/{video_file.filename}"
|
file_url: str = f"{settings.serve_domain}/video/{video_file.filename}"
|
||||||
res: Resolution = video_resolution(video_file.location)
|
res: Resolution = video_resolution(video_file.location)
|
||||||
screenshot_url = make_thumbnail(video_file.location, video_file.filename)
|
screenshot_url: str = make_thumbnail(video_file.location, video_file.filename)
|
||||||
html_url = generate_html_for_videos(
|
html_url: str = generate_html_for_videos(
|
||||||
url=file_url,
|
url=file_url,
|
||||||
width=res.width,
|
width=res.width,
|
||||||
height=res.height,
|
height=res.height,
|
||||||
|
@ -9,7 +9,7 @@ def send_webhook(message: str) -> None:
|
|||||||
Args:
|
Args:
|
||||||
message: The message to send.
|
message: The message to send.
|
||||||
"""
|
"""
|
||||||
webhook = DiscordWebhook(
|
webhook: DiscordWebhook = DiscordWebhook(
|
||||||
url=settings.webhook_url,
|
url=settings.webhook_url,
|
||||||
content=message,
|
content=message,
|
||||||
rate_limit_retry=True,
|
rate_limit_retry=True,
|
||||||
|
@ -1,67 +1,68 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
from fastapi import Response
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from discord_embed import __version__, settings
|
from discord_embed import __version__, settings
|
||||||
from discord_embed.main import app
|
from discord_embed.main import app
|
||||||
|
|
||||||
client = TestClient(app)
|
client: TestClient = TestClient(app)
|
||||||
TEST_FILE = "tests/test.mp4"
|
TEST_FILE: str = "tests/test.mp4"
|
||||||
|
|
||||||
|
|
||||||
def test_version():
|
def test_version() -> None:
|
||||||
"""Test version is correct."""
|
"""Test version is correct."""
|
||||||
assert __version__ == "1.0.0"
|
assert __version__ == "1.0.0"
|
||||||
|
|
||||||
|
|
||||||
def test_domain_ends_with_slash():
|
def test_domain_ends_with_slash() -> None:
|
||||||
"""Test domain ends with a slash."""
|
"""Test domain ends with a slash."""
|
||||||
assert not settings.serve_domain.endswith("/")
|
assert not settings.serve_domain.endswith("/")
|
||||||
|
|
||||||
|
|
||||||
def test_save_to_disk():
|
def test_save_to_disk() -> None:
|
||||||
"""Test save_to_disk() works."""
|
"""Test save_to_disk() works."""
|
||||||
# TODO: Implement this test. I need to mock the UploadFile object.
|
# TODO: Implement this test. I need to mock the UploadFile object.
|
||||||
|
|
||||||
|
|
||||||
def test_do_things():
|
def test_do_things() -> None:
|
||||||
"""Test do_things() works."""
|
"""Test do_things() works."""
|
||||||
# TODO: Implement this test. I need to mock the UploadFile object.
|
# TODO: Implement this test. I need to mock the UploadFile object.
|
||||||
|
|
||||||
|
|
||||||
def test_main():
|
def test_main() -> None:
|
||||||
"""Test main() works."""
|
"""Test main() works."""
|
||||||
data_without_trailing_nl = ""
|
data_without_trailing_nl = ""
|
||||||
response = client.get("/")
|
response: Response = client.get("/")
|
||||||
|
|
||||||
# Check if response is our HTML.
|
# Check if response is our HTML.
|
||||||
with open("templates/index.html", encoding="utf8") as our_html:
|
with open("templates/index.html", encoding="utf8") as our_html:
|
||||||
data = our_html.read()
|
data: str = our_html.read()
|
||||||
|
|
||||||
# index.html has a trailing newline that we need to remove.
|
# index.html has a trailing newline that we need to remove.
|
||||||
if data[-1:] == "\n":
|
if data[-1:] == "\n":
|
||||||
data_without_trailing_nl = data[:-1]
|
data_without_trailing_nl: str = data[:-1] # type: ignore
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.text == data_without_trailing_nl
|
assert response.text == data_without_trailing_nl
|
||||||
|
|
||||||
|
|
||||||
def test_upload_file():
|
def test_upload_file() -> None:
|
||||||
"""Test if we can upload files."""
|
"""Test if we can upload files."""
|
||||||
domain = os.environ["SERVE_DOMAIN"]
|
domain = os.environ["SERVE_DOMAIN"]
|
||||||
|
|
||||||
# Remove trailing slash from domain
|
# Remove trailing slash from domain
|
||||||
if domain.endswith("/"):
|
if domain.endswith("/"):
|
||||||
domain = domain[:-1]
|
domain: str = domain[:-1] # type: ignore
|
||||||
|
|
||||||
# Upload our video file and check if it returns the html_url.
|
# Upload our video file and check if it returns the html_url.
|
||||||
with open(TEST_FILE, "rb") as uploaded_file:
|
with open(TEST_FILE, "rb") as uploaded_file:
|
||||||
response = client.post(
|
response: Response = client.post(
|
||||||
url="/uploadfiles/",
|
url="/uploadfiles/",
|
||||||
files={"file": uploaded_file},
|
files={"file": uploaded_file},
|
||||||
)
|
)
|
||||||
returned_json = response.json()
|
returned_json = response.json()
|
||||||
html_url = returned_json["html_url"]
|
html_url: str = returned_json["html_url"]
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert html_url == f"{domain}/test.mp4"
|
assert html_url == f"{domain}/test.mp4"
|
||||||
|
@ -3,19 +3,19 @@ import os
|
|||||||
from discord_embed.generate_html import generate_html_for_videos
|
from discord_embed.generate_html import generate_html_for_videos
|
||||||
|
|
||||||
|
|
||||||
def test_generate_html_for_videos():
|
def test_generate_html_for_videos() -> None:
|
||||||
"""Test generate_html_for_videos() works."""
|
"""Test generate_html_for_videos() works."""
|
||||||
domain = os.environ["SERVE_DOMAIN"]
|
domain: str = os.environ["SERVE_DOMAIN"]
|
||||||
|
|
||||||
# Remove trailing slash from domain
|
# Remove trailing slash from domain
|
||||||
if domain.endswith("/"):
|
if domain.endswith("/"):
|
||||||
domain = domain[:-1]
|
domain = domain[:-1]
|
||||||
|
|
||||||
# Delete the old HTML file if it exists
|
# Delete the old HTML file if it exists
|
||||||
if os.path.exists(f"Uploads/test_video.mp4.html"):
|
if os.path.exists("Uploads/test_video.mp4.html"):
|
||||||
os.remove(f"Uploads/test_video.mp4.html")
|
os.remove("Uploads/test_video.mp4.html")
|
||||||
|
|
||||||
generated_html = generate_html_for_videos(
|
generated_html: str = generate_html_for_videos(
|
||||||
url="https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
url="https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
||||||
width=1920,
|
width=1920,
|
||||||
height=1080,
|
height=1080,
|
||||||
@ -27,7 +27,7 @@ def test_generate_html_for_videos():
|
|||||||
# Open the generated HTML and check if it contains the correct URL, width, height, and screenshot.
|
# Open the generated HTML and check if it contains the correct URL, width, height, and screenshot.
|
||||||
|
|
||||||
with open("Uploads/test_video.mp4.html", "r") as generated_html_file:
|
with open("Uploads/test_video.mp4.html", "r") as generated_html_file:
|
||||||
generated_html_lines = generated_html_file.readlines()
|
generated_html_lines: list[str] = generated_html_file.readlines()
|
||||||
"""
|
"""
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html>
|
<html>
|
||||||
@ -46,32 +46,35 @@ def test_generate_html_for_videos():
|
|||||||
|
|
||||||
for line, html in enumerate(generated_html_lines):
|
for line, html in enumerate(generated_html_lines):
|
||||||
# Strip spaces and newlines
|
# Strip spaces and newlines
|
||||||
html = html.strip()
|
stripped_html: str = html.strip()
|
||||||
|
|
||||||
|
rick: str = "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
|
||||||
|
|
||||||
# Check each line
|
# Check each line
|
||||||
if line == 1:
|
if line == 1:
|
||||||
assert html == "<!DOCTYPE html>"
|
assert stripped_html == "<!DOCTYPE html>"
|
||||||
elif line == 2:
|
elif line == 2:
|
||||||
assert html == "<html>"
|
assert stripped_html == "<html>"
|
||||||
elif line == 3:
|
elif line == 3:
|
||||||
assert html.startswith("<!-- Generated at ")
|
assert stripped_html.startswith("<!-- Generated at ")
|
||||||
elif line == 4:
|
elif line == 4:
|
||||||
assert html == "<head>"
|
assert stripped_html == "<head>"
|
||||||
elif line == 5:
|
elif line == 5:
|
||||||
assert html == '<meta property="og:type" content="video.other">'
|
assert stripped_html == '<meta property="og:type" content="video.other">'
|
||||||
elif line == 6:
|
elif line == 6:
|
||||||
assert html == '<meta property="twitter:player" content="https://www.youtube.com/watch?v=dQw4w9WgXcQ">'
|
assert stripped_html == f'<meta property="twitter:player" content="{rick}">'
|
||||||
elif line == 7:
|
elif line == 7:
|
||||||
assert html == '<meta property="og:video:type" content="text/html">'
|
assert stripped_html == '<meta property="og:video:type" content="text/html">'
|
||||||
elif line == 8:
|
elif line == 8:
|
||||||
assert html == '<meta property="og:video:width" content="1920">'
|
assert stripped_html == '<meta property="og:video:width" content="1920">'
|
||||||
elif line == 9:
|
elif line == 9:
|
||||||
assert html == '<meta property="og:video:height" content="1080">'
|
assert stripped_html == '<meta property="og:video:height" content="1080">'
|
||||||
elif line == 10:
|
elif line == 10:
|
||||||
assert html == '<meta name="twitter:image" content="https://i.ytimg.com/vi/dQw4w9WgXcQ/hqdefault.jpg">'
|
thumb: str = "https://i.ytimg.com/vi/dQw4w9WgXcQ/hqdefault.jpg"
|
||||||
|
assert stripped_html == f'<meta name="twitter:image" content="{thumb}">'
|
||||||
elif line == 11:
|
elif line == 11:
|
||||||
assert html == '<meta http-equiv="refresh" content="0;url=https://www.youtube.com/watch?v=dQw4w9WgXcQ">'
|
assert stripped_html == f'<meta http-equiv="refresh" content="0;url={rick}">'
|
||||||
elif line == 12:
|
elif line == 12:
|
||||||
assert html == "</head>"
|
assert stripped_html == "</head>"
|
||||||
elif line == 13:
|
elif line == 13:
|
||||||
assert html == "</html>"
|
assert stripped_html == "</html>"
|
||||||
|
@ -7,24 +7,24 @@ from discord_embed.video import Resolution, make_thumbnail, video_resolution
|
|||||||
TEST_FILE = "tests/test.mp4"
|
TEST_FILE = "tests/test.mp4"
|
||||||
|
|
||||||
|
|
||||||
def test_video_resolution():
|
def test_video_resolution() -> None:
|
||||||
"""Test video_resolution() works."""
|
"""Test video_resolution() works."""
|
||||||
assert video_resolution(TEST_FILE) == Resolution(height=422, width=422)
|
assert video_resolution(TEST_FILE) == Resolution(height=422, width=422)
|
||||||
|
|
||||||
|
|
||||||
def test_make_thumbnail():
|
def test_make_thumbnail() -> None:
|
||||||
"""Test make_thumbnail() works."""
|
"""Test make_thumbnail() works."""
|
||||||
domain = os.environ["SERVE_DOMAIN"]
|
domain: str = os.environ["SERVE_DOMAIN"]
|
||||||
|
|
||||||
# Remove trailing slash from domain
|
# Remove trailing slash from domain
|
||||||
if domain.endswith("/"):
|
if domain.endswith("/"):
|
||||||
domain = domain[:-1]
|
domain: str = domain[:-1] # type: ignore
|
||||||
|
|
||||||
# Remove thumbnail if it exists
|
# Remove thumbnail if it exists
|
||||||
if os.path.exists(f"{settings.upload_folder}/test.mp4.jpg"):
|
if os.path.exists(f"{settings.upload_folder}/test.mp4.jpg"):
|
||||||
os.remove(f"{settings.upload_folder}/test.mp4.jpg")
|
os.remove(f"{settings.upload_folder}/test.mp4.jpg")
|
||||||
|
|
||||||
thumbnail = make_thumbnail(TEST_FILE, "test.mp4")
|
thumbnail: str = make_thumbnail(TEST_FILE, "test.mp4")
|
||||||
|
|
||||||
# Check if thumbnail is a jpeg.
|
# Check if thumbnail is a jpeg.
|
||||||
assert imghdr.what(f"{settings.upload_folder}/test.mp4.jpg") == "jpeg"
|
assert imghdr.what(f"{settings.upload_folder}/test.mp4.jpg") == "jpeg"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from discord_embed.webhook import send_webhook
|
from discord_embed.webhook import send_webhook
|
||||||
|
|
||||||
|
|
||||||
def test_send_webhook():
|
def test_send_webhook() -> None:
|
||||||
"""Test send_webhook() works."""
|
"""Test send_webhook() works."""
|
||||||
send_webhook("Running Pytest")
|
send_webhook("Running Pytest")
|
||||||
|
Reference in New Issue
Block a user