Refactor chat function to support asynchronous image processing and enhance image extraction from user messages
This commit is contained in:
parent
2aec54d51b
commit
08c286cff5
2 changed files with 148 additions and 84 deletions
122
main.py
122
main.py
|
|
@ -1,23 +1,25 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import re
|
from typing import TYPE_CHECKING, Any, TypeVar
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import discord
|
import discord
|
||||||
import httpx
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import openai
|
import openai
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
from discord import app_commands
|
from discord import Forbidden, HTTPException, NotFound, app_commands
|
||||||
from openai import OpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from misc import add_message_to_memory, chat, get_allowed_users, should_respond_without_trigger, update_trigger_time
|
from misc import add_message_to_memory, chat, get_allowed_users, get_raw_images_from_text, should_respond_without_trigger, update_trigger_time
|
||||||
from settings import Settings
|
from settings import Settings
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
sentry_sdk.init(
|
sentry_sdk.init(
|
||||||
dsn="https://ebbd2cdfbd08dba008d628dad7941091@o4505228040339456.ingest.us.sentry.io/4507630719401984",
|
dsn="https://ebbd2cdfbd08dba008d628dad7941091@o4505228040339456.ingest.us.sentry.io/4507630719401984",
|
||||||
send_default_pii=True,
|
send_default_pii=True,
|
||||||
|
|
@ -32,7 +34,7 @@ discord_token: str = settings.discord_token
|
||||||
openai_api_key: str = settings.openai_api_key
|
openai_api_key: str = settings.openai_api_key
|
||||||
|
|
||||||
|
|
||||||
openai_client = OpenAI(api_key=openai_api_key)
|
openai_client = AsyncOpenAI(api_key=openai_api_key)
|
||||||
|
|
||||||
|
|
||||||
class LoviBotClient(discord.Client):
|
class LoviBotClient(discord.Client):
|
||||||
|
|
@ -92,7 +94,7 @@ class LoviBotClient(discord.Client):
|
||||||
|
|
||||||
async with message.channel.typing():
|
async with message.channel.typing():
|
||||||
try:
|
try:
|
||||||
response: str | None = chat(
|
response: str | None = await chat(
|
||||||
user_message=incoming_message,
|
user_message=incoming_message,
|
||||||
openai_client=openai_client,
|
openai_client=openai_client,
|
||||||
current_channel=message.channel,
|
current_channel=message.channel,
|
||||||
|
|
@ -185,7 +187,7 @@ async def ask(interaction: discord.Interaction, text: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response: str | None = chat(
|
response: str | None = await chat(
|
||||||
user_message=text,
|
user_message=text,
|
||||||
openai_client=openai_client,
|
openai_client=openai_client,
|
||||||
current_channel=interaction.channel,
|
current_channel=interaction.channel,
|
||||||
|
|
@ -321,6 +323,23 @@ def enhance_image3(image: bytes) -> bytes:
|
||||||
return enhanced_webp.tobytes()
|
return enhanced_webp.tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) -> T: # noqa: ANN401
|
||||||
|
"""Run a blocking function in a separate thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (Callable[..., T]): The blocking function to run.
|
||||||
|
*args (tuple[Any, ...]): Positional arguments to pass to the function.
|
||||||
|
**kwargs (dict[str, Any]): Keyword arguments to pass to the function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
T: The result of the function.
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(func, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@client.tree.context_menu(name="Enhance Image")
|
@client.tree.context_menu(name="Enhance Image")
|
||||||
@app_commands.allowed_installs(guilds=True, users=True)
|
@app_commands.allowed_installs(guilds=True, users=True)
|
||||||
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
||||||
|
|
@ -329,82 +348,39 @@ async def enhance_image_command(interaction: discord.Interaction, message: disco
|
||||||
await interaction.response.defer()
|
await interaction.response.defer()
|
||||||
|
|
||||||
# Check if message has attachments or embeds with images
|
# Check if message has attachments or embeds with images
|
||||||
image_url: str | None = extract_image_url(message)
|
images: list[bytes] = await get_raw_images_from_text(message.content)
|
||||||
if not image_url:
|
|
||||||
await interaction.followup.send("No image found in the message.", ephemeral=True)
|
# Also check attachments
|
||||||
|
for attachment in message.attachments:
|
||||||
|
if attachment.content_type and attachment.content_type.startswith("image/"):
|
||||||
|
try:
|
||||||
|
img_bytes: bytes = await attachment.read()
|
||||||
|
images.append(img_bytes)
|
||||||
|
except (TimeoutError, HTTPException, Forbidden, NotFound):
|
||||||
|
logger.exception("Failed to read attachment %s", attachment.url)
|
||||||
|
|
||||||
|
if not images:
|
||||||
|
await interaction.followup.send(f"No images found in the message: \n{message.content=}")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
for image in images:
|
||||||
# Download the image
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response: httpx.Response = await client.get(image_url)
|
|
||||||
response.raise_for_status()
|
|
||||||
image_bytes: bytes = response.content
|
|
||||||
|
|
||||||
timestamp: str = datetime.datetime.now(tz=datetime.UTC).isoformat()
|
timestamp: str = datetime.datetime.now(tz=datetime.UTC).isoformat()
|
||||||
|
|
||||||
enhanced_image1: bytes = enhance_image1(image_bytes)
|
enhanced_image1, enhanced_image2, enhanced_image3 = await asyncio.gather(
|
||||||
|
run_in_thread(enhance_image1, image),
|
||||||
|
run_in_thread(enhance_image2, image),
|
||||||
|
run_in_thread(enhance_image3, image),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare files
|
||||||
file1 = discord.File(fp=io.BytesIO(enhanced_image1), filename=f"enhanced1-{timestamp}.webp")
|
file1 = discord.File(fp=io.BytesIO(enhanced_image1), filename=f"enhanced1-{timestamp}.webp")
|
||||||
|
|
||||||
enhanced_image2: bytes = enhance_image2(image_bytes)
|
|
||||||
file2 = discord.File(fp=io.BytesIO(enhanced_image2), filename=f"enhanced2-{timestamp}.webp")
|
file2 = discord.File(fp=io.BytesIO(enhanced_image2), filename=f"enhanced2-{timestamp}.webp")
|
||||||
|
|
||||||
enhanced_image3: bytes = enhance_image3(image_bytes)
|
|
||||||
file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp")
|
file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp")
|
||||||
|
|
||||||
files: list[discord.File] = [file1, file2, file3]
|
files: list[discord.File] = [file1, file2, file3]
|
||||||
logger.info("Enhanced image: %s", image_url)
|
|
||||||
logger.info("Enhanced image files: %s", files)
|
|
||||||
|
|
||||||
await interaction.followup.send("Enhanced version:", files=files)
|
await interaction.followup.send("Enhanced version:", files=files)
|
||||||
|
|
||||||
except (httpx.HTTPError, openai.OpenAIError) as e:
|
|
||||||
logger.exception("Failed to enhance image")
|
|
||||||
await interaction.followup.send(f"An error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def extract_image_url(message: discord.Message) -> str | None:
|
|
||||||
"""Extracts the first image URL from a given Discord message.
|
|
||||||
|
|
||||||
This function checks the attachments of the provided message for any image
|
|
||||||
attachments. If none are found, it then examines the message embeds to see if
|
|
||||||
they include an image. Finally, if no images are found in attachments or embeds,
|
|
||||||
the function searches the message content for any direct links ending in
|
|
||||||
common image file extensions (e.g., .png, .jpg, .jpeg, .gif, .webp).
|
|
||||||
|
|
||||||
Additionally, it handles Twitter image URLs and normalizes them to a standard format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message (discord.Message): The message from which to extract the image URL.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str | None: The URL of the first image found, or None if no image is found.
|
|
||||||
"""
|
|
||||||
image_url: str | None = None
|
|
||||||
if message.attachments:
|
|
||||||
for attachment in message.attachments:
|
|
||||||
if attachment.content_type and attachment.content_type.startswith("image/"):
|
|
||||||
image_url = attachment.url
|
|
||||||
break
|
|
||||||
|
|
||||||
elif message.embeds:
|
|
||||||
for embed in message.embeds:
|
|
||||||
if embed.image:
|
|
||||||
image_url = embed.image.url
|
|
||||||
break
|
|
||||||
|
|
||||||
if not image_url:
|
|
||||||
match: re.Match[str] | None = re.search(r"(https?://[^\s]+\.(png|jpg|jpeg|gif|webp)(\?[^\s]*)?)", message.content, re.IGNORECASE)
|
|
||||||
if match:
|
|
||||||
image_url = match.group(0)
|
|
||||||
|
|
||||||
# Handle Twitter image URLs
|
|
||||||
if image_url and "pbs.twimg.com/media/" in image_url:
|
|
||||||
# Normalize Twitter image URLs to the highest quality format
|
|
||||||
image_url = re.sub(r"\?format=[^&]+&name=[^&]+", "?format=jpg&name=orig", image_url)
|
|
||||||
|
|
||||||
return image_url
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logger.info("Starting the bot.")
|
logger.info("Starting the bot.")
|
||||||
|
|
|
||||||
114
misc.py
114
misc.py
|
|
@ -2,11 +2,22 @@ from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
import psutil
|
import psutil
|
||||||
from discord import Emoji, Member, User, channel
|
from discord import Emoji, Member, User, channel
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletion,
|
||||||
|
ChatCompletionContentPartImageParam,
|
||||||
|
ChatCompletionContentPartParam,
|
||||||
|
ChatCompletionContentPartTextParam,
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionSystemMessageParam,
|
||||||
|
ChatCompletionUserMessageParam,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
@ -14,8 +25,8 @@ if TYPE_CHECKING:
|
||||||
from discord.abc import MessageableChannel
|
from discord.abc import MessageableChannel
|
||||||
from discord.guild import GuildChannel
|
from discord.guild import GuildChannel
|
||||||
from discord.interactions import InteractionChannel
|
from discord.interactions import InteractionChannel
|
||||||
from openai import OpenAI
|
from openai import AsyncOpenAI
|
||||||
from openai.types.responses import Response
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
@ -140,6 +151,14 @@ def extra_context(current_channel: MessageableChannel | InteractionChannel | Non
|
||||||
context += "KillYoy's Server Information:\n"
|
context += "KillYoy's Server Information:\n"
|
||||||
context += "- Server is for friends to hang out and chat.\n"
|
context += "- Server is for friends to hang out and chat.\n"
|
||||||
context += "- Server was created by KillYoy (<@98468214824001536>)\n"
|
context += "- Server was created by KillYoy (<@98468214824001536>)\n"
|
||||||
|
context += "Users in the server are from Sweden. They speak Swedish and English.\n"
|
||||||
|
context += "Everyone in the server is over 18 years old so you can talk about adult topics.\n"
|
||||||
|
context += "You should adapt your language to be suitable for an adult audience. For example, use swear words."
|
||||||
|
context += "Users in this server are members of the fictional space communist political party called the 'Swedish Space Communist Party'.\n"
|
||||||
|
context += "Try to align your responses with the values of this party, which include fast computer programs without bloat, HTML instead of JavaScript, and a focus on open source software and cooperation.\n" # noqa: E501
|
||||||
|
context += "We don't like bloated software.\n"
|
||||||
|
|
||||||
|
context += "The bot is called 'LoviBot'.\n"
|
||||||
|
|
||||||
# Current date and time
|
# Current date and time
|
||||||
context += f"Current date and time: {datetime.datetime.now(tz=datetime.UTC)} UTC, but user is in CEST or CET\n"
|
context += f"Current date and time: {datetime.datetime.now(tz=datetime.UTC)} UTC, but user is in CEST or CET\n"
|
||||||
|
|
@ -188,9 +207,9 @@ def extra_context(current_channel: MessageableChannel | InteractionChannel | Non
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def chat( # noqa: PLR0913, PLR0917
|
async def chat( # noqa: PLR0913, PLR0917
|
||||||
user_message: str,
|
user_message: str,
|
||||||
openai_client: OpenAI,
|
openai_client: AsyncOpenAI,
|
||||||
current_channel: MessageableChannel | InteractionChannel | None,
|
current_channel: MessageableChannel | InteractionChannel | None,
|
||||||
user: User | Member,
|
user: User | Member,
|
||||||
allowed_users: list[str],
|
allowed_users: list[str],
|
||||||
|
|
@ -276,12 +295,81 @@ def chat( # noqa: PLR0913, PLR0917
|
||||||
|
|
||||||
logger.info("Sending request to OpenAI API with prompt: %s", prompt)
|
logger.info("Sending request to OpenAI API with prompt: %s", prompt)
|
||||||
|
|
||||||
resp: Response = openai_client.responses.create(
|
# Always include text first
|
||||||
model="gpt-5-chat-latest",
|
user_content: list[ChatCompletionContentPartParam] = [
|
||||||
instructions=prompt,
|
ChatCompletionContentPartTextParam(type="text", text=user_message),
|
||||||
input=user_message,
|
]
|
||||||
)
|
|
||||||
response: str | None = resp.output_text
|
|
||||||
logger.info("AI response: %s", response)
|
|
||||||
|
|
||||||
return response
|
# Add images if found
|
||||||
|
image_urls = await get_images_from_text(user_message)
|
||||||
|
user_content.extend(
|
||||||
|
ChatCompletionContentPartImageParam(
|
||||||
|
type="image_url",
|
||||||
|
image_url={"url": _img},
|
||||||
|
)
|
||||||
|
for _img in image_urls
|
||||||
|
)
|
||||||
|
|
||||||
|
messages: list[ChatCompletionMessageParam] = [
|
||||||
|
ChatCompletionSystemMessageParam(role="system", content=prompt),
|
||||||
|
ChatCompletionUserMessageParam(role="user", content=user_content),
|
||||||
|
]
|
||||||
|
|
||||||
|
resp: ChatCompletion = await openai_client.chat.completions.create(
|
||||||
|
model="gpt-5-chat-latest",
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
return resp.choices[0].message.content if isinstance(resp.choices[0].message.content, str) else None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_images_from_text(text: str) -> list[str]:
|
||||||
|
"""Extract all image URLs from text and return their URLs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to search for URLs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of urls for each image found.
|
||||||
|
"""
|
||||||
|
# Find all URLs in the text
|
||||||
|
url_pattern = r"https?://[^\s]+"
|
||||||
|
urls: list[Any] = re.findall(url_pattern, text)
|
||||||
|
|
||||||
|
images: list[str] = []
|
||||||
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||||
|
for url in urls:
|
||||||
|
try:
|
||||||
|
response: httpx.Response = await client.get(url)
|
||||||
|
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
|
||||||
|
images.append(url)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.warning("GET request failed for URL %s: %s", url, e)
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
async def get_raw_images_from_text(text: str) -> list[bytes]:
|
||||||
|
"""Extract all image URLs from text and return their bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to search for URLs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of bytes for each image found.
|
||||||
|
"""
|
||||||
|
# Find all URLs in the text
|
||||||
|
url_pattern = r"https?://[^\s]+"
|
||||||
|
urls: list[Any] = re.findall(url_pattern, text)
|
||||||
|
|
||||||
|
images: list[bytes] = []
|
||||||
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||||
|
for url in urls:
|
||||||
|
try:
|
||||||
|
response: httpx.Response = await client.get(url)
|
||||||
|
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
|
||||||
|
images.append(response.content)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.warning("GET request failed for URL %s: %s", url, e)
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue