Compare commits
10 commits
1e515a8394
...
80e0637e8a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
80e0637e8a |
||
|
|
faa77c38f6 |
||
|
|
ec325ed178 |
||
|
|
0dd877c227 |
||
|
71f29c3467 |
|||
|
|
dcfc76bdc9 |
||
|
|
5695722ad2 |
||
|
9738c37aba |
|||
| 350af2a3a9 | |||
| 10408c2fa7 |
7 changed files with 509 additions and 18 deletions
|
|
@ -1,3 +1,4 @@
|
|||
DISCORD_TOKEN=
|
||||
OPENAI_TOKEN=
|
||||
OLLAMA_API_KEY=
|
||||
OLLAMA_API_KEY=
|
||||
OPENROUTER_API_KEY=
|
||||
108
.github/copilot-instructions.md
vendored
Normal file
108
.github/copilot-instructions.md
vendored
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
# Copilot Instructions for ANewDawn
|
||||
|
||||
## Project Overview
|
||||
|
||||
ANewDawn is a Discord bot written in Python 3.13+ using the discord.py library and Pydantic AI for AI-powered chat capabilities. The bot includes features such as:
|
||||
|
||||
- AI-powered chat responses using OpenAI and Grok models
|
||||
- Conversation memory with reset/undo functionality
|
||||
- Image enhancement using OpenCV
|
||||
- Web search integration via Ollama
|
||||
- Slash commands and context menus
|
||||
|
||||
## Development Environment
|
||||
|
||||
- **Python**: 3.13 or higher required
|
||||
- **Package Manager**: Use `uv` for dependency management (see `pyproject.toml`)
|
||||
- **Docker**: The project uses Docker for deployment (see `Dockerfile` and `docker-compose.yml`)
|
||||
- **Environment Variables**: Copy `.env.example` to `.env` and fill in required tokens
|
||||
|
||||
## Code Style and Conventions
|
||||
|
||||
### Linting and Formatting
|
||||
|
||||
This project uses **Ruff** for linting and formatting with strict settings:
|
||||
|
||||
- All rules enabled (`lint.select = ["ALL"]`)
|
||||
- Preview features enabled
|
||||
- Auto-fix enabled
|
||||
- Line length: 160 characters
|
||||
- Google-style docstrings required
|
||||
|
||||
Run linting:
|
||||
```bash
|
||||
ruff check --exit-non-zero-on-fix --verbose
|
||||
```
|
||||
|
||||
Run formatting check:
|
||||
```bash
|
||||
ruff format --check --verbose
|
||||
```
|
||||
|
||||
### Python Conventions
|
||||
|
||||
- Use `from __future__ import annotations` at the top of all files (automatically added by Ruff)
|
||||
- Use type hints for all function parameters and return types
|
||||
- Follow Google docstring convention
|
||||
- Use `logging` module for logging, not print statements
|
||||
- Prefer explicit imports over wildcard imports
|
||||
|
||||
### Testing
|
||||
|
||||
- Tests use pytest
|
||||
- Test files should be named `*_test.py` or `test_*.py`
|
||||
- Run tests with: `pytest`
|
||||
|
||||
## Project Structure
|
||||
|
||||
- `main.py` - Main bot application with all commands and event handlers
|
||||
- `pyproject.toml` - Project configuration and dependencies
|
||||
- `Dockerfile` / `docker-compose.yml` - Container configuration
|
||||
- `.github/workflows/` - CI/CD workflows
|
||||
|
||||
## Key Components
|
||||
|
||||
### Bot Client
|
||||
|
||||
The main bot client is `LoviBotClient` which extends `discord.Client`. It handles:
|
||||
- Message events (`on_message`)
|
||||
- Slash commands (`/ask`, `/grok`, `/reset`, `/undo`)
|
||||
- Context menus (image enhancement)
|
||||
|
||||
### AI Integration
|
||||
|
||||
- `chatgpt_agent` - Pydantic AI agent using OpenAI
|
||||
- `grok_it()` - Function for Grok model responses
|
||||
- Message history is stored in `recent_messages` dict per channel
|
||||
|
||||
### Memory Management
|
||||
|
||||
- `add_message_to_memory()` - Store messages for context
|
||||
- `reset_memory()` - Clear conversation history
|
||||
- `undo_reset()` - Restore previous state
|
||||
|
||||
## CI/CD
|
||||
|
||||
The GitHub Actions workflow (`.github/workflows/docker-publish.yml`) runs:
|
||||
1. Ruff linting and format check
|
||||
2. Dockerfile validation
|
||||
3. Docker image build and push to GitHub Container Registry
|
||||
|
||||
## Common Tasks
|
||||
|
||||
### Adding a New Slash Command
|
||||
|
||||
1. Add the command function with `@client.tree.command()` decorator
|
||||
2. Include `@app_commands.allowed_installs()` and `@app_commands.allowed_contexts()` decorators
|
||||
3. Use `await interaction.response.defer()` for long-running operations
|
||||
4. Check user authorization with `get_allowed_users()`
|
||||
|
||||
### Adding a New AI Instruction
|
||||
|
||||
1. Create a function decorated with `@chatgpt_agent.instructions`
|
||||
2. The function should return a string with the instruction content
|
||||
3. Use `RunContext[BotDependencies]` parameter to access dependencies
|
||||
|
||||
### Modifying Image Enhancement
|
||||
|
||||
Image enhancement functions (`enhance_image1`, `enhance_image2`, `enhance_image3`) use OpenCV. Each returns WebP-encoded bytes.
|
||||
14
.github/workflows/docker-publish.yml
vendored
14
.github/workflows/docker-publish.yml
vendored
|
|
@ -13,7 +13,7 @@ jobs:
|
|||
OPENAI_TOKEN: "0"
|
||||
steps:
|
||||
# GitHub Container Registry
|
||||
- uses: docker/login-action@v3
|
||||
- uses: docker/login-action@v4
|
||||
if: github.event_name != 'pull_request'
|
||||
with:
|
||||
registry: ghcr.io
|
||||
|
|
@ -21,7 +21,7 @@ jobs:
|
|||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
# Download the latest commit from the master branch
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
# Install the latest version of ruff
|
||||
- uses: astral-sh/ruff-action@v3
|
||||
|
|
@ -39,15 +39,17 @@ jobs:
|
|||
|
||||
# Extract metadata (tags, labels) from Git reference and GitHub events for Docker
|
||||
- id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
uses: docker/metadata-action@v6
|
||||
if: github.ref == 'refs/heads/master'
|
||||
with:
|
||||
images: ghcr.io/thelovinator1/anewdawn
|
||||
tags: type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', 'master') }}
|
||||
tags: type=raw,value=latest
|
||||
|
||||
# Build and push the Docker image
|
||||
- uses: docker/build-push-action@v6
|
||||
- uses: docker/build-push-action@v7
|
||||
if: github.event_name != 'pull_request' && github.ref == 'refs/heads/master'
|
||||
with:
|
||||
context: .
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
push: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
|
|
|
|||
1
.vscode/settings.json
vendored
1
.vscode/settings.json
vendored
|
|
@ -37,6 +37,7 @@
|
|||
"numpy",
|
||||
"Ollama",
|
||||
"opencv",
|
||||
"OPENROUTER",
|
||||
"percpu",
|
||||
"phibiscarf",
|
||||
"plubplub",
|
||||
|
|
|
|||
193
main.py
193
main.py
|
|
@ -36,6 +36,7 @@ if TYPE_CHECKING:
|
|||
from discord.abc import MessageableChannel
|
||||
from discord.guild import GuildChannel
|
||||
from discord.interactions import InteractionChannel
|
||||
from openai.types.chat import ChatCompletion
|
||||
from pydantic_ai.run import AgentRunResult
|
||||
|
||||
load_dotenv(verbose=True)
|
||||
|
|
@ -56,6 +57,10 @@ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_TOKEN", "")
|
|||
recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {}
|
||||
last_trigger_time: dict[str, dict[str, datetime.datetime]] = {}
|
||||
|
||||
# Storage for reset snapshots to enable undo functionality
|
||||
# Each channel stores its previous state: (recent_messages_snapshot, last_trigger_time_snapshot)
|
||||
reset_snapshots: dict[str, tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]]] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotDependencies:
|
||||
|
|
@ -72,20 +77,68 @@ class BotDependencies:
|
|||
openai_settings = OpenAIResponsesModelSettings(
|
||||
openai_text_verbosity="low",
|
||||
)
|
||||
agent: Agent[BotDependencies, str] = Agent(
|
||||
chatgpt_agent: Agent[BotDependencies, str] = Agent(
|
||||
model="gpt-5-chat-latest",
|
||||
deps_type=BotDependencies,
|
||||
model_settings=openai_settings,
|
||||
)
|
||||
grok_client = openai.OpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
)
|
||||
|
||||
|
||||
def grok_it(
|
||||
message: discord.Message | None,
|
||||
user_message: str,
|
||||
) -> str | None:
|
||||
"""Chat with the bot using the Pydantic AI agent.
|
||||
|
||||
Args:
|
||||
user_message: The message from the user.
|
||||
message: The original Discord message object.
|
||||
|
||||
Returns:
|
||||
The bot's response as a string, or None if no response.
|
||||
"""
|
||||
allowed_users: list[str] = get_allowed_users()
|
||||
if message and message.author.name not in allowed_users:
|
||||
return None
|
||||
|
||||
response: ChatCompletion = grok_client.chat.completions.create(
|
||||
model="x-ai/grok-4-fast:free",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_message,
|
||||
},
|
||||
],
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
# MARK: reset_memory
|
||||
def reset_memory(channel_id: str) -> None:
|
||||
"""Reset the conversation memory for a specific channel.
|
||||
|
||||
Creates a snapshot of the current state before resetting to enable undo.
|
||||
|
||||
Args:
|
||||
channel_id (str): The ID of the channel to reset memory for.
|
||||
"""
|
||||
# Create snapshot before reset for undo functionality
|
||||
messages_snapshot: deque[tuple[str, str, datetime.datetime]] = (
|
||||
deque(recent_messages[channel_id], maxlen=50) if channel_id in recent_messages else deque(maxlen=50)
|
||||
)
|
||||
|
||||
trigger_snapshot: dict[str, datetime.datetime] = dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {}
|
||||
|
||||
# Only save snapshot if there's something to restore
|
||||
if messages_snapshot or trigger_snapshot:
|
||||
reset_snapshots[channel_id] = (messages_snapshot, trigger_snapshot)
|
||||
logger.info("Created reset snapshot for channel %s", channel_id)
|
||||
|
||||
# Perform the actual reset
|
||||
if channel_id in recent_messages:
|
||||
del recent_messages[channel_id]
|
||||
logger.info("Reset memory for channel %s", channel_id)
|
||||
|
|
@ -94,6 +147,41 @@ def reset_memory(channel_id: str) -> None:
|
|||
logger.info("Reset trigger times for channel %s", channel_id)
|
||||
|
||||
|
||||
# MARK: undo_reset
|
||||
def undo_reset(channel_id: str) -> bool:
|
||||
"""Undo the last reset operation for a specific channel.
|
||||
|
||||
Restores the conversation memory from the saved snapshot.
|
||||
|
||||
Args:
|
||||
channel_id (str): The ID of the channel to undo reset for.
|
||||
|
||||
Returns:
|
||||
bool: True if undo was successful, False if no snapshot exists.
|
||||
"""
|
||||
if channel_id not in reset_snapshots:
|
||||
logger.info("No reset snapshot found for channel %s", channel_id)
|
||||
return False
|
||||
|
||||
messages_snapshot, trigger_snapshot = reset_snapshots[channel_id]
|
||||
|
||||
# Restore recent messages
|
||||
if messages_snapshot:
|
||||
recent_messages[channel_id] = messages_snapshot
|
||||
logger.info("Restored messages for channel %s", channel_id)
|
||||
|
||||
# Restore trigger times
|
||||
if trigger_snapshot:
|
||||
last_trigger_time[channel_id] = trigger_snapshot
|
||||
logger.info("Restored trigger times for channel %s", channel_id)
|
||||
|
||||
# Remove the snapshot after successful undo (only one undo allowed)
|
||||
del reset_snapshots[channel_id]
|
||||
logger.info("Removed reset snapshot for channel %s after undo", channel_id)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _message_text_length(msg: ModelRequest | ModelResponse) -> int:
|
||||
"""Compute the total text length of all text parts in a message.
|
||||
|
||||
|
|
@ -143,7 +231,7 @@ def compact_message_history(
|
|||
|
||||
|
||||
# MARK: fetch_user_info
|
||||
@agent.instructions
|
||||
@chatgpt_agent.instructions
|
||||
def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
|
||||
"""Fetches detailed information about the user who sent the message, including their roles, status, and activity.
|
||||
|
||||
|
|
@ -164,7 +252,7 @@ def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
|
|||
|
||||
|
||||
# MARK: get_system_performance_stats
|
||||
@agent.instructions
|
||||
@chatgpt_agent.instructions
|
||||
def get_system_performance_stats() -> str:
|
||||
"""Retrieves current system performance metrics, including CPU, memory, and disk usage.
|
||||
|
||||
|
|
@ -181,7 +269,7 @@ def get_system_performance_stats() -> str:
|
|||
|
||||
|
||||
# MARK: get_channels
|
||||
@agent.instructions
|
||||
@chatgpt_agent.instructions
|
||||
def get_channels(ctx: RunContext[BotDependencies]) -> str:
|
||||
"""Retrieves a list of all channels the bot is currently in.
|
||||
|
||||
|
|
@ -220,7 +308,7 @@ def do_web_search(query: str) -> ollama.WebSearchResponse | None:
|
|||
|
||||
|
||||
# MARK: get_time_and_timezone
|
||||
@agent.instructions
|
||||
@chatgpt_agent.instructions
|
||||
def get_time_and_timezone() -> str:
|
||||
"""Retrieves the current time and timezone information.
|
||||
|
||||
|
|
@ -232,7 +320,7 @@ def get_time_and_timezone() -> str:
|
|||
|
||||
|
||||
# MARK: get_latency
|
||||
@agent.instructions
|
||||
@chatgpt_agent.instructions
|
||||
def get_latency(ctx: RunContext[BotDependencies]) -> str:
|
||||
"""Retrieves the current latency information.
|
||||
|
||||
|
|
@ -244,7 +332,7 @@ def get_latency(ctx: RunContext[BotDependencies]) -> str:
|
|||
|
||||
|
||||
# MARK: added_information_from_web_search
|
||||
@agent.instructions
|
||||
@chatgpt_agent.instructions
|
||||
def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
|
||||
"""Adds information from a web search to the system prompt.
|
||||
|
||||
|
|
@ -262,7 +350,7 @@ def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
|
|||
|
||||
|
||||
# MARK: get_sticker_instructions
|
||||
@agent.instructions
|
||||
@chatgpt_agent.instructions
|
||||
def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str:
|
||||
"""Provides instructions for using stickers in the chat.
|
||||
|
||||
|
|
@ -291,7 +379,7 @@ def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str:
|
|||
|
||||
|
||||
# MARK: get_emoji_instructions
|
||||
@agent.instructions
|
||||
@chatgpt_agent.instructions
|
||||
def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
|
||||
"""Provides instructions for using emojis in the chat.
|
||||
|
||||
|
|
@ -340,7 +428,7 @@ def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
|
|||
|
||||
|
||||
# MARK: get_system_prompt
|
||||
@agent.instructions
|
||||
@chatgpt_agent.instructions
|
||||
def get_system_prompt() -> str:
|
||||
"""Generate the core system prompt.
|
||||
|
||||
|
|
@ -410,7 +498,7 @@ async def chat( # noqa: PLR0913, PLR0917
|
|||
|
||||
images: list[str] = await get_images_from_text(user_message)
|
||||
|
||||
result: AgentRunResult[str] = await agent.run(
|
||||
result: AgentRunResult[str] = await chatgpt_agent.run(
|
||||
user_prompt=[
|
||||
user_message,
|
||||
*[ImageUrl(url=image_url) for image_url in images],
|
||||
|
|
@ -765,6 +853,62 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo
|
|||
await send_response(interaction=interaction, text=text, response=display_response)
|
||||
|
||||
|
||||
# MARK: /grok command
|
||||
@client.tree.command(name="grok", description="Grok a question.")
|
||||
@app_commands.allowed_installs(guilds=True, users=True)
|
||||
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
||||
@app_commands.describe(text="Grok a question.")
|
||||
async def grok(interaction: discord.Interaction, text: str) -> None:
|
||||
"""A command to ask the AI a question.
|
||||
|
||||
Args:
|
||||
interaction (discord.Interaction): The interaction object.
|
||||
text (str): The question or message to ask.
|
||||
"""
|
||||
await interaction.response.defer()
|
||||
|
||||
if not text:
|
||||
logger.error("No question or message provided.")
|
||||
await interaction.followup.send("You need to provide a question or message.", ephemeral=True)
|
||||
return
|
||||
|
||||
user_name_lowercase: str = interaction.user.name.lower()
|
||||
logger.info("Received command from: %s", user_name_lowercase)
|
||||
|
||||
# Only allow certain users to interact with the bot
|
||||
allowed_users: list[str] = get_allowed_users()
|
||||
if user_name_lowercase not in allowed_users:
|
||||
await send_response(interaction=interaction, text=text, response="You are not authorized to use this command.")
|
||||
return
|
||||
|
||||
# Get model response
|
||||
try:
|
||||
model_response: str | None = grok_it(message=interaction.message, user_message=text)
|
||||
except openai.OpenAIError as e:
|
||||
logger.exception("An error occurred while chatting with the AI model.")
|
||||
await send_response(interaction=interaction, text=text, response=f"An error occurred: {e}")
|
||||
return
|
||||
|
||||
truncated_text: str = truncate_user_input(text)
|
||||
|
||||
# Fallback if model provided no response
|
||||
if not model_response:
|
||||
logger.warning("No response from the AI model. Message: %s", text)
|
||||
model_response = "I forgor how to think 💀"
|
||||
|
||||
display_response: str = f"`{truncated_text}`\n\n{model_response}"
|
||||
logger.info("Responding to message: %s with: %s", text, display_response)
|
||||
|
||||
# If response is longer than 2000 characters, split it into multiple messages
|
||||
max_discord_message_length: int = 2000
|
||||
if len(display_response) > max_discord_message_length:
|
||||
for i in range(0, len(display_response), max_discord_message_length):
|
||||
await send_response(interaction=interaction, text=text, response=display_response[i : i + max_discord_message_length])
|
||||
return
|
||||
|
||||
await send_response(interaction=interaction, text=text, response=display_response)
|
||||
|
||||
|
||||
# MARK: /reset command
|
||||
@client.tree.command(name="reset", description="Reset the conversation memory.")
|
||||
@app_commands.allowed_installs(guilds=True, users=True)
|
||||
|
|
@ -789,6 +933,33 @@ async def reset(interaction: discord.Interaction) -> None:
|
|||
await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.")
|
||||
|
||||
|
||||
# MARK: /undo command
|
||||
@client.tree.command(name="undo", description="Undo the last /reset command.")
|
||||
@app_commands.allowed_installs(guilds=True, users=True)
|
||||
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
||||
async def undo(interaction: discord.Interaction) -> None:
|
||||
"""A command to undo the last reset operation."""
|
||||
await interaction.response.defer()
|
||||
|
||||
user_name_lowercase: str = interaction.user.name.lower()
|
||||
logger.info("Received undo command from: %s", user_name_lowercase)
|
||||
|
||||
# Only allow certain users to interact with the bot
|
||||
allowed_users: list[str] = get_allowed_users()
|
||||
if user_name_lowercase not in allowed_users:
|
||||
await send_response(interaction=interaction, text="", response="You are not authorized to use this command.")
|
||||
return
|
||||
|
||||
# Undo the last reset
|
||||
if interaction.channel is not None:
|
||||
if undo_reset(str(interaction.channel.id)):
|
||||
await interaction.followup.send(f"Successfully restored conversation memory for {interaction.channel}.")
|
||||
else:
|
||||
await interaction.followup.send(f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.")
|
||||
else:
|
||||
await interaction.followup.send("Cannot undo: No channel context available.")
|
||||
|
||||
|
||||
# MARK: send_response
|
||||
async def send_response(interaction: discord.Interaction, text: str, response: str) -> None:
|
||||
"""Send a response to the interaction, handling potential errors.
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ docstring-code-line-length = 20
|
|||
"ARG", # Unused function args -> fixtures nevertheless are functionally relevant...
|
||||
"FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize()
|
||||
"PLR2004", # Magic value used in comparison, ...
|
||||
"PLR6301", # Method could be a function, class method, or static method
|
||||
"S101", # asserts allowed in tests...
|
||||
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
|
||||
]
|
||||
|
|
@ -76,3 +77,9 @@ log_cli_level = "INFO"
|
|||
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
|
||||
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
|
||||
python_files = "test_*.py *_test.py *_tests.py"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=9.0.1",
|
||||
"ruff>=0.14.7",
|
||||
]
|
||||
|
|
|
|||
201
reset_undo_test.py
Normal file
201
reset_undo_test.py
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from main import (
|
||||
add_message_to_memory,
|
||||
last_trigger_time,
|
||||
recent_messages,
|
||||
reset_memory,
|
||||
reset_snapshots,
|
||||
undo_reset,
|
||||
update_trigger_time,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_state() -> None:
|
||||
"""Clear all state before each test."""
|
||||
recent_messages.clear()
|
||||
last_trigger_time.clear()
|
||||
reset_snapshots.clear()
|
||||
|
||||
|
||||
class TestResetMemory:
|
||||
"""Tests for the reset_memory function."""
|
||||
|
||||
def test_reset_memory_clears_messages(self) -> None:
|
||||
"""Test that reset_memory clears messages for the channel."""
|
||||
channel_id = "test_channel_123"
|
||||
add_message_to_memory(channel_id, "user1", "Hello")
|
||||
add_message_to_memory(channel_id, "user2", "World")
|
||||
|
||||
assert channel_id in recent_messages
|
||||
assert len(recent_messages[channel_id]) == 2
|
||||
|
||||
reset_memory(channel_id)
|
||||
|
||||
assert channel_id not in recent_messages
|
||||
|
||||
def test_reset_memory_clears_trigger_times(self) -> None:
|
||||
"""Test that reset_memory clears trigger times for the channel."""
|
||||
channel_id = "test_channel_123"
|
||||
update_trigger_time(channel_id, "user1")
|
||||
|
||||
assert channel_id in last_trigger_time
|
||||
|
||||
reset_memory(channel_id)
|
||||
|
||||
assert channel_id not in last_trigger_time
|
||||
|
||||
def test_reset_memory_creates_snapshot(self) -> None:
|
||||
"""Test that reset_memory creates a snapshot for undo."""
|
||||
channel_id = "test_channel_123"
|
||||
add_message_to_memory(channel_id, "user1", "Test message")
|
||||
update_trigger_time(channel_id, "user1")
|
||||
|
||||
reset_memory(channel_id)
|
||||
|
||||
assert channel_id in reset_snapshots
|
||||
messages_snapshot, trigger_snapshot = reset_snapshots[channel_id]
|
||||
assert len(messages_snapshot) == 1
|
||||
assert "user1" in trigger_snapshot
|
||||
|
||||
def test_reset_memory_no_snapshot_for_empty_channel(self) -> None:
|
||||
"""Test that reset_memory doesn't create snapshot for empty channel."""
|
||||
channel_id = "empty_channel"
|
||||
|
||||
reset_memory(channel_id)
|
||||
|
||||
assert channel_id not in reset_snapshots
|
||||
|
||||
|
||||
class TestUndoReset:
|
||||
"""Tests for the undo_reset function."""
|
||||
|
||||
def test_undo_reset_restores_messages(self) -> None:
|
||||
"""Test that undo_reset restores messages."""
|
||||
channel_id = "test_channel_123"
|
||||
add_message_to_memory(channel_id, "user1", "Hello")
|
||||
add_message_to_memory(channel_id, "user2", "World")
|
||||
|
||||
reset_memory(channel_id)
|
||||
assert channel_id not in recent_messages
|
||||
|
||||
result = undo_reset(channel_id)
|
||||
|
||||
assert result is True
|
||||
assert channel_id in recent_messages
|
||||
assert len(recent_messages[channel_id]) == 2
|
||||
|
||||
def test_undo_reset_restores_trigger_times(self) -> None:
|
||||
"""Test that undo_reset restores trigger times."""
|
||||
channel_id = "test_channel_123"
|
||||
update_trigger_time(channel_id, "user1")
|
||||
original_time = last_trigger_time[channel_id]["user1"]
|
||||
|
||||
reset_memory(channel_id)
|
||||
assert channel_id not in last_trigger_time
|
||||
|
||||
result = undo_reset(channel_id)
|
||||
|
||||
assert result is True
|
||||
assert channel_id in last_trigger_time
|
||||
assert last_trigger_time[channel_id]["user1"] == original_time
|
||||
|
||||
def test_undo_reset_removes_snapshot(self) -> None:
|
||||
"""Test that undo_reset removes the snapshot after restoring."""
|
||||
channel_id = "test_channel_123"
|
||||
add_message_to_memory(channel_id, "user1", "Hello")
|
||||
|
||||
reset_memory(channel_id)
|
||||
assert channel_id in reset_snapshots
|
||||
|
||||
undo_reset(channel_id)
|
||||
|
||||
assert channel_id not in reset_snapshots
|
||||
|
||||
def test_undo_reset_returns_false_when_no_snapshot(self) -> None:
|
||||
"""Test that undo_reset returns False when no snapshot exists."""
|
||||
channel_id = "nonexistent_channel"
|
||||
|
||||
result = undo_reset(channel_id)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_undo_reset_only_works_once(self) -> None:
|
||||
"""Test that undo_reset only works once (snapshot is removed after undo)."""
|
||||
channel_id = "test_channel_123"
|
||||
add_message_to_memory(channel_id, "user1", "Hello")
|
||||
|
||||
reset_memory(channel_id)
|
||||
first_undo = undo_reset(channel_id)
|
||||
second_undo = undo_reset(channel_id)
|
||||
|
||||
assert first_undo is True
|
||||
assert second_undo is False
|
||||
|
||||
|
||||
class TestResetUndoIntegration:
|
||||
"""Integration tests for reset and undo functionality."""
|
||||
|
||||
def test_reset_then_undo_preserves_content(self) -> None:
|
||||
"""Test that reset followed by undo preserves original content."""
|
||||
channel_id = "test_channel_123"
|
||||
add_message_to_memory(channel_id, "user1", "Message 1")
|
||||
add_message_to_memory(channel_id, "user2", "Message 2")
|
||||
add_message_to_memory(channel_id, "user3", "Message 3")
|
||||
update_trigger_time(channel_id, "user1")
|
||||
update_trigger_time(channel_id, "user2")
|
||||
|
||||
# Capture original state
|
||||
original_messages = list(recent_messages[channel_id])
|
||||
original_trigger_users = set(last_trigger_time[channel_id].keys())
|
||||
|
||||
reset_memory(channel_id)
|
||||
undo_reset(channel_id)
|
||||
|
||||
# Verify restored state matches original
|
||||
restored_messages = list(recent_messages[channel_id])
|
||||
restored_trigger_users = set(last_trigger_time[channel_id].keys())
|
||||
|
||||
assert len(restored_messages) == len(original_messages)
|
||||
assert restored_trigger_users == original_trigger_users
|
||||
|
||||
def test_multiple_resets_overwrite_snapshot(self) -> None:
|
||||
"""Test that multiple resets overwrite the previous snapshot."""
|
||||
channel_id = "test_channel_123"
|
||||
|
||||
# First set of messages
|
||||
add_message_to_memory(channel_id, "user1", "First message")
|
||||
reset_memory(channel_id)
|
||||
|
||||
# Second set of messages
|
||||
add_message_to_memory(channel_id, "user1", "Second message")
|
||||
add_message_to_memory(channel_id, "user1", "Third message")
|
||||
reset_memory(channel_id)
|
||||
|
||||
# Undo should restore the second set, not the first
|
||||
undo_reset(channel_id)
|
||||
|
||||
assert channel_id in recent_messages
|
||||
assert len(recent_messages[channel_id]) == 2
|
||||
|
||||
def test_different_channels_independent_undo(self) -> None:
|
||||
"""Test that different channels have independent undo functionality."""
|
||||
channel_1 = "channel_1"
|
||||
channel_2 = "channel_2"
|
||||
|
||||
add_message_to_memory(channel_1, "user1", "Channel 1 message")
|
||||
add_message_to_memory(channel_2, "user2", "Channel 2 message")
|
||||
|
||||
reset_memory(channel_1)
|
||||
reset_memory(channel_2)
|
||||
|
||||
# Undo only channel 1
|
||||
undo_reset(channel_1)
|
||||
|
||||
assert channel_1 in recent_messages
|
||||
assert channel_2 not in recent_messages
|
||||
assert channel_1 not in reset_snapshots
|
||||
assert channel_2 in reset_snapshots
|
||||
Loading…
Add table
Add a link
Reference in a new issue