Add undo functionality for /reset command (#61)
Co-authored-by: TheLovinator1 <4153203+TheLovinator1@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Joakim Hellsén <tlovinator@gmail.com>
This commit is contained in:
parent
9738c37aba
commit
5695722ad2
3 changed files with 289 additions and 0 deletions
81
main.py
81
main.py
|
|
@ -57,6 +57,10 @@ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_TOKEN", "")
|
|||
recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {}
|
||||
last_trigger_time: dict[str, dict[str, datetime.datetime]] = {}
|
||||
|
||||
# Storage for reset snapshots to enable undo functionality
|
||||
# Each channel stores its previous state: (recent_messages_snapshot, last_trigger_time_snapshot)
|
||||
reset_snapshots: dict[str, tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]]] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotDependencies:
|
||||
|
|
@ -117,9 +121,24 @@ def grok_it(
|
|||
def reset_memory(channel_id: str) -> None:
|
||||
"""Reset the conversation memory for a specific channel.
|
||||
|
||||
Creates a snapshot of the current state before resetting to enable undo.
|
||||
|
||||
Args:
|
||||
channel_id (str): The ID of the channel to reset memory for.
|
||||
"""
|
||||
# Create snapshot before reset for undo functionality
|
||||
messages_snapshot: deque[tuple[str, str, datetime.datetime]] = (
|
||||
deque(recent_messages[channel_id], maxlen=50) if channel_id in recent_messages else deque(maxlen=50)
|
||||
)
|
||||
|
||||
trigger_snapshot: dict[str, datetime.datetime] = dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {}
|
||||
|
||||
# Only save snapshot if there's something to restore
|
||||
if messages_snapshot or trigger_snapshot:
|
||||
reset_snapshots[channel_id] = (messages_snapshot, trigger_snapshot)
|
||||
logger.info("Created reset snapshot for channel %s", channel_id)
|
||||
|
||||
# Perform the actual reset
|
||||
if channel_id in recent_messages:
|
||||
del recent_messages[channel_id]
|
||||
logger.info("Reset memory for channel %s", channel_id)
|
||||
|
|
@ -128,6 +147,41 @@ def reset_memory(channel_id: str) -> None:
|
|||
logger.info("Reset trigger times for channel %s", channel_id)
|
||||
|
||||
|
||||
# MARK: undo_reset
|
||||
def undo_reset(channel_id: str) -> bool:
|
||||
"""Undo the last reset operation for a specific channel.
|
||||
|
||||
Restores the conversation memory from the saved snapshot.
|
||||
|
||||
Args:
|
||||
channel_id (str): The ID of the channel to undo reset for.
|
||||
|
||||
Returns:
|
||||
bool: True if undo was successful, False if no snapshot exists.
|
||||
"""
|
||||
if channel_id not in reset_snapshots:
|
||||
logger.info("No reset snapshot found for channel %s", channel_id)
|
||||
return False
|
||||
|
||||
messages_snapshot, trigger_snapshot = reset_snapshots[channel_id]
|
||||
|
||||
# Restore recent messages
|
||||
if messages_snapshot:
|
||||
recent_messages[channel_id] = messages_snapshot
|
||||
logger.info("Restored messages for channel %s", channel_id)
|
||||
|
||||
# Restore trigger times
|
||||
if trigger_snapshot:
|
||||
last_trigger_time[channel_id] = trigger_snapshot
|
||||
logger.info("Restored trigger times for channel %s", channel_id)
|
||||
|
||||
# Remove the snapshot after successful undo (only one undo allowed)
|
||||
del reset_snapshots[channel_id]
|
||||
logger.info("Removed reset snapshot for channel %s after undo", channel_id)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _message_text_length(msg: ModelRequest | ModelResponse) -> int:
|
||||
"""Compute the total text length of all text parts in a message.
|
||||
|
||||
|
|
@ -879,6 +933,33 @@ async def reset(interaction: discord.Interaction) -> None:
|
|||
await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.")
|
||||
|
||||
|
||||
# MARK: /undo command
|
||||
@client.tree.command(name="undo", description="Undo the last /reset command.")
|
||||
@app_commands.allowed_installs(guilds=True, users=True)
|
||||
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
||||
async def undo(interaction: discord.Interaction) -> None:
|
||||
"""A command to undo the last reset operation."""
|
||||
await interaction.response.defer()
|
||||
|
||||
user_name_lowercase: str = interaction.user.name.lower()
|
||||
logger.info("Received undo command from: %s", user_name_lowercase)
|
||||
|
||||
# Only allow certain users to interact with the bot
|
||||
allowed_users: list[str] = get_allowed_users()
|
||||
if user_name_lowercase not in allowed_users:
|
||||
await send_response(interaction=interaction, text="", response="You are not authorized to use this command.")
|
||||
return
|
||||
|
||||
# Undo the last reset
|
||||
if interaction.channel is not None:
|
||||
if undo_reset(str(interaction.channel.id)):
|
||||
await interaction.followup.send(f"Successfully restored conversation memory for {interaction.channel}.")
|
||||
else:
|
||||
await interaction.followup.send(f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.")
|
||||
else:
|
||||
await interaction.followup.send("Cannot undo: No channel context available.")
|
||||
|
||||
|
||||
# MARK: send_response
|
||||
async def send_response(interaction: discord.Interaction, text: str, response: str) -> None:
|
||||
"""Send a response to the interaction, handling potential errors.
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ docstring-code-line-length = 20
|
|||
"ARG", # Unused function args -> fixtures nevertheless are functionally relevant...
|
||||
"FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize()
|
||||
"PLR2004", # Magic value used in comparison, ...
|
||||
"PLR6301", # Method could be a function, class method, or static method
|
||||
"S101", # asserts allowed in tests...
|
||||
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
|
||||
]
|
||||
|
|
@ -76,3 +77,9 @@ log_cli_level = "INFO"
|
|||
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
|
||||
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
|
||||
python_files = "test_*.py *_test.py *_tests.py"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=9.0.1",
|
||||
"ruff>=0.14.7",
|
||||
]
|
||||
|
|
|
|||
201
reset_undo_test.py
Normal file
201
reset_undo_test.py
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from main import (
|
||||
add_message_to_memory,
|
||||
last_trigger_time,
|
||||
recent_messages,
|
||||
reset_memory,
|
||||
reset_snapshots,
|
||||
undo_reset,
|
||||
update_trigger_time,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_state() -> None:
|
||||
"""Clear all state before each test."""
|
||||
recent_messages.clear()
|
||||
last_trigger_time.clear()
|
||||
reset_snapshots.clear()
|
||||
|
||||
|
||||
class TestResetMemory:
|
||||
"""Tests for the reset_memory function."""
|
||||
|
||||
def test_reset_memory_clears_messages(self) -> None:
|
||||
"""Test that reset_memory clears messages for the channel."""
|
||||
channel_id = "test_channel_123"
|
||||
add_message_to_memory(channel_id, "user1", "Hello")
|
||||
add_message_to_memory(channel_id, "user2", "World")
|
||||
|
||||
assert channel_id in recent_messages
|
||||
assert len(recent_messages[channel_id]) == 2
|
||||
|
||||
reset_memory(channel_id)
|
||||
|
||||
assert channel_id not in recent_messages
|
||||
|
||||
def test_reset_memory_clears_trigger_times(self) -> None:
|
||||
"""Test that reset_memory clears trigger times for the channel."""
|
||||
channel_id = "test_channel_123"
|
||||
update_trigger_time(channel_id, "user1")
|
||||
|
||||
assert channel_id in last_trigger_time
|
||||
|
||||
reset_memory(channel_id)
|
||||
|
||||
assert channel_id not in last_trigger_time
|
||||
|
||||
def test_reset_memory_creates_snapshot(self) -> None:
|
||||
"""Test that reset_memory creates a snapshot for undo."""
|
||||
channel_id = "test_channel_123"
|
||||
add_message_to_memory(channel_id, "user1", "Test message")
|
||||
update_trigger_time(channel_id, "user1")
|
||||
|
||||
reset_memory(channel_id)
|
||||
|
||||
assert channel_id in reset_snapshots
|
||||
messages_snapshot, trigger_snapshot = reset_snapshots[channel_id]
|
||||
assert len(messages_snapshot) == 1
|
||||
assert "user1" in trigger_snapshot
|
||||
|
||||
def test_reset_memory_no_snapshot_for_empty_channel(self) -> None:
|
||||
"""Test that reset_memory doesn't create snapshot for empty channel."""
|
||||
channel_id = "empty_channel"
|
||||
|
||||
reset_memory(channel_id)
|
||||
|
||||
assert channel_id not in reset_snapshots
|
||||
|
||||
|
||||
class TestUndoReset:
|
||||
"""Tests for the undo_reset function."""
|
||||
|
||||
def test_undo_reset_restores_messages(self) -> None:
|
||||
"""Test that undo_reset restores messages."""
|
||||
channel_id = "test_channel_123"
|
||||
add_message_to_memory(channel_id, "user1", "Hello")
|
||||
add_message_to_memory(channel_id, "user2", "World")
|
||||
|
||||
reset_memory(channel_id)
|
||||
assert channel_id not in recent_messages
|
||||
|
||||
result = undo_reset(channel_id)
|
||||
|
||||
assert result is True
|
||||
assert channel_id in recent_messages
|
||||
assert len(recent_messages[channel_id]) == 2
|
||||
|
||||
def test_undo_reset_restores_trigger_times(self) -> None:
|
||||
"""Test that undo_reset restores trigger times."""
|
||||
channel_id = "test_channel_123"
|
||||
update_trigger_time(channel_id, "user1")
|
||||
original_time = last_trigger_time[channel_id]["user1"]
|
||||
|
||||
reset_memory(channel_id)
|
||||
assert channel_id not in last_trigger_time
|
||||
|
||||
result = undo_reset(channel_id)
|
||||
|
||||
assert result is True
|
||||
assert channel_id in last_trigger_time
|
||||
assert last_trigger_time[channel_id]["user1"] == original_time
|
||||
|
||||
def test_undo_reset_removes_snapshot(self) -> None:
|
||||
"""Test that undo_reset removes the snapshot after restoring."""
|
||||
channel_id = "test_channel_123"
|
||||
add_message_to_memory(channel_id, "user1", "Hello")
|
||||
|
||||
reset_memory(channel_id)
|
||||
assert channel_id in reset_snapshots
|
||||
|
||||
undo_reset(channel_id)
|
||||
|
||||
assert channel_id not in reset_snapshots
|
||||
|
||||
def test_undo_reset_returns_false_when_no_snapshot(self) -> None:
|
||||
"""Test that undo_reset returns False when no snapshot exists."""
|
||||
channel_id = "nonexistent_channel"
|
||||
|
||||
result = undo_reset(channel_id)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_undo_reset_only_works_once(self) -> None:
|
||||
"""Test that undo_reset only works once (snapshot is removed after undo)."""
|
||||
channel_id = "test_channel_123"
|
||||
add_message_to_memory(channel_id, "user1", "Hello")
|
||||
|
||||
reset_memory(channel_id)
|
||||
first_undo = undo_reset(channel_id)
|
||||
second_undo = undo_reset(channel_id)
|
||||
|
||||
assert first_undo is True
|
||||
assert second_undo is False
|
||||
|
||||
|
||||
class TestResetUndoIntegration:
|
||||
"""Integration tests for reset and undo functionality."""
|
||||
|
||||
def test_reset_then_undo_preserves_content(self) -> None:
|
||||
"""Test that reset followed by undo preserves original content."""
|
||||
channel_id = "test_channel_123"
|
||||
add_message_to_memory(channel_id, "user1", "Message 1")
|
||||
add_message_to_memory(channel_id, "user2", "Message 2")
|
||||
add_message_to_memory(channel_id, "user3", "Message 3")
|
||||
update_trigger_time(channel_id, "user1")
|
||||
update_trigger_time(channel_id, "user2")
|
||||
|
||||
# Capture original state
|
||||
original_messages = list(recent_messages[channel_id])
|
||||
original_trigger_users = set(last_trigger_time[channel_id].keys())
|
||||
|
||||
reset_memory(channel_id)
|
||||
undo_reset(channel_id)
|
||||
|
||||
# Verify restored state matches original
|
||||
restored_messages = list(recent_messages[channel_id])
|
||||
restored_trigger_users = set(last_trigger_time[channel_id].keys())
|
||||
|
||||
assert len(restored_messages) == len(original_messages)
|
||||
assert restored_trigger_users == original_trigger_users
|
||||
|
||||
def test_multiple_resets_overwrite_snapshot(self) -> None:
|
||||
"""Test that multiple resets overwrite the previous snapshot."""
|
||||
channel_id = "test_channel_123"
|
||||
|
||||
# First set of messages
|
||||
add_message_to_memory(channel_id, "user1", "First message")
|
||||
reset_memory(channel_id)
|
||||
|
||||
# Second set of messages
|
||||
add_message_to_memory(channel_id, "user1", "Second message")
|
||||
add_message_to_memory(channel_id, "user1", "Third message")
|
||||
reset_memory(channel_id)
|
||||
|
||||
# Undo should restore the second set, not the first
|
||||
undo_reset(channel_id)
|
||||
|
||||
assert channel_id in recent_messages
|
||||
assert len(recent_messages[channel_id]) == 2
|
||||
|
||||
def test_different_channels_independent_undo(self) -> None:
|
||||
"""Test that different channels have independent undo functionality."""
|
||||
channel_1 = "channel_1"
|
||||
channel_2 = "channel_2"
|
||||
|
||||
add_message_to_memory(channel_1, "user1", "Channel 1 message")
|
||||
add_message_to_memory(channel_2, "user2", "Channel 2 message")
|
||||
|
||||
reset_memory(channel_1)
|
||||
reset_memory(channel_2)
|
||||
|
||||
# Undo only channel 1
|
||||
undo_reset(channel_1)
|
||||
|
||||
assert channel_1 in recent_messages
|
||||
assert channel_2 not in recent_messages
|
||||
assert channel_1 not in reset_snapshots
|
||||
assert channel_2 in reset_snapshots
|
||||
Loading…
Add table
Add a link
Reference in a new issue