Refactor chat function to support asynchronous image processing and enhance image extraction from user messages

This commit is contained in:
Joakim Hellsén 2025-09-06 01:58:19 +02:00
commit 08c286cff5
2 changed files with 148 additions and 84 deletions

122
main.py
View file

@ -1,23 +1,25 @@
from __future__ import annotations
import asyncio
import datetime
import io
import logging
import re
from typing import Any
from typing import TYPE_CHECKING, Any, TypeVar
import cv2
import discord
import httpx
import numpy as np
import openai
import sentry_sdk
from discord import app_commands
from openai import OpenAI
from discord import Forbidden, HTTPException, NotFound, app_commands
from openai import AsyncOpenAI
from misc import add_message_to_memory, chat, get_allowed_users, should_respond_without_trigger, update_trigger_time
from misc import add_message_to_memory, chat, get_allowed_users, get_raw_images_from_text, should_respond_without_trigger, update_trigger_time
from settings import Settings
if TYPE_CHECKING:
from collections.abc import Callable
sentry_sdk.init(
dsn="https://ebbd2cdfbd08dba008d628dad7941091@o4505228040339456.ingest.us.sentry.io/4507630719401984",
send_default_pii=True,
@ -32,7 +34,7 @@ discord_token: str = settings.discord_token
openai_api_key: str = settings.openai_api_key
openai_client = OpenAI(api_key=openai_api_key)
openai_client = AsyncOpenAI(api_key=openai_api_key)
class LoviBotClient(discord.Client):
@ -92,7 +94,7 @@ class LoviBotClient(discord.Client):
async with message.channel.typing():
try:
response: str | None = chat(
response: str | None = await chat(
user_message=incoming_message,
openai_client=openai_client,
current_channel=message.channel,
@ -185,7 +187,7 @@ async def ask(interaction: discord.Interaction, text: str) -> None:
return
try:
response: str | None = chat(
response: str | None = await chat(
user_message=text,
openai_client=openai_client,
current_channel=interaction.channel,
@ -321,6 +323,23 @@ def enhance_image3(image: bytes) -> bytes:
return enhanced_webp.tobytes()
T = TypeVar("T")
async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) -> T: # noqa: ANN401
"""Run a blocking function in a separate thread.
Args:
func (Callable[..., T]): The blocking function to run.
*args (tuple[Any, ...]): Positional arguments to pass to the function.
**kwargs (dict[str, Any]): Keyword arguments to pass to the function.
Returns:
T: The result of the function.
"""
return await asyncio.to_thread(func, *args, **kwargs)
@client.tree.context_menu(name="Enhance Image")
@app_commands.allowed_installs(guilds=True, users=True)
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
@ -329,82 +348,39 @@ async def enhance_image_command(interaction: discord.Interaction, message: disco
await interaction.response.defer()
# Check if message has attachments or embeds with images
image_url: str | None = extract_image_url(message)
if not image_url:
await interaction.followup.send("No image found in the message.", ephemeral=True)
images: list[bytes] = await get_raw_images_from_text(message.content)
# Also check attachments
for attachment in message.attachments:
if attachment.content_type and attachment.content_type.startswith("image/"):
try:
img_bytes: bytes = await attachment.read()
images.append(img_bytes)
except (TimeoutError, HTTPException, Forbidden, NotFound):
logger.exception("Failed to read attachment %s", attachment.url)
if not images:
await interaction.followup.send(f"No images found in the message: \n{message.content=}")
return
try:
# Download the image
async with httpx.AsyncClient() as client:
response: httpx.Response = await client.get(image_url)
response.raise_for_status()
image_bytes: bytes = response.content
for image in images:
timestamp: str = datetime.datetime.now(tz=datetime.UTC).isoformat()
enhanced_image1: bytes = enhance_image1(image_bytes)
enhanced_image1, enhanced_image2, enhanced_image3 = await asyncio.gather(
run_in_thread(enhance_image1, image),
run_in_thread(enhance_image2, image),
run_in_thread(enhance_image3, image),
)
# Prepare files
file1 = discord.File(fp=io.BytesIO(enhanced_image1), filename=f"enhanced1-{timestamp}.webp")
enhanced_image2: bytes = enhance_image2(image_bytes)
file2 = discord.File(fp=io.BytesIO(enhanced_image2), filename=f"enhanced2-{timestamp}.webp")
enhanced_image3: bytes = enhance_image3(image_bytes)
file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp")
files: list[discord.File] = [file1, file2, file3]
logger.info("Enhanced image: %s", image_url)
logger.info("Enhanced image files: %s", files)
await interaction.followup.send("Enhanced version:", files=files)
except (httpx.HTTPError, openai.OpenAIError) as e:
logger.exception("Failed to enhance image")
await interaction.followup.send(f"An error occurred: {e}")
def extract_image_url(message: discord.Message) -> str | None:
"""Extracts the first image URL from a given Discord message.
This function checks the attachments of the provided message for any image
attachments. If none are found, it then examines the message embeds to see if
they include an image. Finally, if no images are found in attachments or embeds,
the function searches the message content for any direct links ending in
common image file extensions (e.g., .png, .jpg, .jpeg, .gif, .webp).
Additionally, it handles Twitter image URLs and normalizes them to a standard format.
Args:
message (discord.Message): The message from which to extract the image URL.
Returns:
str | None: The URL of the first image found, or None if no image is found.
"""
image_url: str | None = None
if message.attachments:
for attachment in message.attachments:
if attachment.content_type and attachment.content_type.startswith("image/"):
image_url = attachment.url
break
elif message.embeds:
for embed in message.embeds:
if embed.image:
image_url = embed.image.url
break
if not image_url:
match: re.Match[str] | None = re.search(r"(https?://[^\s]+\.(png|jpg|jpeg|gif|webp)(\?[^\s]*)?)", message.content, re.IGNORECASE)
if match:
image_url = match.group(0)
# Handle Twitter image URLs
if image_url and "pbs.twimg.com/media/" in image_url:
# Normalize Twitter image URLs to the highest quality format
image_url = re.sub(r"\?format=[^&]+&name=[^&]+", "?format=jpg&name=orig", image_url)
return image_url
if __name__ == "__main__":
logger.info("Starting the bot.")

114
misc.py
View file

@ -2,11 +2,22 @@ from __future__ import annotations
import datetime
import logging
import re
from collections import deque
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import httpx
import psutil
from discord import Emoji, Member, User, channel
from openai.types.chat import (
ChatCompletion,
ChatCompletionContentPartImageParam,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
)
if TYPE_CHECKING:
from collections.abc import Sequence
@ -14,8 +25,8 @@ if TYPE_CHECKING:
from discord.abc import MessageableChannel
from discord.guild import GuildChannel
from discord.interactions import InteractionChannel
from openai import OpenAI
from openai.types.responses import Response
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
logger: logging.Logger = logging.getLogger(__name__)
@ -140,6 +151,14 @@ def extra_context(current_channel: MessageableChannel | InteractionChannel | Non
context += "KillYoy's Server Information:\n"
context += "- Server is for friends to hang out and chat.\n"
context += "- Server was created by KillYoy (<@98468214824001536>)\n"
context += "Users in the server are from Sweden. They speak Swedish and English.\n"
context += "Everyone in the server is over 18 years old so you can talk about adult topics.\n"
context += "You should adapt your language to be suitable for an adult audience. For example, use swear words."
context += "Users in this server are members of the fictional space communist political party called the 'Swedish Space Communist Party'.\n"
context += "Try to align your responses with the values of this party, which include fast computer programs without bloat, HTML instead of JavaScript, and a focus on open source software and cooperation.\n" # noqa: E501
context += "We don't like bloated software.\n"
context += "The bot is called 'LoviBot'.\n"
# Current date and time
context += f"Current date and time: {datetime.datetime.now(tz=datetime.UTC)} UTC, but user is in CEST or CET\n"
@ -188,9 +207,9 @@ def extra_context(current_channel: MessageableChannel | InteractionChannel | Non
return context
def chat( # noqa: PLR0913, PLR0917
async def chat( # noqa: PLR0913, PLR0917
user_message: str,
openai_client: OpenAI,
openai_client: AsyncOpenAI,
current_channel: MessageableChannel | InteractionChannel | None,
user: User | Member,
allowed_users: list[str],
@ -276,12 +295,81 @@ def chat( # noqa: PLR0913, PLR0917
logger.info("Sending request to OpenAI API with prompt: %s", prompt)
resp: Response = openai_client.responses.create(
model="gpt-5-chat-latest",
instructions=prompt,
input=user_message,
)
response: str | None = resp.output_text
logger.info("AI response: %s", response)
# Always include text first
user_content: list[ChatCompletionContentPartParam] = [
ChatCompletionContentPartTextParam(type="text", text=user_message),
]
return response
# Add images if found
image_urls = await get_images_from_text(user_message)
user_content.extend(
ChatCompletionContentPartImageParam(
type="image_url",
image_url={"url": _img},
)
for _img in image_urls
)
messages: list[ChatCompletionMessageParam] = [
ChatCompletionSystemMessageParam(role="system", content=prompt),
ChatCompletionUserMessageParam(role="user", content=user_content),
]
resp: ChatCompletion = await openai_client.chat.completions.create(
model="gpt-5-chat-latest",
messages=messages,
)
return resp.choices[0].message.content if isinstance(resp.choices[0].message.content, str) else None
async def get_images_from_text(text: str) -> list[str]:
"""Extract all image URLs from text and return their URLs.
Args:
text: The text to search for URLs.
Returns:
A list of urls for each image found.
"""
# Find all URLs in the text
url_pattern = r"https?://[^\s]+"
urls: list[Any] = re.findall(url_pattern, text)
images: list[str] = []
async with httpx.AsyncClient(timeout=5.0) as client:
for url in urls:
try:
response: httpx.Response = await client.get(url)
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
images.append(url)
except httpx.RequestError as e:
logger.warning("GET request failed for URL %s: %s", url, e)
return images
async def get_raw_images_from_text(text: str) -> list[bytes]:
"""Extract all image URLs from text and return their bytes.
Args:
text: The text to search for URLs.
Returns:
A list of bytes for each image found.
"""
# Find all URLs in the text
url_pattern = r"https?://[^\s]+"
urls: list[Any] = re.findall(url_pattern, text)
images: list[bytes] = []
async with httpx.AsyncClient(timeout=5.0) as client:
for url in urls:
try:
response: httpx.Response = await client.get(url)
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
images.append(response.content)
except httpx.RequestError as e:
logger.warning("GET request failed for URL %s: %s", url, e)
return images