From 08c286cff54cd1d60e08d300cc7167f90d738c6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Hells=C3=A9n?= Date: Sat, 6 Sep 2025 01:58:19 +0200 Subject: [PATCH] Refactor chat function to support asynchronous image processing and enhance image extraction from user messages --- main.py | 122 +++++++++++++++++++++++--------------------------------- misc.py | 114 ++++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 150 insertions(+), 86 deletions(-) diff --git a/main.py b/main.py index 473701f..136dce0 100644 --- a/main.py +++ b/main.py @@ -1,23 +1,25 @@ from __future__ import annotations +import asyncio import datetime import io import logging -import re -from typing import Any +from typing import TYPE_CHECKING, Any, TypeVar import cv2 import discord -import httpx import numpy as np import openai import sentry_sdk -from discord import app_commands -from openai import OpenAI +from discord import Forbidden, HTTPException, NotFound, app_commands +from openai import AsyncOpenAI -from misc import add_message_to_memory, chat, get_allowed_users, should_respond_without_trigger, update_trigger_time +from misc import add_message_to_memory, chat, get_allowed_users, get_raw_images_from_text, should_respond_without_trigger, update_trigger_time from settings import Settings +if TYPE_CHECKING: + from collections.abc import Callable + sentry_sdk.init( dsn="https://ebbd2cdfbd08dba008d628dad7941091@o4505228040339456.ingest.us.sentry.io/4507630719401984", send_default_pii=True, @@ -32,7 +34,7 @@ discord_token: str = settings.discord_token openai_api_key: str = settings.openai_api_key -openai_client = OpenAI(api_key=openai_api_key) +openai_client = AsyncOpenAI(api_key=openai_api_key) class LoviBotClient(discord.Client): @@ -92,7 +94,7 @@ class LoviBotClient(discord.Client): async with message.channel.typing(): try: - response: str | None = chat( + response: str | None = await chat( user_message=incoming_message, openai_client=openai_client, current_channel=message.channel, @@ -185,7 +187,7 @@ async def ask(interaction: discord.Interaction, text: str) -> None: return try: - response: str | None = chat( + response: str | None = await chat( user_message=text, openai_client=openai_client, current_channel=interaction.channel, @@ -321,6 +323,23 @@ def enhance_image3(image: bytes) -> bytes: return enhanced_webp.tobytes() +T = TypeVar("T") + + +async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) -> T: # noqa: ANN401 + """Run a blocking function in a separate thread. + + Args: + func (Callable[..., T]): The blocking function to run. + *args (tuple[Any, ...]): Positional arguments to pass to the function. + **kwargs (dict[str, Any]): Keyword arguments to pass to the function. + + Returns: + T: The result of the function. + """ + return await asyncio.to_thread(func, *args, **kwargs) + + @client.tree.context_menu(name="Enhance Image") @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @@ -329,82 +348,39 @@ async def enhance_image_command(interaction: discord.Interaction, message: disco await interaction.response.defer() # Check if message has attachments or embeds with images - image_url: str | None = extract_image_url(message) - if not image_url: - await interaction.followup.send("No image found in the message.", ephemeral=True) + images: list[bytes] = await get_raw_images_from_text(message.content) + + # Also check attachments + for attachment in message.attachments: + if attachment.content_type and attachment.content_type.startswith("image/"): + try: + img_bytes: bytes = await attachment.read() + images.append(img_bytes) + except (TimeoutError, HTTPException, Forbidden, NotFound): + logger.exception("Failed to read attachment %s", attachment.url) + + if not images: + await interaction.followup.send(f"No images found in the message: \n{message.content=}") return - try: - # Download the image - async with httpx.AsyncClient() as client: - response: httpx.Response = await client.get(image_url) - response.raise_for_status() - image_bytes: bytes = response.content - + for image in images: timestamp: str = datetime.datetime.now(tz=datetime.UTC).isoformat() - enhanced_image1: bytes = enhance_image1(image_bytes) + enhanced_image1, enhanced_image2, enhanced_image3 = await asyncio.gather( + run_in_thread(enhance_image1, image), + run_in_thread(enhance_image2, image), + run_in_thread(enhance_image3, image), + ) + + # Prepare files file1 = discord.File(fp=io.BytesIO(enhanced_image1), filename=f"enhanced1-{timestamp}.webp") - - enhanced_image2: bytes = enhance_image2(image_bytes) file2 = discord.File(fp=io.BytesIO(enhanced_image2), filename=f"enhanced2-{timestamp}.webp") - - enhanced_image3: bytes = enhance_image3(image_bytes) file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp") files: list[discord.File] = [file1, file2, file3] - logger.info("Enhanced image: %s", image_url) - logger.info("Enhanced image files: %s", files) await interaction.followup.send("Enhanced version:", files=files) - except (httpx.HTTPError, openai.OpenAIError) as e: - logger.exception("Failed to enhance image") - await interaction.followup.send(f"An error occurred: {e}") - - -def extract_image_url(message: discord.Message) -> str | None: - """Extracts the first image URL from a given Discord message. - - This function checks the attachments of the provided message for any image - attachments. If none are found, it then examines the message embeds to see if - they include an image. Finally, if no images are found in attachments or embeds, - the function searches the message content for any direct links ending in - common image file extensions (e.g., .png, .jpg, .jpeg, .gif, .webp). - - Additionally, it handles Twitter image URLs and normalizes them to a standard format. - - Args: - message (discord.Message): The message from which to extract the image URL. - - Returns: - str | None: The URL of the first image found, or None if no image is found. - """ - image_url: str | None = None - if message.attachments: - for attachment in message.attachments: - if attachment.content_type and attachment.content_type.startswith("image/"): - image_url = attachment.url - break - - elif message.embeds: - for embed in message.embeds: - if embed.image: - image_url = embed.image.url - break - - if not image_url: - match: re.Match[str] | None = re.search(r"(https?://[^\s]+\.(png|jpg|jpeg|gif|webp)(\?[^\s]*)?)", message.content, re.IGNORECASE) - if match: - image_url = match.group(0) - - # Handle Twitter image URLs - if image_url and "pbs.twimg.com/media/" in image_url: - # Normalize Twitter image URLs to the highest quality format - image_url = re.sub(r"\?format=[^&]+&name=[^&]+", "?format=jpg&name=orig", image_url) - - return image_url - if __name__ == "__main__": logger.info("Starting the bot.") diff --git a/misc.py b/misc.py index 9d918a1..a25a82b 100644 --- a/misc.py +++ b/misc.py @@ -2,11 +2,22 @@ from __future__ import annotations import datetime import logging +import re from collections import deque -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +import httpx import psutil from discord import Emoji, Member, User, channel +from openai.types.chat import ( + ChatCompletion, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartParam, + ChatCompletionContentPartTextParam, + ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, +) if TYPE_CHECKING: from collections.abc import Sequence @@ -14,8 +25,8 @@ if TYPE_CHECKING: from discord.abc import MessageableChannel from discord.guild import GuildChannel from discord.interactions import InteractionChannel - from openai import OpenAI - from openai.types.responses import Response + from openai import AsyncOpenAI + from openai.types.chat import ChatCompletionMessageParam logger: logging.Logger = logging.getLogger(__name__) @@ -140,6 +151,14 @@ def extra_context(current_channel: MessageableChannel | InteractionChannel | Non context += "KillYoy's Server Information:\n" context += "- Server is for friends to hang out and chat.\n" context += "- Server was created by KillYoy (<@98468214824001536>)\n" + context += "Users in the server are from Sweden. They speak Swedish and English.\n" + context += "Everyone in the server is over 18 years old so you can talk about adult topics.\n" + context += "You should adapt your language to be suitable for an adult audience. For example, use swear words." + context += "Users in this server are members of the fictional space communist political party called the 'Swedish Space Communist Party'.\n" + context += "Try to align your responses with the values of this party, which include fast computer programs without bloat, HTML instead of JavaScript, and a focus on open source software and cooperation.\n" # noqa: E501 + context += "We don't like bloated software.\n" + + context += "The bot is called 'LoviBot'.\n" # Current date and time context += f"Current date and time: {datetime.datetime.now(tz=datetime.UTC)} UTC, but user is in CEST or CET\n" @@ -188,9 +207,9 @@ def extra_context(current_channel: MessageableChannel | InteractionChannel | Non return context -def chat( # noqa: PLR0913, PLR0917 +async def chat( # noqa: PLR0913, PLR0917 user_message: str, - openai_client: OpenAI, + openai_client: AsyncOpenAI, current_channel: MessageableChannel | InteractionChannel | None, user: User | Member, allowed_users: list[str], @@ -276,12 +295,81 @@ def chat( # noqa: PLR0913, PLR0917 logger.info("Sending request to OpenAI API with prompt: %s", prompt) - resp: Response = openai_client.responses.create( - model="gpt-5-chat-latest", - instructions=prompt, - input=user_message, - ) - response: str | None = resp.output_text - logger.info("AI response: %s", response) + # Always include text first + user_content: list[ChatCompletionContentPartParam] = [ + ChatCompletionContentPartTextParam(type="text", text=user_message), + ] - return response + # Add images if found + image_urls = await get_images_from_text(user_message) + user_content.extend( + ChatCompletionContentPartImageParam( + type="image_url", + image_url={"url": _img}, + ) + for _img in image_urls + ) + + messages: list[ChatCompletionMessageParam] = [ + ChatCompletionSystemMessageParam(role="system", content=prompt), + ChatCompletionUserMessageParam(role="user", content=user_content), + ] + + resp: ChatCompletion = await openai_client.chat.completions.create( + model="gpt-5-chat-latest", + messages=messages, + ) + + return resp.choices[0].message.content if isinstance(resp.choices[0].message.content, str) else None + + +async def get_images_from_text(text: str) -> list[str]: + """Extract all image URLs from text and return their URLs. + + Args: + text: The text to search for URLs. + + Returns: + A list of urls for each image found. + """ + # Find all URLs in the text + url_pattern = r"https?://[^\s]+" + urls: list[Any] = re.findall(url_pattern, text) + + images: list[str] = [] + async with httpx.AsyncClient(timeout=5.0) as client: + for url in urls: + try: + response: httpx.Response = await client.get(url) + if not response.is_error and response.headers.get("content-type", "").startswith("image/"): + images.append(url) + except httpx.RequestError as e: + logger.warning("GET request failed for URL %s: %s", url, e) + + return images + + +async def get_raw_images_from_text(text: str) -> list[bytes]: + """Extract all image URLs from text and return their bytes. + + Args: + text: The text to search for URLs. + + Returns: + A list of bytes for each image found. + """ + # Find all URLs in the text + url_pattern = r"https?://[^\s]+" + urls: list[Any] = re.findall(url_pattern, text) + + images: list[bytes] = [] + async with httpx.AsyncClient(timeout=5.0) as client: + for url in urls: + try: + response: httpx.Response = await client.get(url) + if not response.is_error and response.headers.get("content-type", "").startswith("image/"): + images.append(response.content) + except httpx.RequestError as e: + logger.warning("GET request failed for URL %s: %s", url, e) + + return images