Refactor chat function to support asynchronous image processing and enhance image extraction from user messages

This commit is contained in:
Joakim Hellsén 2025-09-06 01:58:19 +02:00
commit 08c286cff5
2 changed files with 148 additions and 84 deletions

114
misc.py
View file

@ -2,11 +2,22 @@ from __future__ import annotations
import datetime
import logging
import re
from collections import deque
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import httpx
import psutil
from discord import Emoji, Member, User, channel
from openai.types.chat import (
ChatCompletion,
ChatCompletionContentPartImageParam,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
)
if TYPE_CHECKING:
from collections.abc import Sequence
@ -14,8 +25,8 @@ if TYPE_CHECKING:
from discord.abc import MessageableChannel
from discord.guild import GuildChannel
from discord.interactions import InteractionChannel
from openai import OpenAI
from openai.types.responses import Response
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
logger: logging.Logger = logging.getLogger(__name__)
@ -140,6 +151,14 @@ def extra_context(current_channel: MessageableChannel | InteractionChannel | Non
context += "KillYoy's Server Information:\n"
context += "- Server is for friends to hang out and chat.\n"
context += "- Server was created by KillYoy (<@98468214824001536>)\n"
context += "Users in the server are from Sweden. They speak Swedish and English.\n"
context += "Everyone in the server is over 18 years old so you can talk about adult topics.\n"
context += "You should adapt your language to be suitable for an adult audience. For example, use swear words."
context += "Users in this server are members of the fictional space communist political party called the 'Swedish Space Communist Party'.\n"
context += "Try to align your responses with the values of this party, which include fast computer programs without bloat, HTML instead of JavaScript, and a focus on open source software and cooperation.\n" # noqa: E501
context += "We don't like bloated software.\n"
context += "The bot is called 'LoviBot'.\n"
# Current date and time
context += f"Current date and time: {datetime.datetime.now(tz=datetime.UTC)} UTC, but user is in CEST or CET\n"
@ -188,9 +207,9 @@ def extra_context(current_channel: MessageableChannel | InteractionChannel | Non
return context
def chat( # noqa: PLR0913, PLR0917
async def chat( # noqa: PLR0913, PLR0917
user_message: str,
openai_client: OpenAI,
openai_client: AsyncOpenAI,
current_channel: MessageableChannel | InteractionChannel | None,
user: User | Member,
allowed_users: list[str],
@ -276,12 +295,81 @@ def chat( # noqa: PLR0913, PLR0917
logger.info("Sending request to OpenAI API with prompt: %s", prompt)
resp: Response = openai_client.responses.create(
model="gpt-5-chat-latest",
instructions=prompt,
input=user_message,
)
response: str | None = resp.output_text
logger.info("AI response: %s", response)
# Always include text first
user_content: list[ChatCompletionContentPartParam] = [
ChatCompletionContentPartTextParam(type="text", text=user_message),
]
return response
# Add images if found
image_urls = await get_images_from_text(user_message)
user_content.extend(
ChatCompletionContentPartImageParam(
type="image_url",
image_url={"url": _img},
)
for _img in image_urls
)
messages: list[ChatCompletionMessageParam] = [
ChatCompletionSystemMessageParam(role="system", content=prompt),
ChatCompletionUserMessageParam(role="user", content=user_content),
]
resp: ChatCompletion = await openai_client.chat.completions.create(
model="gpt-5-chat-latest",
messages=messages,
)
return resp.choices[0].message.content if isinstance(resp.choices[0].message.content, str) else None
async def get_images_from_text(text: str) -> list[str]:
"""Extract all image URLs from text and return their URLs.
Args:
text: The text to search for URLs.
Returns:
A list of urls for each image found.
"""
# Find all URLs in the text
url_pattern = r"https?://[^\s]+"
urls: list[Any] = re.findall(url_pattern, text)
images: list[str] = []
async with httpx.AsyncClient(timeout=5.0) as client:
for url in urls:
try:
response: httpx.Response = await client.get(url)
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
images.append(url)
except httpx.RequestError as e:
logger.warning("GET request failed for URL %s: %s", url, e)
return images
async def get_raw_images_from_text(text: str) -> list[bytes]:
"""Extract all image URLs from text and return their bytes.
Args:
text: The text to search for URLs.
Returns:
A list of bytes for each image found.
"""
# Find all URLs in the text
url_pattern = r"https?://[^\s]+"
urls: list[Any] = re.findall(url_pattern, text)
images: list[bytes] = []
async with httpx.AsyncClient(timeout=5.0) as client:
for url in urls:
try:
response: httpx.Response = await client.get(url)
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
images.append(response.content)
except httpx.RequestError as e:
logger.warning("GET request failed for URL %s: %s", url, e)
return images