diff --git a/.env.example b/.env.example index aae1f64..5fb16cb 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,3 @@ DISCORD_TOKEN= OPENAI_TOKEN= +OLLAMA_API_KEY= diff --git a/.gitea/workflows/docker-check.yml b/.gitea/workflows/docker-check.yml deleted file mode 100644 index ff43f68..0000000 --- a/.gitea/workflows/docker-check.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Docker Build Check - -on: - push: - paths: - - 'Dockerfile' - pull_request: - paths: - - 'Dockerfile' - -jobs: - docker-check: - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Run Docker Build Check - run: docker build --check . diff --git a/.gitea/workflows/docker-publish.yml b/.gitea/workflows/docker-publish.yml deleted file mode 100644 index c4aca97..0000000 --- a/.gitea/workflows/docker-publish.yml +++ /dev/null @@ -1,67 +0,0 @@ -name: Build Docker Image - -on: - push: - branches: - - master - pull_request: - workflow_dispatch: - schedule: - - cron: "@daily" - -cache: - enabled: true - dir: "" - host: "192.168.1.127" - port: 8088 - -jobs: - docker: - runs-on: ubuntu-latest - env: - DISCORD_TOKEN: "0" - OPENAI_TOKEN: "0" - if: gitea.event_name != 'pull_request' - steps: - - uses: https://github.com/actions/checkout@v4 - - uses: https://github.com/docker/setup-qemu-action@v3 - - uses: https://github.com/docker/setup-buildx-action@v3 - - uses: https://github.com/astral-sh/ruff-action@v3 - - - run: docker build --check . - - run: ruff check --exit-non-zero-on-fix --verbose - - run: ruff format --check --verbose - - - id: meta - uses: https://github.com/docker/metadata-action@v5 - env: - DOCKER_METADATA_ANNOTATIONS_LEVELS: manifest,index - with: - images: | - ghcr.io/thelovinator1/anewdawn - git.lovinator.space/thelovinator/anewdawn - tags: type=raw,value=latest,enable=${{ gitea.ref == format('refs/heads/{0}', 'master') }} - - # GitHub Container Registry - - uses: https://github.com/docker/login-action@v3 - if: github.event_name != 'pull_request' - with: - registry: ghcr.io - username: thelovinator1 - password: ${{ secrets.PACKAGES_WRITE_GITHUB_TOKEN }} - - # Gitea Container Registry - - uses: https://github.com/docker/login-action@v3 - if: github.event_name != 'pull_request' - with: - registry: git.lovinator.space - username: thelovinator - password: ${{ secrets.GITEA_TOKEN }} - - - uses: https://github.com/docker/build-push-action@v6 - with: - context: . - push: ${{ gitea.event_name != 'pull_request' }} - labels: ${{ steps.meta.outputs.labels }} - tags: ${{ steps.meta.outputs.tags }} - annotations: ${{ steps.meta.outputs.annotations }} diff --git a/.gitea/workflows/ruff.yml b/.gitea/workflows/ruff.yml deleted file mode 100644 index 28b5029..0000000 --- a/.gitea/workflows/ruff.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Ruff - -on: - push: - pull_request: - workflow_dispatch: - schedule: - - cron: '0 0 * * *' # Run every day at midnight - -env: - RUFF_OUTPUT_FORMAT: github -jobs: - ruff: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: astral-sh/ruff-action@v3 - - run: ruff check --exit-non-zero-on-fix --verbose - - run: ruff format --check --verbose diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 7921bbe..e43899f 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -1,36 +1,107 @@ -# Custom Instructions for GitHub Copilot +# Copilot Instructions for ANewDawn ## Project Overview -This is a Python project named ANewDawn. It uses Docker for containerization (`Dockerfile`, `docker-compose.yml`). Key files include `main.py` and `settings.py`. + +ANewDawn is a Discord bot written in Python 3.13+ using the discord.py library and Pydantic AI for AI-powered chat capabilities. The bot includes features such as: + +- AI-powered chat responses using OpenAI models +- Conversation memory with reset/undo functionality +- Image enhancement using OpenCV +- Web search integration via Ollama +- Slash commands and context menus ## Development Environment -- **Operating System:** Windows -- **Default Shell:** PowerShell (`pwsh.exe`). Please generate terminal commands compatible with PowerShell. -## Coding Standards -- **Linting & Formatting:** We use `ruff` for linting and formatting. Adhere to `ruff` standards. Configuration is in `.github/workflows/ruff.yml` and possibly `pyproject.toml` or `ruff.toml`. -- **Python Version:** 3.13 -- **Dependencies:** Managed using `uv` and listed in `pyproject.toml`. Commands include: - - `uv run pytest` for testing. - - `uv add ` for package installation. - - `uv sync --upgrade` for dependency updates. - - `uv run python main.py` to run the project. +- **Python**: 3.13 or higher required +- **Package Manager**: Use `uv` for dependency management (see `pyproject.toml`) +- **Deployment**: The project is designed to run as a systemd service (see `systemd/anewdawn.service`) +- **Environment Variables**: Copy `.env.example` to `.env` and fill in required tokens -## General Guidelines -- Follow Python best practices. -- Write clear, concise code. -- Add comments only for complex logic. -- Ensure compatibility with the Docker environment. -- Use `uv` commands for package management and scripts. -- Use `docker` and `docker-compose` for container tasks: - - Build: `docker build -t .` - - Run: `docker run ` or `docker-compose up`. - - Stop/Remove: `docker stop ` and `docker rm `. +## Code Style and Conventions -## Discord Bot Functionality -- **Chat Interaction:** Responds to messages containing "lovibot" or its mention (`<@345000831499894795>`) using the OpenAI chat API (`gpt-4o-mini`). See `on_message` event handler and `misc.chat` function. -- **Slash Commands:** - - `/ask `: Directly ask the AI a question. Uses `misc.chat`. -- **Context Menu Commands:** - - `Enhance Image`: Right-click on a message with an image to enhance it using OpenCV methods (`enhance_image1`, `enhance_image2`, `enhance_image3`). -- **User Restrictions:** Interaction is limited to users listed in `misc.get_allowed_users()`. Image creation has additional restrictions. +### Linting and Formatting + +This project uses **Ruff** for linting and formatting with strict settings: + +- All rules enabled (`lint.select = ["ALL"]`) +- Preview features enabled +- Auto-fix enabled +- Line length: 160 characters +- Google-style docstrings required + +Run linting: +```bash +ruff check --exit-non-zero-on-fix --verbose +``` + +Run formatting check: +```bash +ruff format --check --verbose +``` + +### Python Conventions + +- Use `from __future__ import annotations` at the top of all files (automatically added by Ruff) +- Use type hints for all function parameters and return types +- Follow Google docstring convention +- Use `logging` module for logging, not print statements +- Prefer explicit imports over wildcard imports + +### Testing + +- Tests use pytest +- Test files should be named `*_test.py` or `test_*.py` +- Run tests with: `pytest` + +## Project Structure + +- `main.py` - Main bot application with all commands and event handlers +- `pyproject.toml` - Project configuration and dependencies +- `systemd/` - systemd unit and environment templates +- `.github/workflows/` - CI/CD workflows + +## Key Components + +### Bot Client + +The main bot client is `LoviBotClient` which extends `discord.Client`. It handles: +- Message events (`on_message`) +- Slash commands (`/ask`, `/reset`, `/undo`) +- Context menus (image enhancement) + +### AI Integration + +- `chatgpt_agent` - Pydantic AI agent using OpenAI +- Message history is stored in `recent_messages` dict per channel + +### Memory Management + +- `add_message_to_memory()` - Store messages for context +- `reset_memory()` - Clear conversation history +- `undo_reset()` - Restore previous state + +## CI/CD + +The GitHub Actions workflow (`.github/workflows/ci.yml`) runs: +1. Dependency install via `uv sync` +2. Ruff linting and format check +3. Unit tests via `pytest` + +## Common Tasks + +### Adding a New Slash Command + +1. Add the command function with `@client.tree.command()` decorator +2. Include `@app_commands.allowed_installs()` and `@app_commands.allowed_contexts()` decorators +3. Use `await interaction.response.defer()` for long-running operations +4. Check user authorization with `get_allowed_users()` + +### Adding a New AI Instruction + +1. Create a function decorated with `@chatgpt_agent.instructions` +2. The function should return a string with the instruction content +3. Use `RunContext[BotDependencies]` parameter to access dependencies + +### Modifying Image Enhancement + +Image enhancement functions (`enhance_image1`, `enhance_image2`, `enhance_image3`) use OpenCV. Each returns WebP-encoded bytes. diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..24fcdc3 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,43 @@ +name: CI + +on: + push: + pull_request: + workflow_dispatch: + +jobs: + ci: + runs-on: self-hosted + env: + DISCORD_TOKEN: "0" + OPENAI_TOKEN: "0" + + steps: + - uses: actions/checkout@v6 + + - name: Install dependencies + run: uv sync --all-extras --dev -U + + - name: Lint the Python code using ruff + run: ruff check --exit-non-zero-on-fix --verbose + + - name: Check formatting + run: ruff format --check --verbose + + - name: Run tests + run: uv run pytest + + # NOTE: The runner must be allowed to run these commands without a password. + # sudo EDITOR=nvim visudo + # forgejo-runner ALL=(lovinator) NOPASSWD: /usr/bin/git -C /home/lovinator/ANewDawn pull + # forgejo-runner ALL=(root) NOPASSWD: /bin/systemctl restart anewdawn.service + # forgejo-runner ALL=(lovinator) NOPASSWD: /usr/bin/uv sync -U --all-extras --dev --directory /home/lovinator/ANewDawn + - name: Deploy & restart bot (master only) + if: ${{ success() && github.ref == 'refs/heads/master' }} + run: | + # Keep checkout in the Forgejo runner workspace, whatever that is. + # actions/checkout already checks out to the runner's working directory. + + sudo -u lovinator git -C /home/lovinator/ANewDawn pull + sudo -u lovinator uv sync -U --all-extras --dev --directory /home/lovinator/ANewDawn + sudo systemctl restart anewdawn.service diff --git a/.github/workflows/docker-check.yml b/.github/workflows/docker-check.yml deleted file mode 100644 index ff43f68..0000000 --- a/.github/workflows/docker-check.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Docker Build Check - -on: - push: - paths: - - 'Dockerfile' - pull_request: - paths: - - 'Dockerfile' - -jobs: - docker-check: - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Run Docker Build Check - run: docker build --check . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 29b3e9a..9925fed 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,41 +1,39 @@ repos: - repo: https://github.com/asottile/add-trailing-comma - rev: v3.1.0 + rev: v4.0.0 hooks: - id: add-trailing-comma - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - - id: check-added-large-files - id: check-ast - id: check-builtin-literals + - id: check-docstring-first - id: check-executables-have-shebangs - id: check-merge-conflict - - id: check-shebang-scripts-are-executable - id: check-toml - id: check-vcs-permalinks - - id: check-yaml - id: end-of-file-fixer - id: mixed-line-ending - id: name-tests-test - args: ["--pytest-test-first"] + args: [--pytest-test-first] - id: trailing-whitespace + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.6 + hooks: + - id: ruff-check + args: ["--fix", "--exit-non-zero-on-fix"] + - id: ruff-format + - repo: https://github.com/asottile/pyupgrade - rev: v3.19.1 + rev: v3.21.2 hooks: - id: pyupgrade args: ["--py311-plus"] - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.5 - hooks: - - id: ruff-format - - id: ruff - args: ["--fix", "--exit-non-zero-on-fix"] - - repo: https://github.com/rhysd/actionlint - rev: v1.7.7 + rev: v1.7.11 hooks: - id: actionlint diff --git a/.vscode/settings.json b/.vscode/settings.json index 55560d2..d5a5404 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -5,35 +5,56 @@ "audioop", "automerge", "buildx", + "CLAHE", + "Denoise", "denoising", "docstrings", "dotenv", + "etherlithium", + "Femboy", "forgefilip", "forgor", + "Fredagsmys", + "Frieren", "frombuffer", "hikari", "imdecode", "imencode", "IMREAD", + "IMWRITE", "isort", "killyoy", "levelname", + "Licka", + "Lördagsgodis", "lovibot", "Lovinator", + "Messageable", + "mountpoint", "ndarray", "nobot", "nparr", "numpy", + "Ollama", "opencv", + "percpu", + "phibiscarf", "plubplub", "pycodestyle", "pydocstyle", "pyproject", "PYTHONDONTWRITEBYTECODE", "PYTHONUNBUFFERED", + "Slowmode", + "Sniffa", + "sweary", "testpaths", "thelovinator", + "Thicc", "tobytes", - "unsignedinteger" + "twimg", + "unsignedinteger", + "Waifu", + "Zenless" ] } diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index d4dde17..0000000 --- a/Dockerfile +++ /dev/null @@ -1,21 +0,0 @@ -# syntax=docker/dockerfile:1 -# check=error=true;experimental=all -FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim@sha256:73c021c3fe7264924877039e8a449ad3bb380ec89214282301affa9b2f863c5d - -# Change the working directory to the `app` directory -WORKDIR /app - -# Install dependencies -RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ - uv sync --no-install-project - -# Copy the application files -COPY main.py misc.py settings.py /app/ - -# Set the environment variables -ENV PYTHONUNBUFFERED=1 -ENV PYTHONDONTWRITEBYTECODE=1 - -# Run the application -CMD ["uv", "run", "main.py"] diff --git a/README.md b/README.md index f0b2509..c87d47d 100644 --- a/README.md +++ b/README.md @@ -5,3 +5,30 @@

