Enhance chat functionality by adding message memory and context for improved responses
This commit is contained in:
parent
659fe3f13d
commit
86cb28208d
3 changed files with 70 additions and 15 deletions
1
.vscode/settings.json
vendored
1
.vscode/settings.json
vendored
|
|
@ -37,6 +37,7 @@
|
||||||
"testpaths",
|
"testpaths",
|
||||||
"thelovinator",
|
"thelovinator",
|
||||||
"tobytes",
|
"tobytes",
|
||||||
|
"twimg",
|
||||||
"unsignedinteger"
|
"unsignedinteger"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
23
main.py
23
main.py
|
|
@ -15,7 +15,7 @@ import sentry_sdk
|
||||||
from discord import app_commands
|
from discord import app_commands
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from misc import chat, get_allowed_users
|
from misc import add_message_to_memory, chat, get_allowed_users
|
||||||
from settings import Settings
|
from settings import Settings
|
||||||
|
|
||||||
sentry_sdk.init(
|
sentry_sdk.init(
|
||||||
|
|
@ -74,14 +74,17 @@ class LoviBotClient(discord.Client):
|
||||||
logger.info("No message content found in the event: %s", message)
|
logger.info("No message content found in the event: %s", message)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Add the message to memory
|
||||||
|
add_message_to_memory(str(message.channel.id), message.author.name, incoming_message)
|
||||||
|
|
||||||
lowercase_message: str = incoming_message.lower() if incoming_message else ""
|
lowercase_message: str = incoming_message.lower() if incoming_message else ""
|
||||||
trigger_keywords: list[str] = ["lovibot", "<@345000831499894795>"]
|
trigger_keywords: list[str] = ["lovibot", "@lovibot", "<@345000831499894795>", "grok", "@grok"]
|
||||||
if any(trigger in lowercase_message for trigger in trigger_keywords):
|
if any(trigger in lowercase_message for trigger in trigger_keywords):
|
||||||
logger.info("Received message: %s from: %s", incoming_message, message.author.name)
|
logger.info("Received message: %s from: %s", incoming_message, message.author.name)
|
||||||
|
|
||||||
async with message.channel.typing():
|
async with message.channel.typing():
|
||||||
try:
|
try:
|
||||||
response: str | None = chat(incoming_message, openai_client)
|
response: str | None = chat(incoming_message, openai_client, str(message.channel.id))
|
||||||
except openai.OpenAIError as e:
|
except openai.OpenAIError as e:
|
||||||
logger.exception("An error occurred while chatting with the AI model.")
|
logger.exception("An error occurred while chatting with the AI model.")
|
||||||
e.add_note(f"Message: {incoming_message}\nEvent: {message}\nWho: {message.author.name}")
|
e.add_note(f"Message: {incoming_message}\nEvent: {message}\nWho: {message.author.name}")
|
||||||
|
|
@ -167,7 +170,7 @@ async def ask(interaction: discord.Interaction, text: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response: str | None = chat(text, openai_client)
|
response: str | None = chat(text, openai_client, str(interaction.channel_id))
|
||||||
except openai.OpenAIError as e:
|
except openai.OpenAIError as e:
|
||||||
logger.exception("An error occurred while chatting with the AI model.")
|
logger.exception("An error occurred while chatting with the AI model.")
|
||||||
await interaction.followup.send(f"An error occurred: {e}")
|
await interaction.followup.send(f"An error occurred: {e}")
|
||||||
|
|
@ -343,6 +346,8 @@ def extract_image_url(message: discord.Message) -> str | None:
|
||||||
the function searches the message content for any direct links ending in
|
the function searches the message content for any direct links ending in
|
||||||
common image file extensions (e.g., .png, .jpg, .jpeg, .gif, .webp).
|
common image file extensions (e.g., .png, .jpg, .jpeg, .gif, .webp).
|
||||||
|
|
||||||
|
Additionally, it handles Twitter image URLs and normalizes them to a standard format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message (discord.Message): The message from which to extract the image URL.
|
message (discord.Message): The message from which to extract the image URL.
|
||||||
|
|
||||||
|
|
@ -364,12 +369,16 @@ def extract_image_url(message: discord.Message) -> str | None:
|
||||||
|
|
||||||
if not image_url:
|
if not image_url:
|
||||||
match: re.Match[str] | None = re.search(
|
match: re.Match[str] | None = re.search(
|
||||||
pattern=r"(https?://[^\s]+(\.png|\.jpg|\.jpeg|\.gif|\.webp))",
|
r"(https?://[^\s]+\.(png|jpg|jpeg|gif|webp)(\?[^\s]*)?)", message.content, re.IGNORECASE
|
||||||
string=message.content,
|
|
||||||
flags=re.IGNORECASE,
|
|
||||||
)
|
)
|
||||||
if match:
|
if match:
|
||||||
image_url = match.group(0)
|
image_url = match.group(0)
|
||||||
|
|
||||||
|
# Handle Twitter image URLs
|
||||||
|
if image_url and "pbs.twimg.com/media/" in image_url:
|
||||||
|
# Normalize Twitter image URLs to the highest quality format
|
||||||
|
image_url = re.sub(r"\?format=[^&]+&name=[^&]+", "?format=jpg&name=orig", image_url)
|
||||||
|
|
||||||
return image_url
|
return image_url
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
61
misc.py
61
misc.py
|
|
@ -1,6 +1,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
|
from collections import deque
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -10,6 +12,9 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# A dictionary to store recent messages per channel with a maximum length per channel
|
||||||
|
recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {}
|
||||||
|
|
||||||
|
|
||||||
def get_allowed_users() -> list[str]:
|
def get_allowed_users() -> list[str]:
|
||||||
"""Get the list of allowed users to interact with the bot.
|
"""Get the list of allowed users to interact with the bot.
|
||||||
|
|
@ -27,25 +32,65 @@ def get_allowed_users() -> list[str]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def chat(user_message: str, openai_client: OpenAI) -> str | None:
|
def add_message_to_memory(channel_id: str, user: str, message: str) -> None:
|
||||||
|
"""Add a message to the memory for a specific channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: The ID of the channel where the message was sent.
|
||||||
|
user: The user who sent the message.
|
||||||
|
message: The content of the message.
|
||||||
|
"""
|
||||||
|
if channel_id not in recent_messages:
|
||||||
|
recent_messages[channel_id] = deque(maxlen=50)
|
||||||
|
|
||||||
|
timestamp: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
|
||||||
|
recent_messages[channel_id].append((user, message, timestamp))
|
||||||
|
|
||||||
|
logger.info("Added message to memory: %s from %s in channel %s", message, user, channel_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_recent_messages(channel_id: str, threshold_minutes: int = 10) -> list[tuple[str, str]]:
|
||||||
|
"""Retrieve messages from the last `threshold_minutes` minutes for a specific channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: The ID of the channel to retrieve messages for.
|
||||||
|
threshold_minutes: The number of minutes to consider messages as recent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of tuples containing user and message content.
|
||||||
|
"""
|
||||||
|
if channel_id not in recent_messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(
|
||||||
|
minutes=threshold_minutes
|
||||||
|
)
|
||||||
|
return [(user, message) for user, message, timestamp in recent_messages[channel_id] if timestamp > threshold]
|
||||||
|
|
||||||
|
|
||||||
|
def chat(user_message: str, openai_client: OpenAI, channel_id: str) -> str | None:
|
||||||
"""Chat with the bot using the OpenAI API.
|
"""Chat with the bot using the OpenAI API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_message: The message to send to OpenAI.
|
user_message: The message to send to OpenAI.
|
||||||
openai_client: The OpenAI client to use.
|
openai_client: The OpenAI client to use.
|
||||||
|
channel_id: The ID of the channel where the conversation is happening.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The response from the AI model.
|
The response from the AI model.
|
||||||
"""
|
"""
|
||||||
|
# Include recent messages in the prompt
|
||||||
|
recent_context: str = "\n".join([f"{user}: {message}" for user, message in get_recent_messages(channel_id)])
|
||||||
|
prompt: str = (
|
||||||
|
"You are in a Discord group chat. People can ask you questions. "
|
||||||
|
"Use Discord Markdown to format messages if needed.\n"
|
||||||
|
f"Recent context:\n{recent_context}\n"
|
||||||
|
f"User: {user_message}"
|
||||||
|
)
|
||||||
|
|
||||||
completion: ChatCompletion = openai_client.chat.completions.create(
|
completion: ChatCompletion = openai_client.chat.completions.create(
|
||||||
model="gpt-5-chat-latest",
|
model="gpt-5-chat-latest",
|
||||||
messages=[
|
messages=[{"role": "system", "content": prompt}],
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": "You are in a Discord group chat. People can ask you questions. Use Discord Markdown to format messages if needed.", # noqa: E501
|
|
||||||
},
|
|
||||||
{"role": "user", "content": user_message},
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
response: str | None = completion.choices[0].message.content
|
response: str | None = completion.choices[0].message.content
|
||||||
logger.info("AI response: %s", response)
|
logger.info("AI response: %s", response)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue