From 5695722ad259f2140f433ccf95a5113d3f7d7aab Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 4 Dec 2025 02:03:40 +0100 Subject: [PATCH] Add undo functionality for /reset command (#61) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: TheLovinator1 <4153203+TheLovinator1@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Joakim Hellsén --- main.py | 81 ++++++++++++++++++ pyproject.toml | 7 ++ reset_undo_test.py | 201 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 289 insertions(+) create mode 100644 reset_undo_test.py diff --git a/main.py b/main.py index 864c907..1375ec0 100644 --- a/main.py +++ b/main.py @@ -57,6 +57,10 @@ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_TOKEN", "") recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {} last_trigger_time: dict[str, dict[str, datetime.datetime]] = {} +# Storage for reset snapshots to enable undo functionality +# Each channel stores its previous state: (recent_messages_snapshot, last_trigger_time_snapshot) +reset_snapshots: dict[str, tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]]] = {} + @dataclass class BotDependencies: @@ -117,9 +121,24 @@ def grok_it( def reset_memory(channel_id: str) -> None: """Reset the conversation memory for a specific channel. + Creates a snapshot of the current state before resetting to enable undo. + Args: channel_id (str): The ID of the channel to reset memory for. """ + # Create snapshot before reset for undo functionality + messages_snapshot: deque[tuple[str, str, datetime.datetime]] = ( + deque(recent_messages[channel_id], maxlen=50) if channel_id in recent_messages else deque(maxlen=50) + ) + + trigger_snapshot: dict[str, datetime.datetime] = dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {} + + # Only save snapshot if there's something to restore + if messages_snapshot or trigger_snapshot: + reset_snapshots[channel_id] = (messages_snapshot, trigger_snapshot) + logger.info("Created reset snapshot for channel %s", channel_id) + + # Perform the actual reset if channel_id in recent_messages: del recent_messages[channel_id] logger.info("Reset memory for channel %s", channel_id) @@ -128,6 +147,41 @@ def reset_memory(channel_id: str) -> None: logger.info("Reset trigger times for channel %s", channel_id) +# MARK: undo_reset +def undo_reset(channel_id: str) -> bool: + """Undo the last reset operation for a specific channel. + + Restores the conversation memory from the saved snapshot. + + Args: + channel_id (str): The ID of the channel to undo reset for. + + Returns: + bool: True if undo was successful, False if no snapshot exists. + """ + if channel_id not in reset_snapshots: + logger.info("No reset snapshot found for channel %s", channel_id) + return False + + messages_snapshot, trigger_snapshot = reset_snapshots[channel_id] + + # Restore recent messages + if messages_snapshot: + recent_messages[channel_id] = messages_snapshot + logger.info("Restored messages for channel %s", channel_id) + + # Restore trigger times + if trigger_snapshot: + last_trigger_time[channel_id] = trigger_snapshot + logger.info("Restored trigger times for channel %s", channel_id) + + # Remove the snapshot after successful undo (only one undo allowed) + del reset_snapshots[channel_id] + logger.info("Removed reset snapshot for channel %s after undo", channel_id) + + return True + + def _message_text_length(msg: ModelRequest | ModelResponse) -> int: """Compute the total text length of all text parts in a message. @@ -879,6 +933,33 @@ async def reset(interaction: discord.Interaction) -> None: await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.") +# MARK: /undo command +@client.tree.command(name="undo", description="Undo the last /reset command.") +@app_commands.allowed_installs(guilds=True, users=True) +@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) +async def undo(interaction: discord.Interaction) -> None: + """A command to undo the last reset operation.""" + await interaction.response.defer() + + user_name_lowercase: str = interaction.user.name.lower() + logger.info("Received undo command from: %s", user_name_lowercase) + + # Only allow certain users to interact with the bot + allowed_users: list[str] = get_allowed_users() + if user_name_lowercase not in allowed_users: + await send_response(interaction=interaction, text="", response="You are not authorized to use this command.") + return + + # Undo the last reset + if interaction.channel is not None: + if undo_reset(str(interaction.channel.id)): + await interaction.followup.send(f"Successfully restored conversation memory for {interaction.channel}.") + else: + await interaction.followup.send(f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.") + else: + await interaction.followup.send("Cannot undo: No channel context available.") + + # MARK: send_response async def send_response(interaction: discord.Interaction, text: str, response: str) -> None: """Send a response to the interaction, handling potential errors. diff --git a/pyproject.toml b/pyproject.toml index eae6453..a5686b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ docstring-code-line-length = 20 "ARG", # Unused function args -> fixtures nevertheless are functionally relevant... "FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize() "PLR2004", # Magic value used in comparison, ... + "PLR6301", # Method could be a function, class method, or static method "S101", # asserts allowed in tests... "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes ] @@ -76,3 +77,9 @@ log_cli_level = "INFO" log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" log_cli_date_format = "%Y-%m-%d %H:%M:%S" python_files = "test_*.py *_test.py *_tests.py" + +[dependency-groups] +dev = [ + "pytest>=9.0.1", + "ruff>=0.14.7", +] diff --git a/reset_undo_test.py b/reset_undo_test.py new file mode 100644 index 0000000..1a90956 --- /dev/null +++ b/reset_undo_test.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +import pytest + +from main import ( + add_message_to_memory, + last_trigger_time, + recent_messages, + reset_memory, + reset_snapshots, + undo_reset, + update_trigger_time, +) + + +@pytest.fixture(autouse=True) +def clear_state() -> None: + """Clear all state before each test.""" + recent_messages.clear() + last_trigger_time.clear() + reset_snapshots.clear() + + +class TestResetMemory: + """Tests for the reset_memory function.""" + + def test_reset_memory_clears_messages(self) -> None: + """Test that reset_memory clears messages for the channel.""" + channel_id = "test_channel_123" + add_message_to_memory(channel_id, "user1", "Hello") + add_message_to_memory(channel_id, "user2", "World") + + assert channel_id in recent_messages + assert len(recent_messages[channel_id]) == 2 + + reset_memory(channel_id) + + assert channel_id not in recent_messages + + def test_reset_memory_clears_trigger_times(self) -> None: + """Test that reset_memory clears trigger times for the channel.""" + channel_id = "test_channel_123" + update_trigger_time(channel_id, "user1") + + assert channel_id in last_trigger_time + + reset_memory(channel_id) + + assert channel_id not in last_trigger_time + + def test_reset_memory_creates_snapshot(self) -> None: + """Test that reset_memory creates a snapshot for undo.""" + channel_id = "test_channel_123" + add_message_to_memory(channel_id, "user1", "Test message") + update_trigger_time(channel_id, "user1") + + reset_memory(channel_id) + + assert channel_id in reset_snapshots + messages_snapshot, trigger_snapshot = reset_snapshots[channel_id] + assert len(messages_snapshot) == 1 + assert "user1" in trigger_snapshot + + def test_reset_memory_no_snapshot_for_empty_channel(self) -> None: + """Test that reset_memory doesn't create snapshot for empty channel.""" + channel_id = "empty_channel" + + reset_memory(channel_id) + + assert channel_id not in reset_snapshots + + +class TestUndoReset: + """Tests for the undo_reset function.""" + + def test_undo_reset_restores_messages(self) -> None: + """Test that undo_reset restores messages.""" + channel_id = "test_channel_123" + add_message_to_memory(channel_id, "user1", "Hello") + add_message_to_memory(channel_id, "user2", "World") + + reset_memory(channel_id) + assert channel_id not in recent_messages + + result = undo_reset(channel_id) + + assert result is True + assert channel_id in recent_messages + assert len(recent_messages[channel_id]) == 2 + + def test_undo_reset_restores_trigger_times(self) -> None: + """Test that undo_reset restores trigger times.""" + channel_id = "test_channel_123" + update_trigger_time(channel_id, "user1") + original_time = last_trigger_time[channel_id]["user1"] + + reset_memory(channel_id) + assert channel_id not in last_trigger_time + + result = undo_reset(channel_id) + + assert result is True + assert channel_id in last_trigger_time + assert last_trigger_time[channel_id]["user1"] == original_time + + def test_undo_reset_removes_snapshot(self) -> None: + """Test that undo_reset removes the snapshot after restoring.""" + channel_id = "test_channel_123" + add_message_to_memory(channel_id, "user1", "Hello") + + reset_memory(channel_id) + assert channel_id in reset_snapshots + + undo_reset(channel_id) + + assert channel_id not in reset_snapshots + + def test_undo_reset_returns_false_when_no_snapshot(self) -> None: + """Test that undo_reset returns False when no snapshot exists.""" + channel_id = "nonexistent_channel" + + result = undo_reset(channel_id) + + assert result is False + + def test_undo_reset_only_works_once(self) -> None: + """Test that undo_reset only works once (snapshot is removed after undo).""" + channel_id = "test_channel_123" + add_message_to_memory(channel_id, "user1", "Hello") + + reset_memory(channel_id) + first_undo = undo_reset(channel_id) + second_undo = undo_reset(channel_id) + + assert first_undo is True + assert second_undo is False + + +class TestResetUndoIntegration: + """Integration tests for reset and undo functionality.""" + + def test_reset_then_undo_preserves_content(self) -> None: + """Test that reset followed by undo preserves original content.""" + channel_id = "test_channel_123" + add_message_to_memory(channel_id, "user1", "Message 1") + add_message_to_memory(channel_id, "user2", "Message 2") + add_message_to_memory(channel_id, "user3", "Message 3") + update_trigger_time(channel_id, "user1") + update_trigger_time(channel_id, "user2") + + # Capture original state + original_messages = list(recent_messages[channel_id]) + original_trigger_users = set(last_trigger_time[channel_id].keys()) + + reset_memory(channel_id) + undo_reset(channel_id) + + # Verify restored state matches original + restored_messages = list(recent_messages[channel_id]) + restored_trigger_users = set(last_trigger_time[channel_id].keys()) + + assert len(restored_messages) == len(original_messages) + assert restored_trigger_users == original_trigger_users + + def test_multiple_resets_overwrite_snapshot(self) -> None: + """Test that multiple resets overwrite the previous snapshot.""" + channel_id = "test_channel_123" + + # First set of messages + add_message_to_memory(channel_id, "user1", "First message") + reset_memory(channel_id) + + # Second set of messages + add_message_to_memory(channel_id, "user1", "Second message") + add_message_to_memory(channel_id, "user1", "Third message") + reset_memory(channel_id) + + # Undo should restore the second set, not the first + undo_reset(channel_id) + + assert channel_id in recent_messages + assert len(recent_messages[channel_id]) == 2 + + def test_different_channels_independent_undo(self) -> None: + """Test that different channels have independent undo functionality.""" + channel_1 = "channel_1" + channel_2 = "channel_2" + + add_message_to_memory(channel_1, "user1", "Channel 1 message") + add_message_to_memory(channel_2, "user2", "Channel 2 message") + + reset_memory(channel_1) + reset_memory(channel_2) + + # Undo only channel 1 + undo_reset(channel_1) + + assert channel_1 in recent_messages + assert channel_2 not in recent_messages + assert channel_1 not in reset_snapshots + assert channel_2 in reset_snapshots