Refactor chat function to support asynchronous image processing and enhance image extraction from user messages
This commit is contained in:
parent
2aec54d51b
commit
08c286cff5
2 changed files with 148 additions and 84 deletions
122
main.py
122
main.py
|
|
@ -1,23 +1,25 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import cv2
|
||||
import discord
|
||||
import httpx
|
||||
import numpy as np
|
||||
import openai
|
||||
import sentry_sdk
|
||||
from discord import app_commands
|
||||
from openai import OpenAI
|
||||
from discord import Forbidden, HTTPException, NotFound, app_commands
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from misc import add_message_to_memory, chat, get_allowed_users, should_respond_without_trigger, update_trigger_time
|
||||
from misc import add_message_to_memory, chat, get_allowed_users, get_raw_images_from_text, should_respond_without_trigger, update_trigger_time
|
||||
from settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn="https://ebbd2cdfbd08dba008d628dad7941091@o4505228040339456.ingest.us.sentry.io/4507630719401984",
|
||||
send_default_pii=True,
|
||||
|
|
@ -32,7 +34,7 @@ discord_token: str = settings.discord_token
|
|||
openai_api_key: str = settings.openai_api_key
|
||||
|
||||
|
||||
openai_client = OpenAI(api_key=openai_api_key)
|
||||
openai_client = AsyncOpenAI(api_key=openai_api_key)
|
||||
|
||||
|
||||
class LoviBotClient(discord.Client):
|
||||
|
|
@ -92,7 +94,7 @@ class LoviBotClient(discord.Client):
|
|||
|
||||
async with message.channel.typing():
|
||||
try:
|
||||
response: str | None = chat(
|
||||
response: str | None = await chat(
|
||||
user_message=incoming_message,
|
||||
openai_client=openai_client,
|
||||
current_channel=message.channel,
|
||||
|
|
@ -185,7 +187,7 @@ async def ask(interaction: discord.Interaction, text: str) -> None:
|
|||
return
|
||||
|
||||
try:
|
||||
response: str | None = chat(
|
||||
response: str | None = await chat(
|
||||
user_message=text,
|
||||
openai_client=openai_client,
|
||||
current_channel=interaction.channel,
|
||||
|
|
@ -321,6 +323,23 @@ def enhance_image3(image: bytes) -> bytes:
|
|||
return enhanced_webp.tobytes()
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) -> T: # noqa: ANN401
|
||||
"""Run a blocking function in a separate thread.
|
||||
|
||||
Args:
|
||||
func (Callable[..., T]): The blocking function to run.
|
||||
*args (tuple[Any, ...]): Positional arguments to pass to the function.
|
||||
**kwargs (dict[str, Any]): Keyword arguments to pass to the function.
|
||||
|
||||
Returns:
|
||||
T: The result of the function.
|
||||
"""
|
||||
return await asyncio.to_thread(func, *args, **kwargs)
|
||||
|
||||
|
||||
@client.tree.context_menu(name="Enhance Image")
|
||||
@app_commands.allowed_installs(guilds=True, users=True)
|
||||
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
||||
|
|
@ -329,82 +348,39 @@ async def enhance_image_command(interaction: discord.Interaction, message: disco
|
|||
await interaction.response.defer()
|
||||
|
||||
# Check if message has attachments or embeds with images
|
||||
image_url: str | None = extract_image_url(message)
|
||||
if not image_url:
|
||||
await interaction.followup.send("No image found in the message.", ephemeral=True)
|
||||
images: list[bytes] = await get_raw_images_from_text(message.content)
|
||||
|
||||
# Also check attachments
|
||||
for attachment in message.attachments:
|
||||
if attachment.content_type and attachment.content_type.startswith("image/"):
|
||||
try:
|
||||
img_bytes: bytes = await attachment.read()
|
||||
images.append(img_bytes)
|
||||
except (TimeoutError, HTTPException, Forbidden, NotFound):
|
||||
logger.exception("Failed to read attachment %s", attachment.url)
|
||||
|
||||
if not images:
|
||||
await interaction.followup.send(f"No images found in the message: \n{message.content=}")
|
||||
return
|
||||
|
||||
try:
|
||||
# Download the image
|
||||
async with httpx.AsyncClient() as client:
|
||||
response: httpx.Response = await client.get(image_url)
|
||||
response.raise_for_status()
|
||||
image_bytes: bytes = response.content
|
||||
|
||||
for image in images:
|
||||
timestamp: str = datetime.datetime.now(tz=datetime.UTC).isoformat()
|
||||
|
||||
enhanced_image1: bytes = enhance_image1(image_bytes)
|
||||
enhanced_image1, enhanced_image2, enhanced_image3 = await asyncio.gather(
|
||||
run_in_thread(enhance_image1, image),
|
||||
run_in_thread(enhance_image2, image),
|
||||
run_in_thread(enhance_image3, image),
|
||||
)
|
||||
|
||||
# Prepare files
|
||||
file1 = discord.File(fp=io.BytesIO(enhanced_image1), filename=f"enhanced1-{timestamp}.webp")
|
||||
|
||||
enhanced_image2: bytes = enhance_image2(image_bytes)
|
||||
file2 = discord.File(fp=io.BytesIO(enhanced_image2), filename=f"enhanced2-{timestamp}.webp")
|
||||
|
||||
enhanced_image3: bytes = enhance_image3(image_bytes)
|
||||
file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp")
|
||||
|
||||
files: list[discord.File] = [file1, file2, file3]
|
||||
logger.info("Enhanced image: %s", image_url)
|
||||
logger.info("Enhanced image files: %s", files)
|
||||
|
||||
await interaction.followup.send("Enhanced version:", files=files)
|
||||
|
||||
except (httpx.HTTPError, openai.OpenAIError) as e:
|
||||
logger.exception("Failed to enhance image")
|
||||
await interaction.followup.send(f"An error occurred: {e}")
|
||||
|
||||
|
||||
def extract_image_url(message: discord.Message) -> str | None:
|
||||
"""Extracts the first image URL from a given Discord message.
|
||||
|
||||
This function checks the attachments of the provided message for any image
|
||||
attachments. If none are found, it then examines the message embeds to see if
|
||||
they include an image. Finally, if no images are found in attachments or embeds,
|
||||
the function searches the message content for any direct links ending in
|
||||
common image file extensions (e.g., .png, .jpg, .jpeg, .gif, .webp).
|
||||
|
||||
Additionally, it handles Twitter image URLs and normalizes them to a standard format.
|
||||
|
||||
Args:
|
||||
message (discord.Message): The message from which to extract the image URL.
|
||||
|
||||
Returns:
|
||||
str | None: The URL of the first image found, or None if no image is found.
|
||||
"""
|
||||
image_url: str | None = None
|
||||
if message.attachments:
|
||||
for attachment in message.attachments:
|
||||
if attachment.content_type and attachment.content_type.startswith("image/"):
|
||||
image_url = attachment.url
|
||||
break
|
||||
|
||||
elif message.embeds:
|
||||
for embed in message.embeds:
|
||||
if embed.image:
|
||||
image_url = embed.image.url
|
||||
break
|
||||
|
||||
if not image_url:
|
||||
match: re.Match[str] | None = re.search(r"(https?://[^\s]+\.(png|jpg|jpeg|gif|webp)(\?[^\s]*)?)", message.content, re.IGNORECASE)
|
||||
if match:
|
||||
image_url = match.group(0)
|
||||
|
||||
# Handle Twitter image URLs
|
||||
if image_url and "pbs.twimg.com/media/" in image_url:
|
||||
# Normalize Twitter image URLs to the highest quality format
|
||||
image_url = re.sub(r"\?format=[^&]+&name=[^&]+", "?format=jpg&name=orig", image_url)
|
||||
|
||||
return image_url
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting the bot.")
|
||||
|
|
|
|||
114
misc.py
114
misc.py
|
|
@ -2,11 +2,22 @@ from __future__ import annotations
|
|||
|
||||
import datetime
|
||||
import logging
|
||||
import re
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
import psutil
|
||||
from discord import Emoji, Member, User, channel
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
|
@ -14,8 +25,8 @@ if TYPE_CHECKING:
|
|||
from discord.abc import MessageableChannel
|
||||
from discord.guild import GuildChannel
|
||||
from discord.interactions import InteractionChannel
|
||||
from openai import OpenAI
|
||||
from openai.types.responses import Response
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
|
@ -140,6 +151,14 @@ def extra_context(current_channel: MessageableChannel | InteractionChannel | Non
|
|||
context += "KillYoy's Server Information:\n"
|
||||
context += "- Server is for friends to hang out and chat.\n"
|
||||
context += "- Server was created by KillYoy (<@98468214824001536>)\n"
|
||||
context += "Users in the server are from Sweden. They speak Swedish and English.\n"
|
||||
context += "Everyone in the server is over 18 years old so you can talk about adult topics.\n"
|
||||
context += "You should adapt your language to be suitable for an adult audience. For example, use swear words."
|
||||
context += "Users in this server are members of the fictional space communist political party called the 'Swedish Space Communist Party'.\n"
|
||||
context += "Try to align your responses with the values of this party, which include fast computer programs without bloat, HTML instead of JavaScript, and a focus on open source software and cooperation.\n" # noqa: E501
|
||||
context += "We don't like bloated software.\n"
|
||||
|
||||
context += "The bot is called 'LoviBot'.\n"
|
||||
|
||||
# Current date and time
|
||||
context += f"Current date and time: {datetime.datetime.now(tz=datetime.UTC)} UTC, but user is in CEST or CET\n"
|
||||
|
|
@ -188,9 +207,9 @@ def extra_context(current_channel: MessageableChannel | InteractionChannel | Non
|
|||
return context
|
||||
|
||||
|
||||
def chat( # noqa: PLR0913, PLR0917
|
||||
async def chat( # noqa: PLR0913, PLR0917
|
||||
user_message: str,
|
||||
openai_client: OpenAI,
|
||||
openai_client: AsyncOpenAI,
|
||||
current_channel: MessageableChannel | InteractionChannel | None,
|
||||
user: User | Member,
|
||||
allowed_users: list[str],
|
||||
|
|
@ -276,12 +295,81 @@ def chat( # noqa: PLR0913, PLR0917
|
|||
|
||||
logger.info("Sending request to OpenAI API with prompt: %s", prompt)
|
||||
|
||||
resp: Response = openai_client.responses.create(
|
||||
model="gpt-5-chat-latest",
|
||||
instructions=prompt,
|
||||
input=user_message,
|
||||
)
|
||||
response: str | None = resp.output_text
|
||||
logger.info("AI response: %s", response)
|
||||
# Always include text first
|
||||
user_content: list[ChatCompletionContentPartParam] = [
|
||||
ChatCompletionContentPartTextParam(type="text", text=user_message),
|
||||
]
|
||||
|
||||
return response
|
||||
# Add images if found
|
||||
image_urls = await get_images_from_text(user_message)
|
||||
user_content.extend(
|
||||
ChatCompletionContentPartImageParam(
|
||||
type="image_url",
|
||||
image_url={"url": _img},
|
||||
)
|
||||
for _img in image_urls
|
||||
)
|
||||
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
ChatCompletionSystemMessageParam(role="system", content=prompt),
|
||||
ChatCompletionUserMessageParam(role="user", content=user_content),
|
||||
]
|
||||
|
||||
resp: ChatCompletion = await openai_client.chat.completions.create(
|
||||
model="gpt-5-chat-latest",
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return resp.choices[0].message.content if isinstance(resp.choices[0].message.content, str) else None
|
||||
|
||||
|
||||
async def get_images_from_text(text: str) -> list[str]:
|
||||
"""Extract all image URLs from text and return their URLs.
|
||||
|
||||
Args:
|
||||
text: The text to search for URLs.
|
||||
|
||||
Returns:
|
||||
A list of urls for each image found.
|
||||
"""
|
||||
# Find all URLs in the text
|
||||
url_pattern = r"https?://[^\s]+"
|
||||
urls: list[Any] = re.findall(url_pattern, text)
|
||||
|
||||
images: list[str] = []
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
for url in urls:
|
||||
try:
|
||||
response: httpx.Response = await client.get(url)
|
||||
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
|
||||
images.append(url)
|
||||
except httpx.RequestError as e:
|
||||
logger.warning("GET request failed for URL %s: %s", url, e)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
async def get_raw_images_from_text(text: str) -> list[bytes]:
|
||||
"""Extract all image URLs from text and return their bytes.
|
||||
|
||||
Args:
|
||||
text: The text to search for URLs.
|
||||
|
||||
Returns:
|
||||
A list of bytes for each image found.
|
||||
"""
|
||||
# Find all URLs in the text
|
||||
url_pattern = r"https?://[^\s]+"
|
||||
urls: list[Any] = re.findall(url_pattern, text)
|
||||
|
||||
images: list[bytes] = []
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
for url in urls:
|
||||
try:
|
||||
response: httpx.Response = await client.get(url)
|
||||
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
|
||||
images.append(response.content)
|
||||
except httpx.RequestError as e:
|
||||
logger.warning("GET request failed for URL %s: %s", url, e)
|
||||
|
||||
return images
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue