Compare commits

..

No commits in common. "80e0637e8ae6c3532c671e3f9fd708079e6940fe" and "1e515a839478dd680f54efcf4e09eb38987f6da9" have entirely different histories.

7 changed files with 18 additions and 509 deletions

View file

@ -1,4 +1,3 @@
DISCORD_TOKEN= DISCORD_TOKEN=
OPENAI_TOKEN= OPENAI_TOKEN=
OLLAMA_API_KEY= OLLAMA_API_KEY=
OPENROUTER_API_KEY=

View file

@ -1,108 +0,0 @@
# Copilot Instructions for ANewDawn
## Project Overview
ANewDawn is a Discord bot written in Python 3.13+ using the discord.py library and Pydantic AI for AI-powered chat capabilities. The bot includes features such as:
- AI-powered chat responses using OpenAI and Grok models
- Conversation memory with reset/undo functionality
- Image enhancement using OpenCV
- Web search integration via Ollama
- Slash commands and context menus
## Development Environment
- **Python**: 3.13 or higher required
- **Package Manager**: Use `uv` for dependency management (see `pyproject.toml`)
- **Docker**: The project uses Docker for deployment (see `Dockerfile` and `docker-compose.yml`)
- **Environment Variables**: Copy `.env.example` to `.env` and fill in required tokens
## Code Style and Conventions
### Linting and Formatting
This project uses **Ruff** for linting and formatting with strict settings:
- All rules enabled (`lint.select = ["ALL"]`)
- Preview features enabled
- Auto-fix enabled
- Line length: 160 characters
- Google-style docstrings required
Run linting:
```bash
ruff check --exit-non-zero-on-fix --verbose
```
Run formatting check:
```bash
ruff format --check --verbose
```
### Python Conventions
- Use `from __future__ import annotations` at the top of all files (automatically added by Ruff)
- Use type hints for all function parameters and return types
- Follow Google docstring convention
- Use `logging` module for logging, not print statements
- Prefer explicit imports over wildcard imports
### Testing
- Tests use pytest
- Test files should be named `*_test.py` or `test_*.py`
- Run tests with: `pytest`
## Project Structure
- `main.py` - Main bot application with all commands and event handlers
- `pyproject.toml` - Project configuration and dependencies
- `Dockerfile` / `docker-compose.yml` - Container configuration
- `.github/workflows/` - CI/CD workflows
## Key Components
### Bot Client
The main bot client is `LoviBotClient` which extends `discord.Client`. It handles:
- Message events (`on_message`)
- Slash commands (`/ask`, `/grok`, `/reset`, `/undo`)
- Context menus (image enhancement)
### AI Integration
- `chatgpt_agent` - Pydantic AI agent using OpenAI
- `grok_it()` - Function for Grok model responses
- Message history is stored in `recent_messages` dict per channel
### Memory Management
- `add_message_to_memory()` - Store messages for context
- `reset_memory()` - Clear conversation history
- `undo_reset()` - Restore previous state
## CI/CD
The GitHub Actions workflow (`.github/workflows/docker-publish.yml`) runs:
1. Ruff linting and format check
2. Dockerfile validation
3. Docker image build and push to GitHub Container Registry
## Common Tasks
### Adding a New Slash Command
1. Add the command function with `@client.tree.command()` decorator
2. Include `@app_commands.allowed_installs()` and `@app_commands.allowed_contexts()` decorators
3. Use `await interaction.response.defer()` for long-running operations
4. Check user authorization with `get_allowed_users()`
### Adding a New AI Instruction
1. Create a function decorated with `@chatgpt_agent.instructions`
2. The function should return a string with the instruction content
3. Use `RunContext[BotDependencies]` parameter to access dependencies
### Modifying Image Enhancement
Image enhancement functions (`enhance_image1`, `enhance_image2`, `enhance_image3`) use OpenCV. Each returns WebP-encoded bytes.

View file

@ -13,7 +13,7 @@ jobs:
OPENAI_TOKEN: "0" OPENAI_TOKEN: "0"
steps: steps:
# GitHub Container Registry # GitHub Container Registry
- uses: docker/login-action@v4 - uses: docker/login-action@v3
if: github.event_name != 'pull_request' if: github.event_name != 'pull_request'
with: with:
registry: ghcr.io registry: ghcr.io
@ -21,7 +21,7 @@ jobs:
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
# Download the latest commit from the master branch # Download the latest commit from the master branch
- uses: actions/checkout@v6 - uses: actions/checkout@v4
# Install the latest version of ruff # Install the latest version of ruff
- uses: astral-sh/ruff-action@v3 - uses: astral-sh/ruff-action@v3
@ -39,17 +39,15 @@ jobs:
# Extract metadata (tags, labels) from Git reference and GitHub events for Docker # Extract metadata (tags, labels) from Git reference and GitHub events for Docker
- id: meta - id: meta
uses: docker/metadata-action@v6 uses: docker/metadata-action@v5
if: github.ref == 'refs/heads/master'
with: with:
images: ghcr.io/thelovinator1/anewdawn images: ghcr.io/thelovinator1/anewdawn
tags: type=raw,value=latest tags: type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', 'master') }}
# Build and push the Docker image # Build and push the Docker image
- uses: docker/build-push-action@v7 - uses: docker/build-push-action@v6
if: github.event_name != 'pull_request' && github.ref == 'refs/heads/master'
with: with:
context: . context: .
push: true push: ${{ github.event_name != 'pull_request' }}
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
tags: ${{ steps.meta.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}

View file

@ -37,7 +37,6 @@
"numpy", "numpy",
"Ollama", "Ollama",
"opencv", "opencv",
"OPENROUTER",
"percpu", "percpu",
"phibiscarf", "phibiscarf",
"plubplub", "plubplub",

193
main.py
View file

@ -36,7 +36,6 @@ if TYPE_CHECKING:
from discord.abc import MessageableChannel from discord.abc import MessageableChannel
from discord.guild import GuildChannel from discord.guild import GuildChannel
from discord.interactions import InteractionChannel from discord.interactions import InteractionChannel
from openai.types.chat import ChatCompletion
from pydantic_ai.run import AgentRunResult from pydantic_ai.run import AgentRunResult
load_dotenv(verbose=True) load_dotenv(verbose=True)
@ -57,10 +56,6 @@ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_TOKEN", "")
recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {} recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {}
last_trigger_time: dict[str, dict[str, datetime.datetime]] = {} last_trigger_time: dict[str, dict[str, datetime.datetime]] = {}
# Storage for reset snapshots to enable undo functionality
# Each channel stores its previous state: (recent_messages_snapshot, last_trigger_time_snapshot)
reset_snapshots: dict[str, tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]]] = {}
@dataclass @dataclass
class BotDependencies: class BotDependencies:
@ -77,68 +72,20 @@ class BotDependencies:
openai_settings = OpenAIResponsesModelSettings( openai_settings = OpenAIResponsesModelSettings(
openai_text_verbosity="low", openai_text_verbosity="low",
) )
chatgpt_agent: Agent[BotDependencies, str] = Agent( agent: Agent[BotDependencies, str] = Agent(
model="gpt-5-chat-latest", model="gpt-5-chat-latest",
deps_type=BotDependencies, deps_type=BotDependencies,
model_settings=openai_settings, model_settings=openai_settings,
) )
grok_client = openai.OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=os.getenv("OPENROUTER_API_KEY"),
)
def grok_it(
message: discord.Message | None,
user_message: str,
) -> str | None:
"""Chat with the bot using the Pydantic AI agent.
Args:
user_message: The message from the user.
message: The original Discord message object.
Returns:
The bot's response as a string, or None if no response.
"""
allowed_users: list[str] = get_allowed_users()
if message and message.author.name not in allowed_users:
return None
response: ChatCompletion = grok_client.chat.completions.create(
model="x-ai/grok-4-fast:free",
messages=[
{
"role": "user",
"content": user_message,
},
],
)
return response.choices[0].message.content
# MARK: reset_memory # MARK: reset_memory
def reset_memory(channel_id: str) -> None: def reset_memory(channel_id: str) -> None:
"""Reset the conversation memory for a specific channel. """Reset the conversation memory for a specific channel.
Creates a snapshot of the current state before resetting to enable undo.
Args: Args:
channel_id (str): The ID of the channel to reset memory for. channel_id (str): The ID of the channel to reset memory for.
""" """
# Create snapshot before reset for undo functionality
messages_snapshot: deque[tuple[str, str, datetime.datetime]] = (
deque(recent_messages[channel_id], maxlen=50) if channel_id in recent_messages else deque(maxlen=50)
)
trigger_snapshot: dict[str, datetime.datetime] = dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {}
# Only save snapshot if there's something to restore
if messages_snapshot or trigger_snapshot:
reset_snapshots[channel_id] = (messages_snapshot, trigger_snapshot)
logger.info("Created reset snapshot for channel %s", channel_id)
# Perform the actual reset
if channel_id in recent_messages: if channel_id in recent_messages:
del recent_messages[channel_id] del recent_messages[channel_id]
logger.info("Reset memory for channel %s", channel_id) logger.info("Reset memory for channel %s", channel_id)
@ -147,41 +94,6 @@ def reset_memory(channel_id: str) -> None:
logger.info("Reset trigger times for channel %s", channel_id) logger.info("Reset trigger times for channel %s", channel_id)
# MARK: undo_reset
def undo_reset(channel_id: str) -> bool:
"""Undo the last reset operation for a specific channel.
Restores the conversation memory from the saved snapshot.
Args:
channel_id (str): The ID of the channel to undo reset for.
Returns:
bool: True if undo was successful, False if no snapshot exists.
"""
if channel_id not in reset_snapshots:
logger.info("No reset snapshot found for channel %s", channel_id)
return False
messages_snapshot, trigger_snapshot = reset_snapshots[channel_id]
# Restore recent messages
if messages_snapshot:
recent_messages[channel_id] = messages_snapshot
logger.info("Restored messages for channel %s", channel_id)
# Restore trigger times
if trigger_snapshot:
last_trigger_time[channel_id] = trigger_snapshot
logger.info("Restored trigger times for channel %s", channel_id)
# Remove the snapshot after successful undo (only one undo allowed)
del reset_snapshots[channel_id]
logger.info("Removed reset snapshot for channel %s after undo", channel_id)
return True
def _message_text_length(msg: ModelRequest | ModelResponse) -> int: def _message_text_length(msg: ModelRequest | ModelResponse) -> int:
"""Compute the total text length of all text parts in a message. """Compute the total text length of all text parts in a message.
@ -231,7 +143,7 @@ def compact_message_history(
# MARK: fetch_user_info # MARK: fetch_user_info
@chatgpt_agent.instructions @agent.instructions
def fetch_user_info(ctx: RunContext[BotDependencies]) -> str: def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
"""Fetches detailed information about the user who sent the message, including their roles, status, and activity. """Fetches detailed information about the user who sent the message, including their roles, status, and activity.
@ -252,7 +164,7 @@ def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
# MARK: get_system_performance_stats # MARK: get_system_performance_stats
@chatgpt_agent.instructions @agent.instructions
def get_system_performance_stats() -> str: def get_system_performance_stats() -> str:
"""Retrieves current system performance metrics, including CPU, memory, and disk usage. """Retrieves current system performance metrics, including CPU, memory, and disk usage.
@ -269,7 +181,7 @@ def get_system_performance_stats() -> str:
# MARK: get_channels # MARK: get_channels
@chatgpt_agent.instructions @agent.instructions
def get_channels(ctx: RunContext[BotDependencies]) -> str: def get_channels(ctx: RunContext[BotDependencies]) -> str:
"""Retrieves a list of all channels the bot is currently in. """Retrieves a list of all channels the bot is currently in.
@ -308,7 +220,7 @@ def do_web_search(query: str) -> ollama.WebSearchResponse | None:
# MARK: get_time_and_timezone # MARK: get_time_and_timezone
@chatgpt_agent.instructions @agent.instructions
def get_time_and_timezone() -> str: def get_time_and_timezone() -> str:
"""Retrieves the current time and timezone information. """Retrieves the current time and timezone information.
@ -320,7 +232,7 @@ def get_time_and_timezone() -> str:
# MARK: get_latency # MARK: get_latency
@chatgpt_agent.instructions @agent.instructions
def get_latency(ctx: RunContext[BotDependencies]) -> str: def get_latency(ctx: RunContext[BotDependencies]) -> str:
"""Retrieves the current latency information. """Retrieves the current latency information.
@ -332,7 +244,7 @@ def get_latency(ctx: RunContext[BotDependencies]) -> str:
# MARK: added_information_from_web_search # MARK: added_information_from_web_search
@chatgpt_agent.instructions @agent.instructions
def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str: def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
"""Adds information from a web search to the system prompt. """Adds information from a web search to the system prompt.
@ -350,7 +262,7 @@ def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
# MARK: get_sticker_instructions # MARK: get_sticker_instructions
@chatgpt_agent.instructions @agent.instructions
def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str: def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str:
"""Provides instructions for using stickers in the chat. """Provides instructions for using stickers in the chat.
@ -379,7 +291,7 @@ def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str:
# MARK: get_emoji_instructions # MARK: get_emoji_instructions
@chatgpt_agent.instructions @agent.instructions
def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str: def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
"""Provides instructions for using emojis in the chat. """Provides instructions for using emojis in the chat.
@ -428,7 +340,7 @@ def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
# MARK: get_system_prompt # MARK: get_system_prompt
@chatgpt_agent.instructions @agent.instructions
def get_system_prompt() -> str: def get_system_prompt() -> str:
"""Generate the core system prompt. """Generate the core system prompt.
@ -498,7 +410,7 @@ async def chat( # noqa: PLR0913, PLR0917
images: list[str] = await get_images_from_text(user_message) images: list[str] = await get_images_from_text(user_message)
result: AgentRunResult[str] = await chatgpt_agent.run( result: AgentRunResult[str] = await agent.run(
user_prompt=[ user_prompt=[
user_message, user_message,
*[ImageUrl(url=image_url) for image_url in images], *[ImageUrl(url=image_url) for image_url in images],
@ -853,62 +765,6 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo
await send_response(interaction=interaction, text=text, response=display_response) await send_response(interaction=interaction, text=text, response=display_response)
# MARK: /grok command
@client.tree.command(name="grok", description="Grok a question.")
@app_commands.allowed_installs(guilds=True, users=True)
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
@app_commands.describe(text="Grok a question.")
async def grok(interaction: discord.Interaction, text: str) -> None:
"""A command to ask the AI a question.
Args:
interaction (discord.Interaction): The interaction object.
text (str): The question or message to ask.
"""
await interaction.response.defer()
if not text:
logger.error("No question or message provided.")
await interaction.followup.send("You need to provide a question or message.", ephemeral=True)
return
user_name_lowercase: str = interaction.user.name.lower()
logger.info("Received command from: %s", user_name_lowercase)
# Only allow certain users to interact with the bot
allowed_users: list[str] = get_allowed_users()
if user_name_lowercase not in allowed_users:
await send_response(interaction=interaction, text=text, response="You are not authorized to use this command.")
return
# Get model response
try:
model_response: str | None = grok_it(message=interaction.message, user_message=text)
except openai.OpenAIError as e:
logger.exception("An error occurred while chatting with the AI model.")
await send_response(interaction=interaction, text=text, response=f"An error occurred: {e}")
return
truncated_text: str = truncate_user_input(text)
# Fallback if model provided no response
if not model_response:
logger.warning("No response from the AI model. Message: %s", text)
model_response = "I forgor how to think 💀"
display_response: str = f"`{truncated_text}`\n\n{model_response}"
logger.info("Responding to message: %s with: %s", text, display_response)
# If response is longer than 2000 characters, split it into multiple messages
max_discord_message_length: int = 2000
if len(display_response) > max_discord_message_length:
for i in range(0, len(display_response), max_discord_message_length):
await send_response(interaction=interaction, text=text, response=display_response[i : i + max_discord_message_length])
return
await send_response(interaction=interaction, text=text, response=display_response)
# MARK: /reset command # MARK: /reset command
@client.tree.command(name="reset", description="Reset the conversation memory.") @client.tree.command(name="reset", description="Reset the conversation memory.")
@app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_installs(guilds=True, users=True)
@ -933,33 +789,6 @@ async def reset(interaction: discord.Interaction) -> None:
await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.") await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.")
# MARK: /undo command
@client.tree.command(name="undo", description="Undo the last /reset command.")
@app_commands.allowed_installs(guilds=True, users=True)
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
async def undo(interaction: discord.Interaction) -> None:
"""A command to undo the last reset operation."""
await interaction.response.defer()
user_name_lowercase: str = interaction.user.name.lower()
logger.info("Received undo command from: %s", user_name_lowercase)
# Only allow certain users to interact with the bot
allowed_users: list[str] = get_allowed_users()
if user_name_lowercase not in allowed_users:
await send_response(interaction=interaction, text="", response="You are not authorized to use this command.")
return
# Undo the last reset
if interaction.channel is not None:
if undo_reset(str(interaction.channel.id)):
await interaction.followup.send(f"Successfully restored conversation memory for {interaction.channel}.")
else:
await interaction.followup.send(f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.")
else:
await interaction.followup.send("Cannot undo: No channel context available.")
# MARK: send_response # MARK: send_response
async def send_response(interaction: discord.Interaction, text: str, response: str) -> None: async def send_response(interaction: discord.Interaction, text: str, response: str) -> None:
"""Send a response to the interaction, handling potential errors. """Send a response to the interaction, handling potential errors.

View file

@ -65,7 +65,6 @@ docstring-code-line-length = 20
"ARG", # Unused function args -> fixtures nevertheless are functionally relevant... "ARG", # Unused function args -> fixtures nevertheless are functionally relevant...
"FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize() "FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize()
"PLR2004", # Magic value used in comparison, ... "PLR2004", # Magic value used in comparison, ...
"PLR6301", # Method could be a function, class method, or static method
"S101", # asserts allowed in tests... "S101", # asserts allowed in tests...
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
] ]
@ -77,9 +76,3 @@ log_cli_level = "INFO"
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
log_cli_date_format = "%Y-%m-%d %H:%M:%S" log_cli_date_format = "%Y-%m-%d %H:%M:%S"
python_files = "test_*.py *_test.py *_tests.py" python_files = "test_*.py *_test.py *_tests.py"
[dependency-groups]
dev = [
"pytest>=9.0.1",
"ruff>=0.14.7",
]

View file

@ -1,201 +0,0 @@
from __future__ import annotations
import pytest
from main import (
add_message_to_memory,
last_trigger_time,
recent_messages,
reset_memory,
reset_snapshots,
undo_reset,
update_trigger_time,
)
@pytest.fixture(autouse=True)
def clear_state() -> None:
"""Clear all state before each test."""
recent_messages.clear()
last_trigger_time.clear()
reset_snapshots.clear()
class TestResetMemory:
"""Tests for the reset_memory function."""
def test_reset_memory_clears_messages(self) -> None:
"""Test that reset_memory clears messages for the channel."""
channel_id = "test_channel_123"
add_message_to_memory(channel_id, "user1", "Hello")
add_message_to_memory(channel_id, "user2", "World")
assert channel_id in recent_messages
assert len(recent_messages[channel_id]) == 2
reset_memory(channel_id)
assert channel_id not in recent_messages
def test_reset_memory_clears_trigger_times(self) -> None:
"""Test that reset_memory clears trigger times for the channel."""
channel_id = "test_channel_123"
update_trigger_time(channel_id, "user1")
assert channel_id in last_trigger_time
reset_memory(channel_id)
assert channel_id not in last_trigger_time
def test_reset_memory_creates_snapshot(self) -> None:
"""Test that reset_memory creates a snapshot for undo."""
channel_id = "test_channel_123"
add_message_to_memory(channel_id, "user1", "Test message")
update_trigger_time(channel_id, "user1")
reset_memory(channel_id)
assert channel_id in reset_snapshots
messages_snapshot, trigger_snapshot = reset_snapshots[channel_id]
assert len(messages_snapshot) == 1
assert "user1" in trigger_snapshot
def test_reset_memory_no_snapshot_for_empty_channel(self) -> None:
"""Test that reset_memory doesn't create snapshot for empty channel."""
channel_id = "empty_channel"
reset_memory(channel_id)
assert channel_id not in reset_snapshots
class TestUndoReset:
"""Tests for the undo_reset function."""
def test_undo_reset_restores_messages(self) -> None:
"""Test that undo_reset restores messages."""
channel_id = "test_channel_123"
add_message_to_memory(channel_id, "user1", "Hello")
add_message_to_memory(channel_id, "user2", "World")
reset_memory(channel_id)
assert channel_id not in recent_messages
result = undo_reset(channel_id)
assert result is True
assert channel_id in recent_messages
assert len(recent_messages[channel_id]) == 2
def test_undo_reset_restores_trigger_times(self) -> None:
"""Test that undo_reset restores trigger times."""
channel_id = "test_channel_123"
update_trigger_time(channel_id, "user1")
original_time = last_trigger_time[channel_id]["user1"]
reset_memory(channel_id)
assert channel_id not in last_trigger_time
result = undo_reset(channel_id)
assert result is True
assert channel_id in last_trigger_time
assert last_trigger_time[channel_id]["user1"] == original_time
def test_undo_reset_removes_snapshot(self) -> None:
"""Test that undo_reset removes the snapshot after restoring."""
channel_id = "test_channel_123"
add_message_to_memory(channel_id, "user1", "Hello")
reset_memory(channel_id)
assert channel_id in reset_snapshots
undo_reset(channel_id)
assert channel_id not in reset_snapshots
def test_undo_reset_returns_false_when_no_snapshot(self) -> None:
"""Test that undo_reset returns False when no snapshot exists."""
channel_id = "nonexistent_channel"
result = undo_reset(channel_id)
assert result is False
def test_undo_reset_only_works_once(self) -> None:
"""Test that undo_reset only works once (snapshot is removed after undo)."""
channel_id = "test_channel_123"
add_message_to_memory(channel_id, "user1", "Hello")
reset_memory(channel_id)
first_undo = undo_reset(channel_id)
second_undo = undo_reset(channel_id)
assert first_undo is True
assert second_undo is False
class TestResetUndoIntegration:
"""Integration tests for reset and undo functionality."""
def test_reset_then_undo_preserves_content(self) -> None:
"""Test that reset followed by undo preserves original content."""
channel_id = "test_channel_123"
add_message_to_memory(channel_id, "user1", "Message 1")
add_message_to_memory(channel_id, "user2", "Message 2")
add_message_to_memory(channel_id, "user3", "Message 3")
update_trigger_time(channel_id, "user1")
update_trigger_time(channel_id, "user2")
# Capture original state
original_messages = list(recent_messages[channel_id])
original_trigger_users = set(last_trigger_time[channel_id].keys())
reset_memory(channel_id)
undo_reset(channel_id)
# Verify restored state matches original
restored_messages = list(recent_messages[channel_id])
restored_trigger_users = set(last_trigger_time[channel_id].keys())
assert len(restored_messages) == len(original_messages)
assert restored_trigger_users == original_trigger_users
def test_multiple_resets_overwrite_snapshot(self) -> None:
"""Test that multiple resets overwrite the previous snapshot."""
channel_id = "test_channel_123"
# First set of messages
add_message_to_memory(channel_id, "user1", "First message")
reset_memory(channel_id)
# Second set of messages
add_message_to_memory(channel_id, "user1", "Second message")
add_message_to_memory(channel_id, "user1", "Third message")
reset_memory(channel_id)
# Undo should restore the second set, not the first
undo_reset(channel_id)
assert channel_id in recent_messages
assert len(recent_messages[channel_id]) == 2
def test_different_channels_independent_undo(self) -> None:
"""Test that different channels have independent undo functionality."""
channel_1 = "channel_1"
channel_2 = "channel_2"
add_message_to_memory(channel_1, "user1", "Channel 1 message")
add_message_to_memory(channel_2, "user2", "Channel 2 message")
reset_memory(channel_1)
reset_memory(channel_2)
# Undo only channel 1
undo_reset(channel_1)
assert channel_1 in recent_messages
assert channel_2 not in recent_messages
assert channel_1 not in reset_snapshots
assert channel_2 in reset_snapshots