Refactor chat function to support asynchronous image processing and enhance image extraction from user messages

This commit is contained in:
Joakim Hellsén 2025-09-06 01:58:19 +02:00
commit 08c286cff5
2 changed files with 148 additions and 84 deletions

122
main.py
View file

@ -1,23 +1,25 @@
from __future__ import annotations
import asyncio
import datetime
import io
import logging
import re
from typing import Any
from typing import TYPE_CHECKING, Any, TypeVar
import cv2
import discord
import httpx
import numpy as np
import openai
import sentry_sdk
from discord import app_commands
from openai import OpenAI
from discord import Forbidden, HTTPException, NotFound, app_commands
from openai import AsyncOpenAI
from misc import add_message_to_memory, chat, get_allowed_users, should_respond_without_trigger, update_trigger_time
from misc import add_message_to_memory, chat, get_allowed_users, get_raw_images_from_text, should_respond_without_trigger, update_trigger_time
from settings import Settings
if TYPE_CHECKING:
from collections.abc import Callable
sentry_sdk.init(
dsn="https://ebbd2cdfbd08dba008d628dad7941091@o4505228040339456.ingest.us.sentry.io/4507630719401984",
send_default_pii=True,
@ -32,7 +34,7 @@ discord_token: str = settings.discord_token
openai_api_key: str = settings.openai_api_key
openai_client = OpenAI(api_key=openai_api_key)
openai_client = AsyncOpenAI(api_key=openai_api_key)
class LoviBotClient(discord.Client):
@ -92,7 +94,7 @@ class LoviBotClient(discord.Client):
async with message.channel.typing():
try:
response: str | None = chat(
response: str | None = await chat(
user_message=incoming_message,
openai_client=openai_client,
current_channel=message.channel,
@ -185,7 +187,7 @@ async def ask(interaction: discord.Interaction, text: str) -> None:
return
try:
response: str | None = chat(
response: str | None = await chat(
user_message=text,
openai_client=openai_client,
current_channel=interaction.channel,
@ -321,6 +323,23 @@ def enhance_image3(image: bytes) -> bytes:
return enhanced_webp.tobytes()
T = TypeVar("T")
async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) -> T: # noqa: ANN401
"""Run a blocking function in a separate thread.
Args:
func (Callable[..., T]): The blocking function to run.
*args (tuple[Any, ...]): Positional arguments to pass to the function.
**kwargs (dict[str, Any]): Keyword arguments to pass to the function.
Returns:
T: The result of the function.
"""
return await asyncio.to_thread(func, *args, **kwargs)
@client.tree.context_menu(name="Enhance Image")
@app_commands.allowed_installs(guilds=True, users=True)
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
@ -329,82 +348,39 @@ async def enhance_image_command(interaction: discord.Interaction, message: disco
await interaction.response.defer()
# Check if message has attachments or embeds with images
image_url: str | None = extract_image_url(message)
if not image_url:
await interaction.followup.send("No image found in the message.", ephemeral=True)
images: list[bytes] = await get_raw_images_from_text(message.content)
# Also check attachments
for attachment in message.attachments:
if attachment.content_type and attachment.content_type.startswith("image/"):
try:
img_bytes: bytes = await attachment.read()
images.append(img_bytes)
except (TimeoutError, HTTPException, Forbidden, NotFound):
logger.exception("Failed to read attachment %s", attachment.url)
if not images:
await interaction.followup.send(f"No images found in the message: \n{message.content=}")
return
try:
# Download the image
async with httpx.AsyncClient() as client:
response: httpx.Response = await client.get(image_url)
response.raise_for_status()
image_bytes: bytes = response.content
for image in images:
timestamp: str = datetime.datetime.now(tz=datetime.UTC).isoformat()
enhanced_image1: bytes = enhance_image1(image_bytes)
enhanced_image1, enhanced_image2, enhanced_image3 = await asyncio.gather(
run_in_thread(enhance_image1, image),
run_in_thread(enhance_image2, image),
run_in_thread(enhance_image3, image),
)
# Prepare files
file1 = discord.File(fp=io.BytesIO(enhanced_image1), filename=f"enhanced1-{timestamp}.webp")
enhanced_image2: bytes = enhance_image2(image_bytes)
file2 = discord.File(fp=io.BytesIO(enhanced_image2), filename=f"enhanced2-{timestamp}.webp")
enhanced_image3: bytes = enhance_image3(image_bytes)
file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp")
files: list[discord.File] = [file1, file2, file3]
logger.info("Enhanced image: %s", image_url)
logger.info("Enhanced image files: %s", files)
await interaction.followup.send("Enhanced version:", files=files)
except (httpx.HTTPError, openai.OpenAIError) as e:
logger.exception("Failed to enhance image")
await interaction.followup.send(f"An error occurred: {e}")
def extract_image_url(message: discord.Message) -> str | None:
"""Extracts the first image URL from a given Discord message.
This function checks the attachments of the provided message for any image
attachments. If none are found, it then examines the message embeds to see if
they include an image. Finally, if no images are found in attachments or embeds,
the function searches the message content for any direct links ending in
common image file extensions (e.g., .png, .jpg, .jpeg, .gif, .webp).
Additionally, it handles Twitter image URLs and normalizes them to a standard format.
Args:
message (discord.Message): The message from which to extract the image URL.
Returns:
str | None: The URL of the first image found, or None if no image is found.
"""
image_url: str | None = None
if message.attachments:
for attachment in message.attachments:
if attachment.content_type and attachment.content_type.startswith("image/"):
image_url = attachment.url
break
elif message.embeds:
for embed in message.embeds:
if embed.image:
image_url = embed.image.url
break
if not image_url:
match: re.Match[str] | None = re.search(r"(https?://[^\s]+\.(png|jpg|jpeg|gif|webp)(\?[^\s]*)?)", message.content, re.IGNORECASE)
if match:
image_url = match.group(0)
# Handle Twitter image URLs
if image_url and "pbs.twimg.com/media/" in image_url:
# Normalize Twitter image URLs to the highest quality format
image_url = re.sub(r"\?format=[^&]+&name=[^&]+", "?format=jpg&name=orig", image_url)
return image_url
if __name__ == "__main__":
logger.info("Starting the bot.")