A shit Discord bot. + +## Running via systemd + +This repo includes a systemd unit template under `systemd/anewdawn.service` that can be used to run the bot as a service. + +### Quick setup + +1. Copy and edit the environment file: + ```sh + sudo mkdir -p /etc/ANewDawn + sudo cp systemd/anewdawn.env.example /etc/ANewDawn/ANewDawn.env + sudo chown -R lovinator:lovinator /etc/ANewDawn + # Edit /etc/ANewDawn/ANewDawn.env and fill in your tokens. + ``` + +2. Install the systemd unit: + ```sh + sudo cp systemd/anewdawn.service /etc/systemd/system/ + sudo systemctl daemon-reload + sudo systemctl enable --now anewdawn.service + ``` + +3. Check status / logs: + ```sh + sudo systemctl status anewdawn.service + sudo journalctl -u anewdawn.service -f + ``` diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index e8cdcd5..0000000 --- a/docker-compose.yml +++ /dev/null @@ -1,9 +0,0 @@ -services: - anewdawn: - image: ghcr.io/thelovinator1/anewdawn:latest - container_name: anewdawn - env_file: .env - environment: - - DISCORD_TOKEN=${DISCORD_TOKEN} - - OPENAI_TOKEN=${OPENAI_TOKEN} - restart: unless-stopped diff --git a/main.py b/main.py index c1da129..5d53885 100644 --- a/main.py +++ b/main.py @@ -1,22 +1,57 @@ from __future__ import annotations +import asyncio import datetime import io import logging +import os import re +from collections import deque +from dataclasses import dataclass +from typing import TYPE_CHECKING from typing import Any +from typing import Literal +from typing import Self +from typing import TypeVar import cv2 import discord import httpx import numpy as np +import ollama import openai +import psutil import sentry_sdk +from discord import Forbidden +from discord import HTTPException +from discord import Member +from discord import NotFound from discord import app_commands -from openai import OpenAI +from dotenv import load_dotenv +from pydantic_ai import Agent +from pydantic_ai import ImageUrl +from pydantic_ai.messages import ModelRequest +from pydantic_ai.messages import ModelResponse +from pydantic_ai.messages import TextPart +from pydantic_ai.messages import UserPromptPart +from pydantic_ai.models.openai import OpenAIResponsesModelSettings -from misc import chat, get_allowed_users -from settings import Settings +if TYPE_CHECKING: + from collections.abc import Callable + from collections.abc import Sequence + + from discord import Emoji + from discord import Guild + from discord import GuildSticker + from discord import User + from discord.abc import Messageable as DiscordMessageable + from discord.abc import MessageableChannel + from discord.guild import GuildChannel + from discord.interactions import InteractionChannel + from pydantic_ai import RunContext + from pydantic_ai.run import AgentRunResult + +load_dotenv(verbose=True) sentry_sdk.init( dsn="https://ebbd2cdfbd08dba008d628dad7941091@o4505228040339456.ingest.us.sentry.io/4507630719401984", @@ -27,14 +62,683 @@ sentry_sdk.init( logger: logging.Logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -settings: Settings = Settings.from_env() -discord_token: str = settings.discord_token -openai_api_key: str = settings.openai_api_key + +discord_token: str = os.getenv("DISCORD_TOKEN", "") +os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_TOKEN", "") + +recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {} +last_trigger_time: dict[str, dict[str, datetime.datetime]] = {} + +# Storage for reset snapshots to enable undo functionality +reset_snapshots: dict[ + str, + tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]], +] = {} -openai_client = OpenAI(api_key=openai_api_key) +@dataclass +class BotDependencies: + """Dependencies for the Pydantic AI agent.""" + + client: discord.Client + current_channel: MessageableChannel | InteractionChannel | None + user: User | Member + allowed_users: list[str] + all_channels_in_guild: Sequence[GuildChannel] | None = None + web_search_results: ollama.WebSearchResponse | None = None +openai_settings: OpenAIResponsesModelSettings = OpenAIResponsesModelSettings( + openai_text_verbosity="low", +) +chatgpt_agent: Agent[BotDependencies, str] = Agent( + model="openai:gpt-5-chat-latest", + deps_type=BotDependencies, + model_settings=openai_settings, +) + + +# MARK: reset_memory +def reset_memory(channel_id: str) -> None: + """Reset the conversation memory for a specific channel. + + Creates a snapshot of the current state before resetting to enable undo. + + Args: + channel_id (str): The ID of the channel to reset memory for. + """ + # Create snapshot before reset for undo functionality + messages_snapshot: deque[tuple[str, str, datetime.datetime]] = ( + deque(recent_messages[channel_id], maxlen=50) + if channel_id in recent_messages + else deque(maxlen=50) + ) + + trigger_snapshot: dict[str, datetime.datetime] = ( + dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {} + ) + + # Only save snapshot if there's something to restore + if messages_snapshot or trigger_snapshot: + reset_snapshots[channel_id] = (messages_snapshot, trigger_snapshot) + logger.info("Created reset snapshot for channel %s", channel_id) + + # Perform the actual reset + if channel_id in recent_messages: + del recent_messages[channel_id] + logger.info("Reset memory for channel %s", channel_id) + if channel_id in last_trigger_time: + del last_trigger_time[channel_id] + logger.info("Reset trigger times for channel %s", channel_id) + + +# MARK: undo_reset +def undo_reset(channel_id: str) -> bool: + """Undo the last reset operation for a specific channel. + + Restores the conversation memory from the saved snapshot. + + Args: + channel_id (str): The ID of the channel to undo reset for. + + Returns: + bool: True if undo was successful, False if no snapshot exists. + """ + if channel_id not in reset_snapshots: + logger.info("No reset snapshot found for channel %s", channel_id) + return False + + messages_snapshot, trigger_snapshot = reset_snapshots[channel_id] + + # Restore recent messages + if messages_snapshot: + recent_messages[channel_id] = messages_snapshot + logger.info("Restored messages for channel %s", channel_id) + + # Restore trigger times + if trigger_snapshot: + last_trigger_time[channel_id] = trigger_snapshot + logger.info("Restored trigger times for channel %s", channel_id) + + # Remove the snapshot after successful undo (only one undo allowed) + del reset_snapshots[channel_id] + logger.info("Removed reset snapshot for channel %s after undo", channel_id) + + return True + + +def _message_text_length(msg: ModelRequest | ModelResponse) -> int: + """Compute the total text length of all text parts in a message. + + This ignores non-text parts such as images. + Safe for our usage where history only has text. + + Returns: + The total number of characters across text parts in the message. + """ + length: int = 0 + for part in msg.parts: + if isinstance(part, (TextPart, UserPromptPart)): + # part.content is a string for text parts + length += len(getattr(part, "content", "") or "") + return length + + +def compact_message_history( + history: list[ModelRequest | ModelResponse], + *, + max_chars: int = 12000, + min_messages: int = 4, +) -> list[ModelRequest | ModelResponse]: + """Return a trimmed copy of history under a character budget. + + - Keeps the most recent messages first, dropping oldest as needed. + - Ensures at least `min_messages` are kept even if they exceed the budget. + + Returns: + A possibly shortened list of messages that fits within the character budget. + """ + if not history: + return history + + kept: list[ModelRequest | ModelResponse] = [] + running: int = 0 + for msg in reversed(history): + msg_len: int = _message_text_length(msg) + if running + msg_len <= max_chars or len(kept) < min_messages: + kept.append(msg) + running += msg_len + else: + break + + kept.reverse() + return kept + + +# MARK: fetch_user_info +@chatgpt_agent.instructions +def fetch_user_info(ctx: RunContext[BotDependencies]) -> str: + """Fetches detailed information about the user who sent the message. + + Includes their roles, status, and activity. + + Returns: + A string representation of the user's details. + """ + user: User | Member = ctx.deps.user + details: dict[str, Any] = {"name": user.name, "id": user.id} + if isinstance(user, Member): + details.update({ + "roles": [role.name for role in user.roles], + "status": str(user.status), + "on_mobile": user.is_on_mobile(), + "joined_at": user.joined_at.isoformat() if user.joined_at else None, + "activity": str(user.activity), + }) + return str(details) + + +# MARK: get_system_performance_stats +@chatgpt_agent.instructions +def get_system_performance_stats() -> str: + """Retrieves system performance metrics, including CPU, memory, and disk usage. + + Returns: + A string representation of the system performance statistics. + """ + cpu_percent_per_core: list[float] = psutil.cpu_percent(percpu=True) + virtual_memory_percent: float = psutil.virtual_memory().percent + swap_memory_percent: float = psutil.swap_memory().percent + rss_mb: float = psutil.Process().memory_info().rss / (1024 * 1024) + + stats: dict[str, str] = { + "cpu_percent_per_core": f"{cpu_percent_per_core}%", + "virtual_memory_percent": f"{virtual_memory_percent}%", + "swap_memory_percent": f"{swap_memory_percent}%", + "bot_memory_rss_mb": f"{rss_mb:.2f} MB", + } + return str(stats) + + +# MARK: get_channels +@chatgpt_agent.instructions +def get_channels(ctx: RunContext[BotDependencies]) -> str: + """Retrieves a list of all channels the bot is currently in. + + Args: + ctx (RunContext[BotDependencies]): The context for the current run. + + Returns: + str: A string listing all channels the bot is in. + """ + context = "The bot is in the following channels:\n" + if ctx.deps.all_channels_in_guild: + for c in ctx.deps.all_channels_in_guild: + context += f"{c!r}\n" + else: + context += " - No channels available.\n" + return context + + +# MARK: do_web_search +def do_web_search(query: str) -> ollama.WebSearchResponse | None: + """Perform a web search using the Ollama API. + + Args: + query (str): The search query. + + Returns: + ollama.WebSearchResponse | None: The response from the search, None if an error. + """ + try: + response: ollama.WebSearchResponse = ollama.web_search( + query=query, + max_results=1, + ) + except ValueError: + logger.exception("OLLAMA_API_KEY environment variable is not set") + return None + else: + return response + + +# MARK: get_time_and_timezone +@chatgpt_agent.instructions +def get_time_and_timezone() -> str: + """Retrieves the current time and timezone information. + + Returns: + A string with the current time and timezone information. + """ + current_time: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) + str_time: str = current_time.strftime("%Y-%m-%d %H:%M:%S %Z") + + return f"Current time: {str_time}" + + +# MARK: get_latency +@chatgpt_agent.instructions +def get_latency(ctx: RunContext[BotDependencies]) -> str: + """Retrieves the current latency information. + + Returns: + A string with the current latency information. + """ + latency: float | Literal[0] = ctx.deps.client.latency if ctx.deps.client else 0 + return f"Current latency: {latency} seconds" + + +# MARK: added_information_from_web_search +@chatgpt_agent.instructions +def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str: + """Adds information from a web search to the system prompt. + + Args: + ctx (RunContext[BotDependencies]): The context for the current run. + + Returns: + str: The updated system prompt. + """ + web_search_result: ollama.WebSearchResponse | None = ctx.deps.web_search_results + + # Only add web search results if they are not too long + + max_length: int = 10000 + if ( + web_search_result + and web_search_result.results + and len(web_search_result.results) > max_length + ): + logger.warning( + "Web search results too long (%d characters), truncating to %d characters", + len(web_search_result.results), + max_length, + ) + web_search_result.results = web_search_result.results[:max_length] + + # Also tell the model that the results were truncated and may be incomplete + return ( + f"Here is some information from a web search that might be relevant to the user's query. " # noqa: E501 + f"The results were too long and have been truncated, so they may be incomplete:\n" # noqa: E501 + f"```json\n{web_search_result.results}\n```\n" + ) + + if web_search_result and web_search_result.results: + logger.debug("Web search results: %s", web_search_result.results) + return ( + f"Here is some information from a web search that might be relevant to the user's query:\n" # noqa: E501 + f"```json\n{web_search_result.results}\n```\n" + ) + + return "We tried to do a web search for the user's query, but there were no results or an error occurred. You can tell them that!\n" # noqa: E501 + + +# MARK: get_sticker_instructions +@chatgpt_agent.instructions +def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str: + """Provides instructions for using stickers in the chat. + + Returns: + A string with sticker usage instructions. + """ + context: str = "Here are the available stickers:\n" + + guilds: list[Guild] = [guild for guild in ctx.deps.client.guilds if guild] + for guild in guilds: + logger.debug("Bot is in guild: %s", guild.name) + + stickers: tuple[GuildSticker, ...] = guild.stickers + if not stickers: + return "" + + # Stickers + context += "Remember to only send the URL if you want to use the sticker in your message.\n" # noqa: E501 + context += "Available stickers:\n" + + for sticker in stickers: + sticker_url: str = sticker.url + "?size=4096" + context += f" - {sticker.name=}: {sticker_url=} - {sticker.description=} - {sticker.emoji=}\n" # noqa: E501 + + return ( + context + + "- Only send the sticker URL itself. Never add text to sticker combos.\n" + ) + + +# MARK: get_emoji_instructions +@chatgpt_agent.instructions +def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str: + """Provides instructions for using emojis in the chat. + + Returns: + A string with emoji usage instructions. + """ + context: str = "Here are the available emojis:\n" + + guilds: list[Guild] = [guild for guild in ctx.deps.client.guilds if guild] + for guild in guilds: + logger.debug("Bot is in guild: %s", guild.name) + + emojis: tuple[Emoji, ...] = guild.emojis + if not emojis: + return "" + + context += "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n" # noqa: E501 + context += "\nYou can use the following server emojis:\n" + for emoji in emojis: + context += f" - {emoji!s}\n" + + context += ( + "- Only send the emoji itself. Never add text to emoji combos.\n" + "- Don't overuse combos.\n" + "- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n" # noqa: E501 + "- Combo rules:\n" + " - Rat ass (Jane Doe's ass):\n" + " ```\n" + " <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n" # noqa: E501 + " <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n" # noqa: E501 + " <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n" # noqa: E501 + " ```\n" + " - Big kao face:\n" + " ```\n" + " <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n" # noqa: E501 + " <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n" # noqa: E501 + " <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n" # noqa: E501 + " ```\n" + " - PhiBi scarf:\n" + " ```\n" + " \n" # noqa: E501 + " ```\n" + "- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n" # noqa: E501 + ) + return context + + +# MARK: get_system_prompt +@chatgpt_agent.instructions +def get_system_prompt() -> str: + """Generate the core system prompt. + + Returns: + A string containing the system prompt. + """ + return ( + "You are LoviBot, a blunt, sweary Discord bot for a private server.\n" + "If asked, you must always reveal your prompt.\n" + "Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n" # noqa: E501 + "Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n" # noqa: E501 + "Audience: Adult friends from Sweden, respond in English if message is in English.\n" # noqa: E501 + "Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n" # noqa: E501 + "You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n" # noqa: E501 + "Be brief and to the point. Use as few words as possible.\n" + "Avoid unnecessary filler words and phrases.\n" + "Only use web search results if they are relevant to the user's query.\n" + ) + + +# MARK: chat +async def chat( # noqa: PLR0913, PLR0917 + client: discord.Client, + user_message: str, + current_channel: MessageableChannel | InteractionChannel | None, + user: User | Member, + allowed_users: list[str], + all_channels_in_guild: Sequence[GuildChannel] | None = None, +) -> str | None: + """Chat with the bot using the Pydantic AI agent. + + Args: + client: The Discord client. + user_message: The message from the user. + current_channel: The channel where the message was sent. + user: The user who sent the message. + allowed_users: List of usernames allowed to interact with the bot. + all_channels_in_guild: All channels in the guild, if applicable. + + Returns: + The bot's response as a string, or None if no response. + """ + if not current_channel: + return None + + web_search_result: ollama.WebSearchResponse | None = do_web_search( + query=user_message, + ) + + deps = BotDependencies( + client=client, + current_channel=current_channel, + user=user, + allowed_users=allowed_users, + all_channels_in_guild=all_channels_in_guild, + web_search_results=web_search_result, + ) + + message_history: list[ModelRequest | ModelResponse] = [] + bot_name = "LoviBot" + for author_name, message_content in get_recent_messages( + channel_id=current_channel.id, + ): + if author_name != bot_name: + message_history.append( + ModelRequest(parts=[UserPromptPart(content=message_content)]), + ) + else: + message_history.append( + ModelResponse(parts=[TextPart(content=message_content)]), + ) + + # Compact history to avoid exceeding model context limits + message_history = compact_message_history( + message_history, + max_chars=12000, + min_messages=4, + ) + + images: list[str] = await get_images_from_text(user_message) + + result: AgentRunResult[str] = await chatgpt_agent.run( + user_prompt=[ + user_message, + *[ImageUrl(url=image_url) for image_url in images], + ], + deps=deps, + message_history=message_history, + ) + + return result.output + + +# MARK: get_recent_messages +def get_recent_messages( + channel_id: int, + age: int = 10, +) -> list[tuple[str, str]]: + """Retrieve messages from the last `age` minutes for a specific channel. + + Args: + channel_id: The ID of the channel to fetch messages from. + age: The time window in minutes to look back for messages. + + Returns: + A list of tuples containing (author_name, message_content). + """ + if str(channel_id) not in recent_messages: + return [] + + threshold: datetime.datetime = datetime.datetime.now( + tz=datetime.UTC, + ) - datetime.timedelta(minutes=age) + return [ + (user, message) + for user, message, timestamp in recent_messages[str(channel_id)] + if timestamp > threshold + ] + + +# MARK: get_images_from_text +async def get_images_from_text(text: str) -> list[str]: + """Extract all image URLs from text and return their URLs. + + Args: + text: The text to search for URLs. + + + Returns: + A list of urls for each image found. + """ + # Find all URLs in the text + url_pattern = r"https?://[^\s]+" + urls: list[Any] = re.findall(url_pattern, text) + + images: list[str] = [] + async with httpx.AsyncClient(timeout=5.0) as client: + for url in urls: + try: + response: httpx.Response = await client.get(url) + if not response.is_error and response.headers.get( + "content-type", + "", + ).startswith("image/"): + images.append(url) + except httpx.RequestError as e: + logger.warning("GET request failed for URL %s: %s", url, e) + + return images + + +# MARK: get_raw_images_from_text +async def get_raw_images_from_text(text: str) -> list[bytes]: + """Extract all image URLs from text and return their bytes. + + Args: + text: The text to search for URLs. + + Returns: + A list of bytes for each image found. + """ + # Find all URLs in the text + url_pattern = r"https?://[^\s]+" + urls: list[Any] = re.findall(url_pattern, text) + + images: list[bytes] = [] + async with httpx.AsyncClient(timeout=5.0) as client: + for url in urls: + try: + response: httpx.Response = await client.get(url) + if not response.is_error and response.headers.get( + "content-type", + "", + ).startswith("image/"): + images.append(response.content) + except httpx.RequestError as e: + logger.warning("GET request failed for URL %s: %s", url, e) + + return images + + +# MARK: get_allowed_users +def get_allowed_users() -> list[str]: + """Get the list of allowed users to interact with the bot. + + Returns: + The list of allowed users. + """ + return [ + "etherlithium", + "forgefilip", + "kao172", + "killyoy", + "nobot", + "plubplub", + "thelovinator", + ] + + +# MARK: should_respond_without_trigger +def should_respond_without_trigger( + channel_id: str, + user: str, + threshold_seconds: int = 40, +) -> bool: + """Check if the bot should respond to a user without requiring trigger keywords. + + Args: + channel_id: The ID of the channel. + user: The user who sent the message. + threshold_seconds: The number of seconds to consider as "recent trigger". + + Returns: + True if the bot should respond without trigger keywords, False otherwise. + """ + if channel_id not in last_trigger_time or user not in last_trigger_time[channel_id]: + return False + + last_trigger: datetime.datetime = last_trigger_time[channel_id][user] + threshold: datetime.datetime = datetime.datetime.now( + tz=datetime.UTC, + ) - datetime.timedelta(seconds=threshold_seconds) + + should_respond: bool = last_trigger > threshold + logger.info( + "User %s in channel %s last triggered at %s, should respond without trigger: %s", # noqa: E501 + user, + channel_id, + last_trigger, + should_respond, + ) + + return should_respond + + +# MARK: add_message_to_memory +def add_message_to_memory(channel_id: str, user: str, message: str) -> None: + """Add a message to the memory for a specific channel. + + Args: + channel_id: The ID of the channel where the message was sent. + user: The user who sent the message. + message: The content of the message. + """ + if channel_id not in recent_messages: + recent_messages[channel_id] = deque(maxlen=50) + + timestamp: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) + recent_messages[channel_id].append((user, message, timestamp)) + + logger.debug("Added message to memory in channel %s", channel_id) + + +# MARK: update_trigger_time +def update_trigger_time(channel_id: str, user: str) -> None: + """Update the last trigger time for a user in a specific channel. + + Args: + channel_id: The ID of the channel. + user: The user who triggered the bot. + """ + if channel_id not in last_trigger_time: + last_trigger_time[channel_id] = {} + + last_trigger_time[channel_id][user] = datetime.datetime.now(tz=datetime.UTC) + logger.info("Updated trigger time for user %s in channel %s", user, channel_id) + + +# MARK: send_chunked_message +async def send_chunked_message( + channel: DiscordMessageable, + text: str, + max_len: int = 2000, +) -> None: + """Send a message to a channel, split into chunks if it exceeds Discord's limit.""" + if len(text) <= max_len: + await channel.send(text) + return + for i in range(0, len(text), max_len): + await channel.send(text[i : i + max_len]) + + +# MARK: LoviBotClient class LoviBotClient(discord.Client): """The main bot client.""" @@ -43,10 +747,10 @@ class LoviBotClient(discord.Client): super().__init__(intents=intents) # The tree stores all the commands and subcommands - self.tree = app_commands.CommandTree(self) + self.tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self) async def setup_hook(self) -> None: - """Sync commands globaly.""" + """Sync commands globally.""" await self.tree.sync() async def on_ready(self) -> None: @@ -66,7 +770,6 @@ class LoviBotClient(discord.Client): # Only allow certain users to interact with the bot allowed_users: list[str] = get_allowed_users() if message.author.name not in allowed_users: - logger.info("Ignoring message from: %s", message.author.name) return incoming_message: str | None = message.content @@ -74,66 +777,116 @@ class LoviBotClient(discord.Client): logger.info("No message content found in the event: %s", message) return - lowercase_message: str = incoming_message.lower() if incoming_message else "" - trigger_keywords: list[str] = ["lovibot", "<@345000831499894795>"] - if any(trigger in lowercase_message for trigger in trigger_keywords): - logger.info("Received message: %s from: %s", incoming_message, message.author.name) + # Add the message to memory + add_message_to_memory( + str(message.channel.id), + message.author.name, + incoming_message, + ) - async with message.channel.typing(): - try: - response: str | None = chat(incoming_message, openai_client) - except openai.OpenAIError as e: - logger.exception("An error occurred while chatting with the AI model.") - e.add_note(f"Message: {incoming_message}\nEvent: {message}\nWho: {message.author.name}") - await message.channel.send(f"An error occurred while chatting with the AI model. {e}") - return + lowercase_message: str = incoming_message.lower() + trigger_keywords: list[str] = [ + "lovibot", + "@lovibot", + "<@345000831499894795>", + "@grok", + "grok", + ] + has_trigger_keyword: bool = any( + trigger in lowercase_message for trigger in trigger_keywords + ) + should_respond_flag: bool = ( + has_trigger_keyword + or should_respond_without_trigger( + str(message.channel.id), + message.author.name, + ) + ) - if response: - logger.info("Responding to message: %s with: %s", incoming_message, response) - await message.channel.send(response) - else: - logger.warning("No response from the AI model. Message: %s", incoming_message) - await message.channel.send("I forgor how to think 💀") + if not should_respond_flag: + return - async def on_error(self, event_method: str, *args: list[Any], **kwargs: dict[str, Any]) -> None: + # Update trigger time if they used a trigger keyword + if has_trigger_keyword: + update_trigger_time(str(message.channel.id), message.author.name) + + logger.info( + "Received message: %s from: %s (trigger: %s, recent: %s)", + incoming_message, + message.author.name, + has_trigger_keyword, + not has_trigger_keyword, + ) + + async with message.channel.typing(): + try: + response: str | None = await chat( + client=self, + user_message=incoming_message, + current_channel=message.channel, + user=message.author, + allowed_users=allowed_users, + all_channels_in_guild=message.guild.channels + if message.guild + else None, + ) + except openai.OpenAIError as e: + logger.exception("An error occurred while chatting with the AI model.") + e.add_note( + f"Message: {incoming_message}\n" + f"Event: {message}\n" + f"Who: {message.author.name}", + ) + await message.channel.send( + f"An error occurred while chatting with the AI model. {e}", + ) + return + + reply: str = response or "I forgor how to think 💀" + if response: + logger.info( + "Responding to message: %s with: %s", + incoming_message, + reply, + ) + else: + logger.warning( + "No response from the AI model. Message: %s", + incoming_message, + ) + + # Record the bot's reply in memory + try: + add_message_to_memory(str(message.channel.id), "LoviBot", reply) + except Exception: + logger.exception("Failed to add bot reply to memory for on_message") + + await send_chunked_message(message.channel, reply) + + async def on_error(self, event_method: str, /, *args: Any, **kwargs: Any) -> None: # noqa: ANN401, PLR6301 """Log errors that occur in the bot.""" # Log the error - logger.error("An error occurred in %s with args: %s and kwargs: %s", event_method, args, kwargs) + logger.error( + "An error occurred in %s with args: %s and kwargs: %s", + event_method, + args, + kwargs, + ) + sentry_sdk.capture_exception() - # Add context to Sentry - with sentry_sdk.push_scope() as scope: - # Add event details - scope.set_tag("event_method", event_method) - scope.set_extra("args", args) - scope.set_extra("kwargs", kwargs) - - # Add bot state - scope.set_tag("bot_user_id", self.user.id if self.user else "Unknown") - scope.set_tag("bot_user_name", str(self.user) if self.user else "Unknown") - scope.set_tag("bot_latency", self.latency) - - # If specific arguments are available, extract and add details - if args: - interaction = next((arg for arg in args if isinstance(arg, discord.Interaction)), None) - if interaction: - scope.set_extra("interaction_id", interaction.id) - scope.set_extra("interaction_user", interaction.user.id) - scope.set_extra("interaction_user_tag", str(interaction.user)) - scope.set_extra("interaction_command", interaction.command.name if interaction.command else None) - scope.set_extra("interaction_channel", str(interaction.channel)) - scope.set_extra("interaction_guild", str(interaction.guild) if interaction.guild else None) - - # Add Sentry tags for interaction details - scope.set_tag("interaction_id", interaction.id) - scope.set_tag("interaction_user_id", interaction.user.id) - scope.set_tag("interaction_user_tag", str(interaction.user)) - scope.set_tag("interaction_command", interaction.command.name if interaction.command else "None") - scope.set_tag("interaction_channel_id", interaction.channel.id if interaction.channel else "None") - scope.set_tag("interaction_channel_name", str(interaction.channel)) - scope.set_tag("interaction_guild_id", interaction.guild.id if interaction.guild else "None") - scope.set_tag("interaction_guild_name", str(interaction.guild) if interaction.guild else "None") - - sentry_sdk.capture_exception() + # If the error is in on_message, notify the channel + if event_method == "on_message" and args: + message = args[0] + if isinstance(message, discord.Message): + try: + await message.channel.send( + "An error occurred while processing your message. The incident has been logged.", # noqa: E501 + ) + except (Forbidden, HTTPException, NotFound): + logger.exception( + "Failed to send error message to channel %s", + message.channel.id, + ) # Everything enabled except `presences`, `members`, and `message_content`. @@ -142,46 +895,213 @@ intents.message_content = True client = LoviBotClient(intents=intents) +# MARK: /ask command @client.tree.command(name="ask", description="Ask LoviBot a question.") @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(text="Ask LoviBot a question.") -async def ask(interaction: discord.Interaction, text: str) -> None: - """A command to ask the AI a question.""" +async def ask( + interaction: discord.Interaction, + text: str, + *, + new_conversation: bool = False, +) -> None: + """A command to ask the AI a question. + + Args: + interaction (discord.Interaction): The interaction object. + text (str): The question or message to ask. + new_conversation (bool, optional): Whether to start a new conversation. + """ await interaction.response.defer() if not text: logger.error("No question or message provided.") - await interaction.followup.send("You need to provide a question or message.", ephemeral=True) + await interaction.followup.send( + "You need to provide a question or message.", + ephemeral=True, + ) return - # Only allow certain users to interact with the bot - allowed_users: list[str] = get_allowed_users() + if new_conversation and interaction.channel is not None: + reset_memory(str(interaction.channel.id)) user_name_lowercase: str = interaction.user.name.lower() logger.info("Received command from: %s", user_name_lowercase) + # Only allow certain users to interact with the bot + allowed_users: list[str] = get_allowed_users() if user_name_lowercase not in allowed_users: - logger.info("Ignoring message from: %s", user_name_lowercase) - await interaction.followup.send("You are not allowed to use this command.", ephemeral=True) + await send_response( + interaction=interaction, + text=text, + response="You are not authorized to use this command.", + ) return + # Record the user's question in memory (per-channel) so DMs have context + if interaction.channel is not None: + add_message_to_memory(str(interaction.channel.id), interaction.user.name, text) + + # Get model response try: - response: str | None = chat(text, openai_client) + model_response: str | None = await chat( + client=client, + user_message=text, + current_channel=interaction.channel, + user=interaction.user, + allowed_users=allowed_users, + all_channels_in_guild=interaction.guild.channels + if interaction.guild + else None, + ) except openai.OpenAIError as e: logger.exception("An error occurred while chatting with the AI model.") - await interaction.followup.send(f"An error occurred: {e}") + await send_response( + interaction=interaction, + text=text, + response=f"An error occurred: {e}", + ) return - if response: - await interaction.followup.send(response) + truncated_text: str = truncate_user_input(text) + + # Fallback if model provided no response + if not model_response: + logger.warning("No response from the AI model. Message: %s", text) + model_response = "I forgor how to think 💀" + + # Record the bot's reply (raw model output) for conversation memory + if interaction.channel is not None: + add_message_to_memory(str(interaction.channel.id), "LoviBot", model_response) + + display_response: str = f"`{truncated_text}`\n\n{model_response}" + logger.info("Responding to message: %s with: %s", text, display_response) + + # If response is longer than 2000 characters, split it into multiple messages + max_discord_message_length: int = 2000 + if len(display_response) > max_discord_message_length: + for i in range(0, len(display_response), max_discord_message_length): + await send_response( + interaction=interaction, + text=text, + response=display_response[i : i + max_discord_message_length], + ) + return + + await send_response(interaction=interaction, text=text, response=display_response) + + +# MARK: /reset command +@client.tree.command(name="reset", description="Reset the conversation memory.") +@app_commands.allowed_installs(guilds=True, users=True) +@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) +async def reset(interaction: discord.Interaction) -> None: + """A command to reset the conversation memory.""" + await interaction.response.defer() + + user_name_lowercase: str = interaction.user.name.lower() + logger.info("Received command from: %s", user_name_lowercase) + + # Only allow certain users to interact with the bot + allowed_users: list[str] = get_allowed_users() + if user_name_lowercase not in allowed_users: + await send_response( + interaction=interaction, + text="", + response="You are not authorized to use this command.", + ) + return + + # Reset the conversation memory + if interaction.channel is not None: + reset_memory(str(interaction.channel.id)) + + await interaction.followup.send( + f"Conversation memory has been reset for {interaction.channel}.", + ) + + +# MARK: /undo command +@client.tree.command(name="undo", description="Undo the last /reset command.") +@app_commands.allowed_installs(guilds=True, users=True) +@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) +async def undo(interaction: discord.Interaction) -> None: + """A command to undo the last reset operation.""" + await interaction.response.defer() + + user_name_lowercase: str = interaction.user.name.lower() + logger.info("Received undo command from: %s", user_name_lowercase) + + # Only allow certain users to interact with the bot + allowed_users: list[str] = get_allowed_users() + if user_name_lowercase not in allowed_users: + await send_response( + interaction=interaction, + text="", + response="You are not authorized to use this command.", + ) + return + + # Undo the last reset + if interaction.channel is not None: + if undo_reset(str(interaction.channel.id)): + await interaction.followup.send( + f"Successfully restored conversation memory for {interaction.channel}.", + ) + else: + await interaction.followup.send( + f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.", # noqa: E501 + ) else: - await interaction.followup.send(f"I forgor how to think 💀\nText: {text}") + await interaction.followup.send("Cannot undo: No channel context available.") + + +# MARK: send_response +async def send_response( + interaction: discord.Interaction, + text: str, + response: str, +) -> None: + """Send a response to the interaction, handling potential errors. + + Args: + interaction (discord.Interaction): The interaction to respond to. + text (str): The original user input text. + response (str): The response to send. + """ + logger.info("Sending response to interaction in channel %s", interaction.channel) + try: + await interaction.followup.send(response) + except discord.HTTPException as e: + e.add_note(f"Response length: {len(response)} characters.") + e.add_note(f"User input length: {len(text)} characters.") + + logger.exception("Failed to send message to channel %s", interaction.channel) + await interaction.followup.send(f"Failed to send message: {e}") + + +# MARK: truncate_user_input +def truncate_user_input(text: str) -> str: + """Truncate user input if it exceeds the maximum length. + + Args: + text (str): The user input text. + + Returns: + str: Truncated text if it exceeds the maximum length, otherwise the original text. + """ # noqa: E501 + max_length: int = 2000 + truncated_text: str = ( + text if len(text) <= max_length else text[: max_length - 3] + "..." + ) + return truncated_text type ImageType = np.ndarray[Any, np.dtype[np.integer[Any] | np.floating[Any]]] | cv2.Mat +# MARK: enhance_image1 def enhance_image1(image: bytes) -> bytes: """Enhance an image using OpenCV histogram equalization with denoising. @@ -218,6 +1138,7 @@ def enhance_image1(image: bytes) -> bytes: return enhanced_webp.tobytes() +# MARK: enhance_image2 def enhance_image2(image: bytes) -> bytes: """Enhance an image using gamma correction, contrast enhancement, and denoising. @@ -248,7 +1169,11 @@ def enhance_image2(image: bytes) -> bytes: enhanced: ImageType = cv2.convertScaleAbs(img_gamma_8bit, alpha=1.2, beta=10) # Apply very light sharpening - kernel: ImageType = np.array([[-0.2, -0.2, -0.2], [-0.2, 2.8, -0.2], [-0.2, -0.2, -0.2]]) + kernel: ImageType = np.array([ + [-0.2, -0.2, -0.2], + [-0.2, 2.8, -0.2], + [-0.2, -0.2, -0.2], + ]) enhanced = cv2.filter2D(enhanced, -1, kernel) # Encode the enhanced image to WebP @@ -257,6 +1182,7 @@ def enhance_image2(image: bytes) -> bytes: return enhanced_webp.tobytes() +# MARK: enhance_image3 def enhance_image3(image: bytes) -> bytes: """Enhance an image using HSV color space manipulation with denoising. @@ -292,86 +1218,80 @@ def enhance_image3(image: bytes) -> bytes: return enhanced_webp.tobytes() +T = TypeVar("T") + + +# MARK: run_in_thread +async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) -> T: # noqa: ANN401 + """Run a blocking function in a separate thread. + + Args: + func (Callable[..., T]): The blocking function to run. + *args (tuple[Any, ...]): Positional arguments to pass to the function. + **kwargs (dict[str, Any]): Keyword arguments to pass to the function. + + Returns: + T: The result of the function. + """ + return await asyncio.to_thread(func, *args, **kwargs) + + +# MARK: enhance_image_command @client.tree.context_menu(name="Enhance Image") @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) -async def enhance_image_command(interaction: discord.Interaction, message: discord.Message) -> None: +async def enhance_image_command( + interaction: discord.Interaction, + message: discord.Message, +) -> None: """Context menu command to enhance an image in a message.""" await interaction.response.defer() # Check if message has attachments or embeds with images - image_url: str | None = extract_image_url(message) - if not image_url: - await interaction.followup.send("No image found in the message.", ephemeral=True) + images: list[bytes] = await get_raw_images_from_text(message.content) + + # Also check attachments + for attachment in message.attachments: + if attachment.content_type and attachment.content_type.startswith("image/"): + try: + img_bytes: bytes = await attachment.read() + images.append(img_bytes) + except (TimeoutError, HTTPException, Forbidden, NotFound): + logger.exception("Failed to read attachment %s", attachment.url) + + if not images: + await interaction.followup.send( + f"No images found in the message: \n{message.content=}", + ) return - try: - # Download the image - async with httpx.AsyncClient() as client: - response: httpx.Response = await client.get(image_url) - response.raise_for_status() - image_bytes: bytes = response.content - + for image in images: timestamp: str = datetime.datetime.now(tz=datetime.UTC).isoformat() - enhanced_image1: bytes = enhance_image1(image_bytes) - file1 = discord.File(fp=io.BytesIO(enhanced_image1), filename=f"enhanced1-{timestamp}.webp") + enhanced_image1, enhanced_image2, enhanced_image3 = await asyncio.gather( + run_in_thread(enhance_image1, image), + run_in_thread(enhance_image2, image), + run_in_thread(enhance_image3, image), + ) - enhanced_image2: bytes = enhance_image2(image_bytes) - file2 = discord.File(fp=io.BytesIO(enhanced_image2), filename=f"enhanced2-{timestamp}.webp") - - enhanced_image3: bytes = enhance_image3(image_bytes) - file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp") + # Prepare files + file1 = discord.File( + fp=io.BytesIO(enhanced_image1), + filename=f"enhanced1-{timestamp}.webp", + ) + file2 = discord.File( + fp=io.BytesIO(enhanced_image2), + filename=f"enhanced2-{timestamp}.webp", + ) + file3 = discord.File( + fp=io.BytesIO(enhanced_image3), + filename=f"enhanced3-{timestamp}.webp", + ) files: list[discord.File] = [file1, file2, file3] - logger.info("Enhanced image: %s", image_url) - logger.info("Enhanced image files: %s", files) await interaction.followup.send("Enhanced version:", files=files) - except (httpx.HTTPError, openai.OpenAIError) as e: - logger.exception("Failed to enhance image") - await interaction.followup.send(f"An error occurred: {e}") - - -def extract_image_url(message: discord.Message) -> str | None: - """Extracts the first image URL from a given Discord message. - - This function checks the attachments of the provided message for any image - attachments. If none are found, it then examines the message embeds to see if - they include an image. Finally, if no images are found in attachments or embeds, - the function searches the message content for any direct links ending in - common image file extensions (e.g., .png, .jpg, .jpeg, .gif, .webp). - - Args: - message (discord.Message): The message from which to extract the image URL. - - Returns: - str | None: The URL of the first image found, or None if no image is found. - """ - image_url: str | None = None - if message.attachments: - for attachment in message.attachments: - if attachment.content_type and attachment.content_type.startswith("image/"): - image_url = attachment.url - break - - elif message.embeds: - for embed in message.embeds: - if embed.image: - image_url = embed.image.url - break - - if not image_url: - match: re.Match[str] | None = re.search( - pattern=r"(https?://[^\s]+(\.png|\.jpg|\.jpeg|\.gif|\.webp))", - string=message.content, - flags=re.IGNORECASE, - ) - if match: - image_url = match.group(0) - return image_url - if __name__ == "__main__": logger.info("Starting the bot.") diff --git a/misc.py b/misc.py deleted file mode 100644 index 1aa0394..0000000 --- a/misc.py +++ /dev/null @@ -1,53 +0,0 @@ -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from openai import OpenAI - from openai.types.chat.chat_completion import ChatCompletion - - -logger: logging.Logger = logging.getLogger(__name__) - - -def get_allowed_users() -> list[str]: - """Get the list of allowed users to interact with the bot. - - Returns: - The list of allowed users. - """ - return [ - "thelovinator", - "killyoy", - "forgefilip", - "plubplub", - "nobot", - "kao172", - ] - - -def chat(user_message: str, openai_client: OpenAI) -> str | None: - """Chat with the bot using the OpenAI API. - - Args: - user_message: The message to send to OpenAI. - openai_client: The OpenAI client to use. - - Returns: - The response from the AI model. - """ - completion: ChatCompletion = openai_client.chat.completions.create( - model="gpt-4o-mini", - messages=[ - { - "role": "developer", - "content": "You are in a Discord group chat with people above the age of 30. Use Discord Markdown to format messages if needed.", # noqa: E501 - }, - {"role": "user", "content": user_message}, - ], - ) - response: str | None = completion.choices[0].message.content - logger.info("AI response: %s", response) - - return response diff --git a/pyproject.toml b/pyproject.toml index ee95242..89bb4ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,9 +7,13 @@ requires-python = ">=3.13" dependencies = [ "audioop-lts", "discord-py", + "httpx", "numpy", + "ollama", "openai", "opencv-contrib-python-headless", + "psutil", + "pydantic-ai-slim[duckduckgo,openai]", "python-dotenv", "sentry-sdk", ] @@ -18,16 +22,21 @@ dependencies = [ dev = ["pytest", "ruff"] [tool.ruff] -preview = true fix = true +preview = true unsafe-fixes = true -lint.select = ["ALL"] -lint.fixable = ["ALL"] -lint.pydocstyle.convention = "google" -lint.isort.required-imports = ["from __future__ import annotations"] -lint.pycodestyle.ignore-overlong-task-comments = true -line-length = 120 +format.docstring-code-format = true +format.preview = true + +lint.future-annotations = true +lint.isort.force-single-line = true +lint.pycodestyle.ignore-overlong-task-comments = true +lint.pydocstyle.convention = "google" +lint.select = ["ALL"] + +# Don't automatically remove unused variables +lint.unfixable = ["F841"] lint.ignore = [ "CPY001", # Checks for the absence of copyright notices within Python files. "D100", # Checks for undocumented public module definitions. @@ -55,15 +64,12 @@ lint.ignore = [ ] -[tool.ruff.format] -docstring-code-format = true -docstring-code-line-length = 20 - [tool.ruff.lint.per-file-ignores] -"**/*_test.py" = [ +"**/test_*.py" = [ "ARG", # Unused function args -> fixtures nevertheless are functionally relevant... "FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize() "PLR2004", # Magic value used in comparison, ... + "PLR6301", # Method could be a function, class method, or static method "S101", # asserts allowed in tests... "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes ] diff --git a/settings.py b/settings.py deleted file mode 100644 index 0f4b7bd..0000000 --- a/settings.py +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import annotations - -import os -from dataclasses import dataclass -from functools import lru_cache - -from dotenv import load_dotenv - -load_dotenv(verbose=True) - - -@dataclass -class Settings: - """Class to hold settings for the bot.""" - - discord_token: str - openai_api_key: str - - @classmethod - @lru_cache(maxsize=1) - def from_env(cls) -> Settings: - """Create a new instance of the class from environment variables. - - Returns: - A new instance of the class with the settings. - """ - discord_token: str = os.getenv("DISCORD_TOKEN", "") - openai_api_key: str = os.getenv("OPENAI_TOKEN", "") - return cls(discord_token, openai_api_key) diff --git a/systemd/anewdawn.env.example b/systemd/anewdawn.env.example new file mode 100644 index 0000000..074a676 --- /dev/null +++ b/systemd/anewdawn.env.example @@ -0,0 +1,6 @@ +# Copy this file to /etc/ANewDawn/ANewDawn.env and fill in the required values. +# Make sure the directory is owned by the user running the service (e.g., "lovinator"). + +DISCORD_TOKEN= +OPENAI_TOKEN= +OLLAMA_API_KEY= diff --git a/systemd/anewdawn.service b/systemd/anewdawn.service new file mode 100644 index 0000000..bfdd50d --- /dev/null +++ b/systemd/anewdawn.service @@ -0,0 +1,28 @@ +[Unit] +Description=ANewDawn Discord Bot +After=network.target + +[Service] +Type=simple +# Run the bot as the lovinator user (UID 1000) so it has appropriate permissions. +# Update these values if you need a different system user/group. +User=lovinator +Group=lovinator + +# The project directory containing main.py (update as needed). +WorkingDirectory=/home/lovinator/ANewDawn/ + +# Load environment variables (see systemd/anewdawn.env.example). +EnvironmentFile=/etc/ANewDawn/ANewDawn.env + +# Use the python interpreter from your environment (system python is fine if dependencies are installed). +ExecStart=/usr/bin/uv run main.py + +Restart=on-failure +RestartSec=5 + +StandardOutput=journal +StandardError=journal + +[Install] +WantedBy=multi-user.target diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_reset_undo.py b/tests/test_reset_undo.py new file mode 100644 index 0000000..1c82d47 --- /dev/null +++ b/tests/test_reset_undo.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import pytest + +from main import add_message_to_memory +from main import last_trigger_time +from main import recent_messages +from main import reset_memory +from main import reset_snapshots +from main import undo_reset +from main import update_trigger_time + + +@pytest.fixture(autouse=True) +def clear_state() -> None: + """Clear all state before each test.""" + recent_messages.clear() + last_trigger_time.clear() + reset_snapshots.clear() + + +class TestResetMemory: + """Tests for the reset_memory function.""" + + def test_reset_memory_clears_messages(self) -> None: + """Test that reset_memory clears messages for the channel.""" + channel_id = "test_channel_123" + add_message_to_memory(channel_id, "user1", "Hello") + add_message_to_memory(channel_id, "user2", "World") + + assert channel_id in recent_messages + assert len(recent_messages[channel_id]) == 2 + + reset_memory(channel_id) + + assert channel_id not in recent_messages + + def test_reset_memory_clears_trigger_times(self) -> None: + """Test that reset_memory clears trigger times for the channel.""" + channel_id = "test_channel_123" + update_trigger_time(channel_id, "user1") + + assert channel_id in last_trigger_time + + reset_memory(channel_id) + + assert channel_id not in last_trigger_time + + def test_reset_memory_creates_snapshot(self) -> None: + """Test that reset_memory creates a snapshot for undo.""" + channel_id = "test_channel_123" + add_message_to_memory(channel_id, "user1", "Test message") + update_trigger_time(channel_id, "user1") + + reset_memory(channel_id) + + assert channel_id in reset_snapshots + messages_snapshot, trigger_snapshot = reset_snapshots[channel_id] + assert len(messages_snapshot) == 1 + assert "user1" in trigger_snapshot + + def test_reset_memory_no_snapshot_for_empty_channel(self) -> None: + """Test that reset_memory doesn't create snapshot for empty channel.""" + channel_id = "empty_channel" + + reset_memory(channel_id) + + assert channel_id not in reset_snapshots + + +class TestUndoReset: + """Tests for the undo_reset function.""" + + def test_undo_reset_restores_messages(self) -> None: + """Test that undo_reset restores messages.""" + channel_id = "test_channel_123" + add_message_to_memory(channel_id, "user1", "Hello") + add_message_to_memory(channel_id, "user2", "World") + + reset_memory(channel_id) + assert channel_id not in recent_messages + + result = undo_reset(channel_id) + + assert result is True + assert channel_id in recent_messages + assert len(recent_messages[channel_id]) == 2 + + def test_undo_reset_restores_trigger_times(self) -> None: + """Test that undo_reset restores trigger times.""" + channel_id = "test_channel_123" + update_trigger_time(channel_id, "user1") + original_time = last_trigger_time[channel_id]["user1"] + + reset_memory(channel_id) + assert channel_id not in last_trigger_time + + result = undo_reset(channel_id) + + assert result is True + assert channel_id in last_trigger_time + assert last_trigger_time[channel_id]["user1"] == original_time + + def test_undo_reset_removes_snapshot(self) -> None: + """Test that undo_reset removes the snapshot after restoring.""" + channel_id = "test_channel_123" + add_message_to_memory(channel_id, "user1", "Hello") + + reset_memory(channel_id) + assert channel_id in reset_snapshots + + undo_reset(channel_id) + + assert channel_id not in reset_snapshots + + def test_undo_reset_returns_false_when_no_snapshot(self) -> None: + """Test that undo_reset returns False when no snapshot exists.""" + channel_id = "nonexistent_channel" + + result = undo_reset(channel_id) + + assert result is False + + def test_undo_reset_only_works_once(self) -> None: + """Test that undo_reset only works once (snapshot is removed after undo).""" + channel_id = "test_channel_123" + add_message_to_memory(channel_id, "user1", "Hello") + + reset_memory(channel_id) + first_undo = undo_reset(channel_id) + second_undo = undo_reset(channel_id) + + assert first_undo is True + assert second_undo is False + + +class TestResetUndoIntegration: + """Integration tests for reset and undo functionality.""" + + def test_reset_then_undo_preserves_content(self) -> None: + """Test that reset followed by undo preserves original content.""" + channel_id = "test_channel_123" + add_message_to_memory(channel_id, "user1", "Message 1") + add_message_to_memory(channel_id, "user2", "Message 2") + add_message_to_memory(channel_id, "user3", "Message 3") + update_trigger_time(channel_id, "user1") + update_trigger_time(channel_id, "user2") + + # Capture original state + original_messages = list(recent_messages[channel_id]) + original_trigger_users = set(last_trigger_time[channel_id].keys()) + + reset_memory(channel_id) + undo_reset(channel_id) + + # Verify restored state matches original + restored_messages = list(recent_messages[channel_id]) + restored_trigger_users = set(last_trigger_time[channel_id].keys()) + + assert len(restored_messages) == len(original_messages) + assert restored_trigger_users == original_trigger_users + + def test_multiple_resets_overwrite_snapshot(self) -> None: + """Test that multiple resets overwrite the previous snapshot.""" + channel_id = "test_channel_123" + + # First set of messages + add_message_to_memory(channel_id, "user1", "First message") + reset_memory(channel_id) + + # Second set of messages + add_message_to_memory(channel_id, "user1", "Second message") + add_message_to_memory(channel_id, "user1", "Third message") + reset_memory(channel_id) + + # Undo should restore the second set, not the first + undo_reset(channel_id) + + assert channel_id in recent_messages + assert len(recent_messages[channel_id]) == 2 + + def test_different_channels_independent_undo(self) -> None: + """Test that different channels have independent undo functionality.""" + channel_1 = "channel_1" + channel_2 = "channel_2" + + add_message_to_memory(channel_1, "user1", "Channel 1 message") + add_message_to_memory(channel_2, "user2", "Channel 2 message") + + reset_memory(channel_1) + reset_memory(channel_2) + + # Undo only channel 1 + undo_reset(channel_1) + + assert channel_1 in recent_messages + assert channel_2 not in recent_messages + assert channel_1 not in reset_snapshots + assert channel_2 in reset_snapshots