Compare commits
No commits in common. "80e0637e8ae6c3532c671e3f9fd708079e6940fe" and "1e515a839478dd680f54efcf4e09eb38987f6da9" have entirely different histories.
80e0637e8a
...
1e515a8394
7 changed files with 18 additions and 509 deletions
|
|
@ -1,4 +1,3 @@
|
||||||
DISCORD_TOKEN=
|
DISCORD_TOKEN=
|
||||||
OPENAI_TOKEN=
|
OPENAI_TOKEN=
|
||||||
OLLAMA_API_KEY=
|
OLLAMA_API_KEY=
|
||||||
OPENROUTER_API_KEY=
|
|
||||||
108
.github/copilot-instructions.md
vendored
108
.github/copilot-instructions.md
vendored
|
|
@ -1,108 +0,0 @@
|
||||||
# Copilot Instructions for ANewDawn
|
|
||||||
|
|
||||||
## Project Overview
|
|
||||||
|
|
||||||
ANewDawn is a Discord bot written in Python 3.13+ using the discord.py library and Pydantic AI for AI-powered chat capabilities. The bot includes features such as:
|
|
||||||
|
|
||||||
- AI-powered chat responses using OpenAI and Grok models
|
|
||||||
- Conversation memory with reset/undo functionality
|
|
||||||
- Image enhancement using OpenCV
|
|
||||||
- Web search integration via Ollama
|
|
||||||
- Slash commands and context menus
|
|
||||||
|
|
||||||
## Development Environment
|
|
||||||
|
|
||||||
- **Python**: 3.13 or higher required
|
|
||||||
- **Package Manager**: Use `uv` for dependency management (see `pyproject.toml`)
|
|
||||||
- **Docker**: The project uses Docker for deployment (see `Dockerfile` and `docker-compose.yml`)
|
|
||||||
- **Environment Variables**: Copy `.env.example` to `.env` and fill in required tokens
|
|
||||||
|
|
||||||
## Code Style and Conventions
|
|
||||||
|
|
||||||
### Linting and Formatting
|
|
||||||
|
|
||||||
This project uses **Ruff** for linting and formatting with strict settings:
|
|
||||||
|
|
||||||
- All rules enabled (`lint.select = ["ALL"]`)
|
|
||||||
- Preview features enabled
|
|
||||||
- Auto-fix enabled
|
|
||||||
- Line length: 160 characters
|
|
||||||
- Google-style docstrings required
|
|
||||||
|
|
||||||
Run linting:
|
|
||||||
```bash
|
|
||||||
ruff check --exit-non-zero-on-fix --verbose
|
|
||||||
```
|
|
||||||
|
|
||||||
Run formatting check:
|
|
||||||
```bash
|
|
||||||
ruff format --check --verbose
|
|
||||||
```
|
|
||||||
|
|
||||||
### Python Conventions
|
|
||||||
|
|
||||||
- Use `from __future__ import annotations` at the top of all files (automatically added by Ruff)
|
|
||||||
- Use type hints for all function parameters and return types
|
|
||||||
- Follow Google docstring convention
|
|
||||||
- Use `logging` module for logging, not print statements
|
|
||||||
- Prefer explicit imports over wildcard imports
|
|
||||||
|
|
||||||
### Testing
|
|
||||||
|
|
||||||
- Tests use pytest
|
|
||||||
- Test files should be named `*_test.py` or `test_*.py`
|
|
||||||
- Run tests with: `pytest`
|
|
||||||
|
|
||||||
## Project Structure
|
|
||||||
|
|
||||||
- `main.py` - Main bot application with all commands and event handlers
|
|
||||||
- `pyproject.toml` - Project configuration and dependencies
|
|
||||||
- `Dockerfile` / `docker-compose.yml` - Container configuration
|
|
||||||
- `.github/workflows/` - CI/CD workflows
|
|
||||||
|
|
||||||
## Key Components
|
|
||||||
|
|
||||||
### Bot Client
|
|
||||||
|
|
||||||
The main bot client is `LoviBotClient` which extends `discord.Client`. It handles:
|
|
||||||
- Message events (`on_message`)
|
|
||||||
- Slash commands (`/ask`, `/grok`, `/reset`, `/undo`)
|
|
||||||
- Context menus (image enhancement)
|
|
||||||
|
|
||||||
### AI Integration
|
|
||||||
|
|
||||||
- `chatgpt_agent` - Pydantic AI agent using OpenAI
|
|
||||||
- `grok_it()` - Function for Grok model responses
|
|
||||||
- Message history is stored in `recent_messages` dict per channel
|
|
||||||
|
|
||||||
### Memory Management
|
|
||||||
|
|
||||||
- `add_message_to_memory()` - Store messages for context
|
|
||||||
- `reset_memory()` - Clear conversation history
|
|
||||||
- `undo_reset()` - Restore previous state
|
|
||||||
|
|
||||||
## CI/CD
|
|
||||||
|
|
||||||
The GitHub Actions workflow (`.github/workflows/docker-publish.yml`) runs:
|
|
||||||
1. Ruff linting and format check
|
|
||||||
2. Dockerfile validation
|
|
||||||
3. Docker image build and push to GitHub Container Registry
|
|
||||||
|
|
||||||
## Common Tasks
|
|
||||||
|
|
||||||
### Adding a New Slash Command
|
|
||||||
|
|
||||||
1. Add the command function with `@client.tree.command()` decorator
|
|
||||||
2. Include `@app_commands.allowed_installs()` and `@app_commands.allowed_contexts()` decorators
|
|
||||||
3. Use `await interaction.response.defer()` for long-running operations
|
|
||||||
4. Check user authorization with `get_allowed_users()`
|
|
||||||
|
|
||||||
### Adding a New AI Instruction
|
|
||||||
|
|
||||||
1. Create a function decorated with `@chatgpt_agent.instructions`
|
|
||||||
2. The function should return a string with the instruction content
|
|
||||||
3. Use `RunContext[BotDependencies]` parameter to access dependencies
|
|
||||||
|
|
||||||
### Modifying Image Enhancement
|
|
||||||
|
|
||||||
Image enhancement functions (`enhance_image1`, `enhance_image2`, `enhance_image3`) use OpenCV. Each returns WebP-encoded bytes.
|
|
||||||
14
.github/workflows/docker-publish.yml
vendored
14
.github/workflows/docker-publish.yml
vendored
|
|
@ -13,7 +13,7 @@ jobs:
|
||||||
OPENAI_TOKEN: "0"
|
OPENAI_TOKEN: "0"
|
||||||
steps:
|
steps:
|
||||||
# GitHub Container Registry
|
# GitHub Container Registry
|
||||||
- uses: docker/login-action@v4
|
- uses: docker/login-action@v3
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request'
|
||||||
with:
|
with:
|
||||||
registry: ghcr.io
|
registry: ghcr.io
|
||||||
|
|
@ -21,7 +21,7 @@ jobs:
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
# Download the latest commit from the master branch
|
# Download the latest commit from the master branch
|
||||||
- uses: actions/checkout@v6
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
# Install the latest version of ruff
|
# Install the latest version of ruff
|
||||||
- uses: astral-sh/ruff-action@v3
|
- uses: astral-sh/ruff-action@v3
|
||||||
|
|
@ -39,17 +39,15 @@ jobs:
|
||||||
|
|
||||||
# Extract metadata (tags, labels) from Git reference and GitHub events for Docker
|
# Extract metadata (tags, labels) from Git reference and GitHub events for Docker
|
||||||
- id: meta
|
- id: meta
|
||||||
uses: docker/metadata-action@v6
|
uses: docker/metadata-action@v5
|
||||||
if: github.ref == 'refs/heads/master'
|
|
||||||
with:
|
with:
|
||||||
images: ghcr.io/thelovinator1/anewdawn
|
images: ghcr.io/thelovinator1/anewdawn
|
||||||
tags: type=raw,value=latest
|
tags: type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', 'master') }}
|
||||||
|
|
||||||
# Build and push the Docker image
|
# Build and push the Docker image
|
||||||
- uses: docker/build-push-action@v7
|
- uses: docker/build-push-action@v6
|
||||||
if: github.event_name != 'pull_request' && github.ref == 'refs/heads/master'
|
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
push: true
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
|
|
|
||||||
1
.vscode/settings.json
vendored
1
.vscode/settings.json
vendored
|
|
@ -37,7 +37,6 @@
|
||||||
"numpy",
|
"numpy",
|
||||||
"Ollama",
|
"Ollama",
|
||||||
"opencv",
|
"opencv",
|
||||||
"OPENROUTER",
|
|
||||||
"percpu",
|
"percpu",
|
||||||
"phibiscarf",
|
"phibiscarf",
|
||||||
"plubplub",
|
"plubplub",
|
||||||
|
|
|
||||||
193
main.py
193
main.py
|
|
@ -36,7 +36,6 @@ if TYPE_CHECKING:
|
||||||
from discord.abc import MessageableChannel
|
from discord.abc import MessageableChannel
|
||||||
from discord.guild import GuildChannel
|
from discord.guild import GuildChannel
|
||||||
from discord.interactions import InteractionChannel
|
from discord.interactions import InteractionChannel
|
||||||
from openai.types.chat import ChatCompletion
|
|
||||||
from pydantic_ai.run import AgentRunResult
|
from pydantic_ai.run import AgentRunResult
|
||||||
|
|
||||||
load_dotenv(verbose=True)
|
load_dotenv(verbose=True)
|
||||||
|
|
@ -57,10 +56,6 @@ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_TOKEN", "")
|
||||||
recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {}
|
recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {}
|
||||||
last_trigger_time: dict[str, dict[str, datetime.datetime]] = {}
|
last_trigger_time: dict[str, dict[str, datetime.datetime]] = {}
|
||||||
|
|
||||||
# Storage for reset snapshots to enable undo functionality
|
|
||||||
# Each channel stores its previous state: (recent_messages_snapshot, last_trigger_time_snapshot)
|
|
||||||
reset_snapshots: dict[str, tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BotDependencies:
|
class BotDependencies:
|
||||||
|
|
@ -77,68 +72,20 @@ class BotDependencies:
|
||||||
openai_settings = OpenAIResponsesModelSettings(
|
openai_settings = OpenAIResponsesModelSettings(
|
||||||
openai_text_verbosity="low",
|
openai_text_verbosity="low",
|
||||||
)
|
)
|
||||||
chatgpt_agent: Agent[BotDependencies, str] = Agent(
|
agent: Agent[BotDependencies, str] = Agent(
|
||||||
model="gpt-5-chat-latest",
|
model="gpt-5-chat-latest",
|
||||||
deps_type=BotDependencies,
|
deps_type=BotDependencies,
|
||||||
model_settings=openai_settings,
|
model_settings=openai_settings,
|
||||||
)
|
)
|
||||||
grok_client = openai.OpenAI(
|
|
||||||
base_url="https://openrouter.ai/api/v1",
|
|
||||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def grok_it(
|
|
||||||
message: discord.Message | None,
|
|
||||||
user_message: str,
|
|
||||||
) -> str | None:
|
|
||||||
"""Chat with the bot using the Pydantic AI agent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_message: The message from the user.
|
|
||||||
message: The original Discord message object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The bot's response as a string, or None if no response.
|
|
||||||
"""
|
|
||||||
allowed_users: list[str] = get_allowed_users()
|
|
||||||
if message and message.author.name not in allowed_users:
|
|
||||||
return None
|
|
||||||
|
|
||||||
response: ChatCompletion = grok_client.chat.completions.create(
|
|
||||||
model="x-ai/grok-4-fast:free",
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": user_message,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
return response.choices[0].message.content
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: reset_memory
|
# MARK: reset_memory
|
||||||
def reset_memory(channel_id: str) -> None:
|
def reset_memory(channel_id: str) -> None:
|
||||||
"""Reset the conversation memory for a specific channel.
|
"""Reset the conversation memory for a specific channel.
|
||||||
|
|
||||||
Creates a snapshot of the current state before resetting to enable undo.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
channel_id (str): The ID of the channel to reset memory for.
|
channel_id (str): The ID of the channel to reset memory for.
|
||||||
"""
|
"""
|
||||||
# Create snapshot before reset for undo functionality
|
|
||||||
messages_snapshot: deque[tuple[str, str, datetime.datetime]] = (
|
|
||||||
deque(recent_messages[channel_id], maxlen=50) if channel_id in recent_messages else deque(maxlen=50)
|
|
||||||
)
|
|
||||||
|
|
||||||
trigger_snapshot: dict[str, datetime.datetime] = dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {}
|
|
||||||
|
|
||||||
# Only save snapshot if there's something to restore
|
|
||||||
if messages_snapshot or trigger_snapshot:
|
|
||||||
reset_snapshots[channel_id] = (messages_snapshot, trigger_snapshot)
|
|
||||||
logger.info("Created reset snapshot for channel %s", channel_id)
|
|
||||||
|
|
||||||
# Perform the actual reset
|
|
||||||
if channel_id in recent_messages:
|
if channel_id in recent_messages:
|
||||||
del recent_messages[channel_id]
|
del recent_messages[channel_id]
|
||||||
logger.info("Reset memory for channel %s", channel_id)
|
logger.info("Reset memory for channel %s", channel_id)
|
||||||
|
|
@ -147,41 +94,6 @@ def reset_memory(channel_id: str) -> None:
|
||||||
logger.info("Reset trigger times for channel %s", channel_id)
|
logger.info("Reset trigger times for channel %s", channel_id)
|
||||||
|
|
||||||
|
|
||||||
# MARK: undo_reset
|
|
||||||
def undo_reset(channel_id: str) -> bool:
|
|
||||||
"""Undo the last reset operation for a specific channel.
|
|
||||||
|
|
||||||
Restores the conversation memory from the saved snapshot.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_id (str): The ID of the channel to undo reset for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if undo was successful, False if no snapshot exists.
|
|
||||||
"""
|
|
||||||
if channel_id not in reset_snapshots:
|
|
||||||
logger.info("No reset snapshot found for channel %s", channel_id)
|
|
||||||
return False
|
|
||||||
|
|
||||||
messages_snapshot, trigger_snapshot = reset_snapshots[channel_id]
|
|
||||||
|
|
||||||
# Restore recent messages
|
|
||||||
if messages_snapshot:
|
|
||||||
recent_messages[channel_id] = messages_snapshot
|
|
||||||
logger.info("Restored messages for channel %s", channel_id)
|
|
||||||
|
|
||||||
# Restore trigger times
|
|
||||||
if trigger_snapshot:
|
|
||||||
last_trigger_time[channel_id] = trigger_snapshot
|
|
||||||
logger.info("Restored trigger times for channel %s", channel_id)
|
|
||||||
|
|
||||||
# Remove the snapshot after successful undo (only one undo allowed)
|
|
||||||
del reset_snapshots[channel_id]
|
|
||||||
logger.info("Removed reset snapshot for channel %s after undo", channel_id)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _message_text_length(msg: ModelRequest | ModelResponse) -> int:
|
def _message_text_length(msg: ModelRequest | ModelResponse) -> int:
|
||||||
"""Compute the total text length of all text parts in a message.
|
"""Compute the total text length of all text parts in a message.
|
||||||
|
|
||||||
|
|
@ -231,7 +143,7 @@ def compact_message_history(
|
||||||
|
|
||||||
|
|
||||||
# MARK: fetch_user_info
|
# MARK: fetch_user_info
|
||||||
@chatgpt_agent.instructions
|
@agent.instructions
|
||||||
def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
|
def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
|
||||||
"""Fetches detailed information about the user who sent the message, including their roles, status, and activity.
|
"""Fetches detailed information about the user who sent the message, including their roles, status, and activity.
|
||||||
|
|
||||||
|
|
@ -252,7 +164,7 @@ def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_system_performance_stats
|
# MARK: get_system_performance_stats
|
||||||
@chatgpt_agent.instructions
|
@agent.instructions
|
||||||
def get_system_performance_stats() -> str:
|
def get_system_performance_stats() -> str:
|
||||||
"""Retrieves current system performance metrics, including CPU, memory, and disk usage.
|
"""Retrieves current system performance metrics, including CPU, memory, and disk usage.
|
||||||
|
|
||||||
|
|
@ -269,7 +181,7 @@ def get_system_performance_stats() -> str:
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_channels
|
# MARK: get_channels
|
||||||
@chatgpt_agent.instructions
|
@agent.instructions
|
||||||
def get_channels(ctx: RunContext[BotDependencies]) -> str:
|
def get_channels(ctx: RunContext[BotDependencies]) -> str:
|
||||||
"""Retrieves a list of all channels the bot is currently in.
|
"""Retrieves a list of all channels the bot is currently in.
|
||||||
|
|
||||||
|
|
@ -308,7 +220,7 @@ def do_web_search(query: str) -> ollama.WebSearchResponse | None:
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_time_and_timezone
|
# MARK: get_time_and_timezone
|
||||||
@chatgpt_agent.instructions
|
@agent.instructions
|
||||||
def get_time_and_timezone() -> str:
|
def get_time_and_timezone() -> str:
|
||||||
"""Retrieves the current time and timezone information.
|
"""Retrieves the current time and timezone information.
|
||||||
|
|
||||||
|
|
@ -320,7 +232,7 @@ def get_time_and_timezone() -> str:
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_latency
|
# MARK: get_latency
|
||||||
@chatgpt_agent.instructions
|
@agent.instructions
|
||||||
def get_latency(ctx: RunContext[BotDependencies]) -> str:
|
def get_latency(ctx: RunContext[BotDependencies]) -> str:
|
||||||
"""Retrieves the current latency information.
|
"""Retrieves the current latency information.
|
||||||
|
|
||||||
|
|
@ -332,7 +244,7 @@ def get_latency(ctx: RunContext[BotDependencies]) -> str:
|
||||||
|
|
||||||
|
|
||||||
# MARK: added_information_from_web_search
|
# MARK: added_information_from_web_search
|
||||||
@chatgpt_agent.instructions
|
@agent.instructions
|
||||||
def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
|
def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
|
||||||
"""Adds information from a web search to the system prompt.
|
"""Adds information from a web search to the system prompt.
|
||||||
|
|
||||||
|
|
@ -350,7 +262,7 @@ def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_sticker_instructions
|
# MARK: get_sticker_instructions
|
||||||
@chatgpt_agent.instructions
|
@agent.instructions
|
||||||
def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str:
|
def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str:
|
||||||
"""Provides instructions for using stickers in the chat.
|
"""Provides instructions for using stickers in the chat.
|
||||||
|
|
||||||
|
|
@ -379,7 +291,7 @@ def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str:
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_emoji_instructions
|
# MARK: get_emoji_instructions
|
||||||
@chatgpt_agent.instructions
|
@agent.instructions
|
||||||
def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
|
def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
|
||||||
"""Provides instructions for using emojis in the chat.
|
"""Provides instructions for using emojis in the chat.
|
||||||
|
|
||||||
|
|
@ -428,7 +340,7 @@ def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_system_prompt
|
# MARK: get_system_prompt
|
||||||
@chatgpt_agent.instructions
|
@agent.instructions
|
||||||
def get_system_prompt() -> str:
|
def get_system_prompt() -> str:
|
||||||
"""Generate the core system prompt.
|
"""Generate the core system prompt.
|
||||||
|
|
||||||
|
|
@ -498,7 +410,7 @@ async def chat( # noqa: PLR0913, PLR0917
|
||||||
|
|
||||||
images: list[str] = await get_images_from_text(user_message)
|
images: list[str] = await get_images_from_text(user_message)
|
||||||
|
|
||||||
result: AgentRunResult[str] = await chatgpt_agent.run(
|
result: AgentRunResult[str] = await agent.run(
|
||||||
user_prompt=[
|
user_prompt=[
|
||||||
user_message,
|
user_message,
|
||||||
*[ImageUrl(url=image_url) for image_url in images],
|
*[ImageUrl(url=image_url) for image_url in images],
|
||||||
|
|
@ -853,62 +765,6 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo
|
||||||
await send_response(interaction=interaction, text=text, response=display_response)
|
await send_response(interaction=interaction, text=text, response=display_response)
|
||||||
|
|
||||||
|
|
||||||
# MARK: /grok command
|
|
||||||
@client.tree.command(name="grok", description="Grok a question.")
|
|
||||||
@app_commands.allowed_installs(guilds=True, users=True)
|
|
||||||
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
|
||||||
@app_commands.describe(text="Grok a question.")
|
|
||||||
async def grok(interaction: discord.Interaction, text: str) -> None:
|
|
||||||
"""A command to ask the AI a question.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
interaction (discord.Interaction): The interaction object.
|
|
||||||
text (str): The question or message to ask.
|
|
||||||
"""
|
|
||||||
await interaction.response.defer()
|
|
||||||
|
|
||||||
if not text:
|
|
||||||
logger.error("No question or message provided.")
|
|
||||||
await interaction.followup.send("You need to provide a question or message.", ephemeral=True)
|
|
||||||
return
|
|
||||||
|
|
||||||
user_name_lowercase: str = interaction.user.name.lower()
|
|
||||||
logger.info("Received command from: %s", user_name_lowercase)
|
|
||||||
|
|
||||||
# Only allow certain users to interact with the bot
|
|
||||||
allowed_users: list[str] = get_allowed_users()
|
|
||||||
if user_name_lowercase not in allowed_users:
|
|
||||||
await send_response(interaction=interaction, text=text, response="You are not authorized to use this command.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get model response
|
|
||||||
try:
|
|
||||||
model_response: str | None = grok_it(message=interaction.message, user_message=text)
|
|
||||||
except openai.OpenAIError as e:
|
|
||||||
logger.exception("An error occurred while chatting with the AI model.")
|
|
||||||
await send_response(interaction=interaction, text=text, response=f"An error occurred: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
truncated_text: str = truncate_user_input(text)
|
|
||||||
|
|
||||||
# Fallback if model provided no response
|
|
||||||
if not model_response:
|
|
||||||
logger.warning("No response from the AI model. Message: %s", text)
|
|
||||||
model_response = "I forgor how to think 💀"
|
|
||||||
|
|
||||||
display_response: str = f"`{truncated_text}`\n\n{model_response}"
|
|
||||||
logger.info("Responding to message: %s with: %s", text, display_response)
|
|
||||||
|
|
||||||
# If response is longer than 2000 characters, split it into multiple messages
|
|
||||||
max_discord_message_length: int = 2000
|
|
||||||
if len(display_response) > max_discord_message_length:
|
|
||||||
for i in range(0, len(display_response), max_discord_message_length):
|
|
||||||
await send_response(interaction=interaction, text=text, response=display_response[i : i + max_discord_message_length])
|
|
||||||
return
|
|
||||||
|
|
||||||
await send_response(interaction=interaction, text=text, response=display_response)
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: /reset command
|
# MARK: /reset command
|
||||||
@client.tree.command(name="reset", description="Reset the conversation memory.")
|
@client.tree.command(name="reset", description="Reset the conversation memory.")
|
||||||
@app_commands.allowed_installs(guilds=True, users=True)
|
@app_commands.allowed_installs(guilds=True, users=True)
|
||||||
|
|
@ -933,33 +789,6 @@ async def reset(interaction: discord.Interaction) -> None:
|
||||||
await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.")
|
await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.")
|
||||||
|
|
||||||
|
|
||||||
# MARK: /undo command
|
|
||||||
@client.tree.command(name="undo", description="Undo the last /reset command.")
|
|
||||||
@app_commands.allowed_installs(guilds=True, users=True)
|
|
||||||
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
|
||||||
async def undo(interaction: discord.Interaction) -> None:
|
|
||||||
"""A command to undo the last reset operation."""
|
|
||||||
await interaction.response.defer()
|
|
||||||
|
|
||||||
user_name_lowercase: str = interaction.user.name.lower()
|
|
||||||
logger.info("Received undo command from: %s", user_name_lowercase)
|
|
||||||
|
|
||||||
# Only allow certain users to interact with the bot
|
|
||||||
allowed_users: list[str] = get_allowed_users()
|
|
||||||
if user_name_lowercase not in allowed_users:
|
|
||||||
await send_response(interaction=interaction, text="", response="You are not authorized to use this command.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Undo the last reset
|
|
||||||
if interaction.channel is not None:
|
|
||||||
if undo_reset(str(interaction.channel.id)):
|
|
||||||
await interaction.followup.send(f"Successfully restored conversation memory for {interaction.channel}.")
|
|
||||||
else:
|
|
||||||
await interaction.followup.send(f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.")
|
|
||||||
else:
|
|
||||||
await interaction.followup.send("Cannot undo: No channel context available.")
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: send_response
|
# MARK: send_response
|
||||||
async def send_response(interaction: discord.Interaction, text: str, response: str) -> None:
|
async def send_response(interaction: discord.Interaction, text: str, response: str) -> None:
|
||||||
"""Send a response to the interaction, handling potential errors.
|
"""Send a response to the interaction, handling potential errors.
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,6 @@ docstring-code-line-length = 20
|
||||||
"ARG", # Unused function args -> fixtures nevertheless are functionally relevant...
|
"ARG", # Unused function args -> fixtures nevertheless are functionally relevant...
|
||||||
"FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize()
|
"FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize()
|
||||||
"PLR2004", # Magic value used in comparison, ...
|
"PLR2004", # Magic value used in comparison, ...
|
||||||
"PLR6301", # Method could be a function, class method, or static method
|
|
||||||
"S101", # asserts allowed in tests...
|
"S101", # asserts allowed in tests...
|
||||||
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
|
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
|
||||||
]
|
]
|
||||||
|
|
@ -77,9 +76,3 @@ log_cli_level = "INFO"
|
||||||
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
|
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
|
||||||
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
|
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
|
||||||
python_files = "test_*.py *_test.py *_tests.py"
|
python_files = "test_*.py *_test.py *_tests.py"
|
||||||
|
|
||||||
[dependency-groups]
|
|
||||||
dev = [
|
|
||||||
"pytest>=9.0.1",
|
|
||||||
"ruff>=0.14.7",
|
|
||||||
]
|
|
||||||
|
|
|
||||||
|
|
@ -1,201 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from main import (
|
|
||||||
add_message_to_memory,
|
|
||||||
last_trigger_time,
|
|
||||||
recent_messages,
|
|
||||||
reset_memory,
|
|
||||||
reset_snapshots,
|
|
||||||
undo_reset,
|
|
||||||
update_trigger_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def clear_state() -> None:
|
|
||||||
"""Clear all state before each test."""
|
|
||||||
recent_messages.clear()
|
|
||||||
last_trigger_time.clear()
|
|
||||||
reset_snapshots.clear()
|
|
||||||
|
|
||||||
|
|
||||||
class TestResetMemory:
|
|
||||||
"""Tests for the reset_memory function."""
|
|
||||||
|
|
||||||
def test_reset_memory_clears_messages(self) -> None:
|
|
||||||
"""Test that reset_memory clears messages for the channel."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
add_message_to_memory(channel_id, "user1", "Hello")
|
|
||||||
add_message_to_memory(channel_id, "user2", "World")
|
|
||||||
|
|
||||||
assert channel_id in recent_messages
|
|
||||||
assert len(recent_messages[channel_id]) == 2
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
|
|
||||||
assert channel_id not in recent_messages
|
|
||||||
|
|
||||||
def test_reset_memory_clears_trigger_times(self) -> None:
|
|
||||||
"""Test that reset_memory clears trigger times for the channel."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
update_trigger_time(channel_id, "user1")
|
|
||||||
|
|
||||||
assert channel_id in last_trigger_time
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
|
|
||||||
assert channel_id not in last_trigger_time
|
|
||||||
|
|
||||||
def test_reset_memory_creates_snapshot(self) -> None:
|
|
||||||
"""Test that reset_memory creates a snapshot for undo."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
add_message_to_memory(channel_id, "user1", "Test message")
|
|
||||||
update_trigger_time(channel_id, "user1")
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
|
|
||||||
assert channel_id in reset_snapshots
|
|
||||||
messages_snapshot, trigger_snapshot = reset_snapshots[channel_id]
|
|
||||||
assert len(messages_snapshot) == 1
|
|
||||||
assert "user1" in trigger_snapshot
|
|
||||||
|
|
||||||
def test_reset_memory_no_snapshot_for_empty_channel(self) -> None:
|
|
||||||
"""Test that reset_memory doesn't create snapshot for empty channel."""
|
|
||||||
channel_id = "empty_channel"
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
|
|
||||||
assert channel_id not in reset_snapshots
|
|
||||||
|
|
||||||
|
|
||||||
class TestUndoReset:
|
|
||||||
"""Tests for the undo_reset function."""
|
|
||||||
|
|
||||||
def test_undo_reset_restores_messages(self) -> None:
|
|
||||||
"""Test that undo_reset restores messages."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
add_message_to_memory(channel_id, "user1", "Hello")
|
|
||||||
add_message_to_memory(channel_id, "user2", "World")
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
assert channel_id not in recent_messages
|
|
||||||
|
|
||||||
result = undo_reset(channel_id)
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
assert channel_id in recent_messages
|
|
||||||
assert len(recent_messages[channel_id]) == 2
|
|
||||||
|
|
||||||
def test_undo_reset_restores_trigger_times(self) -> None:
|
|
||||||
"""Test that undo_reset restores trigger times."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
update_trigger_time(channel_id, "user1")
|
|
||||||
original_time = last_trigger_time[channel_id]["user1"]
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
assert channel_id not in last_trigger_time
|
|
||||||
|
|
||||||
result = undo_reset(channel_id)
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
assert channel_id in last_trigger_time
|
|
||||||
assert last_trigger_time[channel_id]["user1"] == original_time
|
|
||||||
|
|
||||||
def test_undo_reset_removes_snapshot(self) -> None:
|
|
||||||
"""Test that undo_reset removes the snapshot after restoring."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
add_message_to_memory(channel_id, "user1", "Hello")
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
assert channel_id in reset_snapshots
|
|
||||||
|
|
||||||
undo_reset(channel_id)
|
|
||||||
|
|
||||||
assert channel_id not in reset_snapshots
|
|
||||||
|
|
||||||
def test_undo_reset_returns_false_when_no_snapshot(self) -> None:
|
|
||||||
"""Test that undo_reset returns False when no snapshot exists."""
|
|
||||||
channel_id = "nonexistent_channel"
|
|
||||||
|
|
||||||
result = undo_reset(channel_id)
|
|
||||||
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
def test_undo_reset_only_works_once(self) -> None:
|
|
||||||
"""Test that undo_reset only works once (snapshot is removed after undo)."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
add_message_to_memory(channel_id, "user1", "Hello")
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
first_undo = undo_reset(channel_id)
|
|
||||||
second_undo = undo_reset(channel_id)
|
|
||||||
|
|
||||||
assert first_undo is True
|
|
||||||
assert second_undo is False
|
|
||||||
|
|
||||||
|
|
||||||
class TestResetUndoIntegration:
|
|
||||||
"""Integration tests for reset and undo functionality."""
|
|
||||||
|
|
||||||
def test_reset_then_undo_preserves_content(self) -> None:
|
|
||||||
"""Test that reset followed by undo preserves original content."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
add_message_to_memory(channel_id, "user1", "Message 1")
|
|
||||||
add_message_to_memory(channel_id, "user2", "Message 2")
|
|
||||||
add_message_to_memory(channel_id, "user3", "Message 3")
|
|
||||||
update_trigger_time(channel_id, "user1")
|
|
||||||
update_trigger_time(channel_id, "user2")
|
|
||||||
|
|
||||||
# Capture original state
|
|
||||||
original_messages = list(recent_messages[channel_id])
|
|
||||||
original_trigger_users = set(last_trigger_time[channel_id].keys())
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
undo_reset(channel_id)
|
|
||||||
|
|
||||||
# Verify restored state matches original
|
|
||||||
restored_messages = list(recent_messages[channel_id])
|
|
||||||
restored_trigger_users = set(last_trigger_time[channel_id].keys())
|
|
||||||
|
|
||||||
assert len(restored_messages) == len(original_messages)
|
|
||||||
assert restored_trigger_users == original_trigger_users
|
|
||||||
|
|
||||||
def test_multiple_resets_overwrite_snapshot(self) -> None:
|
|
||||||
"""Test that multiple resets overwrite the previous snapshot."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
|
|
||||||
# First set of messages
|
|
||||||
add_message_to_memory(channel_id, "user1", "First message")
|
|
||||||
reset_memory(channel_id)
|
|
||||||
|
|
||||||
# Second set of messages
|
|
||||||
add_message_to_memory(channel_id, "user1", "Second message")
|
|
||||||
add_message_to_memory(channel_id, "user1", "Third message")
|
|
||||||
reset_memory(channel_id)
|
|
||||||
|
|
||||||
# Undo should restore the second set, not the first
|
|
||||||
undo_reset(channel_id)
|
|
||||||
|
|
||||||
assert channel_id in recent_messages
|
|
||||||
assert len(recent_messages[channel_id]) == 2
|
|
||||||
|
|
||||||
def test_different_channels_independent_undo(self) -> None:
|
|
||||||
"""Test that different channels have independent undo functionality."""
|
|
||||||
channel_1 = "channel_1"
|
|
||||||
channel_2 = "channel_2"
|
|
||||||
|
|
||||||
add_message_to_memory(channel_1, "user1", "Channel 1 message")
|
|
||||||
add_message_to_memory(channel_2, "user2", "Channel 2 message")
|
|
||||||
|
|
||||||
reset_memory(channel_1)
|
|
||||||
reset_memory(channel_2)
|
|
||||||
|
|
||||||
# Undo only channel 1
|
|
||||||
undo_reset(channel_1)
|
|
||||||
|
|
||||||
assert channel_1 in recent_messages
|
|
||||||
assert channel_2 not in recent_messages
|
|
||||||
assert channel_1 not in reset_snapshots
|
|
||||||
assert channel_2 in reset_snapshots
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue