From eec1ed4f59c7c093945353381a32dec72af77e81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Hells=C3=A9n?= Date: Fri, 26 Sep 2025 01:06:18 +0200 Subject: [PATCH] Add conversation memory reset command Introduces a new /reset command to allow authorized users to reset the conversation memory for a channel. Also adds a new_conversation option to the /ask command to start a fresh conversation, and implements the reset_memory function in misc.py. --- main.py | 39 ++++++++++++++++++++++++++++++++++++--- misc.py | 14 ++++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 8487451..11d9955 100644 --- a/main.py +++ b/main.py @@ -15,7 +15,7 @@ import sentry_sdk from discord import Forbidden, HTTPException, NotFound, app_commands from dotenv import load_dotenv -from misc import add_message_to_memory, chat, get_allowed_users, get_raw_images_from_text, should_respond_without_trigger, update_trigger_time +from misc import add_message_to_memory, chat, get_allowed_users, get_raw_images_from_text, reset_memory, should_respond_without_trigger, update_trigger_time if TYPE_CHECKING: from collections.abc import Callable @@ -158,8 +158,14 @@ client = LoviBotClient(intents=intents) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(text="Ask LoviBot a question.") -async def ask(interaction: discord.Interaction, text: str) -> None: - """A command to ask the AI a question.""" +async def ask(interaction: discord.Interaction, text: str, new_conversation: bool = False) -> None: # noqa: FBT001, FBT002 + """A command to ask the AI a question. + + Args: + interaction (discord.Interaction): The interaction object. + text (str): The question or message to ask. + new_conversation (bool, optional): Whether to start a new conversation. Defaults to False. + """ await interaction.response.defer() if not text: @@ -167,6 +173,9 @@ async def ask(interaction: discord.Interaction, text: str) -> None: await interaction.followup.send("You need to provide a question or message.", ephemeral=True) return + if new_conversation and interaction.channel is not None: + reset_memory(str(interaction.channel.id)) + user_name_lowercase: str = interaction.user.name.lower() logger.info("Received command from: %s", user_name_lowercase) @@ -218,6 +227,30 @@ async def ask(interaction: discord.Interaction, text: str) -> None: await send_response(interaction=interaction, text=text, response=display_response) +@client.tree.command(name="reset", description="Reset the conversation memory.") +@app_commands.allowed_installs(guilds=True, users=True) +@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) +async def reset(interaction: discord.Interaction) -> None: + """A command to reset the conversation memory.""" + await interaction.response.defer() + + user_name_lowercase: str = interaction.user.name.lower() + logger.info("Received command from: %s", user_name_lowercase) + + # Only allow certain users to interact with the bot + allowed_users: list[str] = get_allowed_users() + if user_name_lowercase not in allowed_users: + await send_response(interaction=interaction, text="", response="You are not authorized to use this command.") + return + + # Reset the conversation memory + if interaction.channel is not None: + reset_memory(str(interaction.channel.id)) + await send_response(interaction=interaction, text="", response="Conversation memory has been reset.") + + await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.") + + async def send_response(interaction: discord.Interaction, text: str, response: str) -> None: """Send a response to the interaction, handling potential errors. diff --git a/misc.py b/misc.py index 28d0e1a..166457a 100644 --- a/misc.py +++ b/misc.py @@ -57,6 +57,20 @@ agent: Agent[BotDependencies, str] = Agent( ) +def reset_memory(channel_id: str) -> None: + """Reset the conversation memory for a specific channel. + + Args: + channel_id (str): The ID of the channel to reset memory for. + """ + if channel_id in recent_messages: + del recent_messages[channel_id] + logger.info("Reset memory for channel %s", channel_id) + if channel_id in last_trigger_time: + del last_trigger_time[channel_id] + logger.info("Reset trigger times for channel %s", channel_id) + + def _message_text_length(msg: ModelRequest | ModelResponse) -> int: """Compute the total text length of all text parts in a message.