Modify OpenAPI and only have one return in upload_file()

This commit is contained in:
2023-01-09 05:07:56 +01:00
parent f6261107bd
commit 88074d8ab5
3 changed files with 16 additions and 31 deletions

View File

@ -1,8 +1,7 @@
from typing import Dict
from urllib.parse import urljoin
from fastapi import FastAPI, File, Request, UploadFile
from fastapi.responses import HTMLResponse
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
@ -12,16 +11,9 @@ from discord_embed.webhook import send_webhook
app: FastAPI = FastAPI(
title="discord-nice-embed",
description=settings.DESCRIPTION,
version="0.0.1",
contact={
"name": "Joakim Hellsén",
"url": "https://github.com/TheLovinator1",
"email": "tlovinator@gmail.com",
},
license_info={
"name": "GPL-3.0",
"url": "https://www.gnu.org/licenses/gpl-3.0.txt",
"name": "Github repo",
"url": "https://github.com/TheLovinator1/discord-embed",
},
)
@ -29,8 +21,8 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
templates: Jinja2Templates = Jinja2Templates(directory="templates")
@app.post("/uploadfiles/")
async def upload_file(file: UploadFile = File(...)) -> Dict[str, str]:
@app.post("/uploadfiles/", description="Where to send a POST request to upload files.")
async def upload_file(file: UploadFile = File()):
"""Page for uploading files.
If it is a video, we need to make an HTML file, and a thumbnail
@ -44,16 +36,17 @@ async def upload_file(file: UploadFile = File(...)) -> Dict[str, str]:
Returns a dict with the filename, or a link to the .html if it was a video.
"""
if file.content_type.startswith("video/"):
return await do_things(file)
html_url: str = await do_things(file)
else:
filename: str = await remove_illegal_chars(file.filename)
filename: str = await remove_illegal_chars(file.filename)
with open(f"{settings.upload_folder}/{filename}", "wb+") as f:
f.write(file.file.read())
with open(f"{settings.upload_folder}/{filename}", "wb+") as f:
f.write(file.file.read())
html_url: str = urljoin(settings.serve_domain, filename) # type: ignore
domain_url: str = urljoin(settings.serve_domain, filename)
send_webhook(f"{domain_url} was uploaded.")
return {"html_url": domain_url}
send_webhook(f"{html_url} was uploaded.")
return JSONResponse(content={"html_url": html_url})
async def remove_illegal_chars(file_name: str) -> str:
@ -96,7 +89,7 @@ async def remove_illegal_chars(file_name: str) -> str:
return filename
@app.get("/", response_class=HTMLResponse)
@app.get("/", response_class=HTMLResponse, include_in_schema=False)
async def main(request: Request):
"""Our index view.