Refactor chat function to support asynchronous image processing and enhance image extraction from user messages
This commit is contained in:
parent
2aec54d51b
commit
08c286cff5
2 changed files with 148 additions and 84 deletions
122
main.py
122
main.py
|
|
@ -1,23 +1,25 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import cv2
|
||||
import discord
|
||||
import httpx
|
||||
import numpy as np
|
||||
import openai
|
||||
import sentry_sdk
|
||||
from discord import app_commands
|
||||
from openai import OpenAI
|
||||
from discord import Forbidden, HTTPException, NotFound, app_commands
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from misc import add_message_to_memory, chat, get_allowed_users, should_respond_without_trigger, update_trigger_time
|
||||
from misc import add_message_to_memory, chat, get_allowed_users, get_raw_images_from_text, should_respond_without_trigger, update_trigger_time
|
||||
from settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn="https://ebbd2cdfbd08dba008d628dad7941091@o4505228040339456.ingest.us.sentry.io/4507630719401984",
|
||||
send_default_pii=True,
|
||||
|
|
@ -32,7 +34,7 @@ discord_token: str = settings.discord_token
|
|||
openai_api_key: str = settings.openai_api_key
|
||||
|
||||
|
||||
openai_client = OpenAI(api_key=openai_api_key)
|
||||
openai_client = AsyncOpenAI(api_key=openai_api_key)
|
||||
|
||||
|
||||
class LoviBotClient(discord.Client):
|
||||
|
|
@ -92,7 +94,7 @@ class LoviBotClient(discord.Client):
|
|||
|
||||
async with message.channel.typing():
|
||||
try:
|
||||
response: str | None = chat(
|
||||
response: str | None = await chat(
|
||||
user_message=incoming_message,
|
||||
openai_client=openai_client,
|
||||
current_channel=message.channel,
|
||||
|
|
@ -185,7 +187,7 @@ async def ask(interaction: discord.Interaction, text: str) -> None:
|
|||
return
|
||||
|
||||
try:
|
||||
response: str | None = chat(
|
||||
response: str | None = await chat(
|
||||
user_message=text,
|
||||
openai_client=openai_client,
|
||||
current_channel=interaction.channel,
|
||||
|
|
@ -321,6 +323,23 @@ def enhance_image3(image: bytes) -> bytes:
|
|||
return enhanced_webp.tobytes()
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) -> T: # noqa: ANN401
|
||||
"""Run a blocking function in a separate thread.
|
||||
|
||||
Args:
|
||||
func (Callable[..., T]): The blocking function to run.
|
||||
*args (tuple[Any, ...]): Positional arguments to pass to the function.
|
||||
**kwargs (dict[str, Any]): Keyword arguments to pass to the function.
|
||||
|
||||
Returns:
|
||||
T: The result of the function.
|
||||
"""
|
||||
return await asyncio.to_thread(func, *args, **kwargs)
|
||||
|
||||
|
||||
@client.tree.context_menu(name="Enhance Image")
|
||||
@app_commands.allowed_installs(guilds=True, users=True)
|
||||
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
||||
|
|
@ -329,82 +348,39 @@ async def enhance_image_command(interaction: discord.Interaction, message: disco
|
|||
await interaction.response.defer()
|
||||
|
||||
# Check if message has attachments or embeds with images
|
||||
image_url: str | None = extract_image_url(message)
|
||||
if not image_url:
|
||||
await interaction.followup.send("No image found in the message.", ephemeral=True)
|
||||
images: list[bytes] = await get_raw_images_from_text(message.content)
|
||||
|
||||
# Also check attachments
|
||||
for attachment in message.attachments:
|
||||
if attachment.content_type and attachment.content_type.startswith("image/"):
|
||||
try:
|
||||
img_bytes: bytes = await attachment.read()
|
||||
images.append(img_bytes)
|
||||
except (TimeoutError, HTTPException, Forbidden, NotFound):
|
||||
logger.exception("Failed to read attachment %s", attachment.url)
|
||||
|
||||
if not images:
|
||||
await interaction.followup.send(f"No images found in the message: \n{message.content=}")
|
||||
return
|
||||
|
||||
try:
|
||||
# Download the image
|
||||
async with httpx.AsyncClient() as client:
|
||||
response: httpx.Response = await client.get(image_url)
|
||||
response.raise_for_status()
|
||||
image_bytes: bytes = response.content
|
||||
|
||||
for image in images:
|
||||
timestamp: str = datetime.datetime.now(tz=datetime.UTC).isoformat()
|
||||
|
||||
enhanced_image1: bytes = enhance_image1(image_bytes)
|
||||
enhanced_image1, enhanced_image2, enhanced_image3 = await asyncio.gather(
|
||||
run_in_thread(enhance_image1, image),
|
||||
run_in_thread(enhance_image2, image),
|
||||
run_in_thread(enhance_image3, image),
|
||||
)
|
||||
|
||||
# Prepare files
|
||||
file1 = discord.File(fp=io.BytesIO(enhanced_image1), filename=f"enhanced1-{timestamp}.webp")
|
||||
|
||||
enhanced_image2: bytes = enhance_image2(image_bytes)
|
||||
file2 = discord.File(fp=io.BytesIO(enhanced_image2), filename=f"enhanced2-{timestamp}.webp")
|
||||
|
||||
enhanced_image3: bytes = enhance_image3(image_bytes)
|
||||
file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp")
|
||||
|
||||
files: list[discord.File] = [file1, file2, file3]
|
||||
logger.info("Enhanced image: %s", image_url)
|
||||
logger.info("Enhanced image files: %s", files)
|
||||
|
||||
await interaction.followup.send("Enhanced version:", files=files)
|
||||
|
||||
except (httpx.HTTPError, openai.OpenAIError) as e:
|
||||
logger.exception("Failed to enhance image")
|
||||
await interaction.followup.send(f"An error occurred: {e}")
|
||||
|
||||
|
||||
def extract_image_url(message: discord.Message) -> str | None:
|
||||
"""Extracts the first image URL from a given Discord message.
|
||||
|
||||
This function checks the attachments of the provided message for any image
|
||||
attachments. If none are found, it then examines the message embeds to see if
|
||||
they include an image. Finally, if no images are found in attachments or embeds,
|
||||
the function searches the message content for any direct links ending in
|
||||
common image file extensions (e.g., .png, .jpg, .jpeg, .gif, .webp).
|
||||
|
||||
Additionally, it handles Twitter image URLs and normalizes them to a standard format.
|
||||
|
||||
Args:
|
||||
message (discord.Message): The message from which to extract the image URL.
|
||||
|
||||
Returns:
|
||||
str | None: The URL of the first image found, or None if no image is found.
|
||||
"""
|
||||
image_url: str | None = None
|
||||
if message.attachments:
|
||||
for attachment in message.attachments:
|
||||
if attachment.content_type and attachment.content_type.startswith("image/"):
|
||||
image_url = attachment.url
|
||||
break
|
||||
|
||||
elif message.embeds:
|
||||
for embed in message.embeds:
|
||||
if embed.image:
|
||||
image_url = embed.image.url
|
||||
break
|
||||
|
||||
if not image_url:
|
||||
match: re.Match[str] | None = re.search(r"(https?://[^\s]+\.(png|jpg|jpeg|gif|webp)(\?[^\s]*)?)", message.content, re.IGNORECASE)
|
||||
if match:
|
||||
image_url = match.group(0)
|
||||
|
||||
# Handle Twitter image URLs
|
||||
if image_url and "pbs.twimg.com/media/" in image_url:
|
||||
# Normalize Twitter image URLs to the highest quality format
|
||||
image_url = re.sub(r"\?format=[^&]+&name=[^&]+", "?format=jpg&name=orig", image_url)
|
||||
|
||||
return image_url
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting the bot.")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue