Enhance chat functionality by adding message memory and context for improved responses

This commit is contained in:
Joakim Hellsén 2025-08-17 02:25:43 +02:00
commit 86cb28208d
3 changed files with 70 additions and 15 deletions

View file

@ -37,6 +37,7 @@
"testpaths",
"thelovinator",
"tobytes",
"twimg",
"unsignedinteger"
]
}

23
main.py
View file

@ -15,7 +15,7 @@ import sentry_sdk
from discord import app_commands
from openai import OpenAI
from misc import chat, get_allowed_users
from misc import add_message_to_memory, chat, get_allowed_users
from settings import Settings
sentry_sdk.init(
@ -74,14 +74,17 @@ class LoviBotClient(discord.Client):
logger.info("No message content found in the event: %s", message)
return
# Add the message to memory
add_message_to_memory(str(message.channel.id), message.author.name, incoming_message)
lowercase_message: str = incoming_message.lower() if incoming_message else ""
trigger_keywords: list[str] = ["lovibot", "<@345000831499894795>"]
trigger_keywords: list[str] = ["lovibot", "@lovibot", "<@345000831499894795>", "grok", "@grok"]
if any(trigger in lowercase_message for trigger in trigger_keywords):
logger.info("Received message: %s from: %s", incoming_message, message.author.name)
async with message.channel.typing():
try:
response: str | None = chat(incoming_message, openai_client)
response: str | None = chat(incoming_message, openai_client, str(message.channel.id))
except openai.OpenAIError as e:
logger.exception("An error occurred while chatting with the AI model.")
e.add_note(f"Message: {incoming_message}\nEvent: {message}\nWho: {message.author.name}")
@ -167,7 +170,7 @@ async def ask(interaction: discord.Interaction, text: str) -> None:
return
try:
response: str | None = chat(text, openai_client)
response: str | None = chat(text, openai_client, str(interaction.channel_id))
except openai.OpenAIError as e:
logger.exception("An error occurred while chatting with the AI model.")
await interaction.followup.send(f"An error occurred: {e}")
@ -343,6 +346,8 @@ def extract_image_url(message: discord.Message) -> str | None:
the function searches the message content for any direct links ending in
common image file extensions (e.g., .png, .jpg, .jpeg, .gif, .webp).
Additionally, it handles Twitter image URLs and normalizes them to a standard format.
Args:
message (discord.Message): The message from which to extract the image URL.
@ -364,12 +369,16 @@ def extract_image_url(message: discord.Message) -> str | None:
if not image_url:
match: re.Match[str] | None = re.search(
pattern=r"(https?://[^\s]+(\.png|\.jpg|\.jpeg|\.gif|\.webp))",
string=message.content,
flags=re.IGNORECASE,
r"(https?://[^\s]+\.(png|jpg|jpeg|gif|webp)(\?[^\s]*)?)", message.content, re.IGNORECASE
)
if match:
image_url = match.group(0)
# Handle Twitter image URLs
if image_url and "pbs.twimg.com/media/" in image_url:
# Normalize Twitter image URLs to the highest quality format
image_url = re.sub(r"\?format=[^&]+&name=[^&]+", "?format=jpg&name=orig", image_url)
return image_url

61
misc.py
View file

@ -1,6 +1,8 @@
from __future__ import annotations
import datetime
import logging
from collections import deque
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@ -10,6 +12,9 @@ if TYPE_CHECKING:
logger: logging.Logger = logging.getLogger(__name__)
# A dictionary to store recent messages per channel with a maximum length per channel
recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {}
def get_allowed_users() -> list[str]:
"""Get the list of allowed users to interact with the bot.
@ -27,25 +32,65 @@ def get_allowed_users() -> list[str]:
]
def chat(user_message: str, openai_client: OpenAI) -> str | None:
def add_message_to_memory(channel_id: str, user: str, message: str) -> None:
"""Add a message to the memory for a specific channel.
Args:
channel_id: The ID of the channel where the message was sent.
user: The user who sent the message.
message: The content of the message.
"""
if channel_id not in recent_messages:
recent_messages[channel_id] = deque(maxlen=50)
timestamp: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
recent_messages[channel_id].append((user, message, timestamp))
logger.info("Added message to memory: %s from %s in channel %s", message, user, channel_id)
def get_recent_messages(channel_id: str, threshold_minutes: int = 10) -> list[tuple[str, str]]:
"""Retrieve messages from the last `threshold_minutes` minutes for a specific channel.
Args:
channel_id: The ID of the channel to retrieve messages for.
threshold_minutes: The number of minutes to consider messages as recent.
Returns:
A list of tuples containing user and message content.
"""
if channel_id not in recent_messages:
return []
threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(
minutes=threshold_minutes
)
return [(user, message) for user, message, timestamp in recent_messages[channel_id] if timestamp > threshold]
def chat(user_message: str, openai_client: OpenAI, channel_id: str) -> str | None:
"""Chat with the bot using the OpenAI API.
Args:
user_message: The message to send to OpenAI.
openai_client: The OpenAI client to use.
channel_id: The ID of the channel where the conversation is happening.
Returns:
The response from the AI model.
"""
# Include recent messages in the prompt
recent_context: str = "\n".join([f"{user}: {message}" for user, message in get_recent_messages(channel_id)])
prompt: str = (
"You are in a Discord group chat. People can ask you questions. "
"Use Discord Markdown to format messages if needed.\n"
f"Recent context:\n{recent_context}\n"
f"User: {user_message}"
)
completion: ChatCompletion = openai_client.chat.completions.create(
model="gpt-5-chat-latest",
messages=[
{
"role": "system",
"content": "You are in a Discord group chat. People can ask you questions. Use Discord Markdown to format messages if needed.", # noqa: E501
},
{"role": "user", "content": user_message},
],
messages=[{"role": "system", "content": prompt}],
)
response: str | None = completion.choices[0].message.content
logger.info("AI response: %s", response)