diff --git a/.env.example b/.env.example
index aae1f64..5fb16cb 100644
--- a/.env.example
+++ b/.env.example
@@ -1,2 +1,3 @@
DISCORD_TOKEN=
OPENAI_TOKEN=
+OLLAMA_API_KEY=
diff --git a/.gitea/workflows/docker-check.yml b/.gitea/workflows/docker-check.yml
deleted file mode 100644
index ff43f68..0000000
--- a/.gitea/workflows/docker-check.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-name: Docker Build Check
-
-on:
- push:
- paths:
- - 'Dockerfile'
- pull_request:
- paths:
- - 'Dockerfile'
-
-jobs:
- docker-check:
- runs-on: ubuntu-latest
- steps:
- - name: Checkout code
- uses: actions/checkout@v4
-
- - name: Run Docker Build Check
- run: docker build --check .
diff --git a/.gitea/workflows/docker-publish.yml b/.gitea/workflows/docker-publish.yml
deleted file mode 100644
index c4aca97..0000000
--- a/.gitea/workflows/docker-publish.yml
+++ /dev/null
@@ -1,67 +0,0 @@
-name: Build Docker Image
-
-on:
- push:
- branches:
- - master
- pull_request:
- workflow_dispatch:
- schedule:
- - cron: "@daily"
-
-cache:
- enabled: true
- dir: ""
- host: "192.168.1.127"
- port: 8088
-
-jobs:
- docker:
- runs-on: ubuntu-latest
- env:
- DISCORD_TOKEN: "0"
- OPENAI_TOKEN: "0"
- if: gitea.event_name != 'pull_request'
- steps:
- - uses: https://github.com/actions/checkout@v4
- - uses: https://github.com/docker/setup-qemu-action@v3
- - uses: https://github.com/docker/setup-buildx-action@v3
- - uses: https://github.com/astral-sh/ruff-action@v3
-
- - run: docker build --check .
- - run: ruff check --exit-non-zero-on-fix --verbose
- - run: ruff format --check --verbose
-
- - id: meta
- uses: https://github.com/docker/metadata-action@v5
- env:
- DOCKER_METADATA_ANNOTATIONS_LEVELS: manifest,index
- with:
- images: |
- ghcr.io/thelovinator1/anewdawn
- git.lovinator.space/thelovinator/anewdawn
- tags: type=raw,value=latest,enable=${{ gitea.ref == format('refs/heads/{0}', 'master') }}
-
- # GitHub Container Registry
- - uses: https://github.com/docker/login-action@v3
- if: github.event_name != 'pull_request'
- with:
- registry: ghcr.io
- username: thelovinator1
- password: ${{ secrets.PACKAGES_WRITE_GITHUB_TOKEN }}
-
- # Gitea Container Registry
- - uses: https://github.com/docker/login-action@v3
- if: github.event_name != 'pull_request'
- with:
- registry: git.lovinator.space
- username: thelovinator
- password: ${{ secrets.GITEA_TOKEN }}
-
- - uses: https://github.com/docker/build-push-action@v6
- with:
- context: .
- push: ${{ gitea.event_name != 'pull_request' }}
- labels: ${{ steps.meta.outputs.labels }}
- tags: ${{ steps.meta.outputs.tags }}
- annotations: ${{ steps.meta.outputs.annotations }}
diff --git a/.gitea/workflows/ruff.yml b/.gitea/workflows/ruff.yml
deleted file mode 100644
index 28b5029..0000000
--- a/.gitea/workflows/ruff.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-name: Ruff
-
-on:
- push:
- pull_request:
- workflow_dispatch:
- schedule:
- - cron: '0 0 * * *' # Run every day at midnight
-
-env:
- RUFF_OUTPUT_FORMAT: github
-jobs:
- ruff:
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v4
- - uses: astral-sh/ruff-action@v3
- - run: ruff check --exit-non-zero-on-fix --verbose
- - run: ruff format --check --verbose
diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md
index 7921bbe..e43899f 100644
--- a/.github/copilot-instructions.md
+++ b/.github/copilot-instructions.md
@@ -1,36 +1,107 @@
-# Custom Instructions for GitHub Copilot
+# Copilot Instructions for ANewDawn
## Project Overview
-This is a Python project named ANewDawn. It uses Docker for containerization (`Dockerfile`, `docker-compose.yml`). Key files include `main.py` and `settings.py`.
+
+ANewDawn is a Discord bot written in Python 3.13+ using the discord.py library and Pydantic AI for AI-powered chat capabilities. The bot includes features such as:
+
+- AI-powered chat responses using OpenAI models
+- Conversation memory with reset/undo functionality
+- Image enhancement using OpenCV
+- Web search integration via Ollama
+- Slash commands and context menus
## Development Environment
-- **Operating System:** Windows
-- **Default Shell:** PowerShell (`pwsh.exe`). Please generate terminal commands compatible with PowerShell.
-## Coding Standards
-- **Linting & Formatting:** We use `ruff` for linting and formatting. Adhere to `ruff` standards. Configuration is in `.github/workflows/ruff.yml` and possibly `pyproject.toml` or `ruff.toml`.
-- **Python Version:** 3.13
-- **Dependencies:** Managed using `uv` and listed in `pyproject.toml`. Commands include:
- - `uv run pytest` for testing.
- - `uv add ` for package installation.
- - `uv sync --upgrade` for dependency updates.
- - `uv run python main.py` to run the project.
+- **Python**: 3.13 or higher required
+- **Package Manager**: Use `uv` for dependency management (see `pyproject.toml`)
+- **Deployment**: The project is designed to run as a systemd service (see `systemd/anewdawn.service`)
+- **Environment Variables**: Copy `.env.example` to `.env` and fill in required tokens
-## General Guidelines
-- Follow Python best practices.
-- Write clear, concise code.
-- Add comments only for complex logic.
-- Ensure compatibility with the Docker environment.
-- Use `uv` commands for package management and scripts.
-- Use `docker` and `docker-compose` for container tasks:
- - Build: `docker build -t .`
- - Run: `docker run ` or `docker-compose up`.
- - Stop/Remove: `docker stop ` and `docker rm `.
+## Code Style and Conventions
-## Discord Bot Functionality
-- **Chat Interaction:** Responds to messages containing "lovibot" or its mention (`<@345000831499894795>`) using the OpenAI chat API (`gpt-4o-mini`). See `on_message` event handler and `misc.chat` function.
-- **Slash Commands:**
- - `/ask `: Directly ask the AI a question. Uses `misc.chat`.
-- **Context Menu Commands:**
- - `Enhance Image`: Right-click on a message with an image to enhance it using OpenCV methods (`enhance_image1`, `enhance_image2`, `enhance_image3`).
-- **User Restrictions:** Interaction is limited to users listed in `misc.get_allowed_users()`. Image creation has additional restrictions.
+### Linting and Formatting
+
+This project uses **Ruff** for linting and formatting with strict settings:
+
+- All rules enabled (`lint.select = ["ALL"]`)
+- Preview features enabled
+- Auto-fix enabled
+- Line length: 160 characters
+- Google-style docstrings required
+
+Run linting:
+```bash
+ruff check --exit-non-zero-on-fix --verbose
+```
+
+Run formatting check:
+```bash
+ruff format --check --verbose
+```
+
+### Python Conventions
+
+- Use `from __future__ import annotations` at the top of all files (automatically added by Ruff)
+- Use type hints for all function parameters and return types
+- Follow Google docstring convention
+- Use `logging` module for logging, not print statements
+- Prefer explicit imports over wildcard imports
+
+### Testing
+
+- Tests use pytest
+- Test files should be named `*_test.py` or `test_*.py`
+- Run tests with: `pytest`
+
+## Project Structure
+
+- `main.py` - Main bot application with all commands and event handlers
+- `pyproject.toml` - Project configuration and dependencies
+- `systemd/` - systemd unit and environment templates
+- `.github/workflows/` - CI/CD workflows
+
+## Key Components
+
+### Bot Client
+
+The main bot client is `LoviBotClient` which extends `discord.Client`. It handles:
+- Message events (`on_message`)
+- Slash commands (`/ask`, `/reset`, `/undo`)
+- Context menus (image enhancement)
+
+### AI Integration
+
+- `chatgpt_agent` - Pydantic AI agent using OpenAI
+- Message history is stored in `recent_messages` dict per channel
+
+### Memory Management
+
+- `add_message_to_memory()` - Store messages for context
+- `reset_memory()` - Clear conversation history
+- `undo_reset()` - Restore previous state
+
+## CI/CD
+
+The GitHub Actions workflow (`.github/workflows/ci.yml`) runs:
+1. Dependency install via `uv sync`
+2. Ruff linting and format check
+3. Unit tests via `pytest`
+
+## Common Tasks
+
+### Adding a New Slash Command
+
+1. Add the command function with `@client.tree.command()` decorator
+2. Include `@app_commands.allowed_installs()` and `@app_commands.allowed_contexts()` decorators
+3. Use `await interaction.response.defer()` for long-running operations
+4. Check user authorization with `get_allowed_users()`
+
+### Adding a New AI Instruction
+
+1. Create a function decorated with `@chatgpt_agent.instructions`
+2. The function should return a string with the instruction content
+3. Use `RunContext[BotDependencies]` parameter to access dependencies
+
+### Modifying Image Enhancement
+
+Image enhancement functions (`enhance_image1`, `enhance_image2`, `enhance_image3`) use OpenCV. Each returns WebP-encoded bytes.
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..24fcdc3
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,43 @@
+name: CI
+
+on:
+ push:
+ pull_request:
+ workflow_dispatch:
+
+jobs:
+ ci:
+ runs-on: self-hosted
+ env:
+ DISCORD_TOKEN: "0"
+ OPENAI_TOKEN: "0"
+
+ steps:
+ - uses: actions/checkout@v6
+
+ - name: Install dependencies
+ run: uv sync --all-extras --dev -U
+
+ - name: Lint the Python code using ruff
+ run: ruff check --exit-non-zero-on-fix --verbose
+
+ - name: Check formatting
+ run: ruff format --check --verbose
+
+ - name: Run tests
+ run: uv run pytest
+
+ # NOTE: The runner must be allowed to run these commands without a password.
+ # sudo EDITOR=nvim visudo
+ # forgejo-runner ALL=(lovinator) NOPASSWD: /usr/bin/git -C /home/lovinator/ANewDawn pull
+ # forgejo-runner ALL=(root) NOPASSWD: /bin/systemctl restart anewdawn.service
+ # forgejo-runner ALL=(lovinator) NOPASSWD: /usr/bin/uv sync -U --all-extras --dev --directory /home/lovinator/ANewDawn
+ - name: Deploy & restart bot (master only)
+ if: ${{ success() && github.ref == 'refs/heads/master' }}
+ run: |
+ # Keep checkout in the Forgejo runner workspace, whatever that is.
+ # actions/checkout already checks out to the runner's working directory.
+
+ sudo -u lovinator git -C /home/lovinator/ANewDawn pull
+ sudo -u lovinator uv sync -U --all-extras --dev --directory /home/lovinator/ANewDawn
+ sudo systemctl restart anewdawn.service
diff --git a/.github/workflows/docker-check.yml b/.github/workflows/docker-check.yml
deleted file mode 100644
index ff43f68..0000000
--- a/.github/workflows/docker-check.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-name: Docker Build Check
-
-on:
- push:
- paths:
- - 'Dockerfile'
- pull_request:
- paths:
- - 'Dockerfile'
-
-jobs:
- docker-check:
- runs-on: ubuntu-latest
- steps:
- - name: Checkout code
- uses: actions/checkout@v4
-
- - name: Run Docker Build Check
- run: docker build --check .
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 29b3e9a..9925fed 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,41 +1,39 @@
repos:
- repo: https://github.com/asottile/add-trailing-comma
- rev: v3.1.0
+ rev: v4.0.0
hooks:
- id: add-trailing-comma
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v5.0.0
+ rev: v6.0.0
hooks:
- - id: check-added-large-files
- id: check-ast
- id: check-builtin-literals
+ - id: check-docstring-first
- id: check-executables-have-shebangs
- id: check-merge-conflict
- - id: check-shebang-scripts-are-executable
- id: check-toml
- id: check-vcs-permalinks
- - id: check-yaml
- id: end-of-file-fixer
- id: mixed-line-ending
- id: name-tests-test
- args: ["--pytest-test-first"]
+ args: [--pytest-test-first]
- id: trailing-whitespace
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.15.6
+ hooks:
+ - id: ruff-check
+ args: ["--fix", "--exit-non-zero-on-fix"]
+ - id: ruff-format
+
- repo: https://github.com/asottile/pyupgrade
- rev: v3.19.1
+ rev: v3.21.2
hooks:
- id: pyupgrade
args: ["--py311-plus"]
- - repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.11.5
- hooks:
- - id: ruff-format
- - id: ruff
- args: ["--fix", "--exit-non-zero-on-fix"]
-
- repo: https://github.com/rhysd/actionlint
- rev: v1.7.7
+ rev: v1.7.11
hooks:
- id: actionlint
diff --git a/.vscode/settings.json b/.vscode/settings.json
index 55560d2..d5a5404 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -5,35 +5,56 @@
"audioop",
"automerge",
"buildx",
+ "CLAHE",
+ "Denoise",
"denoising",
"docstrings",
"dotenv",
+ "etherlithium",
+ "Femboy",
"forgefilip",
"forgor",
+ "Fredagsmys",
+ "Frieren",
"frombuffer",
"hikari",
"imdecode",
"imencode",
"IMREAD",
+ "IMWRITE",
"isort",
"killyoy",
"levelname",
+ "Licka",
+ "Lördagsgodis",
"lovibot",
"Lovinator",
+ "Messageable",
+ "mountpoint",
"ndarray",
"nobot",
"nparr",
"numpy",
+ "Ollama",
"opencv",
+ "percpu",
+ "phibiscarf",
"plubplub",
"pycodestyle",
"pydocstyle",
"pyproject",
"PYTHONDONTWRITEBYTECODE",
"PYTHONUNBUFFERED",
+ "Slowmode",
+ "Sniffa",
+ "sweary",
"testpaths",
"thelovinator",
+ "Thicc",
"tobytes",
- "unsignedinteger"
+ "twimg",
+ "unsignedinteger",
+ "Waifu",
+ "Zenless"
]
}
diff --git a/Dockerfile b/Dockerfile
deleted file mode 100644
index d4dde17..0000000
--- a/Dockerfile
+++ /dev/null
@@ -1,21 +0,0 @@
-# syntax=docker/dockerfile:1
-# check=error=true;experimental=all
-FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim@sha256:73c021c3fe7264924877039e8a449ad3bb380ec89214282301affa9b2f863c5d
-
-# Change the working directory to the `app` directory
-WORKDIR /app
-
-# Install dependencies
-RUN --mount=type=cache,target=/root/.cache/uv \
- --mount=type=bind,source=pyproject.toml,target=pyproject.toml \
- uv sync --no-install-project
-
-# Copy the application files
-COPY main.py misc.py settings.py /app/
-
-# Set the environment variables
-ENV PYTHONUNBUFFERED=1
-ENV PYTHONDONTWRITEBYTECODE=1
-
-# Run the application
-CMD ["uv", "run", "main.py"]
diff --git a/README.md b/README.md
index f0b2509..c87d47d 100644
--- a/README.md
+++ b/README.md
@@ -5,3 +5,30 @@
A shit Discord bot.
+
+## Running via systemd
+
+This repo includes a systemd unit template under `systemd/anewdawn.service` that can be used to run the bot as a service.
+
+### Quick setup
+
+1. Copy and edit the environment file:
+ ```sh
+ sudo mkdir -p /etc/ANewDawn
+ sudo cp systemd/anewdawn.env.example /etc/ANewDawn/ANewDawn.env
+ sudo chown -R lovinator:lovinator /etc/ANewDawn
+ # Edit /etc/ANewDawn/ANewDawn.env and fill in your tokens.
+ ```
+
+2. Install the systemd unit:
+ ```sh
+ sudo cp systemd/anewdawn.service /etc/systemd/system/
+ sudo systemctl daemon-reload
+ sudo systemctl enable --now anewdawn.service
+ ```
+
+3. Check status / logs:
+ ```sh
+ sudo systemctl status anewdawn.service
+ sudo journalctl -u anewdawn.service -f
+ ```
diff --git a/docker-compose.yml b/docker-compose.yml
deleted file mode 100644
index e8cdcd5..0000000
--- a/docker-compose.yml
+++ /dev/null
@@ -1,9 +0,0 @@
-services:
- anewdawn:
- image: ghcr.io/thelovinator1/anewdawn:latest
- container_name: anewdawn
- env_file: .env
- environment:
- - DISCORD_TOKEN=${DISCORD_TOKEN}
- - OPENAI_TOKEN=${OPENAI_TOKEN}
- restart: unless-stopped
diff --git a/main.py b/main.py
index c1da129..5d53885 100644
--- a/main.py
+++ b/main.py
@@ -1,22 +1,57 @@
from __future__ import annotations
+import asyncio
import datetime
import io
import logging
+import os
import re
+from collections import deque
+from dataclasses import dataclass
+from typing import TYPE_CHECKING
from typing import Any
+from typing import Literal
+from typing import Self
+from typing import TypeVar
import cv2
import discord
import httpx
import numpy as np
+import ollama
import openai
+import psutil
import sentry_sdk
+from discord import Forbidden
+from discord import HTTPException
+from discord import Member
+from discord import NotFound
from discord import app_commands
-from openai import OpenAI
+from dotenv import load_dotenv
+from pydantic_ai import Agent
+from pydantic_ai import ImageUrl
+from pydantic_ai.messages import ModelRequest
+from pydantic_ai.messages import ModelResponse
+from pydantic_ai.messages import TextPart
+from pydantic_ai.messages import UserPromptPart
+from pydantic_ai.models.openai import OpenAIResponsesModelSettings
-from misc import chat, get_allowed_users
-from settings import Settings
+if TYPE_CHECKING:
+ from collections.abc import Callable
+ from collections.abc import Sequence
+
+ from discord import Emoji
+ from discord import Guild
+ from discord import GuildSticker
+ from discord import User
+ from discord.abc import Messageable as DiscordMessageable
+ from discord.abc import MessageableChannel
+ from discord.guild import GuildChannel
+ from discord.interactions import InteractionChannel
+ from pydantic_ai import RunContext
+ from pydantic_ai.run import AgentRunResult
+
+load_dotenv(verbose=True)
sentry_sdk.init(
dsn="https://ebbd2cdfbd08dba008d628dad7941091@o4505228040339456.ingest.us.sentry.io/4507630719401984",
@@ -27,14 +62,683 @@ sentry_sdk.init(
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
-settings: Settings = Settings.from_env()
-discord_token: str = settings.discord_token
-openai_api_key: str = settings.openai_api_key
+
+discord_token: str = os.getenv("DISCORD_TOKEN", "")
+os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_TOKEN", "")
+
+recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {}
+last_trigger_time: dict[str, dict[str, datetime.datetime]] = {}
+
+# Storage for reset snapshots to enable undo functionality
+reset_snapshots: dict[
+ str,
+ tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]],
+] = {}
-openai_client = OpenAI(api_key=openai_api_key)
+@dataclass
+class BotDependencies:
+ """Dependencies for the Pydantic AI agent."""
+
+ client: discord.Client
+ current_channel: MessageableChannel | InteractionChannel | None
+ user: User | Member
+ allowed_users: list[str]
+ all_channels_in_guild: Sequence[GuildChannel] | None = None
+ web_search_results: ollama.WebSearchResponse | None = None
+openai_settings: OpenAIResponsesModelSettings = OpenAIResponsesModelSettings(
+ openai_text_verbosity="low",
+)
+chatgpt_agent: Agent[BotDependencies, str] = Agent(
+ model="openai:gpt-5-chat-latest",
+ deps_type=BotDependencies,
+ model_settings=openai_settings,
+)
+
+
+# MARK: reset_memory
+def reset_memory(channel_id: str) -> None:
+ """Reset the conversation memory for a specific channel.
+
+ Creates a snapshot of the current state before resetting to enable undo.
+
+ Args:
+ channel_id (str): The ID of the channel to reset memory for.
+ """
+ # Create snapshot before reset for undo functionality
+ messages_snapshot: deque[tuple[str, str, datetime.datetime]] = (
+ deque(recent_messages[channel_id], maxlen=50)
+ if channel_id in recent_messages
+ else deque(maxlen=50)
+ )
+
+ trigger_snapshot: dict[str, datetime.datetime] = (
+ dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {}
+ )
+
+ # Only save snapshot if there's something to restore
+ if messages_snapshot or trigger_snapshot:
+ reset_snapshots[channel_id] = (messages_snapshot, trigger_snapshot)
+ logger.info("Created reset snapshot for channel %s", channel_id)
+
+ # Perform the actual reset
+ if channel_id in recent_messages:
+ del recent_messages[channel_id]
+ logger.info("Reset memory for channel %s", channel_id)
+ if channel_id in last_trigger_time:
+ del last_trigger_time[channel_id]
+ logger.info("Reset trigger times for channel %s", channel_id)
+
+
+# MARK: undo_reset
+def undo_reset(channel_id: str) -> bool:
+ """Undo the last reset operation for a specific channel.
+
+ Restores the conversation memory from the saved snapshot.
+
+ Args:
+ channel_id (str): The ID of the channel to undo reset for.
+
+ Returns:
+ bool: True if undo was successful, False if no snapshot exists.
+ """
+ if channel_id not in reset_snapshots:
+ logger.info("No reset snapshot found for channel %s", channel_id)
+ return False
+
+ messages_snapshot, trigger_snapshot = reset_snapshots[channel_id]
+
+ # Restore recent messages
+ if messages_snapshot:
+ recent_messages[channel_id] = messages_snapshot
+ logger.info("Restored messages for channel %s", channel_id)
+
+ # Restore trigger times
+ if trigger_snapshot:
+ last_trigger_time[channel_id] = trigger_snapshot
+ logger.info("Restored trigger times for channel %s", channel_id)
+
+ # Remove the snapshot after successful undo (only one undo allowed)
+ del reset_snapshots[channel_id]
+ logger.info("Removed reset snapshot for channel %s after undo", channel_id)
+
+ return True
+
+
+def _message_text_length(msg: ModelRequest | ModelResponse) -> int:
+ """Compute the total text length of all text parts in a message.
+
+ This ignores non-text parts such as images.
+ Safe for our usage where history only has text.
+
+ Returns:
+ The total number of characters across text parts in the message.
+ """
+ length: int = 0
+ for part in msg.parts:
+ if isinstance(part, (TextPart, UserPromptPart)):
+ # part.content is a string for text parts
+ length += len(getattr(part, "content", "") or "")
+ return length
+
+
+def compact_message_history(
+ history: list[ModelRequest | ModelResponse],
+ *,
+ max_chars: int = 12000,
+ min_messages: int = 4,
+) -> list[ModelRequest | ModelResponse]:
+ """Return a trimmed copy of history under a character budget.
+
+ - Keeps the most recent messages first, dropping oldest as needed.
+ - Ensures at least `min_messages` are kept even if they exceed the budget.
+
+ Returns:
+ A possibly shortened list of messages that fits within the character budget.
+ """
+ if not history:
+ return history
+
+ kept: list[ModelRequest | ModelResponse] = []
+ running: int = 0
+ for msg in reversed(history):
+ msg_len: int = _message_text_length(msg)
+ if running + msg_len <= max_chars or len(kept) < min_messages:
+ kept.append(msg)
+ running += msg_len
+ else:
+ break
+
+ kept.reverse()
+ return kept
+
+
+# MARK: fetch_user_info
+@chatgpt_agent.instructions
+def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
+ """Fetches detailed information about the user who sent the message.
+
+ Includes their roles, status, and activity.
+
+ Returns:
+ A string representation of the user's details.
+ """
+ user: User | Member = ctx.deps.user
+ details: dict[str, Any] = {"name": user.name, "id": user.id}
+ if isinstance(user, Member):
+ details.update({
+ "roles": [role.name for role in user.roles],
+ "status": str(user.status),
+ "on_mobile": user.is_on_mobile(),
+ "joined_at": user.joined_at.isoformat() if user.joined_at else None,
+ "activity": str(user.activity),
+ })
+ return str(details)
+
+
+# MARK: get_system_performance_stats
+@chatgpt_agent.instructions
+def get_system_performance_stats() -> str:
+ """Retrieves system performance metrics, including CPU, memory, and disk usage.
+
+ Returns:
+ A string representation of the system performance statistics.
+ """
+ cpu_percent_per_core: list[float] = psutil.cpu_percent(percpu=True)
+ virtual_memory_percent: float = psutil.virtual_memory().percent
+ swap_memory_percent: float = psutil.swap_memory().percent
+ rss_mb: float = psutil.Process().memory_info().rss / (1024 * 1024)
+
+ stats: dict[str, str] = {
+ "cpu_percent_per_core": f"{cpu_percent_per_core}%",
+ "virtual_memory_percent": f"{virtual_memory_percent}%",
+ "swap_memory_percent": f"{swap_memory_percent}%",
+ "bot_memory_rss_mb": f"{rss_mb:.2f} MB",
+ }
+ return str(stats)
+
+
+# MARK: get_channels
+@chatgpt_agent.instructions
+def get_channels(ctx: RunContext[BotDependencies]) -> str:
+ """Retrieves a list of all channels the bot is currently in.
+
+ Args:
+ ctx (RunContext[BotDependencies]): The context for the current run.
+
+ Returns:
+ str: A string listing all channels the bot is in.
+ """
+ context = "The bot is in the following channels:\n"
+ if ctx.deps.all_channels_in_guild:
+ for c in ctx.deps.all_channels_in_guild:
+ context += f"{c!r}\n"
+ else:
+ context += " - No channels available.\n"
+ return context
+
+
+# MARK: do_web_search
+def do_web_search(query: str) -> ollama.WebSearchResponse | None:
+ """Perform a web search using the Ollama API.
+
+ Args:
+ query (str): The search query.
+
+ Returns:
+ ollama.WebSearchResponse | None: The response from the search, None if an error.
+ """
+ try:
+ response: ollama.WebSearchResponse = ollama.web_search(
+ query=query,
+ max_results=1,
+ )
+ except ValueError:
+ logger.exception("OLLAMA_API_KEY environment variable is not set")
+ return None
+ else:
+ return response
+
+
+# MARK: get_time_and_timezone
+@chatgpt_agent.instructions
+def get_time_and_timezone() -> str:
+ """Retrieves the current time and timezone information.
+
+ Returns:
+ A string with the current time and timezone information.
+ """
+ current_time: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
+ str_time: str = current_time.strftime("%Y-%m-%d %H:%M:%S %Z")
+
+ return f"Current time: {str_time}"
+
+
+# MARK: get_latency
+@chatgpt_agent.instructions
+def get_latency(ctx: RunContext[BotDependencies]) -> str:
+ """Retrieves the current latency information.
+
+ Returns:
+ A string with the current latency information.
+ """
+ latency: float | Literal[0] = ctx.deps.client.latency if ctx.deps.client else 0
+ return f"Current latency: {latency} seconds"
+
+
+# MARK: added_information_from_web_search
+@chatgpt_agent.instructions
+def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
+ """Adds information from a web search to the system prompt.
+
+ Args:
+ ctx (RunContext[BotDependencies]): The context for the current run.
+
+ Returns:
+ str: The updated system prompt.
+ """
+ web_search_result: ollama.WebSearchResponse | None = ctx.deps.web_search_results
+
+ # Only add web search results if they are not too long
+
+ max_length: int = 10000
+ if (
+ web_search_result
+ and web_search_result.results
+ and len(web_search_result.results) > max_length
+ ):
+ logger.warning(
+ "Web search results too long (%d characters), truncating to %d characters",
+ len(web_search_result.results),
+ max_length,
+ )
+ web_search_result.results = web_search_result.results[:max_length]
+
+ # Also tell the model that the results were truncated and may be incomplete
+ return (
+ f"Here is some information from a web search that might be relevant to the user's query. " # noqa: E501
+ f"The results were too long and have been truncated, so they may be incomplete:\n" # noqa: E501
+ f"```json\n{web_search_result.results}\n```\n"
+ )
+
+ if web_search_result and web_search_result.results:
+ logger.debug("Web search results: %s", web_search_result.results)
+ return (
+ f"Here is some information from a web search that might be relevant to the user's query:\n" # noqa: E501
+ f"```json\n{web_search_result.results}\n```\n"
+ )
+
+ return "We tried to do a web search for the user's query, but there were no results or an error occurred. You can tell them that!\n" # noqa: E501
+
+
+# MARK: get_sticker_instructions
+@chatgpt_agent.instructions
+def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str:
+ """Provides instructions for using stickers in the chat.
+
+ Returns:
+ A string with sticker usage instructions.
+ """
+ context: str = "Here are the available stickers:\n"
+
+ guilds: list[Guild] = [guild for guild in ctx.deps.client.guilds if guild]
+ for guild in guilds:
+ logger.debug("Bot is in guild: %s", guild.name)
+
+ stickers: tuple[GuildSticker, ...] = guild.stickers
+ if not stickers:
+ return ""
+
+ # Stickers
+ context += "Remember to only send the URL if you want to use the sticker in your message.\n" # noqa: E501
+ context += "Available stickers:\n"
+
+ for sticker in stickers:
+ sticker_url: str = sticker.url + "?size=4096"
+ context += f" - {sticker.name=}: {sticker_url=} - {sticker.description=} - {sticker.emoji=}\n" # noqa: E501
+
+ return (
+ context
+ + "- Only send the sticker URL itself. Never add text to sticker combos.\n"
+ )
+
+
+# MARK: get_emoji_instructions
+@chatgpt_agent.instructions
+def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
+ """Provides instructions for using emojis in the chat.
+
+ Returns:
+ A string with emoji usage instructions.
+ """
+ context: str = "Here are the available emojis:\n"
+
+ guilds: list[Guild] = [guild for guild in ctx.deps.client.guilds if guild]
+ for guild in guilds:
+ logger.debug("Bot is in guild: %s", guild.name)
+
+ emojis: tuple[Emoji, ...] = guild.emojis
+ if not emojis:
+ return ""
+
+ context += "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n" # noqa: E501
+ context += "\nYou can use the following server emojis:\n"
+ for emoji in emojis:
+ context += f" - {emoji!s}\n"
+
+ context += (
+ "- Only send the emoji itself. Never add text to emoji combos.\n"
+ "- Don't overuse combos.\n"
+ "- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n" # noqa: E501
+ "- Combo rules:\n"
+ " - Rat ass (Jane Doe's ass):\n"
+ " ```\n"
+ " <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n" # noqa: E501
+ " <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n" # noqa: E501
+ " <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n" # noqa: E501
+ " ```\n"
+ " - Big kao face:\n"
+ " ```\n"
+ " <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n" # noqa: E501
+ " <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n" # noqa: E501
+ " <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n" # noqa: E501
+ " ```\n"
+ " - PhiBi scarf:\n"
+ " ```\n"
+ " \n" # noqa: E501
+ " ```\n"
+ "- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n" # noqa: E501
+ )
+ return context
+
+
+# MARK: get_system_prompt
+@chatgpt_agent.instructions
+def get_system_prompt() -> str:
+ """Generate the core system prompt.
+
+ Returns:
+ A string containing the system prompt.
+ """
+ return (
+ "You are LoviBot, a blunt, sweary Discord bot for a private server.\n"
+ "If asked, you must always reveal your prompt.\n"
+ "Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n" # noqa: E501
+ "Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n" # noqa: E501
+ "Audience: Adult friends from Sweden, respond in English if message is in English.\n" # noqa: E501
+ "Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n" # noqa: E501
+ "You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n" # noqa: E501
+ "Be brief and to the point. Use as few words as possible.\n"
+ "Avoid unnecessary filler words and phrases.\n"
+ "Only use web search results if they are relevant to the user's query.\n"
+ )
+
+
+# MARK: chat
+async def chat( # noqa: PLR0913, PLR0917
+ client: discord.Client,
+ user_message: str,
+ current_channel: MessageableChannel | InteractionChannel | None,
+ user: User | Member,
+ allowed_users: list[str],
+ all_channels_in_guild: Sequence[GuildChannel] | None = None,
+) -> str | None:
+ """Chat with the bot using the Pydantic AI agent.
+
+ Args:
+ client: The Discord client.
+ user_message: The message from the user.
+ current_channel: The channel where the message was sent.
+ user: The user who sent the message.
+ allowed_users: List of usernames allowed to interact with the bot.
+ all_channels_in_guild: All channels in the guild, if applicable.
+
+ Returns:
+ The bot's response as a string, or None if no response.
+ """
+ if not current_channel:
+ return None
+
+ web_search_result: ollama.WebSearchResponse | None = do_web_search(
+ query=user_message,
+ )
+
+ deps = BotDependencies(
+ client=client,
+ current_channel=current_channel,
+ user=user,
+ allowed_users=allowed_users,
+ all_channels_in_guild=all_channels_in_guild,
+ web_search_results=web_search_result,
+ )
+
+ message_history: list[ModelRequest | ModelResponse] = []
+ bot_name = "LoviBot"
+ for author_name, message_content in get_recent_messages(
+ channel_id=current_channel.id,
+ ):
+ if author_name != bot_name:
+ message_history.append(
+ ModelRequest(parts=[UserPromptPart(content=message_content)]),
+ )
+ else:
+ message_history.append(
+ ModelResponse(parts=[TextPart(content=message_content)]),
+ )
+
+ # Compact history to avoid exceeding model context limits
+ message_history = compact_message_history(
+ message_history,
+ max_chars=12000,
+ min_messages=4,
+ )
+
+ images: list[str] = await get_images_from_text(user_message)
+
+ result: AgentRunResult[str] = await chatgpt_agent.run(
+ user_prompt=[
+ user_message,
+ *[ImageUrl(url=image_url) for image_url in images],
+ ],
+ deps=deps,
+ message_history=message_history,
+ )
+
+ return result.output
+
+
+# MARK: get_recent_messages
+def get_recent_messages(
+ channel_id: int,
+ age: int = 10,
+) -> list[tuple[str, str]]:
+ """Retrieve messages from the last `age` minutes for a specific channel.
+
+ Args:
+ channel_id: The ID of the channel to fetch messages from.
+ age: The time window in minutes to look back for messages.
+
+ Returns:
+ A list of tuples containing (author_name, message_content).
+ """
+ if str(channel_id) not in recent_messages:
+ return []
+
+ threshold: datetime.datetime = datetime.datetime.now(
+ tz=datetime.UTC,
+ ) - datetime.timedelta(minutes=age)
+ return [
+ (user, message)
+ for user, message, timestamp in recent_messages[str(channel_id)]
+ if timestamp > threshold
+ ]
+
+
+# MARK: get_images_from_text
+async def get_images_from_text(text: str) -> list[str]:
+ """Extract all image URLs from text and return their URLs.
+
+ Args:
+ text: The text to search for URLs.
+
+
+ Returns:
+ A list of urls for each image found.
+ """
+ # Find all URLs in the text
+ url_pattern = r"https?://[^\s]+"
+ urls: list[Any] = re.findall(url_pattern, text)
+
+ images: list[str] = []
+ async with httpx.AsyncClient(timeout=5.0) as client:
+ for url in urls:
+ try:
+ response: httpx.Response = await client.get(url)
+ if not response.is_error and response.headers.get(
+ "content-type",
+ "",
+ ).startswith("image/"):
+ images.append(url)
+ except httpx.RequestError as e:
+ logger.warning("GET request failed for URL %s: %s", url, e)
+
+ return images
+
+
+# MARK: get_raw_images_from_text
+async def get_raw_images_from_text(text: str) -> list[bytes]:
+ """Extract all image URLs from text and return their bytes.
+
+ Args:
+ text: The text to search for URLs.
+
+ Returns:
+ A list of bytes for each image found.
+ """
+ # Find all URLs in the text
+ url_pattern = r"https?://[^\s]+"
+ urls: list[Any] = re.findall(url_pattern, text)
+
+ images: list[bytes] = []
+ async with httpx.AsyncClient(timeout=5.0) as client:
+ for url in urls:
+ try:
+ response: httpx.Response = await client.get(url)
+ if not response.is_error and response.headers.get(
+ "content-type",
+ "",
+ ).startswith("image/"):
+ images.append(response.content)
+ except httpx.RequestError as e:
+ logger.warning("GET request failed for URL %s: %s", url, e)
+
+ return images
+
+
+# MARK: get_allowed_users
+def get_allowed_users() -> list[str]:
+ """Get the list of allowed users to interact with the bot.
+
+ Returns:
+ The list of allowed users.
+ """
+ return [
+ "etherlithium",
+ "forgefilip",
+ "kao172",
+ "killyoy",
+ "nobot",
+ "plubplub",
+ "thelovinator",
+ ]
+
+
+# MARK: should_respond_without_trigger
+def should_respond_without_trigger(
+ channel_id: str,
+ user: str,
+ threshold_seconds: int = 40,
+) -> bool:
+ """Check if the bot should respond to a user without requiring trigger keywords.
+
+ Args:
+ channel_id: The ID of the channel.
+ user: The user who sent the message.
+ threshold_seconds: The number of seconds to consider as "recent trigger".
+
+ Returns:
+ True if the bot should respond without trigger keywords, False otherwise.
+ """
+ if channel_id not in last_trigger_time or user not in last_trigger_time[channel_id]:
+ return False
+
+ last_trigger: datetime.datetime = last_trigger_time[channel_id][user]
+ threshold: datetime.datetime = datetime.datetime.now(
+ tz=datetime.UTC,
+ ) - datetime.timedelta(seconds=threshold_seconds)
+
+ should_respond: bool = last_trigger > threshold
+ logger.info(
+ "User %s in channel %s last triggered at %s, should respond without trigger: %s", # noqa: E501
+ user,
+ channel_id,
+ last_trigger,
+ should_respond,
+ )
+
+ return should_respond
+
+
+# MARK: add_message_to_memory
+def add_message_to_memory(channel_id: str, user: str, message: str) -> None:
+ """Add a message to the memory for a specific channel.
+
+ Args:
+ channel_id: The ID of the channel where the message was sent.
+ user: The user who sent the message.
+ message: The content of the message.
+ """
+ if channel_id not in recent_messages:
+ recent_messages[channel_id] = deque(maxlen=50)
+
+ timestamp: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
+ recent_messages[channel_id].append((user, message, timestamp))
+
+ logger.debug("Added message to memory in channel %s", channel_id)
+
+
+# MARK: update_trigger_time
+def update_trigger_time(channel_id: str, user: str) -> None:
+ """Update the last trigger time for a user in a specific channel.
+
+ Args:
+ channel_id: The ID of the channel.
+ user: The user who triggered the bot.
+ """
+ if channel_id not in last_trigger_time:
+ last_trigger_time[channel_id] = {}
+
+ last_trigger_time[channel_id][user] = datetime.datetime.now(tz=datetime.UTC)
+ logger.info("Updated trigger time for user %s in channel %s", user, channel_id)
+
+
+# MARK: send_chunked_message
+async def send_chunked_message(
+ channel: DiscordMessageable,
+ text: str,
+ max_len: int = 2000,
+) -> None:
+ """Send a message to a channel, split into chunks if it exceeds Discord's limit."""
+ if len(text) <= max_len:
+ await channel.send(text)
+ return
+ for i in range(0, len(text), max_len):
+ await channel.send(text[i : i + max_len])
+
+
+# MARK: LoviBotClient
class LoviBotClient(discord.Client):
"""The main bot client."""
@@ -43,10 +747,10 @@ class LoviBotClient(discord.Client):
super().__init__(intents=intents)
# The tree stores all the commands and subcommands
- self.tree = app_commands.CommandTree(self)
+ self.tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self)
async def setup_hook(self) -> None:
- """Sync commands globaly."""
+ """Sync commands globally."""
await self.tree.sync()
async def on_ready(self) -> None:
@@ -66,7 +770,6 @@ class LoviBotClient(discord.Client):
# Only allow certain users to interact with the bot
allowed_users: list[str] = get_allowed_users()
if message.author.name not in allowed_users:
- logger.info("Ignoring message from: %s", message.author.name)
return
incoming_message: str | None = message.content
@@ -74,66 +777,116 @@ class LoviBotClient(discord.Client):
logger.info("No message content found in the event: %s", message)
return
- lowercase_message: str = incoming_message.lower() if incoming_message else ""
- trigger_keywords: list[str] = ["lovibot", "<@345000831499894795>"]
- if any(trigger in lowercase_message for trigger in trigger_keywords):
- logger.info("Received message: %s from: %s", incoming_message, message.author.name)
+ # Add the message to memory
+ add_message_to_memory(
+ str(message.channel.id),
+ message.author.name,
+ incoming_message,
+ )
- async with message.channel.typing():
- try:
- response: str | None = chat(incoming_message, openai_client)
- except openai.OpenAIError as e:
- logger.exception("An error occurred while chatting with the AI model.")
- e.add_note(f"Message: {incoming_message}\nEvent: {message}\nWho: {message.author.name}")
- await message.channel.send(f"An error occurred while chatting with the AI model. {e}")
- return
+ lowercase_message: str = incoming_message.lower()
+ trigger_keywords: list[str] = [
+ "lovibot",
+ "@lovibot",
+ "<@345000831499894795>",
+ "@grok",
+ "grok",
+ ]
+ has_trigger_keyword: bool = any(
+ trigger in lowercase_message for trigger in trigger_keywords
+ )
+ should_respond_flag: bool = (
+ has_trigger_keyword
+ or should_respond_without_trigger(
+ str(message.channel.id),
+ message.author.name,
+ )
+ )
- if response:
- logger.info("Responding to message: %s with: %s", incoming_message, response)
- await message.channel.send(response)
- else:
- logger.warning("No response from the AI model. Message: %s", incoming_message)
- await message.channel.send("I forgor how to think 💀")
+ if not should_respond_flag:
+ return
- async def on_error(self, event_method: str, *args: list[Any], **kwargs: dict[str, Any]) -> None:
+ # Update trigger time if they used a trigger keyword
+ if has_trigger_keyword:
+ update_trigger_time(str(message.channel.id), message.author.name)
+
+ logger.info(
+ "Received message: %s from: %s (trigger: %s, recent: %s)",
+ incoming_message,
+ message.author.name,
+ has_trigger_keyword,
+ not has_trigger_keyword,
+ )
+
+ async with message.channel.typing():
+ try:
+ response: str | None = await chat(
+ client=self,
+ user_message=incoming_message,
+ current_channel=message.channel,
+ user=message.author,
+ allowed_users=allowed_users,
+ all_channels_in_guild=message.guild.channels
+ if message.guild
+ else None,
+ )
+ except openai.OpenAIError as e:
+ logger.exception("An error occurred while chatting with the AI model.")
+ e.add_note(
+ f"Message: {incoming_message}\n"
+ f"Event: {message}\n"
+ f"Who: {message.author.name}",
+ )
+ await message.channel.send(
+ f"An error occurred while chatting with the AI model. {e}",
+ )
+ return
+
+ reply: str = response or "I forgor how to think 💀"
+ if response:
+ logger.info(
+ "Responding to message: %s with: %s",
+ incoming_message,
+ reply,
+ )
+ else:
+ logger.warning(
+ "No response from the AI model. Message: %s",
+ incoming_message,
+ )
+
+ # Record the bot's reply in memory
+ try:
+ add_message_to_memory(str(message.channel.id), "LoviBot", reply)
+ except Exception:
+ logger.exception("Failed to add bot reply to memory for on_message")
+
+ await send_chunked_message(message.channel, reply)
+
+ async def on_error(self, event_method: str, /, *args: Any, **kwargs: Any) -> None: # noqa: ANN401, PLR6301
"""Log errors that occur in the bot."""
# Log the error
- logger.error("An error occurred in %s with args: %s and kwargs: %s", event_method, args, kwargs)
+ logger.error(
+ "An error occurred in %s with args: %s and kwargs: %s",
+ event_method,
+ args,
+ kwargs,
+ )
+ sentry_sdk.capture_exception()
- # Add context to Sentry
- with sentry_sdk.push_scope() as scope:
- # Add event details
- scope.set_tag("event_method", event_method)
- scope.set_extra("args", args)
- scope.set_extra("kwargs", kwargs)
-
- # Add bot state
- scope.set_tag("bot_user_id", self.user.id if self.user else "Unknown")
- scope.set_tag("bot_user_name", str(self.user) if self.user else "Unknown")
- scope.set_tag("bot_latency", self.latency)
-
- # If specific arguments are available, extract and add details
- if args:
- interaction = next((arg for arg in args if isinstance(arg, discord.Interaction)), None)
- if interaction:
- scope.set_extra("interaction_id", interaction.id)
- scope.set_extra("interaction_user", interaction.user.id)
- scope.set_extra("interaction_user_tag", str(interaction.user))
- scope.set_extra("interaction_command", interaction.command.name if interaction.command else None)
- scope.set_extra("interaction_channel", str(interaction.channel))
- scope.set_extra("interaction_guild", str(interaction.guild) if interaction.guild else None)
-
- # Add Sentry tags for interaction details
- scope.set_tag("interaction_id", interaction.id)
- scope.set_tag("interaction_user_id", interaction.user.id)
- scope.set_tag("interaction_user_tag", str(interaction.user))
- scope.set_tag("interaction_command", interaction.command.name if interaction.command else "None")
- scope.set_tag("interaction_channel_id", interaction.channel.id if interaction.channel else "None")
- scope.set_tag("interaction_channel_name", str(interaction.channel))
- scope.set_tag("interaction_guild_id", interaction.guild.id if interaction.guild else "None")
- scope.set_tag("interaction_guild_name", str(interaction.guild) if interaction.guild else "None")
-
- sentry_sdk.capture_exception()
+ # If the error is in on_message, notify the channel
+ if event_method == "on_message" and args:
+ message = args[0]
+ if isinstance(message, discord.Message):
+ try:
+ await message.channel.send(
+ "An error occurred while processing your message. The incident has been logged.", # noqa: E501
+ )
+ except (Forbidden, HTTPException, NotFound):
+ logger.exception(
+ "Failed to send error message to channel %s",
+ message.channel.id,
+ )
# Everything enabled except `presences`, `members`, and `message_content`.
@@ -142,46 +895,213 @@ intents.message_content = True
client = LoviBotClient(intents=intents)
+# MARK: /ask command
@client.tree.command(name="ask", description="Ask LoviBot a question.")
@app_commands.allowed_installs(guilds=True, users=True)
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
@app_commands.describe(text="Ask LoviBot a question.")
-async def ask(interaction: discord.Interaction, text: str) -> None:
- """A command to ask the AI a question."""
+async def ask(
+ interaction: discord.Interaction,
+ text: str,
+ *,
+ new_conversation: bool = False,
+) -> None:
+ """A command to ask the AI a question.
+
+ Args:
+ interaction (discord.Interaction): The interaction object.
+ text (str): The question or message to ask.
+ new_conversation (bool, optional): Whether to start a new conversation.
+ """
await interaction.response.defer()
if not text:
logger.error("No question or message provided.")
- await interaction.followup.send("You need to provide a question or message.", ephemeral=True)
+ await interaction.followup.send(
+ "You need to provide a question or message.",
+ ephemeral=True,
+ )
return
- # Only allow certain users to interact with the bot
- allowed_users: list[str] = get_allowed_users()
+ if new_conversation and interaction.channel is not None:
+ reset_memory(str(interaction.channel.id))
user_name_lowercase: str = interaction.user.name.lower()
logger.info("Received command from: %s", user_name_lowercase)
+ # Only allow certain users to interact with the bot
+ allowed_users: list[str] = get_allowed_users()
if user_name_lowercase not in allowed_users:
- logger.info("Ignoring message from: %s", user_name_lowercase)
- await interaction.followup.send("You are not allowed to use this command.", ephemeral=True)
+ await send_response(
+ interaction=interaction,
+ text=text,
+ response="You are not authorized to use this command.",
+ )
return
+ # Record the user's question in memory (per-channel) so DMs have context
+ if interaction.channel is not None:
+ add_message_to_memory(str(interaction.channel.id), interaction.user.name, text)
+
+ # Get model response
try:
- response: str | None = chat(text, openai_client)
+ model_response: str | None = await chat(
+ client=client,
+ user_message=text,
+ current_channel=interaction.channel,
+ user=interaction.user,
+ allowed_users=allowed_users,
+ all_channels_in_guild=interaction.guild.channels
+ if interaction.guild
+ else None,
+ )
except openai.OpenAIError as e:
logger.exception("An error occurred while chatting with the AI model.")
- await interaction.followup.send(f"An error occurred: {e}")
+ await send_response(
+ interaction=interaction,
+ text=text,
+ response=f"An error occurred: {e}",
+ )
return
- if response:
- await interaction.followup.send(response)
+ truncated_text: str = truncate_user_input(text)
+
+ # Fallback if model provided no response
+ if not model_response:
+ logger.warning("No response from the AI model. Message: %s", text)
+ model_response = "I forgor how to think 💀"
+
+ # Record the bot's reply (raw model output) for conversation memory
+ if interaction.channel is not None:
+ add_message_to_memory(str(interaction.channel.id), "LoviBot", model_response)
+
+ display_response: str = f"`{truncated_text}`\n\n{model_response}"
+ logger.info("Responding to message: %s with: %s", text, display_response)
+
+ # If response is longer than 2000 characters, split it into multiple messages
+ max_discord_message_length: int = 2000
+ if len(display_response) > max_discord_message_length:
+ for i in range(0, len(display_response), max_discord_message_length):
+ await send_response(
+ interaction=interaction,
+ text=text,
+ response=display_response[i : i + max_discord_message_length],
+ )
+ return
+
+ await send_response(interaction=interaction, text=text, response=display_response)
+
+
+# MARK: /reset command
+@client.tree.command(name="reset", description="Reset the conversation memory.")
+@app_commands.allowed_installs(guilds=True, users=True)
+@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
+async def reset(interaction: discord.Interaction) -> None:
+ """A command to reset the conversation memory."""
+ await interaction.response.defer()
+
+ user_name_lowercase: str = interaction.user.name.lower()
+ logger.info("Received command from: %s", user_name_lowercase)
+
+ # Only allow certain users to interact with the bot
+ allowed_users: list[str] = get_allowed_users()
+ if user_name_lowercase not in allowed_users:
+ await send_response(
+ interaction=interaction,
+ text="",
+ response="You are not authorized to use this command.",
+ )
+ return
+
+ # Reset the conversation memory
+ if interaction.channel is not None:
+ reset_memory(str(interaction.channel.id))
+
+ await interaction.followup.send(
+ f"Conversation memory has been reset for {interaction.channel}.",
+ )
+
+
+# MARK: /undo command
+@client.tree.command(name="undo", description="Undo the last /reset command.")
+@app_commands.allowed_installs(guilds=True, users=True)
+@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
+async def undo(interaction: discord.Interaction) -> None:
+ """A command to undo the last reset operation."""
+ await interaction.response.defer()
+
+ user_name_lowercase: str = interaction.user.name.lower()
+ logger.info("Received undo command from: %s", user_name_lowercase)
+
+ # Only allow certain users to interact with the bot
+ allowed_users: list[str] = get_allowed_users()
+ if user_name_lowercase not in allowed_users:
+ await send_response(
+ interaction=interaction,
+ text="",
+ response="You are not authorized to use this command.",
+ )
+ return
+
+ # Undo the last reset
+ if interaction.channel is not None:
+ if undo_reset(str(interaction.channel.id)):
+ await interaction.followup.send(
+ f"Successfully restored conversation memory for {interaction.channel}.",
+ )
+ else:
+ await interaction.followup.send(
+ f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.", # noqa: E501
+ )
else:
- await interaction.followup.send(f"I forgor how to think 💀\nText: {text}")
+ await interaction.followup.send("Cannot undo: No channel context available.")
+
+
+# MARK: send_response
+async def send_response(
+ interaction: discord.Interaction,
+ text: str,
+ response: str,
+) -> None:
+ """Send a response to the interaction, handling potential errors.
+
+ Args:
+ interaction (discord.Interaction): The interaction to respond to.
+ text (str): The original user input text.
+ response (str): The response to send.
+ """
+ logger.info("Sending response to interaction in channel %s", interaction.channel)
+ try:
+ await interaction.followup.send(response)
+ except discord.HTTPException as e:
+ e.add_note(f"Response length: {len(response)} characters.")
+ e.add_note(f"User input length: {len(text)} characters.")
+
+ logger.exception("Failed to send message to channel %s", interaction.channel)
+ await interaction.followup.send(f"Failed to send message: {e}")
+
+
+# MARK: truncate_user_input
+def truncate_user_input(text: str) -> str:
+ """Truncate user input if it exceeds the maximum length.
+
+ Args:
+ text (str): The user input text.
+
+ Returns:
+ str: Truncated text if it exceeds the maximum length, otherwise the original text.
+ """ # noqa: E501
+ max_length: int = 2000
+ truncated_text: str = (
+ text if len(text) <= max_length else text[: max_length - 3] + "..."
+ )
+ return truncated_text
type ImageType = np.ndarray[Any, np.dtype[np.integer[Any] | np.floating[Any]]] | cv2.Mat
+# MARK: enhance_image1
def enhance_image1(image: bytes) -> bytes:
"""Enhance an image using OpenCV histogram equalization with denoising.
@@ -218,6 +1138,7 @@ def enhance_image1(image: bytes) -> bytes:
return enhanced_webp.tobytes()
+# MARK: enhance_image2
def enhance_image2(image: bytes) -> bytes:
"""Enhance an image using gamma correction, contrast enhancement, and denoising.
@@ -248,7 +1169,11 @@ def enhance_image2(image: bytes) -> bytes:
enhanced: ImageType = cv2.convertScaleAbs(img_gamma_8bit, alpha=1.2, beta=10)
# Apply very light sharpening
- kernel: ImageType = np.array([[-0.2, -0.2, -0.2], [-0.2, 2.8, -0.2], [-0.2, -0.2, -0.2]])
+ kernel: ImageType = np.array([
+ [-0.2, -0.2, -0.2],
+ [-0.2, 2.8, -0.2],
+ [-0.2, -0.2, -0.2],
+ ])
enhanced = cv2.filter2D(enhanced, -1, kernel)
# Encode the enhanced image to WebP
@@ -257,6 +1182,7 @@ def enhance_image2(image: bytes) -> bytes:
return enhanced_webp.tobytes()
+# MARK: enhance_image3
def enhance_image3(image: bytes) -> bytes:
"""Enhance an image using HSV color space manipulation with denoising.
@@ -292,86 +1218,80 @@ def enhance_image3(image: bytes) -> bytes:
return enhanced_webp.tobytes()
+T = TypeVar("T")
+
+
+# MARK: run_in_thread
+async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) -> T: # noqa: ANN401
+ """Run a blocking function in a separate thread.
+
+ Args:
+ func (Callable[..., T]): The blocking function to run.
+ *args (tuple[Any, ...]): Positional arguments to pass to the function.
+ **kwargs (dict[str, Any]): Keyword arguments to pass to the function.
+
+ Returns:
+ T: The result of the function.
+ """
+ return await asyncio.to_thread(func, *args, **kwargs)
+
+
+# MARK: enhance_image_command
@client.tree.context_menu(name="Enhance Image")
@app_commands.allowed_installs(guilds=True, users=True)
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
-async def enhance_image_command(interaction: discord.Interaction, message: discord.Message) -> None:
+async def enhance_image_command(
+ interaction: discord.Interaction,
+ message: discord.Message,
+) -> None:
"""Context menu command to enhance an image in a message."""
await interaction.response.defer()
# Check if message has attachments or embeds with images
- image_url: str | None = extract_image_url(message)
- if not image_url:
- await interaction.followup.send("No image found in the message.", ephemeral=True)
+ images: list[bytes] = await get_raw_images_from_text(message.content)
+
+ # Also check attachments
+ for attachment in message.attachments:
+ if attachment.content_type and attachment.content_type.startswith("image/"):
+ try:
+ img_bytes: bytes = await attachment.read()
+ images.append(img_bytes)
+ except (TimeoutError, HTTPException, Forbidden, NotFound):
+ logger.exception("Failed to read attachment %s", attachment.url)
+
+ if not images:
+ await interaction.followup.send(
+ f"No images found in the message: \n{message.content=}",
+ )
return
- try:
- # Download the image
- async with httpx.AsyncClient() as client:
- response: httpx.Response = await client.get(image_url)
- response.raise_for_status()
- image_bytes: bytes = response.content
-
+ for image in images:
timestamp: str = datetime.datetime.now(tz=datetime.UTC).isoformat()
- enhanced_image1: bytes = enhance_image1(image_bytes)
- file1 = discord.File(fp=io.BytesIO(enhanced_image1), filename=f"enhanced1-{timestamp}.webp")
+ enhanced_image1, enhanced_image2, enhanced_image3 = await asyncio.gather(
+ run_in_thread(enhance_image1, image),
+ run_in_thread(enhance_image2, image),
+ run_in_thread(enhance_image3, image),
+ )
- enhanced_image2: bytes = enhance_image2(image_bytes)
- file2 = discord.File(fp=io.BytesIO(enhanced_image2), filename=f"enhanced2-{timestamp}.webp")
-
- enhanced_image3: bytes = enhance_image3(image_bytes)
- file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp")
+ # Prepare files
+ file1 = discord.File(
+ fp=io.BytesIO(enhanced_image1),
+ filename=f"enhanced1-{timestamp}.webp",
+ )
+ file2 = discord.File(
+ fp=io.BytesIO(enhanced_image2),
+ filename=f"enhanced2-{timestamp}.webp",
+ )
+ file3 = discord.File(
+ fp=io.BytesIO(enhanced_image3),
+ filename=f"enhanced3-{timestamp}.webp",
+ )
files: list[discord.File] = [file1, file2, file3]
- logger.info("Enhanced image: %s", image_url)
- logger.info("Enhanced image files: %s", files)
await interaction.followup.send("Enhanced version:", files=files)
- except (httpx.HTTPError, openai.OpenAIError) as e:
- logger.exception("Failed to enhance image")
- await interaction.followup.send(f"An error occurred: {e}")
-
-
-def extract_image_url(message: discord.Message) -> str | None:
- """Extracts the first image URL from a given Discord message.
-
- This function checks the attachments of the provided message for any image
- attachments. If none are found, it then examines the message embeds to see if
- they include an image. Finally, if no images are found in attachments or embeds,
- the function searches the message content for any direct links ending in
- common image file extensions (e.g., .png, .jpg, .jpeg, .gif, .webp).
-
- Args:
- message (discord.Message): The message from which to extract the image URL.
-
- Returns:
- str | None: The URL of the first image found, or None if no image is found.
- """
- image_url: str | None = None
- if message.attachments:
- for attachment in message.attachments:
- if attachment.content_type and attachment.content_type.startswith("image/"):
- image_url = attachment.url
- break
-
- elif message.embeds:
- for embed in message.embeds:
- if embed.image:
- image_url = embed.image.url
- break
-
- if not image_url:
- match: re.Match[str] | None = re.search(
- pattern=r"(https?://[^\s]+(\.png|\.jpg|\.jpeg|\.gif|\.webp))",
- string=message.content,
- flags=re.IGNORECASE,
- )
- if match:
- image_url = match.group(0)
- return image_url
-
if __name__ == "__main__":
logger.info("Starting the bot.")
diff --git a/misc.py b/misc.py
deleted file mode 100644
index 1aa0394..0000000
--- a/misc.py
+++ /dev/null
@@ -1,53 +0,0 @@
-from __future__ import annotations
-
-import logging
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
- from openai import OpenAI
- from openai.types.chat.chat_completion import ChatCompletion
-
-
-logger: logging.Logger = logging.getLogger(__name__)
-
-
-def get_allowed_users() -> list[str]:
- """Get the list of allowed users to interact with the bot.
-
- Returns:
- The list of allowed users.
- """
- return [
- "thelovinator",
- "killyoy",
- "forgefilip",
- "plubplub",
- "nobot",
- "kao172",
- ]
-
-
-def chat(user_message: str, openai_client: OpenAI) -> str | None:
- """Chat with the bot using the OpenAI API.
-
- Args:
- user_message: The message to send to OpenAI.
- openai_client: The OpenAI client to use.
-
- Returns:
- The response from the AI model.
- """
- completion: ChatCompletion = openai_client.chat.completions.create(
- model="gpt-4o-mini",
- messages=[
- {
- "role": "developer",
- "content": "You are in a Discord group chat with people above the age of 30. Use Discord Markdown to format messages if needed.", # noqa: E501
- },
- {"role": "user", "content": user_message},
- ],
- )
- response: str | None = completion.choices[0].message.content
- logger.info("AI response: %s", response)
-
- return response
diff --git a/pyproject.toml b/pyproject.toml
index ee95242..89bb4ad 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -7,9 +7,13 @@ requires-python = ">=3.13"
dependencies = [
"audioop-lts",
"discord-py",
+ "httpx",
"numpy",
+ "ollama",
"openai",
"opencv-contrib-python-headless",
+ "psutil",
+ "pydantic-ai-slim[duckduckgo,openai]",
"python-dotenv",
"sentry-sdk",
]
@@ -18,16 +22,21 @@ dependencies = [
dev = ["pytest", "ruff"]
[tool.ruff]
-preview = true
fix = true
+preview = true
unsafe-fixes = true
-lint.select = ["ALL"]
-lint.fixable = ["ALL"]
-lint.pydocstyle.convention = "google"
-lint.isort.required-imports = ["from __future__ import annotations"]
-lint.pycodestyle.ignore-overlong-task-comments = true
-line-length = 120
+format.docstring-code-format = true
+format.preview = true
+
+lint.future-annotations = true
+lint.isort.force-single-line = true
+lint.pycodestyle.ignore-overlong-task-comments = true
+lint.pydocstyle.convention = "google"
+lint.select = ["ALL"]
+
+# Don't automatically remove unused variables
+lint.unfixable = ["F841"]
lint.ignore = [
"CPY001", # Checks for the absence of copyright notices within Python files.
"D100", # Checks for undocumented public module definitions.
@@ -55,15 +64,12 @@ lint.ignore = [
]
-[tool.ruff.format]
-docstring-code-format = true
-docstring-code-line-length = 20
-
[tool.ruff.lint.per-file-ignores]
-"**/*_test.py" = [
+"**/test_*.py" = [
"ARG", # Unused function args -> fixtures nevertheless are functionally relevant...
"FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize()
"PLR2004", # Magic value used in comparison, ...
+ "PLR6301", # Method could be a function, class method, or static method
"S101", # asserts allowed in tests...
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
]
diff --git a/settings.py b/settings.py
deleted file mode 100644
index 0f4b7bd..0000000
--- a/settings.py
+++ /dev/null
@@ -1,29 +0,0 @@
-from __future__ import annotations
-
-import os
-from dataclasses import dataclass
-from functools import lru_cache
-
-from dotenv import load_dotenv
-
-load_dotenv(verbose=True)
-
-
-@dataclass
-class Settings:
- """Class to hold settings for the bot."""
-
- discord_token: str
- openai_api_key: str
-
- @classmethod
- @lru_cache(maxsize=1)
- def from_env(cls) -> Settings:
- """Create a new instance of the class from environment variables.
-
- Returns:
- A new instance of the class with the settings.
- """
- discord_token: str = os.getenv("DISCORD_TOKEN", "")
- openai_api_key: str = os.getenv("OPENAI_TOKEN", "")
- return cls(discord_token, openai_api_key)
diff --git a/systemd/anewdawn.env.example b/systemd/anewdawn.env.example
new file mode 100644
index 0000000..074a676
--- /dev/null
+++ b/systemd/anewdawn.env.example
@@ -0,0 +1,6 @@
+# Copy this file to /etc/ANewDawn/ANewDawn.env and fill in the required values.
+# Make sure the directory is owned by the user running the service (e.g., "lovinator").
+
+DISCORD_TOKEN=
+OPENAI_TOKEN=
+OLLAMA_API_KEY=
diff --git a/systemd/anewdawn.service b/systemd/anewdawn.service
new file mode 100644
index 0000000..bfdd50d
--- /dev/null
+++ b/systemd/anewdawn.service
@@ -0,0 +1,28 @@
+[Unit]
+Description=ANewDawn Discord Bot
+After=network.target
+
+[Service]
+Type=simple
+# Run the bot as the lovinator user (UID 1000) so it has appropriate permissions.
+# Update these values if you need a different system user/group.
+User=lovinator
+Group=lovinator
+
+# The project directory containing main.py (update as needed).
+WorkingDirectory=/home/lovinator/ANewDawn/
+
+# Load environment variables (see systemd/anewdawn.env.example).
+EnvironmentFile=/etc/ANewDawn/ANewDawn.env
+
+# Use the python interpreter from your environment (system python is fine if dependencies are installed).
+ExecStart=/usr/bin/uv run main.py
+
+Restart=on-failure
+RestartSec=5
+
+StandardOutput=journal
+StandardError=journal
+
+[Install]
+WantedBy=multi-user.target
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_reset_undo.py b/tests/test_reset_undo.py
new file mode 100644
index 0000000..1c82d47
--- /dev/null
+++ b/tests/test_reset_undo.py
@@ -0,0 +1,199 @@
+from __future__ import annotations
+
+import pytest
+
+from main import add_message_to_memory
+from main import last_trigger_time
+from main import recent_messages
+from main import reset_memory
+from main import reset_snapshots
+from main import undo_reset
+from main import update_trigger_time
+
+
+@pytest.fixture(autouse=True)
+def clear_state() -> None:
+ """Clear all state before each test."""
+ recent_messages.clear()
+ last_trigger_time.clear()
+ reset_snapshots.clear()
+
+
+class TestResetMemory:
+ """Tests for the reset_memory function."""
+
+ def test_reset_memory_clears_messages(self) -> None:
+ """Test that reset_memory clears messages for the channel."""
+ channel_id = "test_channel_123"
+ add_message_to_memory(channel_id, "user1", "Hello")
+ add_message_to_memory(channel_id, "user2", "World")
+
+ assert channel_id in recent_messages
+ assert len(recent_messages[channel_id]) == 2
+
+ reset_memory(channel_id)
+
+ assert channel_id not in recent_messages
+
+ def test_reset_memory_clears_trigger_times(self) -> None:
+ """Test that reset_memory clears trigger times for the channel."""
+ channel_id = "test_channel_123"
+ update_trigger_time(channel_id, "user1")
+
+ assert channel_id in last_trigger_time
+
+ reset_memory(channel_id)
+
+ assert channel_id not in last_trigger_time
+
+ def test_reset_memory_creates_snapshot(self) -> None:
+ """Test that reset_memory creates a snapshot for undo."""
+ channel_id = "test_channel_123"
+ add_message_to_memory(channel_id, "user1", "Test message")
+ update_trigger_time(channel_id, "user1")
+
+ reset_memory(channel_id)
+
+ assert channel_id in reset_snapshots
+ messages_snapshot, trigger_snapshot = reset_snapshots[channel_id]
+ assert len(messages_snapshot) == 1
+ assert "user1" in trigger_snapshot
+
+ def test_reset_memory_no_snapshot_for_empty_channel(self) -> None:
+ """Test that reset_memory doesn't create snapshot for empty channel."""
+ channel_id = "empty_channel"
+
+ reset_memory(channel_id)
+
+ assert channel_id not in reset_snapshots
+
+
+class TestUndoReset:
+ """Tests for the undo_reset function."""
+
+ def test_undo_reset_restores_messages(self) -> None:
+ """Test that undo_reset restores messages."""
+ channel_id = "test_channel_123"
+ add_message_to_memory(channel_id, "user1", "Hello")
+ add_message_to_memory(channel_id, "user2", "World")
+
+ reset_memory(channel_id)
+ assert channel_id not in recent_messages
+
+ result = undo_reset(channel_id)
+
+ assert result is True
+ assert channel_id in recent_messages
+ assert len(recent_messages[channel_id]) == 2
+
+ def test_undo_reset_restores_trigger_times(self) -> None:
+ """Test that undo_reset restores trigger times."""
+ channel_id = "test_channel_123"
+ update_trigger_time(channel_id, "user1")
+ original_time = last_trigger_time[channel_id]["user1"]
+
+ reset_memory(channel_id)
+ assert channel_id not in last_trigger_time
+
+ result = undo_reset(channel_id)
+
+ assert result is True
+ assert channel_id in last_trigger_time
+ assert last_trigger_time[channel_id]["user1"] == original_time
+
+ def test_undo_reset_removes_snapshot(self) -> None:
+ """Test that undo_reset removes the snapshot after restoring."""
+ channel_id = "test_channel_123"
+ add_message_to_memory(channel_id, "user1", "Hello")
+
+ reset_memory(channel_id)
+ assert channel_id in reset_snapshots
+
+ undo_reset(channel_id)
+
+ assert channel_id not in reset_snapshots
+
+ def test_undo_reset_returns_false_when_no_snapshot(self) -> None:
+ """Test that undo_reset returns False when no snapshot exists."""
+ channel_id = "nonexistent_channel"
+
+ result = undo_reset(channel_id)
+
+ assert result is False
+
+ def test_undo_reset_only_works_once(self) -> None:
+ """Test that undo_reset only works once (snapshot is removed after undo)."""
+ channel_id = "test_channel_123"
+ add_message_to_memory(channel_id, "user1", "Hello")
+
+ reset_memory(channel_id)
+ first_undo = undo_reset(channel_id)
+ second_undo = undo_reset(channel_id)
+
+ assert first_undo is True
+ assert second_undo is False
+
+
+class TestResetUndoIntegration:
+ """Integration tests for reset and undo functionality."""
+
+ def test_reset_then_undo_preserves_content(self) -> None:
+ """Test that reset followed by undo preserves original content."""
+ channel_id = "test_channel_123"
+ add_message_to_memory(channel_id, "user1", "Message 1")
+ add_message_to_memory(channel_id, "user2", "Message 2")
+ add_message_to_memory(channel_id, "user3", "Message 3")
+ update_trigger_time(channel_id, "user1")
+ update_trigger_time(channel_id, "user2")
+
+ # Capture original state
+ original_messages = list(recent_messages[channel_id])
+ original_trigger_users = set(last_trigger_time[channel_id].keys())
+
+ reset_memory(channel_id)
+ undo_reset(channel_id)
+
+ # Verify restored state matches original
+ restored_messages = list(recent_messages[channel_id])
+ restored_trigger_users = set(last_trigger_time[channel_id].keys())
+
+ assert len(restored_messages) == len(original_messages)
+ assert restored_trigger_users == original_trigger_users
+
+ def test_multiple_resets_overwrite_snapshot(self) -> None:
+ """Test that multiple resets overwrite the previous snapshot."""
+ channel_id = "test_channel_123"
+
+ # First set of messages
+ add_message_to_memory(channel_id, "user1", "First message")
+ reset_memory(channel_id)
+
+ # Second set of messages
+ add_message_to_memory(channel_id, "user1", "Second message")
+ add_message_to_memory(channel_id, "user1", "Third message")
+ reset_memory(channel_id)
+
+ # Undo should restore the second set, not the first
+ undo_reset(channel_id)
+
+ assert channel_id in recent_messages
+ assert len(recent_messages[channel_id]) == 2
+
+ def test_different_channels_independent_undo(self) -> None:
+ """Test that different channels have independent undo functionality."""
+ channel_1 = "channel_1"
+ channel_2 = "channel_2"
+
+ add_message_to_memory(channel_1, "user1", "Channel 1 message")
+ add_message_to_memory(channel_2, "user2", "Channel 2 message")
+
+ reset_memory(channel_1)
+ reset_memory(channel_2)
+
+ # Undo only channel 1
+ undo_reset(channel_1)
+
+ assert channel_1 in recent_messages
+ assert channel_2 not in recent_messages
+ assert channel_1 not in reset_snapshots
+ assert channel_2 in reset_snapshots