Add undo functionality for /reset command (#61)

Co-authored-by: TheLovinator1 <4153203+TheLovinator1@users.noreply.github.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Joakim Hellsén <tlovinator@gmail.com>
This commit is contained in:
Copilot 2025-12-04 02:03:40 +01:00 committed by GitHub
commit 5695722ad2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 289 additions and 0 deletions

81
main.py
View file

@ -57,6 +57,10 @@ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_TOKEN", "")
recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {}
last_trigger_time: dict[str, dict[str, datetime.datetime]] = {}
# Storage for reset snapshots to enable undo functionality
# Each channel stores its previous state: (recent_messages_snapshot, last_trigger_time_snapshot)
reset_snapshots: dict[str, tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]]] = {}
@dataclass
class BotDependencies:
@ -117,9 +121,24 @@ def grok_it(
def reset_memory(channel_id: str) -> None:
"""Reset the conversation memory for a specific channel.
Creates a snapshot of the current state before resetting to enable undo.
Args:
channel_id (str): The ID of the channel to reset memory for.
"""
# Create snapshot before reset for undo functionality
messages_snapshot: deque[tuple[str, str, datetime.datetime]] = (
deque(recent_messages[channel_id], maxlen=50) if channel_id in recent_messages else deque(maxlen=50)
)
trigger_snapshot: dict[str, datetime.datetime] = dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {}
# Only save snapshot if there's something to restore
if messages_snapshot or trigger_snapshot:
reset_snapshots[channel_id] = (messages_snapshot, trigger_snapshot)
logger.info("Created reset snapshot for channel %s", channel_id)
# Perform the actual reset
if channel_id in recent_messages:
del recent_messages[channel_id]
logger.info("Reset memory for channel %s", channel_id)
@ -128,6 +147,41 @@ def reset_memory(channel_id: str) -> None:
logger.info("Reset trigger times for channel %s", channel_id)
# MARK: undo_reset
def undo_reset(channel_id: str) -> bool:
"""Undo the last reset operation for a specific channel.
Restores the conversation memory from the saved snapshot.
Args:
channel_id (str): The ID of the channel to undo reset for.
Returns:
bool: True if undo was successful, False if no snapshot exists.
"""
if channel_id not in reset_snapshots:
logger.info("No reset snapshot found for channel %s", channel_id)
return False
messages_snapshot, trigger_snapshot = reset_snapshots[channel_id]
# Restore recent messages
if messages_snapshot:
recent_messages[channel_id] = messages_snapshot
logger.info("Restored messages for channel %s", channel_id)
# Restore trigger times
if trigger_snapshot:
last_trigger_time[channel_id] = trigger_snapshot
logger.info("Restored trigger times for channel %s", channel_id)
# Remove the snapshot after successful undo (only one undo allowed)
del reset_snapshots[channel_id]
logger.info("Removed reset snapshot for channel %s after undo", channel_id)
return True
def _message_text_length(msg: ModelRequest | ModelResponse) -> int:
"""Compute the total text length of all text parts in a message.
@ -879,6 +933,33 @@ async def reset(interaction: discord.Interaction) -> None:
await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.")
# MARK: /undo command
@client.tree.command(name="undo", description="Undo the last /reset command.")
@app_commands.allowed_installs(guilds=True, users=True)
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
async def undo(interaction: discord.Interaction) -> None:
"""A command to undo the last reset operation."""
await interaction.response.defer()
user_name_lowercase: str = interaction.user.name.lower()
logger.info("Received undo command from: %s", user_name_lowercase)
# Only allow certain users to interact with the bot
allowed_users: list[str] = get_allowed_users()
if user_name_lowercase not in allowed_users:
await send_response(interaction=interaction, text="", response="You are not authorized to use this command.")
return
# Undo the last reset
if interaction.channel is not None:
if undo_reset(str(interaction.channel.id)):
await interaction.followup.send(f"Successfully restored conversation memory for {interaction.channel}.")
else:
await interaction.followup.send(f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.")
else:
await interaction.followup.send("Cannot undo: No channel context available.")
# MARK: send_response
async def send_response(interaction: discord.Interaction, text: str, response: str) -> None:
"""Send a response to the interaction, handling potential errors.