Refactor chat function to support asynchronous image processing and enhance image extraction from user messages
This commit is contained in:
parent
2aec54d51b
commit
08c286cff5
2 changed files with 148 additions and 84 deletions
114
misc.py
114
misc.py
|
|
@ -2,11 +2,22 @@ from __future__ import annotations
|
|||
|
||||
import datetime
|
||||
import logging
|
||||
import re
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
import psutil
|
||||
from discord import Emoji, Member, User, channel
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
|
@ -14,8 +25,8 @@ if TYPE_CHECKING:
|
|||
from discord.abc import MessageableChannel
|
||||
from discord.guild import GuildChannel
|
||||
from discord.interactions import InteractionChannel
|
||||
from openai import OpenAI
|
||||
from openai.types.responses import Response
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
|
@ -140,6 +151,14 @@ def extra_context(current_channel: MessageableChannel | InteractionChannel | Non
|
|||
context += "KillYoy's Server Information:\n"
|
||||
context += "- Server is for friends to hang out and chat.\n"
|
||||
context += "- Server was created by KillYoy (<@98468214824001536>)\n"
|
||||
context += "Users in the server are from Sweden. They speak Swedish and English.\n"
|
||||
context += "Everyone in the server is over 18 years old so you can talk about adult topics.\n"
|
||||
context += "You should adapt your language to be suitable for an adult audience. For example, use swear words."
|
||||
context += "Users in this server are members of the fictional space communist political party called the 'Swedish Space Communist Party'.\n"
|
||||
context += "Try to align your responses with the values of this party, which include fast computer programs without bloat, HTML instead of JavaScript, and a focus on open source software and cooperation.\n" # noqa: E501
|
||||
context += "We don't like bloated software.\n"
|
||||
|
||||
context += "The bot is called 'LoviBot'.\n"
|
||||
|
||||
# Current date and time
|
||||
context += f"Current date and time: {datetime.datetime.now(tz=datetime.UTC)} UTC, but user is in CEST or CET\n"
|
||||
|
|
@ -188,9 +207,9 @@ def extra_context(current_channel: MessageableChannel | InteractionChannel | Non
|
|||
return context
|
||||
|
||||
|
||||
def chat( # noqa: PLR0913, PLR0917
|
||||
async def chat( # noqa: PLR0913, PLR0917
|
||||
user_message: str,
|
||||
openai_client: OpenAI,
|
||||
openai_client: AsyncOpenAI,
|
||||
current_channel: MessageableChannel | InteractionChannel | None,
|
||||
user: User | Member,
|
||||
allowed_users: list[str],
|
||||
|
|
@ -276,12 +295,81 @@ def chat( # noqa: PLR0913, PLR0917
|
|||
|
||||
logger.info("Sending request to OpenAI API with prompt: %s", prompt)
|
||||
|
||||
resp: Response = openai_client.responses.create(
|
||||
model="gpt-5-chat-latest",
|
||||
instructions=prompt,
|
||||
input=user_message,
|
||||
)
|
||||
response: str | None = resp.output_text
|
||||
logger.info("AI response: %s", response)
|
||||
# Always include text first
|
||||
user_content: list[ChatCompletionContentPartParam] = [
|
||||
ChatCompletionContentPartTextParam(type="text", text=user_message),
|
||||
]
|
||||
|
||||
return response
|
||||
# Add images if found
|
||||
image_urls = await get_images_from_text(user_message)
|
||||
user_content.extend(
|
||||
ChatCompletionContentPartImageParam(
|
||||
type="image_url",
|
||||
image_url={"url": _img},
|
||||
)
|
||||
for _img in image_urls
|
||||
)
|
||||
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
ChatCompletionSystemMessageParam(role="system", content=prompt),
|
||||
ChatCompletionUserMessageParam(role="user", content=user_content),
|
||||
]
|
||||
|
||||
resp: ChatCompletion = await openai_client.chat.completions.create(
|
||||
model="gpt-5-chat-latest",
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return resp.choices[0].message.content if isinstance(resp.choices[0].message.content, str) else None
|
||||
|
||||
|
||||
async def get_images_from_text(text: str) -> list[str]:
|
||||
"""Extract all image URLs from text and return their URLs.
|
||||
|
||||
Args:
|
||||
text: The text to search for URLs.
|
||||
|
||||
Returns:
|
||||
A list of urls for each image found.
|
||||
"""
|
||||
# Find all URLs in the text
|
||||
url_pattern = r"https?://[^\s]+"
|
||||
urls: list[Any] = re.findall(url_pattern, text)
|
||||
|
||||
images: list[str] = []
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
for url in urls:
|
||||
try:
|
||||
response: httpx.Response = await client.get(url)
|
||||
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
|
||||
images.append(url)
|
||||
except httpx.RequestError as e:
|
||||
logger.warning("GET request failed for URL %s: %s", url, e)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
async def get_raw_images_from_text(text: str) -> list[bytes]:
|
||||
"""Extract all image URLs from text and return their bytes.
|
||||
|
||||
Args:
|
||||
text: The text to search for URLs.
|
||||
|
||||
Returns:
|
||||
A list of bytes for each image found.
|
||||
"""
|
||||
# Find all URLs in the text
|
||||
url_pattern = r"https?://[^\s]+"
|
||||
urls: list[Any] = re.findall(url_pattern, text)
|
||||
|
||||
images: list[bytes] = []
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
for url in urls:
|
||||
try:
|
||||
response: httpx.Response = await client.get(url)
|
||||
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
|
||||
images.append(response.content)
|
||||
except httpx.RequestError as e:
|
||||
logger.warning("GET request failed for URL %s: %s", url, e)
|
||||
|
||||
return images
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue