Enhance chat functionality by adding message memory and context for improved responses
This commit is contained in:
parent
659fe3f13d
commit
86cb28208d
3 changed files with 70 additions and 15 deletions
1
.vscode/settings.json
vendored
1
.vscode/settings.json
vendored
|
|
@ -37,6 +37,7 @@
|
|||
"testpaths",
|
||||
"thelovinator",
|
||||
"tobytes",
|
||||
"twimg",
|
||||
"unsignedinteger"
|
||||
]
|
||||
}
|
||||
23
main.py
23
main.py
|
|
@ -15,7 +15,7 @@ import sentry_sdk
|
|||
from discord import app_commands
|
||||
from openai import OpenAI
|
||||
|
||||
from misc import chat, get_allowed_users
|
||||
from misc import add_message_to_memory, chat, get_allowed_users
|
||||
from settings import Settings
|
||||
|
||||
sentry_sdk.init(
|
||||
|
|
@ -74,14 +74,17 @@ class LoviBotClient(discord.Client):
|
|||
logger.info("No message content found in the event: %s", message)
|
||||
return
|
||||
|
||||
# Add the message to memory
|
||||
add_message_to_memory(str(message.channel.id), message.author.name, incoming_message)
|
||||
|
||||
lowercase_message: str = incoming_message.lower() if incoming_message else ""
|
||||
trigger_keywords: list[str] = ["lovibot", "<@345000831499894795>"]
|
||||
trigger_keywords: list[str] = ["lovibot", "@lovibot", "<@345000831499894795>", "grok", "@grok"]
|
||||
if any(trigger in lowercase_message for trigger in trigger_keywords):
|
||||
logger.info("Received message: %s from: %s", incoming_message, message.author.name)
|
||||
|
||||
async with message.channel.typing():
|
||||
try:
|
||||
response: str | None = chat(incoming_message, openai_client)
|
||||
response: str | None = chat(incoming_message, openai_client, str(message.channel.id))
|
||||
except openai.OpenAIError as e:
|
||||
logger.exception("An error occurred while chatting with the AI model.")
|
||||
e.add_note(f"Message: {incoming_message}\nEvent: {message}\nWho: {message.author.name}")
|
||||
|
|
@ -167,7 +170,7 @@ async def ask(interaction: discord.Interaction, text: str) -> None:
|
|||
return
|
||||
|
||||
try:
|
||||
response: str | None = chat(text, openai_client)
|
||||
response: str | None = chat(text, openai_client, str(interaction.channel_id))
|
||||
except openai.OpenAIError as e:
|
||||
logger.exception("An error occurred while chatting with the AI model.")
|
||||
await interaction.followup.send(f"An error occurred: {e}")
|
||||
|
|
@ -343,6 +346,8 @@ def extract_image_url(message: discord.Message) -> str | None:
|
|||
the function searches the message content for any direct links ending in
|
||||
common image file extensions (e.g., .png, .jpg, .jpeg, .gif, .webp).
|
||||
|
||||
Additionally, it handles Twitter image URLs and normalizes them to a standard format.
|
||||
|
||||
Args:
|
||||
message (discord.Message): The message from which to extract the image URL.
|
||||
|
||||
|
|
@ -364,12 +369,16 @@ def extract_image_url(message: discord.Message) -> str | None:
|
|||
|
||||
if not image_url:
|
||||
match: re.Match[str] | None = re.search(
|
||||
pattern=r"(https?://[^\s]+(\.png|\.jpg|\.jpeg|\.gif|\.webp))",
|
||||
string=message.content,
|
||||
flags=re.IGNORECASE,
|
||||
r"(https?://[^\s]+\.(png|jpg|jpeg|gif|webp)(\?[^\s]*)?)", message.content, re.IGNORECASE
|
||||
)
|
||||
if match:
|
||||
image_url = match.group(0)
|
||||
|
||||
# Handle Twitter image URLs
|
||||
if image_url and "pbs.twimg.com/media/" in image_url:
|
||||
# Normalize Twitter image URLs to the highest quality format
|
||||
image_url = re.sub(r"\?format=[^&]+&name=[^&]+", "?format=jpg&name=orig", image_url)
|
||||
|
||||
return image_url
|
||||
|
||||
|
||||
|
|
|
|||
61
misc.py
61
misc.py
|
|
@ -1,6 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -10,6 +12,9 @@ if TYPE_CHECKING:
|
|||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
# A dictionary to store recent messages per channel with a maximum length per channel
|
||||
recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {}
|
||||
|
||||
|
||||
def get_allowed_users() -> list[str]:
|
||||
"""Get the list of allowed users to interact with the bot.
|
||||
|
|
@ -27,25 +32,65 @@ def get_allowed_users() -> list[str]:
|
|||
]
|
||||
|
||||
|
||||
def chat(user_message: str, openai_client: OpenAI) -> str | None:
|
||||
def add_message_to_memory(channel_id: str, user: str, message: str) -> None:
|
||||
"""Add a message to the memory for a specific channel.
|
||||
|
||||
Args:
|
||||
channel_id: The ID of the channel where the message was sent.
|
||||
user: The user who sent the message.
|
||||
message: The content of the message.
|
||||
"""
|
||||
if channel_id not in recent_messages:
|
||||
recent_messages[channel_id] = deque(maxlen=50)
|
||||
|
||||
timestamp: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
|
||||
recent_messages[channel_id].append((user, message, timestamp))
|
||||
|
||||
logger.info("Added message to memory: %s from %s in channel %s", message, user, channel_id)
|
||||
|
||||
|
||||
def get_recent_messages(channel_id: str, threshold_minutes: int = 10) -> list[tuple[str, str]]:
|
||||
"""Retrieve messages from the last `threshold_minutes` minutes for a specific channel.
|
||||
|
||||
Args:
|
||||
channel_id: The ID of the channel to retrieve messages for.
|
||||
threshold_minutes: The number of minutes to consider messages as recent.
|
||||
|
||||
Returns:
|
||||
A list of tuples containing user and message content.
|
||||
"""
|
||||
if channel_id not in recent_messages:
|
||||
return []
|
||||
|
||||
threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(
|
||||
minutes=threshold_minutes
|
||||
)
|
||||
return [(user, message) for user, message, timestamp in recent_messages[channel_id] if timestamp > threshold]
|
||||
|
||||
|
||||
def chat(user_message: str, openai_client: OpenAI, channel_id: str) -> str | None:
|
||||
"""Chat with the bot using the OpenAI API.
|
||||
|
||||
Args:
|
||||
user_message: The message to send to OpenAI.
|
||||
openai_client: The OpenAI client to use.
|
||||
channel_id: The ID of the channel where the conversation is happening.
|
||||
|
||||
Returns:
|
||||
The response from the AI model.
|
||||
"""
|
||||
# Include recent messages in the prompt
|
||||
recent_context: str = "\n".join([f"{user}: {message}" for user, message in get_recent_messages(channel_id)])
|
||||
prompt: str = (
|
||||
"You are in a Discord group chat. People can ask you questions. "
|
||||
"Use Discord Markdown to format messages if needed.\n"
|
||||
f"Recent context:\n{recent_context}\n"
|
||||
f"User: {user_message}"
|
||||
)
|
||||
|
||||
completion: ChatCompletion = openai_client.chat.completions.create(
|
||||
model="gpt-5-chat-latest",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are in a Discord group chat. People can ask you questions. Use Discord Markdown to format messages if needed.", # noqa: E501
|
||||
},
|
||||
{"role": "user", "content": user_message},
|
||||
],
|
||||
messages=[{"role": "system", "content": prompt}],
|
||||
)
|
||||
response: str | None = completion.choices[0].message.content
|
||||
logger.info("AI response: %s", response)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue