diff --git a/.vscode/settings.json b/.vscode/settings.json index 1ad47b0..01e4dac 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -37,6 +37,7 @@ "testpaths", "thelovinator", "tobytes", + "twimg", "unsignedinteger" ] } \ No newline at end of file diff --git a/main.py b/main.py index ad8da7e..0ade936 100644 --- a/main.py +++ b/main.py @@ -15,7 +15,7 @@ import sentry_sdk from discord import app_commands from openai import OpenAI -from misc import chat, get_allowed_users +from misc import add_message_to_memory, chat, get_allowed_users from settings import Settings sentry_sdk.init( @@ -74,14 +74,17 @@ class LoviBotClient(discord.Client): logger.info("No message content found in the event: %s", message) return + # Add the message to memory + add_message_to_memory(str(message.channel.id), message.author.name, incoming_message) + lowercase_message: str = incoming_message.lower() if incoming_message else "" - trigger_keywords: list[str] = ["lovibot", "<@345000831499894795>"] + trigger_keywords: list[str] = ["lovibot", "@lovibot", "<@345000831499894795>", "grok", "@grok"] if any(trigger in lowercase_message for trigger in trigger_keywords): logger.info("Received message: %s from: %s", incoming_message, message.author.name) async with message.channel.typing(): try: - response: str | None = chat(incoming_message, openai_client) + response: str | None = chat(incoming_message, openai_client, str(message.channel.id)) except openai.OpenAIError as e: logger.exception("An error occurred while chatting with the AI model.") e.add_note(f"Message: {incoming_message}\nEvent: {message}\nWho: {message.author.name}") @@ -167,7 +170,7 @@ async def ask(interaction: discord.Interaction, text: str) -> None: return try: - response: str | None = chat(text, openai_client) + response: str | None = chat(text, openai_client, str(interaction.channel_id)) except openai.OpenAIError as e: logger.exception("An error occurred while chatting with the AI model.") await interaction.followup.send(f"An error occurred: {e}") @@ -343,6 +346,8 @@ def extract_image_url(message: discord.Message) -> str | None: the function searches the message content for any direct links ending in common image file extensions (e.g., .png, .jpg, .jpeg, .gif, .webp). + Additionally, it handles Twitter image URLs and normalizes them to a standard format. + Args: message (discord.Message): The message from which to extract the image URL. @@ -364,12 +369,16 @@ def extract_image_url(message: discord.Message) -> str | None: if not image_url: match: re.Match[str] | None = re.search( - pattern=r"(https?://[^\s]+(\.png|\.jpg|\.jpeg|\.gif|\.webp))", - string=message.content, - flags=re.IGNORECASE, + r"(https?://[^\s]+\.(png|jpg|jpeg|gif|webp)(\?[^\s]*)?)", message.content, re.IGNORECASE ) if match: image_url = match.group(0) + + # Handle Twitter image URLs + if image_url and "pbs.twimg.com/media/" in image_url: + # Normalize Twitter image URLs to the highest quality format + image_url = re.sub(r"\?format=[^&]+&name=[^&]+", "?format=jpg&name=orig", image_url) + return image_url diff --git a/misc.py b/misc.py index 5e3f395..60c5a91 100644 --- a/misc.py +++ b/misc.py @@ -1,6 +1,8 @@ from __future__ import annotations +import datetime import logging +from collections import deque from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -10,6 +12,9 @@ if TYPE_CHECKING: logger: logging.Logger = logging.getLogger(__name__) +# A dictionary to store recent messages per channel with a maximum length per channel +recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {} + def get_allowed_users() -> list[str]: """Get the list of allowed users to interact with the bot. @@ -27,25 +32,65 @@ def get_allowed_users() -> list[str]: ] -def chat(user_message: str, openai_client: OpenAI) -> str | None: +def add_message_to_memory(channel_id: str, user: str, message: str) -> None: + """Add a message to the memory for a specific channel. + + Args: + channel_id: The ID of the channel where the message was sent. + user: The user who sent the message. + message: The content of the message. + """ + if channel_id not in recent_messages: + recent_messages[channel_id] = deque(maxlen=50) + + timestamp: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) + recent_messages[channel_id].append((user, message, timestamp)) + + logger.info("Added message to memory: %s from %s in channel %s", message, user, channel_id) + + +def get_recent_messages(channel_id: str, threshold_minutes: int = 10) -> list[tuple[str, str]]: + """Retrieve messages from the last `threshold_minutes` minutes for a specific channel. + + Args: + channel_id: The ID of the channel to retrieve messages for. + threshold_minutes: The number of minutes to consider messages as recent. + + Returns: + A list of tuples containing user and message content. + """ + if channel_id not in recent_messages: + return [] + + threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta( + minutes=threshold_minutes + ) + return [(user, message) for user, message, timestamp in recent_messages[channel_id] if timestamp > threshold] + + +def chat(user_message: str, openai_client: OpenAI, channel_id: str) -> str | None: """Chat with the bot using the OpenAI API. Args: user_message: The message to send to OpenAI. openai_client: The OpenAI client to use. + channel_id: The ID of the channel where the conversation is happening. Returns: The response from the AI model. """ + # Include recent messages in the prompt + recent_context: str = "\n".join([f"{user}: {message}" for user, message in get_recent_messages(channel_id)]) + prompt: str = ( + "You are in a Discord group chat. People can ask you questions. " + "Use Discord Markdown to format messages if needed.\n" + f"Recent context:\n{recent_context}\n" + f"User: {user_message}" + ) + completion: ChatCompletion = openai_client.chat.completions.create( model="gpt-5-chat-latest", - messages=[ - { - "role": "system", - "content": "You are in a Discord group chat. People can ask you questions. Use Discord Markdown to format messages if needed.", # noqa: E501 - }, - {"role": "user", "content": user_message}, - ], + messages=[{"role": "system", "content": prompt}], ) response: str | None = completion.choices[0].message.content logger.info("AI response: %s", response)