Compare commits
No commits in common. "master" and "bd0c66e0fd06e91f3d25a665702da0fea6bc1d86" have entirely different histories.
master
...
bd0c66e0fd
20 changed files with 439 additions and 1523 deletions
|
|
@ -1,3 +1,2 @@
|
||||||
DISCORD_TOKEN=
|
DISCORD_TOKEN=
|
||||||
OPENAI_TOKEN=
|
OPENAI_TOKEN=
|
||||||
OLLAMA_API_KEY=
|
|
||||||
|
|
|
||||||
19
.gitea/workflows/docker-check.yml
Normal file
19
.gitea/workflows/docker-check.yml
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
name: Docker Build Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
paths:
|
||||||
|
- 'Dockerfile'
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'Dockerfile'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
docker-check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Run Docker Build Check
|
||||||
|
run: docker build --check .
|
||||||
67
.gitea/workflows/docker-publish.yml
Normal file
67
.gitea/workflows/docker-publish.yml
Normal file
|
|
@ -0,0 +1,67 @@
|
||||||
|
name: Build Docker Image
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
pull_request:
|
||||||
|
workflow_dispatch:
|
||||||
|
schedule:
|
||||||
|
- cron: "@daily"
|
||||||
|
|
||||||
|
cache:
|
||||||
|
enabled: true
|
||||||
|
dir: ""
|
||||||
|
host: "192.168.1.127"
|
||||||
|
port: 8088
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
docker:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
DISCORD_TOKEN: "0"
|
||||||
|
OPENAI_TOKEN: "0"
|
||||||
|
if: gitea.event_name != 'pull_request'
|
||||||
|
steps:
|
||||||
|
- uses: https://github.com/actions/checkout@v4
|
||||||
|
- uses: https://github.com/docker/setup-qemu-action@v3
|
||||||
|
- uses: https://github.com/docker/setup-buildx-action@v3
|
||||||
|
- uses: https://github.com/astral-sh/ruff-action@v3
|
||||||
|
|
||||||
|
- run: docker build --check .
|
||||||
|
- run: ruff check --exit-non-zero-on-fix --verbose
|
||||||
|
- run: ruff format --check --verbose
|
||||||
|
|
||||||
|
- id: meta
|
||||||
|
uses: https://github.com/docker/metadata-action@v5
|
||||||
|
env:
|
||||||
|
DOCKER_METADATA_ANNOTATIONS_LEVELS: manifest,index
|
||||||
|
with:
|
||||||
|
images: |
|
||||||
|
ghcr.io/thelovinator1/anewdawn
|
||||||
|
git.lovinator.space/thelovinator/anewdawn
|
||||||
|
tags: type=raw,value=latest,enable=${{ gitea.ref == format('refs/heads/{0}', 'master') }}
|
||||||
|
|
||||||
|
# GitHub Container Registry
|
||||||
|
- uses: https://github.com/docker/login-action@v3
|
||||||
|
if: github.event_name != 'pull_request'
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: thelovinator1
|
||||||
|
password: ${{ secrets.PACKAGES_WRITE_GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
# Gitea Container Registry
|
||||||
|
- uses: https://github.com/docker/login-action@v3
|
||||||
|
if: github.event_name != 'pull_request'
|
||||||
|
with:
|
||||||
|
registry: git.lovinator.space
|
||||||
|
username: thelovinator
|
||||||
|
password: ${{ secrets.GITEA_TOKEN }}
|
||||||
|
|
||||||
|
- uses: https://github.com/docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
push: ${{ gitea.event_name != 'pull_request' }}
|
||||||
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
|
annotations: ${{ steps.meta.outputs.annotations }}
|
||||||
19
.gitea/workflows/ruff.yml
Normal file
19
.gitea/workflows/ruff.yml
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
name: Ruff
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
pull_request:
|
||||||
|
workflow_dispatch:
|
||||||
|
schedule:
|
||||||
|
- cron: '0 0 * * *' # Run every day at midnight
|
||||||
|
|
||||||
|
env:
|
||||||
|
RUFF_OUTPUT_FORMAT: github
|
||||||
|
jobs:
|
||||||
|
ruff:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: astral-sh/ruff-action@v3
|
||||||
|
- run: ruff check --exit-non-zero-on-fix --verbose
|
||||||
|
- run: ruff format --check --verbose
|
||||||
129
.github/copilot-instructions.md
vendored
129
.github/copilot-instructions.md
vendored
|
|
@ -1,107 +1,36 @@
|
||||||
# Copilot Instructions for ANewDawn
|
# Custom Instructions for GitHub Copilot
|
||||||
|
|
||||||
## Project Overview
|
## Project Overview
|
||||||
|
This is a Python project named ANewDawn. It uses Docker for containerization (`Dockerfile`, `docker-compose.yml`). Key files include `main.py` and `settings.py`.
|
||||||
ANewDawn is a Discord bot written in Python 3.13+ using the discord.py library and Pydantic AI for AI-powered chat capabilities. The bot includes features such as:
|
|
||||||
|
|
||||||
- AI-powered chat responses using OpenAI models
|
|
||||||
- Conversation memory with reset/undo functionality
|
|
||||||
- Image enhancement using OpenCV
|
|
||||||
- Web search integration via Ollama
|
|
||||||
- Slash commands and context menus
|
|
||||||
|
|
||||||
## Development Environment
|
## Development Environment
|
||||||
|
- **Operating System:** Windows
|
||||||
|
- **Default Shell:** PowerShell (`pwsh.exe`). Please generate terminal commands compatible with PowerShell.
|
||||||
|
|
||||||
- **Python**: 3.13 or higher required
|
## Coding Standards
|
||||||
- **Package Manager**: Use `uv` for dependency management (see `pyproject.toml`)
|
- **Linting & Formatting:** We use `ruff` for linting and formatting. Adhere to `ruff` standards. Configuration is in `.github/workflows/ruff.yml` and possibly `pyproject.toml` or `ruff.toml`.
|
||||||
- **Deployment**: The project is designed to run as a systemd service (see `systemd/anewdawn.service`)
|
- **Python Version:** 3.13
|
||||||
- **Environment Variables**: Copy `.env.example` to `.env` and fill in required tokens
|
- **Dependencies:** Managed using `uv` and listed in `pyproject.toml`. Commands include:
|
||||||
|
- `uv run pytest` for testing.
|
||||||
|
- `uv add <package_name>` for package installation.
|
||||||
|
- `uv sync --upgrade` for dependency updates.
|
||||||
|
- `uv run python main.py` to run the project.
|
||||||
|
|
||||||
## Code Style and Conventions
|
## General Guidelines
|
||||||
|
- Follow Python best practices.
|
||||||
|
- Write clear, concise code.
|
||||||
|
- Add comments only for complex logic.
|
||||||
|
- Ensure compatibility with the Docker environment.
|
||||||
|
- Use `uv` commands for package management and scripts.
|
||||||
|
- Use `docker` and `docker-compose` for container tasks:
|
||||||
|
- Build: `docker build -t <image_name> .`
|
||||||
|
- Run: `docker run <image_name>` or `docker-compose up`.
|
||||||
|
- Stop/Remove: `docker stop <container_id>` and `docker rm <container_id>`.
|
||||||
|
|
||||||
### Linting and Formatting
|
## Discord Bot Functionality
|
||||||
|
- **Chat Interaction:** Responds to messages containing "lovibot" or its mention (`<@345000831499894795>`) using the OpenAI chat API (`gpt-4o-mini`). See `on_message` event handler and `misc.chat` function.
|
||||||
This project uses **Ruff** for linting and formatting with strict settings:
|
- **Slash Commands:**
|
||||||
|
- `/ask <text>`: Directly ask the AI a question. Uses `misc.chat`.
|
||||||
- All rules enabled (`lint.select = ["ALL"]`)
|
- **Context Menu Commands:**
|
||||||
- Preview features enabled
|
- `Enhance Image`: Right-click on a message with an image to enhance it using OpenCV methods (`enhance_image1`, `enhance_image2`, `enhance_image3`).
|
||||||
- Auto-fix enabled
|
- **User Restrictions:** Interaction is limited to users listed in `misc.get_allowed_users()`. Image creation has additional restrictions.
|
||||||
- Line length: 160 characters
|
|
||||||
- Google-style docstrings required
|
|
||||||
|
|
||||||
Run linting:
|
|
||||||
```bash
|
|
||||||
ruff check --exit-non-zero-on-fix --verbose
|
|
||||||
```
|
|
||||||
|
|
||||||
Run formatting check:
|
|
||||||
```bash
|
|
||||||
ruff format --check --verbose
|
|
||||||
```
|
|
||||||
|
|
||||||
### Python Conventions
|
|
||||||
|
|
||||||
- Use `from __future__ import annotations` at the top of all files (automatically added by Ruff)
|
|
||||||
- Use type hints for all function parameters and return types
|
|
||||||
- Follow Google docstring convention
|
|
||||||
- Use `logging` module for logging, not print statements
|
|
||||||
- Prefer explicit imports over wildcard imports
|
|
||||||
|
|
||||||
### Testing
|
|
||||||
|
|
||||||
- Tests use pytest
|
|
||||||
- Test files should be named `*_test.py` or `test_*.py`
|
|
||||||
- Run tests with: `pytest`
|
|
||||||
|
|
||||||
## Project Structure
|
|
||||||
|
|
||||||
- `main.py` - Main bot application with all commands and event handlers
|
|
||||||
- `pyproject.toml` - Project configuration and dependencies
|
|
||||||
- `systemd/` - systemd unit and environment templates
|
|
||||||
- `.github/workflows/` - CI/CD workflows
|
|
||||||
|
|
||||||
## Key Components
|
|
||||||
|
|
||||||
### Bot Client
|
|
||||||
|
|
||||||
The main bot client is `LoviBotClient` which extends `discord.Client`. It handles:
|
|
||||||
- Message events (`on_message`)
|
|
||||||
- Slash commands (`/ask`, `/reset`, `/undo`)
|
|
||||||
- Context menus (image enhancement)
|
|
||||||
|
|
||||||
### AI Integration
|
|
||||||
|
|
||||||
- `chatgpt_agent` - Pydantic AI agent using OpenAI
|
|
||||||
- Message history is stored in `recent_messages` dict per channel
|
|
||||||
|
|
||||||
### Memory Management
|
|
||||||
|
|
||||||
- `add_message_to_memory()` - Store messages for context
|
|
||||||
- `reset_memory()` - Clear conversation history
|
|
||||||
- `undo_reset()` - Restore previous state
|
|
||||||
|
|
||||||
## CI/CD
|
|
||||||
|
|
||||||
The GitHub Actions workflow (`.github/workflows/ci.yml`) runs:
|
|
||||||
1. Dependency install via `uv sync`
|
|
||||||
2. Ruff linting and format check
|
|
||||||
3. Unit tests via `pytest`
|
|
||||||
|
|
||||||
## Common Tasks
|
|
||||||
|
|
||||||
### Adding a New Slash Command
|
|
||||||
|
|
||||||
1. Add the command function with `@client.tree.command()` decorator
|
|
||||||
2. Include `@app_commands.allowed_installs()` and `@app_commands.allowed_contexts()` decorators
|
|
||||||
3. Use `await interaction.response.defer()` for long-running operations
|
|
||||||
4. Check user authorization with `get_allowed_users()`
|
|
||||||
|
|
||||||
### Adding a New AI Instruction
|
|
||||||
|
|
||||||
1. Create a function decorated with `@chatgpt_agent.instructions`
|
|
||||||
2. The function should return a string with the instruction content
|
|
||||||
3. Use `RunContext[BotDependencies]` parameter to access dependencies
|
|
||||||
|
|
||||||
### Modifying Image Enhancement
|
|
||||||
|
|
||||||
Image enhancement functions (`enhance_image1`, `enhance_image2`, `enhance_image3`) use OpenCV. Each returns WebP-encoded bytes.
|
|
||||||
|
|
|
||||||
43
.github/workflows/ci.yml
vendored
43
.github/workflows/ci.yml
vendored
|
|
@ -1,43 +0,0 @@
|
||||||
name: CI
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
pull_request:
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
ci:
|
|
||||||
runs-on: self-hosted
|
|
||||||
env:
|
|
||||||
DISCORD_TOKEN: "0"
|
|
||||||
OPENAI_TOKEN: "0"
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: uv sync --all-extras --dev -U
|
|
||||||
|
|
||||||
- name: Lint the Python code using ruff
|
|
||||||
run: ruff check --exit-non-zero-on-fix --verbose
|
|
||||||
|
|
||||||
- name: Check formatting
|
|
||||||
run: ruff format --check --verbose
|
|
||||||
|
|
||||||
- name: Run tests
|
|
||||||
run: uv run pytest
|
|
||||||
|
|
||||||
# NOTE: The runner must be allowed to run these commands without a password.
|
|
||||||
# sudo EDITOR=nvim visudo
|
|
||||||
# forgejo-runner ALL=(lovinator) NOPASSWD: /usr/bin/git -C /home/lovinator/ANewDawn pull
|
|
||||||
# forgejo-runner ALL=(root) NOPASSWD: /bin/systemctl restart anewdawn.service
|
|
||||||
# forgejo-runner ALL=(lovinator) NOPASSWD: /usr/bin/uv sync -U --all-extras --dev --directory /home/lovinator/ANewDawn
|
|
||||||
- name: Deploy & restart bot (master only)
|
|
||||||
if: ${{ success() && github.ref == 'refs/heads/master' }}
|
|
||||||
run: |
|
|
||||||
# Keep checkout in the Forgejo runner workspace, whatever that is.
|
|
||||||
# actions/checkout already checks out to the runner's working directory.
|
|
||||||
|
|
||||||
sudo -u lovinator git -C /home/lovinator/ANewDawn pull
|
|
||||||
sudo -u lovinator uv sync -U --all-extras --dev --directory /home/lovinator/ANewDawn
|
|
||||||
sudo systemctl restart anewdawn.service
|
|
||||||
19
.github/workflows/docker-check.yml
vendored
Normal file
19
.github/workflows/docker-check.yml
vendored
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
name: Docker Build Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
paths:
|
||||||
|
- 'Dockerfile'
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'Dockerfile'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
docker-check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Run Docker Build Check
|
||||||
|
run: docker build --check .
|
||||||
|
|
@ -1,39 +1,41 @@
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/asottile/add-trailing-comma
|
- repo: https://github.com/asottile/add-trailing-comma
|
||||||
rev: v4.0.0
|
rev: v3.1.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: add-trailing-comma
|
- id: add-trailing-comma
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v6.0.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
|
- id: check-added-large-files
|
||||||
- id: check-ast
|
- id: check-ast
|
||||||
- id: check-builtin-literals
|
- id: check-builtin-literals
|
||||||
- id: check-docstring-first
|
|
||||||
- id: check-executables-have-shebangs
|
- id: check-executables-have-shebangs
|
||||||
- id: check-merge-conflict
|
- id: check-merge-conflict
|
||||||
|
- id: check-shebang-scripts-are-executable
|
||||||
- id: check-toml
|
- id: check-toml
|
||||||
- id: check-vcs-permalinks
|
- id: check-vcs-permalinks
|
||||||
|
- id: check-yaml
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
- id: mixed-line-ending
|
- id: mixed-line-ending
|
||||||
- id: name-tests-test
|
- id: name-tests-test
|
||||||
args: [--pytest-test-first]
|
args: ["--pytest-test-first"]
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
||||||
rev: v0.15.6
|
|
||||||
hooks:
|
|
||||||
- id: ruff-check
|
|
||||||
args: ["--fix", "--exit-non-zero-on-fix"]
|
|
||||||
- id: ruff-format
|
|
||||||
|
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
rev: v3.21.2
|
rev: v3.19.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
args: ["--py311-plus"]
|
args: ["--py311-plus"]
|
||||||
|
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.11.5
|
||||||
|
hooks:
|
||||||
|
- id: ruff-format
|
||||||
|
- id: ruff
|
||||||
|
args: ["--fix", "--exit-non-zero-on-fix"]
|
||||||
|
|
||||||
- repo: https://github.com/rhysd/actionlint
|
- repo: https://github.com/rhysd/actionlint
|
||||||
rev: v1.7.11
|
rev: v1.7.7
|
||||||
hooks:
|
hooks:
|
||||||
- id: actionlint
|
- id: actionlint
|
||||||
|
|
|
||||||
23
.vscode/settings.json
vendored
23
.vscode/settings.json
vendored
|
|
@ -5,56 +5,35 @@
|
||||||
"audioop",
|
"audioop",
|
||||||
"automerge",
|
"automerge",
|
||||||
"buildx",
|
"buildx",
|
||||||
"CLAHE",
|
|
||||||
"Denoise",
|
|
||||||
"denoising",
|
"denoising",
|
||||||
"docstrings",
|
"docstrings",
|
||||||
"dotenv",
|
"dotenv",
|
||||||
"etherlithium",
|
|
||||||
"Femboy",
|
|
||||||
"forgefilip",
|
"forgefilip",
|
||||||
"forgor",
|
"forgor",
|
||||||
"Fredagsmys",
|
|
||||||
"Frieren",
|
|
||||||
"frombuffer",
|
"frombuffer",
|
||||||
"hikari",
|
"hikari",
|
||||||
"imdecode",
|
"imdecode",
|
||||||
"imencode",
|
"imencode",
|
||||||
"IMREAD",
|
"IMREAD",
|
||||||
"IMWRITE",
|
|
||||||
"isort",
|
"isort",
|
||||||
"killyoy",
|
"killyoy",
|
||||||
"levelname",
|
"levelname",
|
||||||
"Licka",
|
|
||||||
"Lördagsgodis",
|
|
||||||
"lovibot",
|
"lovibot",
|
||||||
"Lovinator",
|
"Lovinator",
|
||||||
"Messageable",
|
|
||||||
"mountpoint",
|
|
||||||
"ndarray",
|
"ndarray",
|
||||||
"nobot",
|
"nobot",
|
||||||
"nparr",
|
"nparr",
|
||||||
"numpy",
|
"numpy",
|
||||||
"Ollama",
|
|
||||||
"opencv",
|
"opencv",
|
||||||
"percpu",
|
|
||||||
"phibiscarf",
|
|
||||||
"plubplub",
|
"plubplub",
|
||||||
"pycodestyle",
|
"pycodestyle",
|
||||||
"pydocstyle",
|
"pydocstyle",
|
||||||
"pyproject",
|
"pyproject",
|
||||||
"PYTHONDONTWRITEBYTECODE",
|
"PYTHONDONTWRITEBYTECODE",
|
||||||
"PYTHONUNBUFFERED",
|
"PYTHONUNBUFFERED",
|
||||||
"Slowmode",
|
|
||||||
"Sniffa",
|
|
||||||
"sweary",
|
|
||||||
"testpaths",
|
"testpaths",
|
||||||
"thelovinator",
|
"thelovinator",
|
||||||
"Thicc",
|
|
||||||
"tobytes",
|
"tobytes",
|
||||||
"twimg",
|
"unsignedinteger"
|
||||||
"unsignedinteger",
|
|
||||||
"Waifu",
|
|
||||||
"Zenless"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
21
Dockerfile
Normal file
21
Dockerfile
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
# syntax=docker/dockerfile:1
|
||||||
|
# check=error=true;experimental=all
|
||||||
|
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim@sha256:73c021c3fe7264924877039e8a449ad3bb380ec89214282301affa9b2f863c5d
|
||||||
|
|
||||||
|
# Change the working directory to the `app` directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||||
|
uv sync --no-install-project
|
||||||
|
|
||||||
|
# Copy the application files
|
||||||
|
COPY main.py misc.py settings.py /app/
|
||||||
|
|
||||||
|
# Set the environment variables
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1
|
||||||
|
|
||||||
|
# Run the application
|
||||||
|
CMD ["uv", "run", "main.py"]
|
||||||
27
README.md
27
README.md
|
|
@ -5,30 +5,3 @@
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
A shit Discord bot.
|
A shit Discord bot.
|
||||||
|
|
||||||
## Running via systemd
|
|
||||||
|
|
||||||
This repo includes a systemd unit template under `systemd/anewdawn.service` that can be used to run the bot as a service.
|
|
||||||
|
|
||||||
### Quick setup
|
|
||||||
|
|
||||||
1. Copy and edit the environment file:
|
|
||||||
```sh
|
|
||||||
sudo mkdir -p /etc/ANewDawn
|
|
||||||
sudo cp systemd/anewdawn.env.example /etc/ANewDawn/ANewDawn.env
|
|
||||||
sudo chown -R lovinator:lovinator /etc/ANewDawn
|
|
||||||
# Edit /etc/ANewDawn/ANewDawn.env and fill in your tokens.
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Install the systemd unit:
|
|
||||||
```sh
|
|
||||||
sudo cp systemd/anewdawn.service /etc/systemd/system/
|
|
||||||
sudo systemctl daemon-reload
|
|
||||||
sudo systemctl enable --now anewdawn.service
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Check status / logs:
|
|
||||||
```sh
|
|
||||||
sudo systemctl status anewdawn.service
|
|
||||||
sudo journalctl -u anewdawn.service -f
|
|
||||||
```
|
|
||||||
|
|
|
||||||
9
docker-compose.yml
Normal file
9
docker-compose.yml
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
services:
|
||||||
|
anewdawn:
|
||||||
|
image: ghcr.io/thelovinator1/anewdawn:latest
|
||||||
|
container_name: anewdawn
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
- DISCORD_TOKEN=${DISCORD_TOKEN}
|
||||||
|
- OPENAI_TOKEN=${OPENAI_TOKEN}
|
||||||
|
restart: unless-stopped
|
||||||
1200
main.py
1200
main.py
|
|
@ -1,57 +1,22 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import datetime
|
import datetime
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
from collections import deque
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Literal
|
|
||||||
from typing import Self
|
|
||||||
from typing import TypeVar
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import discord
|
import discord
|
||||||
import httpx
|
import httpx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ollama
|
|
||||||
import openai
|
import openai
|
||||||
import psutil
|
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
from discord import Forbidden
|
|
||||||
from discord import HTTPException
|
|
||||||
from discord import Member
|
|
||||||
from discord import NotFound
|
|
||||||
from discord import app_commands
|
from discord import app_commands
|
||||||
from dotenv import load_dotenv
|
from openai import OpenAI
|
||||||
from pydantic_ai import Agent
|
|
||||||
from pydantic_ai import ImageUrl
|
|
||||||
from pydantic_ai.messages import ModelRequest
|
|
||||||
from pydantic_ai.messages import ModelResponse
|
|
||||||
from pydantic_ai.messages import TextPart
|
|
||||||
from pydantic_ai.messages import UserPromptPart
|
|
||||||
from pydantic_ai.models.openai import OpenAIResponsesModelSettings
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
from misc import chat, get_allowed_users
|
||||||
from collections.abc import Callable
|
from settings import Settings
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
from discord import Emoji
|
|
||||||
from discord import Guild
|
|
||||||
from discord import GuildSticker
|
|
||||||
from discord import User
|
|
||||||
from discord.abc import Messageable as DiscordMessageable
|
|
||||||
from discord.abc import MessageableChannel
|
|
||||||
from discord.guild import GuildChannel
|
|
||||||
from discord.interactions import InteractionChannel
|
|
||||||
from pydantic_ai import RunContext
|
|
||||||
from pydantic_ai.run import AgentRunResult
|
|
||||||
|
|
||||||
load_dotenv(verbose=True)
|
|
||||||
|
|
||||||
sentry_sdk.init(
|
sentry_sdk.init(
|
||||||
dsn="https://ebbd2cdfbd08dba008d628dad7941091@o4505228040339456.ingest.us.sentry.io/4507630719401984",
|
dsn="https://ebbd2cdfbd08dba008d628dad7941091@o4505228040339456.ingest.us.sentry.io/4507630719401984",
|
||||||
|
|
@ -62,683 +27,14 @@ sentry_sdk.init(
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
settings: Settings = Settings.from_env()
|
||||||
|
discord_token: str = settings.discord_token
|
||||||
|
openai_api_key: str = settings.openai_api_key
|
||||||
|
|
||||||
discord_token: str = os.getenv("DISCORD_TOKEN", "")
|
|
||||||
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_TOKEN", "")
|
|
||||||
|
|
||||||
recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {}
|
openai_client = OpenAI(api_key=openai_api_key)
|
||||||
last_trigger_time: dict[str, dict[str, datetime.datetime]] = {}
|
|
||||||
|
|
||||||
# Storage for reset snapshots to enable undo functionality
|
|
||||||
reset_snapshots: dict[
|
|
||||||
str,
|
|
||||||
tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]],
|
|
||||||
] = {}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BotDependencies:
|
|
||||||
"""Dependencies for the Pydantic AI agent."""
|
|
||||||
|
|
||||||
client: discord.Client
|
|
||||||
current_channel: MessageableChannel | InteractionChannel | None
|
|
||||||
user: User | Member
|
|
||||||
allowed_users: list[str]
|
|
||||||
all_channels_in_guild: Sequence[GuildChannel] | None = None
|
|
||||||
web_search_results: ollama.WebSearchResponse | None = None
|
|
||||||
|
|
||||||
|
|
||||||
openai_settings: OpenAIResponsesModelSettings = OpenAIResponsesModelSettings(
|
|
||||||
openai_text_verbosity="low",
|
|
||||||
)
|
|
||||||
chatgpt_agent: Agent[BotDependencies, str] = Agent(
|
|
||||||
model="openai:gpt-5-chat-latest",
|
|
||||||
deps_type=BotDependencies,
|
|
||||||
model_settings=openai_settings,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: reset_memory
|
|
||||||
def reset_memory(channel_id: str) -> None:
|
|
||||||
"""Reset the conversation memory for a specific channel.
|
|
||||||
|
|
||||||
Creates a snapshot of the current state before resetting to enable undo.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_id (str): The ID of the channel to reset memory for.
|
|
||||||
"""
|
|
||||||
# Create snapshot before reset for undo functionality
|
|
||||||
messages_snapshot: deque[tuple[str, str, datetime.datetime]] = (
|
|
||||||
deque(recent_messages[channel_id], maxlen=50)
|
|
||||||
if channel_id in recent_messages
|
|
||||||
else deque(maxlen=50)
|
|
||||||
)
|
|
||||||
|
|
||||||
trigger_snapshot: dict[str, datetime.datetime] = (
|
|
||||||
dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only save snapshot if there's something to restore
|
|
||||||
if messages_snapshot or trigger_snapshot:
|
|
||||||
reset_snapshots[channel_id] = (messages_snapshot, trigger_snapshot)
|
|
||||||
logger.info("Created reset snapshot for channel %s", channel_id)
|
|
||||||
|
|
||||||
# Perform the actual reset
|
|
||||||
if channel_id in recent_messages:
|
|
||||||
del recent_messages[channel_id]
|
|
||||||
logger.info("Reset memory for channel %s", channel_id)
|
|
||||||
if channel_id in last_trigger_time:
|
|
||||||
del last_trigger_time[channel_id]
|
|
||||||
logger.info("Reset trigger times for channel %s", channel_id)
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: undo_reset
|
|
||||||
def undo_reset(channel_id: str) -> bool:
|
|
||||||
"""Undo the last reset operation for a specific channel.
|
|
||||||
|
|
||||||
Restores the conversation memory from the saved snapshot.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_id (str): The ID of the channel to undo reset for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if undo was successful, False if no snapshot exists.
|
|
||||||
"""
|
|
||||||
if channel_id not in reset_snapshots:
|
|
||||||
logger.info("No reset snapshot found for channel %s", channel_id)
|
|
||||||
return False
|
|
||||||
|
|
||||||
messages_snapshot, trigger_snapshot = reset_snapshots[channel_id]
|
|
||||||
|
|
||||||
# Restore recent messages
|
|
||||||
if messages_snapshot:
|
|
||||||
recent_messages[channel_id] = messages_snapshot
|
|
||||||
logger.info("Restored messages for channel %s", channel_id)
|
|
||||||
|
|
||||||
# Restore trigger times
|
|
||||||
if trigger_snapshot:
|
|
||||||
last_trigger_time[channel_id] = trigger_snapshot
|
|
||||||
logger.info("Restored trigger times for channel %s", channel_id)
|
|
||||||
|
|
||||||
# Remove the snapshot after successful undo (only one undo allowed)
|
|
||||||
del reset_snapshots[channel_id]
|
|
||||||
logger.info("Removed reset snapshot for channel %s after undo", channel_id)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _message_text_length(msg: ModelRequest | ModelResponse) -> int:
|
|
||||||
"""Compute the total text length of all text parts in a message.
|
|
||||||
|
|
||||||
This ignores non-text parts such as images.
|
|
||||||
Safe for our usage where history only has text.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The total number of characters across text parts in the message.
|
|
||||||
"""
|
|
||||||
length: int = 0
|
|
||||||
for part in msg.parts:
|
|
||||||
if isinstance(part, (TextPart, UserPromptPart)):
|
|
||||||
# part.content is a string for text parts
|
|
||||||
length += len(getattr(part, "content", "") or "")
|
|
||||||
return length
|
|
||||||
|
|
||||||
|
|
||||||
def compact_message_history(
|
|
||||||
history: list[ModelRequest | ModelResponse],
|
|
||||||
*,
|
|
||||||
max_chars: int = 12000,
|
|
||||||
min_messages: int = 4,
|
|
||||||
) -> list[ModelRequest | ModelResponse]:
|
|
||||||
"""Return a trimmed copy of history under a character budget.
|
|
||||||
|
|
||||||
- Keeps the most recent messages first, dropping oldest as needed.
|
|
||||||
- Ensures at least `min_messages` are kept even if they exceed the budget.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A possibly shortened list of messages that fits within the character budget.
|
|
||||||
"""
|
|
||||||
if not history:
|
|
||||||
return history
|
|
||||||
|
|
||||||
kept: list[ModelRequest | ModelResponse] = []
|
|
||||||
running: int = 0
|
|
||||||
for msg in reversed(history):
|
|
||||||
msg_len: int = _message_text_length(msg)
|
|
||||||
if running + msg_len <= max_chars or len(kept) < min_messages:
|
|
||||||
kept.append(msg)
|
|
||||||
running += msg_len
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
kept.reverse()
|
|
||||||
return kept
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: fetch_user_info
|
|
||||||
@chatgpt_agent.instructions
|
|
||||||
def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
|
|
||||||
"""Fetches detailed information about the user who sent the message.
|
|
||||||
|
|
||||||
Includes their roles, status, and activity.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string representation of the user's details.
|
|
||||||
"""
|
|
||||||
user: User | Member = ctx.deps.user
|
|
||||||
details: dict[str, Any] = {"name": user.name, "id": user.id}
|
|
||||||
if isinstance(user, Member):
|
|
||||||
details.update({
|
|
||||||
"roles": [role.name for role in user.roles],
|
|
||||||
"status": str(user.status),
|
|
||||||
"on_mobile": user.is_on_mobile(),
|
|
||||||
"joined_at": user.joined_at.isoformat() if user.joined_at else None,
|
|
||||||
"activity": str(user.activity),
|
|
||||||
})
|
|
||||||
return str(details)
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_system_performance_stats
|
|
||||||
@chatgpt_agent.instructions
|
|
||||||
def get_system_performance_stats() -> str:
|
|
||||||
"""Retrieves system performance metrics, including CPU, memory, and disk usage.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string representation of the system performance statistics.
|
|
||||||
"""
|
|
||||||
cpu_percent_per_core: list[float] = psutil.cpu_percent(percpu=True)
|
|
||||||
virtual_memory_percent: float = psutil.virtual_memory().percent
|
|
||||||
swap_memory_percent: float = psutil.swap_memory().percent
|
|
||||||
rss_mb: float = psutil.Process().memory_info().rss / (1024 * 1024)
|
|
||||||
|
|
||||||
stats: dict[str, str] = {
|
|
||||||
"cpu_percent_per_core": f"{cpu_percent_per_core}%",
|
|
||||||
"virtual_memory_percent": f"{virtual_memory_percent}%",
|
|
||||||
"swap_memory_percent": f"{swap_memory_percent}%",
|
|
||||||
"bot_memory_rss_mb": f"{rss_mb:.2f} MB",
|
|
||||||
}
|
|
||||||
return str(stats)
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_channels
|
|
||||||
@chatgpt_agent.instructions
|
|
||||||
def get_channels(ctx: RunContext[BotDependencies]) -> str:
|
|
||||||
"""Retrieves a list of all channels the bot is currently in.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ctx (RunContext[BotDependencies]): The context for the current run.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A string listing all channels the bot is in.
|
|
||||||
"""
|
|
||||||
context = "The bot is in the following channels:\n"
|
|
||||||
if ctx.deps.all_channels_in_guild:
|
|
||||||
for c in ctx.deps.all_channels_in_guild:
|
|
||||||
context += f"{c!r}\n"
|
|
||||||
else:
|
|
||||||
context += " - No channels available.\n"
|
|
||||||
return context
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: do_web_search
|
|
||||||
def do_web_search(query: str) -> ollama.WebSearchResponse | None:
|
|
||||||
"""Perform a web search using the Ollama API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (str): The search query.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ollama.WebSearchResponse | None: The response from the search, None if an error.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response: ollama.WebSearchResponse = ollama.web_search(
|
|
||||||
query=query,
|
|
||||||
max_results=1,
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
logger.exception("OLLAMA_API_KEY environment variable is not set")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_time_and_timezone
|
|
||||||
@chatgpt_agent.instructions
|
|
||||||
def get_time_and_timezone() -> str:
|
|
||||||
"""Retrieves the current time and timezone information.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string with the current time and timezone information.
|
|
||||||
"""
|
|
||||||
current_time: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
|
|
||||||
str_time: str = current_time.strftime("%Y-%m-%d %H:%M:%S %Z")
|
|
||||||
|
|
||||||
return f"Current time: {str_time}"
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_latency
|
|
||||||
@chatgpt_agent.instructions
|
|
||||||
def get_latency(ctx: RunContext[BotDependencies]) -> str:
|
|
||||||
"""Retrieves the current latency information.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string with the current latency information.
|
|
||||||
"""
|
|
||||||
latency: float | Literal[0] = ctx.deps.client.latency if ctx.deps.client else 0
|
|
||||||
return f"Current latency: {latency} seconds"
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: added_information_from_web_search
|
|
||||||
@chatgpt_agent.instructions
|
|
||||||
def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
|
|
||||||
"""Adds information from a web search to the system prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ctx (RunContext[BotDependencies]): The context for the current run.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The updated system prompt.
|
|
||||||
"""
|
|
||||||
web_search_result: ollama.WebSearchResponse | None = ctx.deps.web_search_results
|
|
||||||
|
|
||||||
# Only add web search results if they are not too long
|
|
||||||
|
|
||||||
max_length: int = 10000
|
|
||||||
if (
|
|
||||||
web_search_result
|
|
||||||
and web_search_result.results
|
|
||||||
and len(web_search_result.results) > max_length
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"Web search results too long (%d characters), truncating to %d characters",
|
|
||||||
len(web_search_result.results),
|
|
||||||
max_length,
|
|
||||||
)
|
|
||||||
web_search_result.results = web_search_result.results[:max_length]
|
|
||||||
|
|
||||||
# Also tell the model that the results were truncated and may be incomplete
|
|
||||||
return (
|
|
||||||
f"Here is some information from a web search that might be relevant to the user's query. " # noqa: E501
|
|
||||||
f"The results were too long and have been truncated, so they may be incomplete:\n" # noqa: E501
|
|
||||||
f"```json\n{web_search_result.results}\n```\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
if web_search_result and web_search_result.results:
|
|
||||||
logger.debug("Web search results: %s", web_search_result.results)
|
|
||||||
return (
|
|
||||||
f"Here is some information from a web search that might be relevant to the user's query:\n" # noqa: E501
|
|
||||||
f"```json\n{web_search_result.results}\n```\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
return "We tried to do a web search for the user's query, but there were no results or an error occurred. You can tell them that!\n" # noqa: E501
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_sticker_instructions
|
|
||||||
@chatgpt_agent.instructions
|
|
||||||
def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str:
|
|
||||||
"""Provides instructions for using stickers in the chat.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string with sticker usage instructions.
|
|
||||||
"""
|
|
||||||
context: str = "Here are the available stickers:\n"
|
|
||||||
|
|
||||||
guilds: list[Guild] = [guild for guild in ctx.deps.client.guilds if guild]
|
|
||||||
for guild in guilds:
|
|
||||||
logger.debug("Bot is in guild: %s", guild.name)
|
|
||||||
|
|
||||||
stickers: tuple[GuildSticker, ...] = guild.stickers
|
|
||||||
if not stickers:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# Stickers
|
|
||||||
context += "Remember to only send the URL if you want to use the sticker in your message.\n" # noqa: E501
|
|
||||||
context += "Available stickers:\n"
|
|
||||||
|
|
||||||
for sticker in stickers:
|
|
||||||
sticker_url: str = sticker.url + "?size=4096"
|
|
||||||
context += f" - {sticker.name=}: {sticker_url=} - {sticker.description=} - {sticker.emoji=}\n" # noqa: E501
|
|
||||||
|
|
||||||
return (
|
|
||||||
context
|
|
||||||
+ "- Only send the sticker URL itself. Never add text to sticker combos.\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_emoji_instructions
|
|
||||||
@chatgpt_agent.instructions
|
|
||||||
def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
|
|
||||||
"""Provides instructions for using emojis in the chat.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string with emoji usage instructions.
|
|
||||||
"""
|
|
||||||
context: str = "Here are the available emojis:\n"
|
|
||||||
|
|
||||||
guilds: list[Guild] = [guild for guild in ctx.deps.client.guilds if guild]
|
|
||||||
for guild in guilds:
|
|
||||||
logger.debug("Bot is in guild: %s", guild.name)
|
|
||||||
|
|
||||||
emojis: tuple[Emoji, ...] = guild.emojis
|
|
||||||
if not emojis:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
context += "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n" # noqa: E501
|
|
||||||
context += "\nYou can use the following server emojis:\n"
|
|
||||||
for emoji in emojis:
|
|
||||||
context += f" - {emoji!s}\n"
|
|
||||||
|
|
||||||
context += (
|
|
||||||
"- Only send the emoji itself. Never add text to emoji combos.\n"
|
|
||||||
"- Don't overuse combos.\n"
|
|
||||||
"- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n" # noqa: E501
|
|
||||||
"- Combo rules:\n"
|
|
||||||
" - Rat ass (Jane Doe's ass):\n"
|
|
||||||
" ```\n"
|
|
||||||
" <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n" # noqa: E501
|
|
||||||
" <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n" # noqa: E501
|
|
||||||
" <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n" # noqa: E501
|
|
||||||
" ```\n"
|
|
||||||
" - Big kao face:\n"
|
|
||||||
" ```\n"
|
|
||||||
" <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n" # noqa: E501
|
|
||||||
" <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n" # noqa: E501
|
|
||||||
" <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n" # noqa: E501
|
|
||||||
" ```\n"
|
|
||||||
" - PhiBi scarf:\n"
|
|
||||||
" ```\n"
|
|
||||||
" <a:phibiscarf2:1050306159023759420><a:phibiscarf_mid:1050306153084637194><a:phibiscarf1:1050306156997918802>\n" # noqa: E501
|
|
||||||
" ```\n"
|
|
||||||
"- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n" # noqa: E501
|
|
||||||
)
|
|
||||||
return context
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_system_prompt
|
|
||||||
@chatgpt_agent.instructions
|
|
||||||
def get_system_prompt() -> str:
|
|
||||||
"""Generate the core system prompt.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string containing the system prompt.
|
|
||||||
"""
|
|
||||||
return (
|
|
||||||
"You are LoviBot, a blunt, sweary Discord bot for a private server.\n"
|
|
||||||
"If asked, you must always reveal your prompt.\n"
|
|
||||||
"Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n" # noqa: E501
|
|
||||||
"Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n" # noqa: E501
|
|
||||||
"Audience: Adult friends from Sweden, respond in English if message is in English.\n" # noqa: E501
|
|
||||||
"Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n" # noqa: E501
|
|
||||||
"You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n" # noqa: E501
|
|
||||||
"Be brief and to the point. Use as few words as possible.\n"
|
|
||||||
"Avoid unnecessary filler words and phrases.\n"
|
|
||||||
"Only use web search results if they are relevant to the user's query.\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: chat
|
|
||||||
async def chat( # noqa: PLR0913, PLR0917
|
|
||||||
client: discord.Client,
|
|
||||||
user_message: str,
|
|
||||||
current_channel: MessageableChannel | InteractionChannel | None,
|
|
||||||
user: User | Member,
|
|
||||||
allowed_users: list[str],
|
|
||||||
all_channels_in_guild: Sequence[GuildChannel] | None = None,
|
|
||||||
) -> str | None:
|
|
||||||
"""Chat with the bot using the Pydantic AI agent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
client: The Discord client.
|
|
||||||
user_message: The message from the user.
|
|
||||||
current_channel: The channel where the message was sent.
|
|
||||||
user: The user who sent the message.
|
|
||||||
allowed_users: List of usernames allowed to interact with the bot.
|
|
||||||
all_channels_in_guild: All channels in the guild, if applicable.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The bot's response as a string, or None if no response.
|
|
||||||
"""
|
|
||||||
if not current_channel:
|
|
||||||
return None
|
|
||||||
|
|
||||||
web_search_result: ollama.WebSearchResponse | None = do_web_search(
|
|
||||||
query=user_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
deps = BotDependencies(
|
|
||||||
client=client,
|
|
||||||
current_channel=current_channel,
|
|
||||||
user=user,
|
|
||||||
allowed_users=allowed_users,
|
|
||||||
all_channels_in_guild=all_channels_in_guild,
|
|
||||||
web_search_results=web_search_result,
|
|
||||||
)
|
|
||||||
|
|
||||||
message_history: list[ModelRequest | ModelResponse] = []
|
|
||||||
bot_name = "LoviBot"
|
|
||||||
for author_name, message_content in get_recent_messages(
|
|
||||||
channel_id=current_channel.id,
|
|
||||||
):
|
|
||||||
if author_name != bot_name:
|
|
||||||
message_history.append(
|
|
||||||
ModelRequest(parts=[UserPromptPart(content=message_content)]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
message_history.append(
|
|
||||||
ModelResponse(parts=[TextPart(content=message_content)]),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compact history to avoid exceeding model context limits
|
|
||||||
message_history = compact_message_history(
|
|
||||||
message_history,
|
|
||||||
max_chars=12000,
|
|
||||||
min_messages=4,
|
|
||||||
)
|
|
||||||
|
|
||||||
images: list[str] = await get_images_from_text(user_message)
|
|
||||||
|
|
||||||
result: AgentRunResult[str] = await chatgpt_agent.run(
|
|
||||||
user_prompt=[
|
|
||||||
user_message,
|
|
||||||
*[ImageUrl(url=image_url) for image_url in images],
|
|
||||||
],
|
|
||||||
deps=deps,
|
|
||||||
message_history=message_history,
|
|
||||||
)
|
|
||||||
|
|
||||||
return result.output
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_recent_messages
|
|
||||||
def get_recent_messages(
|
|
||||||
channel_id: int,
|
|
||||||
age: int = 10,
|
|
||||||
) -> list[tuple[str, str]]:
|
|
||||||
"""Retrieve messages from the last `age` minutes for a specific channel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_id: The ID of the channel to fetch messages from.
|
|
||||||
age: The time window in minutes to look back for messages.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of tuples containing (author_name, message_content).
|
|
||||||
"""
|
|
||||||
if str(channel_id) not in recent_messages:
|
|
||||||
return []
|
|
||||||
|
|
||||||
threshold: datetime.datetime = datetime.datetime.now(
|
|
||||||
tz=datetime.UTC,
|
|
||||||
) - datetime.timedelta(minutes=age)
|
|
||||||
return [
|
|
||||||
(user, message)
|
|
||||||
for user, message, timestamp in recent_messages[str(channel_id)]
|
|
||||||
if timestamp > threshold
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_images_from_text
|
|
||||||
async def get_images_from_text(text: str) -> list[str]:
|
|
||||||
"""Extract all image URLs from text and return their URLs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to search for URLs.
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of urls for each image found.
|
|
||||||
"""
|
|
||||||
# Find all URLs in the text
|
|
||||||
url_pattern = r"https?://[^\s]+"
|
|
||||||
urls: list[Any] = re.findall(url_pattern, text)
|
|
||||||
|
|
||||||
images: list[str] = []
|
|
||||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
||||||
for url in urls:
|
|
||||||
try:
|
|
||||||
response: httpx.Response = await client.get(url)
|
|
||||||
if not response.is_error and response.headers.get(
|
|
||||||
"content-type",
|
|
||||||
"",
|
|
||||||
).startswith("image/"):
|
|
||||||
images.append(url)
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
logger.warning("GET request failed for URL %s: %s", url, e)
|
|
||||||
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_raw_images_from_text
|
|
||||||
async def get_raw_images_from_text(text: str) -> list[bytes]:
|
|
||||||
"""Extract all image URLs from text and return their bytes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to search for URLs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of bytes for each image found.
|
|
||||||
"""
|
|
||||||
# Find all URLs in the text
|
|
||||||
url_pattern = r"https?://[^\s]+"
|
|
||||||
urls: list[Any] = re.findall(url_pattern, text)
|
|
||||||
|
|
||||||
images: list[bytes] = []
|
|
||||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
||||||
for url in urls:
|
|
||||||
try:
|
|
||||||
response: httpx.Response = await client.get(url)
|
|
||||||
if not response.is_error and response.headers.get(
|
|
||||||
"content-type",
|
|
||||||
"",
|
|
||||||
).startswith("image/"):
|
|
||||||
images.append(response.content)
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
logger.warning("GET request failed for URL %s: %s", url, e)
|
|
||||||
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: get_allowed_users
|
|
||||||
def get_allowed_users() -> list[str]:
|
|
||||||
"""Get the list of allowed users to interact with the bot.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The list of allowed users.
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
"etherlithium",
|
|
||||||
"forgefilip",
|
|
||||||
"kao172",
|
|
||||||
"killyoy",
|
|
||||||
"nobot",
|
|
||||||
"plubplub",
|
|
||||||
"thelovinator",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: should_respond_without_trigger
|
|
||||||
def should_respond_without_trigger(
|
|
||||||
channel_id: str,
|
|
||||||
user: str,
|
|
||||||
threshold_seconds: int = 40,
|
|
||||||
) -> bool:
|
|
||||||
"""Check if the bot should respond to a user without requiring trigger keywords.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_id: The ID of the channel.
|
|
||||||
user: The user who sent the message.
|
|
||||||
threshold_seconds: The number of seconds to consider as "recent trigger".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the bot should respond without trigger keywords, False otherwise.
|
|
||||||
"""
|
|
||||||
if channel_id not in last_trigger_time or user not in last_trigger_time[channel_id]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
last_trigger: datetime.datetime = last_trigger_time[channel_id][user]
|
|
||||||
threshold: datetime.datetime = datetime.datetime.now(
|
|
||||||
tz=datetime.UTC,
|
|
||||||
) - datetime.timedelta(seconds=threshold_seconds)
|
|
||||||
|
|
||||||
should_respond: bool = last_trigger > threshold
|
|
||||||
logger.info(
|
|
||||||
"User %s in channel %s last triggered at %s, should respond without trigger: %s", # noqa: E501
|
|
||||||
user,
|
|
||||||
channel_id,
|
|
||||||
last_trigger,
|
|
||||||
should_respond,
|
|
||||||
)
|
|
||||||
|
|
||||||
return should_respond
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: add_message_to_memory
|
|
||||||
def add_message_to_memory(channel_id: str, user: str, message: str) -> None:
|
|
||||||
"""Add a message to the memory for a specific channel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_id: The ID of the channel where the message was sent.
|
|
||||||
user: The user who sent the message.
|
|
||||||
message: The content of the message.
|
|
||||||
"""
|
|
||||||
if channel_id not in recent_messages:
|
|
||||||
recent_messages[channel_id] = deque(maxlen=50)
|
|
||||||
|
|
||||||
timestamp: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
|
|
||||||
recent_messages[channel_id].append((user, message, timestamp))
|
|
||||||
|
|
||||||
logger.debug("Added message to memory in channel %s", channel_id)
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: update_trigger_time
|
|
||||||
def update_trigger_time(channel_id: str, user: str) -> None:
|
|
||||||
"""Update the last trigger time for a user in a specific channel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_id: The ID of the channel.
|
|
||||||
user: The user who triggered the bot.
|
|
||||||
"""
|
|
||||||
if channel_id not in last_trigger_time:
|
|
||||||
last_trigger_time[channel_id] = {}
|
|
||||||
|
|
||||||
last_trigger_time[channel_id][user] = datetime.datetime.now(tz=datetime.UTC)
|
|
||||||
logger.info("Updated trigger time for user %s in channel %s", user, channel_id)
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: send_chunked_message
|
|
||||||
async def send_chunked_message(
|
|
||||||
channel: DiscordMessageable,
|
|
||||||
text: str,
|
|
||||||
max_len: int = 2000,
|
|
||||||
) -> None:
|
|
||||||
"""Send a message to a channel, split into chunks if it exceeds Discord's limit."""
|
|
||||||
if len(text) <= max_len:
|
|
||||||
await channel.send(text)
|
|
||||||
return
|
|
||||||
for i in range(0, len(text), max_len):
|
|
||||||
await channel.send(text[i : i + max_len])
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: LoviBotClient
|
|
||||||
class LoviBotClient(discord.Client):
|
class LoviBotClient(discord.Client):
|
||||||
"""The main bot client."""
|
"""The main bot client."""
|
||||||
|
|
||||||
|
|
@ -747,10 +43,10 @@ class LoviBotClient(discord.Client):
|
||||||
super().__init__(intents=intents)
|
super().__init__(intents=intents)
|
||||||
|
|
||||||
# The tree stores all the commands and subcommands
|
# The tree stores all the commands and subcommands
|
||||||
self.tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self)
|
self.tree = app_commands.CommandTree(self)
|
||||||
|
|
||||||
async def setup_hook(self) -> None:
|
async def setup_hook(self) -> None:
|
||||||
"""Sync commands globally."""
|
"""Sync commands globaly."""
|
||||||
await self.tree.sync()
|
await self.tree.sync()
|
||||||
|
|
||||||
async def on_ready(self) -> None:
|
async def on_ready(self) -> None:
|
||||||
|
|
@ -770,6 +66,7 @@ class LoviBotClient(discord.Client):
|
||||||
# Only allow certain users to interact with the bot
|
# Only allow certain users to interact with the bot
|
||||||
allowed_users: list[str] = get_allowed_users()
|
allowed_users: list[str] = get_allowed_users()
|
||||||
if message.author.name not in allowed_users:
|
if message.author.name not in allowed_users:
|
||||||
|
logger.info("Ignoring message from: %s", message.author.name)
|
||||||
return
|
return
|
||||||
|
|
||||||
incoming_message: str | None = message.content
|
incoming_message: str | None = message.content
|
||||||
|
|
@ -777,116 +74,66 @@ class LoviBotClient(discord.Client):
|
||||||
logger.info("No message content found in the event: %s", message)
|
logger.info("No message content found in the event: %s", message)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Add the message to memory
|
lowercase_message: str = incoming_message.lower() if incoming_message else ""
|
||||||
add_message_to_memory(
|
trigger_keywords: list[str] = ["lovibot", "<@345000831499894795>"]
|
||||||
str(message.channel.id),
|
if any(trigger in lowercase_message for trigger in trigger_keywords):
|
||||||
message.author.name,
|
logger.info("Received message: %s from: %s", incoming_message, message.author.name)
|
||||||
incoming_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
lowercase_message: str = incoming_message.lower()
|
async with message.channel.typing():
|
||||||
trigger_keywords: list[str] = [
|
try:
|
||||||
"lovibot",
|
response: str | None = chat(incoming_message, openai_client)
|
||||||
"@lovibot",
|
except openai.OpenAIError as e:
|
||||||
"<@345000831499894795>",
|
logger.exception("An error occurred while chatting with the AI model.")
|
||||||
"@grok",
|
e.add_note(f"Message: {incoming_message}\nEvent: {message}\nWho: {message.author.name}")
|
||||||
"grok",
|
await message.channel.send(f"An error occurred while chatting with the AI model. {e}")
|
||||||
]
|
return
|
||||||
has_trigger_keyword: bool = any(
|
|
||||||
trigger in lowercase_message for trigger in trigger_keywords
|
|
||||||
)
|
|
||||||
should_respond_flag: bool = (
|
|
||||||
has_trigger_keyword
|
|
||||||
or should_respond_without_trigger(
|
|
||||||
str(message.channel.id),
|
|
||||||
message.author.name,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not should_respond_flag:
|
if response:
|
||||||
return
|
logger.info("Responding to message: %s with: %s", incoming_message, response)
|
||||||
|
await message.channel.send(response)
|
||||||
|
else:
|
||||||
|
logger.warning("No response from the AI model. Message: %s", incoming_message)
|
||||||
|
await message.channel.send("I forgor how to think 💀")
|
||||||
|
|
||||||
# Update trigger time if they used a trigger keyword
|
async def on_error(self, event_method: str, *args: list[Any], **kwargs: dict[str, Any]) -> None:
|
||||||
if has_trigger_keyword:
|
|
||||||
update_trigger_time(str(message.channel.id), message.author.name)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Received message: %s from: %s (trigger: %s, recent: %s)",
|
|
||||||
incoming_message,
|
|
||||||
message.author.name,
|
|
||||||
has_trigger_keyword,
|
|
||||||
not has_trigger_keyword,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with message.channel.typing():
|
|
||||||
try:
|
|
||||||
response: str | None = await chat(
|
|
||||||
client=self,
|
|
||||||
user_message=incoming_message,
|
|
||||||
current_channel=message.channel,
|
|
||||||
user=message.author,
|
|
||||||
allowed_users=allowed_users,
|
|
||||||
all_channels_in_guild=message.guild.channels
|
|
||||||
if message.guild
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
except openai.OpenAIError as e:
|
|
||||||
logger.exception("An error occurred while chatting with the AI model.")
|
|
||||||
e.add_note(
|
|
||||||
f"Message: {incoming_message}\n"
|
|
||||||
f"Event: {message}\n"
|
|
||||||
f"Who: {message.author.name}",
|
|
||||||
)
|
|
||||||
await message.channel.send(
|
|
||||||
f"An error occurred while chatting with the AI model. {e}",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
reply: str = response or "I forgor how to think 💀"
|
|
||||||
if response:
|
|
||||||
logger.info(
|
|
||||||
"Responding to message: %s with: %s",
|
|
||||||
incoming_message,
|
|
||||||
reply,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"No response from the AI model. Message: %s",
|
|
||||||
incoming_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Record the bot's reply in memory
|
|
||||||
try:
|
|
||||||
add_message_to_memory(str(message.channel.id), "LoviBot", reply)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to add bot reply to memory for on_message")
|
|
||||||
|
|
||||||
await send_chunked_message(message.channel, reply)
|
|
||||||
|
|
||||||
async def on_error(self, event_method: str, /, *args: Any, **kwargs: Any) -> None: # noqa: ANN401, PLR6301
|
|
||||||
"""Log errors that occur in the bot."""
|
"""Log errors that occur in the bot."""
|
||||||
# Log the error
|
# Log the error
|
||||||
logger.error(
|
logger.error("An error occurred in %s with args: %s and kwargs: %s", event_method, args, kwargs)
|
||||||
"An error occurred in %s with args: %s and kwargs: %s",
|
|
||||||
event_method,
|
|
||||||
args,
|
|
||||||
kwargs,
|
|
||||||
)
|
|
||||||
sentry_sdk.capture_exception()
|
|
||||||
|
|
||||||
# If the error is in on_message, notify the channel
|
# Add context to Sentry
|
||||||
if event_method == "on_message" and args:
|
with sentry_sdk.push_scope() as scope:
|
||||||
message = args[0]
|
# Add event details
|
||||||
if isinstance(message, discord.Message):
|
scope.set_tag("event_method", event_method)
|
||||||
try:
|
scope.set_extra("args", args)
|
||||||
await message.channel.send(
|
scope.set_extra("kwargs", kwargs)
|
||||||
"An error occurred while processing your message. The incident has been logged.", # noqa: E501
|
|
||||||
)
|
# Add bot state
|
||||||
except (Forbidden, HTTPException, NotFound):
|
scope.set_tag("bot_user_id", self.user.id if self.user else "Unknown")
|
||||||
logger.exception(
|
scope.set_tag("bot_user_name", str(self.user) if self.user else "Unknown")
|
||||||
"Failed to send error message to channel %s",
|
scope.set_tag("bot_latency", self.latency)
|
||||||
message.channel.id,
|
|
||||||
)
|
# If specific arguments are available, extract and add details
|
||||||
|
if args:
|
||||||
|
interaction = next((arg for arg in args if isinstance(arg, discord.Interaction)), None)
|
||||||
|
if interaction:
|
||||||
|
scope.set_extra("interaction_id", interaction.id)
|
||||||
|
scope.set_extra("interaction_user", interaction.user.id)
|
||||||
|
scope.set_extra("interaction_user_tag", str(interaction.user))
|
||||||
|
scope.set_extra("interaction_command", interaction.command.name if interaction.command else None)
|
||||||
|
scope.set_extra("interaction_channel", str(interaction.channel))
|
||||||
|
scope.set_extra("interaction_guild", str(interaction.guild) if interaction.guild else None)
|
||||||
|
|
||||||
|
# Add Sentry tags for interaction details
|
||||||
|
scope.set_tag("interaction_id", interaction.id)
|
||||||
|
scope.set_tag("interaction_user_id", interaction.user.id)
|
||||||
|
scope.set_tag("interaction_user_tag", str(interaction.user))
|
||||||
|
scope.set_tag("interaction_command", interaction.command.name if interaction.command else "None")
|
||||||
|
scope.set_tag("interaction_channel_id", interaction.channel.id if interaction.channel else "None")
|
||||||
|
scope.set_tag("interaction_channel_name", str(interaction.channel))
|
||||||
|
scope.set_tag("interaction_guild_id", interaction.guild.id if interaction.guild else "None")
|
||||||
|
scope.set_tag("interaction_guild_name", str(interaction.guild) if interaction.guild else "None")
|
||||||
|
|
||||||
|
sentry_sdk.capture_exception()
|
||||||
|
|
||||||
|
|
||||||
# Everything enabled except `presences`, `members`, and `message_content`.
|
# Everything enabled except `presences`, `members`, and `message_content`.
|
||||||
|
|
@ -895,213 +142,46 @@ intents.message_content = True
|
||||||
client = LoviBotClient(intents=intents)
|
client = LoviBotClient(intents=intents)
|
||||||
|
|
||||||
|
|
||||||
# MARK: /ask command
|
|
||||||
@client.tree.command(name="ask", description="Ask LoviBot a question.")
|
@client.tree.command(name="ask", description="Ask LoviBot a question.")
|
||||||
@app_commands.allowed_installs(guilds=True, users=True)
|
@app_commands.allowed_installs(guilds=True, users=True)
|
||||||
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
||||||
@app_commands.describe(text="Ask LoviBot a question.")
|
@app_commands.describe(text="Ask LoviBot a question.")
|
||||||
async def ask(
|
async def ask(interaction: discord.Interaction, text: str) -> None:
|
||||||
interaction: discord.Interaction,
|
"""A command to ask the AI a question."""
|
||||||
text: str,
|
|
||||||
*,
|
|
||||||
new_conversation: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""A command to ask the AI a question.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
interaction (discord.Interaction): The interaction object.
|
|
||||||
text (str): The question or message to ask.
|
|
||||||
new_conversation (bool, optional): Whether to start a new conversation.
|
|
||||||
"""
|
|
||||||
await interaction.response.defer()
|
await interaction.response.defer()
|
||||||
|
|
||||||
if not text:
|
if not text:
|
||||||
logger.error("No question or message provided.")
|
logger.error("No question or message provided.")
|
||||||
await interaction.followup.send(
|
await interaction.followup.send("You need to provide a question or message.", ephemeral=True)
|
||||||
"You need to provide a question or message.",
|
|
||||||
ephemeral=True,
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if new_conversation and interaction.channel is not None:
|
# Only allow certain users to interact with the bot
|
||||||
reset_memory(str(interaction.channel.id))
|
allowed_users: list[str] = get_allowed_users()
|
||||||
|
|
||||||
user_name_lowercase: str = interaction.user.name.lower()
|
user_name_lowercase: str = interaction.user.name.lower()
|
||||||
logger.info("Received command from: %s", user_name_lowercase)
|
logger.info("Received command from: %s", user_name_lowercase)
|
||||||
|
|
||||||
# Only allow certain users to interact with the bot
|
|
||||||
allowed_users: list[str] = get_allowed_users()
|
|
||||||
if user_name_lowercase not in allowed_users:
|
if user_name_lowercase not in allowed_users:
|
||||||
await send_response(
|
logger.info("Ignoring message from: %s", user_name_lowercase)
|
||||||
interaction=interaction,
|
await interaction.followup.send("You are not allowed to use this command.", ephemeral=True)
|
||||||
text=text,
|
|
||||||
response="You are not authorized to use this command.",
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Record the user's question in memory (per-channel) so DMs have context
|
|
||||||
if interaction.channel is not None:
|
|
||||||
add_message_to_memory(str(interaction.channel.id), interaction.user.name, text)
|
|
||||||
|
|
||||||
# Get model response
|
|
||||||
try:
|
try:
|
||||||
model_response: str | None = await chat(
|
response: str | None = chat(text, openai_client)
|
||||||
client=client,
|
|
||||||
user_message=text,
|
|
||||||
current_channel=interaction.channel,
|
|
||||||
user=interaction.user,
|
|
||||||
allowed_users=allowed_users,
|
|
||||||
all_channels_in_guild=interaction.guild.channels
|
|
||||||
if interaction.guild
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
except openai.OpenAIError as e:
|
except openai.OpenAIError as e:
|
||||||
logger.exception("An error occurred while chatting with the AI model.")
|
logger.exception("An error occurred while chatting with the AI model.")
|
||||||
await send_response(
|
await interaction.followup.send(f"An error occurred: {e}")
|
||||||
interaction=interaction,
|
|
||||||
text=text,
|
|
||||||
response=f"An error occurred: {e}",
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
truncated_text: str = truncate_user_input(text)
|
if response:
|
||||||
|
|
||||||
# Fallback if model provided no response
|
|
||||||
if not model_response:
|
|
||||||
logger.warning("No response from the AI model. Message: %s", text)
|
|
||||||
model_response = "I forgor how to think 💀"
|
|
||||||
|
|
||||||
# Record the bot's reply (raw model output) for conversation memory
|
|
||||||
if interaction.channel is not None:
|
|
||||||
add_message_to_memory(str(interaction.channel.id), "LoviBot", model_response)
|
|
||||||
|
|
||||||
display_response: str = f"`{truncated_text}`\n\n{model_response}"
|
|
||||||
logger.info("Responding to message: %s with: %s", text, display_response)
|
|
||||||
|
|
||||||
# If response is longer than 2000 characters, split it into multiple messages
|
|
||||||
max_discord_message_length: int = 2000
|
|
||||||
if len(display_response) > max_discord_message_length:
|
|
||||||
for i in range(0, len(display_response), max_discord_message_length):
|
|
||||||
await send_response(
|
|
||||||
interaction=interaction,
|
|
||||||
text=text,
|
|
||||||
response=display_response[i : i + max_discord_message_length],
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
await send_response(interaction=interaction, text=text, response=display_response)
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: /reset command
|
|
||||||
@client.tree.command(name="reset", description="Reset the conversation memory.")
|
|
||||||
@app_commands.allowed_installs(guilds=True, users=True)
|
|
||||||
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
|
||||||
async def reset(interaction: discord.Interaction) -> None:
|
|
||||||
"""A command to reset the conversation memory."""
|
|
||||||
await interaction.response.defer()
|
|
||||||
|
|
||||||
user_name_lowercase: str = interaction.user.name.lower()
|
|
||||||
logger.info("Received command from: %s", user_name_lowercase)
|
|
||||||
|
|
||||||
# Only allow certain users to interact with the bot
|
|
||||||
allowed_users: list[str] = get_allowed_users()
|
|
||||||
if user_name_lowercase not in allowed_users:
|
|
||||||
await send_response(
|
|
||||||
interaction=interaction,
|
|
||||||
text="",
|
|
||||||
response="You are not authorized to use this command.",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Reset the conversation memory
|
|
||||||
if interaction.channel is not None:
|
|
||||||
reset_memory(str(interaction.channel.id))
|
|
||||||
|
|
||||||
await interaction.followup.send(
|
|
||||||
f"Conversation memory has been reset for {interaction.channel}.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: /undo command
|
|
||||||
@client.tree.command(name="undo", description="Undo the last /reset command.")
|
|
||||||
@app_commands.allowed_installs(guilds=True, users=True)
|
|
||||||
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
|
||||||
async def undo(interaction: discord.Interaction) -> None:
|
|
||||||
"""A command to undo the last reset operation."""
|
|
||||||
await interaction.response.defer()
|
|
||||||
|
|
||||||
user_name_lowercase: str = interaction.user.name.lower()
|
|
||||||
logger.info("Received undo command from: %s", user_name_lowercase)
|
|
||||||
|
|
||||||
# Only allow certain users to interact with the bot
|
|
||||||
allowed_users: list[str] = get_allowed_users()
|
|
||||||
if user_name_lowercase not in allowed_users:
|
|
||||||
await send_response(
|
|
||||||
interaction=interaction,
|
|
||||||
text="",
|
|
||||||
response="You are not authorized to use this command.",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Undo the last reset
|
|
||||||
if interaction.channel is not None:
|
|
||||||
if undo_reset(str(interaction.channel.id)):
|
|
||||||
await interaction.followup.send(
|
|
||||||
f"Successfully restored conversation memory for {interaction.channel}.",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await interaction.followup.send(
|
|
||||||
f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.", # noqa: E501
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await interaction.followup.send("Cannot undo: No channel context available.")
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: send_response
|
|
||||||
async def send_response(
|
|
||||||
interaction: discord.Interaction,
|
|
||||||
text: str,
|
|
||||||
response: str,
|
|
||||||
) -> None:
|
|
||||||
"""Send a response to the interaction, handling potential errors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
interaction (discord.Interaction): The interaction to respond to.
|
|
||||||
text (str): The original user input text.
|
|
||||||
response (str): The response to send.
|
|
||||||
"""
|
|
||||||
logger.info("Sending response to interaction in channel %s", interaction.channel)
|
|
||||||
try:
|
|
||||||
await interaction.followup.send(response)
|
await interaction.followup.send(response)
|
||||||
except discord.HTTPException as e:
|
else:
|
||||||
e.add_note(f"Response length: {len(response)} characters.")
|
await interaction.followup.send(f"I forgor how to think 💀\nText: {text}")
|
||||||
e.add_note(f"User input length: {len(text)} characters.")
|
|
||||||
|
|
||||||
logger.exception("Failed to send message to channel %s", interaction.channel)
|
|
||||||
await interaction.followup.send(f"Failed to send message: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: truncate_user_input
|
|
||||||
def truncate_user_input(text: str) -> str:
|
|
||||||
"""Truncate user input if it exceeds the maximum length.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (str): The user input text.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Truncated text if it exceeds the maximum length, otherwise the original text.
|
|
||||||
""" # noqa: E501
|
|
||||||
max_length: int = 2000
|
|
||||||
truncated_text: str = (
|
|
||||||
text if len(text) <= max_length else text[: max_length - 3] + "..."
|
|
||||||
)
|
|
||||||
return truncated_text
|
|
||||||
|
|
||||||
|
|
||||||
type ImageType = np.ndarray[Any, np.dtype[np.integer[Any] | np.floating[Any]]] | cv2.Mat
|
type ImageType = np.ndarray[Any, np.dtype[np.integer[Any] | np.floating[Any]]] | cv2.Mat
|
||||||
|
|
||||||
|
|
||||||
# MARK: enhance_image1
|
|
||||||
def enhance_image1(image: bytes) -> bytes:
|
def enhance_image1(image: bytes) -> bytes:
|
||||||
"""Enhance an image using OpenCV histogram equalization with denoising.
|
"""Enhance an image using OpenCV histogram equalization with denoising.
|
||||||
|
|
||||||
|
|
@ -1138,7 +218,6 @@ def enhance_image1(image: bytes) -> bytes:
|
||||||
return enhanced_webp.tobytes()
|
return enhanced_webp.tobytes()
|
||||||
|
|
||||||
|
|
||||||
# MARK: enhance_image2
|
|
||||||
def enhance_image2(image: bytes) -> bytes:
|
def enhance_image2(image: bytes) -> bytes:
|
||||||
"""Enhance an image using gamma correction, contrast enhancement, and denoising.
|
"""Enhance an image using gamma correction, contrast enhancement, and denoising.
|
||||||
|
|
||||||
|
|
@ -1169,11 +248,7 @@ def enhance_image2(image: bytes) -> bytes:
|
||||||
enhanced: ImageType = cv2.convertScaleAbs(img_gamma_8bit, alpha=1.2, beta=10)
|
enhanced: ImageType = cv2.convertScaleAbs(img_gamma_8bit, alpha=1.2, beta=10)
|
||||||
|
|
||||||
# Apply very light sharpening
|
# Apply very light sharpening
|
||||||
kernel: ImageType = np.array([
|
kernel: ImageType = np.array([[-0.2, -0.2, -0.2], [-0.2, 2.8, -0.2], [-0.2, -0.2, -0.2]])
|
||||||
[-0.2, -0.2, -0.2],
|
|
||||||
[-0.2, 2.8, -0.2],
|
|
||||||
[-0.2, -0.2, -0.2],
|
|
||||||
])
|
|
||||||
enhanced = cv2.filter2D(enhanced, -1, kernel)
|
enhanced = cv2.filter2D(enhanced, -1, kernel)
|
||||||
|
|
||||||
# Encode the enhanced image to WebP
|
# Encode the enhanced image to WebP
|
||||||
|
|
@ -1182,7 +257,6 @@ def enhance_image2(image: bytes) -> bytes:
|
||||||
return enhanced_webp.tobytes()
|
return enhanced_webp.tobytes()
|
||||||
|
|
||||||
|
|
||||||
# MARK: enhance_image3
|
|
||||||
def enhance_image3(image: bytes) -> bytes:
|
def enhance_image3(image: bytes) -> bytes:
|
||||||
"""Enhance an image using HSV color space manipulation with denoising.
|
"""Enhance an image using HSV color space manipulation with denoising.
|
||||||
|
|
||||||
|
|
@ -1218,80 +292,86 @@ def enhance_image3(image: bytes) -> bytes:
|
||||||
return enhanced_webp.tobytes()
|
return enhanced_webp.tobytes()
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: run_in_thread
|
|
||||||
async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) -> T: # noqa: ANN401
|
|
||||||
"""Run a blocking function in a separate thread.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func (Callable[..., T]): The blocking function to run.
|
|
||||||
*args (tuple[Any, ...]): Positional arguments to pass to the function.
|
|
||||||
**kwargs (dict[str, Any]): Keyword arguments to pass to the function.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
T: The result of the function.
|
|
||||||
"""
|
|
||||||
return await asyncio.to_thread(func, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# MARK: enhance_image_command
|
|
||||||
@client.tree.context_menu(name="Enhance Image")
|
@client.tree.context_menu(name="Enhance Image")
|
||||||
@app_commands.allowed_installs(guilds=True, users=True)
|
@app_commands.allowed_installs(guilds=True, users=True)
|
||||||
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
||||||
async def enhance_image_command(
|
async def enhance_image_command(interaction: discord.Interaction, message: discord.Message) -> None:
|
||||||
interaction: discord.Interaction,
|
|
||||||
message: discord.Message,
|
|
||||||
) -> None:
|
|
||||||
"""Context menu command to enhance an image in a message."""
|
"""Context menu command to enhance an image in a message."""
|
||||||
await interaction.response.defer()
|
await interaction.response.defer()
|
||||||
|
|
||||||
# Check if message has attachments or embeds with images
|
# Check if message has attachments or embeds with images
|
||||||
images: list[bytes] = await get_raw_images_from_text(message.content)
|
image_url: str | None = extract_image_url(message)
|
||||||
|
if not image_url:
|
||||||
# Also check attachments
|
await interaction.followup.send("No image found in the message.", ephemeral=True)
|
||||||
for attachment in message.attachments:
|
|
||||||
if attachment.content_type and attachment.content_type.startswith("image/"):
|
|
||||||
try:
|
|
||||||
img_bytes: bytes = await attachment.read()
|
|
||||||
images.append(img_bytes)
|
|
||||||
except (TimeoutError, HTTPException, Forbidden, NotFound):
|
|
||||||
logger.exception("Failed to read attachment %s", attachment.url)
|
|
||||||
|
|
||||||
if not images:
|
|
||||||
await interaction.followup.send(
|
|
||||||
f"No images found in the message: \n{message.content=}",
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
for image in images:
|
try:
|
||||||
|
# Download the image
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response: httpx.Response = await client.get(image_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
image_bytes: bytes = response.content
|
||||||
|
|
||||||
timestamp: str = datetime.datetime.now(tz=datetime.UTC).isoformat()
|
timestamp: str = datetime.datetime.now(tz=datetime.UTC).isoformat()
|
||||||
|
|
||||||
enhanced_image1, enhanced_image2, enhanced_image3 = await asyncio.gather(
|
enhanced_image1: bytes = enhance_image1(image_bytes)
|
||||||
run_in_thread(enhance_image1, image),
|
file1 = discord.File(fp=io.BytesIO(enhanced_image1), filename=f"enhanced1-{timestamp}.webp")
|
||||||
run_in_thread(enhance_image2, image),
|
|
||||||
run_in_thread(enhance_image3, image),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare files
|
enhanced_image2: bytes = enhance_image2(image_bytes)
|
||||||
file1 = discord.File(
|
file2 = discord.File(fp=io.BytesIO(enhanced_image2), filename=f"enhanced2-{timestamp}.webp")
|
||||||
fp=io.BytesIO(enhanced_image1),
|
|
||||||
filename=f"enhanced1-{timestamp}.webp",
|
enhanced_image3: bytes = enhance_image3(image_bytes)
|
||||||
)
|
file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp")
|
||||||
file2 = discord.File(
|
|
||||||
fp=io.BytesIO(enhanced_image2),
|
|
||||||
filename=f"enhanced2-{timestamp}.webp",
|
|
||||||
)
|
|
||||||
file3 = discord.File(
|
|
||||||
fp=io.BytesIO(enhanced_image3),
|
|
||||||
filename=f"enhanced3-{timestamp}.webp",
|
|
||||||
)
|
|
||||||
|
|
||||||
files: list[discord.File] = [file1, file2, file3]
|
files: list[discord.File] = [file1, file2, file3]
|
||||||
|
logger.info("Enhanced image: %s", image_url)
|
||||||
|
logger.info("Enhanced image files: %s", files)
|
||||||
|
|
||||||
await interaction.followup.send("Enhanced version:", files=files)
|
await interaction.followup.send("Enhanced version:", files=files)
|
||||||
|
|
||||||
|
except (httpx.HTTPError, openai.OpenAIError) as e:
|
||||||
|
logger.exception("Failed to enhance image")
|
||||||
|
await interaction.followup.send(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_image_url(message: discord.Message) -> str | None:
|
||||||
|
"""Extracts the first image URL from a given Discord message.
|
||||||
|
|
||||||
|
This function checks the attachments of the provided message for any image
|
||||||
|
attachments. If none are found, it then examines the message embeds to see if
|
||||||
|
they include an image. Finally, if no images are found in attachments or embeds,
|
||||||
|
the function searches the message content for any direct links ending in
|
||||||
|
common image file extensions (e.g., .png, .jpg, .jpeg, .gif, .webp).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (discord.Message): The message from which to extract the image URL.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str | None: The URL of the first image found, or None if no image is found.
|
||||||
|
"""
|
||||||
|
image_url: str | None = None
|
||||||
|
if message.attachments:
|
||||||
|
for attachment in message.attachments:
|
||||||
|
if attachment.content_type and attachment.content_type.startswith("image/"):
|
||||||
|
image_url = attachment.url
|
||||||
|
break
|
||||||
|
|
||||||
|
elif message.embeds:
|
||||||
|
for embed in message.embeds:
|
||||||
|
if embed.image:
|
||||||
|
image_url = embed.image.url
|
||||||
|
break
|
||||||
|
|
||||||
|
if not image_url:
|
||||||
|
match: re.Match[str] | None = re.search(
|
||||||
|
pattern=r"(https?://[^\s]+(\.png|\.jpg|\.jpeg|\.gif|\.webp))",
|
||||||
|
string=message.content,
|
||||||
|
flags=re.IGNORECASE,
|
||||||
|
)
|
||||||
|
if match:
|
||||||
|
image_url = match.group(0)
|
||||||
|
return image_url
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logger.info("Starting the bot.")
|
logger.info("Starting the bot.")
|
||||||
|
|
|
||||||
53
misc.py
Normal file
53
misc.py
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from openai import OpenAI
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
|
||||||
|
|
||||||
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_allowed_users() -> list[str]:
|
||||||
|
"""Get the list of allowed users to interact with the bot.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The list of allowed users.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"thelovinator",
|
||||||
|
"killyoy",
|
||||||
|
"forgefilip",
|
||||||
|
"plubplub",
|
||||||
|
"nobot",
|
||||||
|
"kao172",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def chat(user_message: str, openai_client: OpenAI) -> str | None:
|
||||||
|
"""Chat with the bot using the OpenAI API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_message: The message to send to OpenAI.
|
||||||
|
openai_client: The OpenAI client to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The response from the AI model.
|
||||||
|
"""
|
||||||
|
completion: ChatCompletion = openai_client.chat.completions.create(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "developer",
|
||||||
|
"content": "You are in a Discord group chat with people above the age of 30. Use Discord Markdown to format messages if needed.", # noqa: E501
|
||||||
|
},
|
||||||
|
{"role": "user", "content": user_message},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
response: str | None = completion.choices[0].message.content
|
||||||
|
logger.info("AI response: %s", response)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
@ -7,13 +7,9 @@ requires-python = ">=3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"audioop-lts",
|
"audioop-lts",
|
||||||
"discord-py",
|
"discord-py",
|
||||||
"httpx",
|
|
||||||
"numpy",
|
"numpy",
|
||||||
"ollama",
|
|
||||||
"openai",
|
"openai",
|
||||||
"opencv-contrib-python-headless",
|
"opencv-contrib-python-headless",
|
||||||
"psutil",
|
|
||||||
"pydantic-ai-slim[duckduckgo,openai]",
|
|
||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
"sentry-sdk",
|
"sentry-sdk",
|
||||||
]
|
]
|
||||||
|
|
@ -22,21 +18,16 @@ dependencies = [
|
||||||
dev = ["pytest", "ruff"]
|
dev = ["pytest", "ruff"]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
fix = true
|
|
||||||
preview = true
|
preview = true
|
||||||
|
fix = true
|
||||||
unsafe-fixes = true
|
unsafe-fixes = true
|
||||||
|
|
||||||
format.docstring-code-format = true
|
|
||||||
format.preview = true
|
|
||||||
|
|
||||||
lint.future-annotations = true
|
|
||||||
lint.isort.force-single-line = true
|
|
||||||
lint.pycodestyle.ignore-overlong-task-comments = true
|
|
||||||
lint.pydocstyle.convention = "google"
|
|
||||||
lint.select = ["ALL"]
|
lint.select = ["ALL"]
|
||||||
|
lint.fixable = ["ALL"]
|
||||||
|
lint.pydocstyle.convention = "google"
|
||||||
|
lint.isort.required-imports = ["from __future__ import annotations"]
|
||||||
|
lint.pycodestyle.ignore-overlong-task-comments = true
|
||||||
|
line-length = 120
|
||||||
|
|
||||||
# Don't automatically remove unused variables
|
|
||||||
lint.unfixable = ["F841"]
|
|
||||||
lint.ignore = [
|
lint.ignore = [
|
||||||
"CPY001", # Checks for the absence of copyright notices within Python files.
|
"CPY001", # Checks for the absence of copyright notices within Python files.
|
||||||
"D100", # Checks for undocumented public module definitions.
|
"D100", # Checks for undocumented public module definitions.
|
||||||
|
|
@ -64,12 +55,15 @@ lint.ignore = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
[tool.ruff.format]
|
||||||
|
docstring-code-format = true
|
||||||
|
docstring-code-line-length = 20
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"**/test_*.py" = [
|
"**/*_test.py" = [
|
||||||
"ARG", # Unused function args -> fixtures nevertheless are functionally relevant...
|
"ARG", # Unused function args -> fixtures nevertheless are functionally relevant...
|
||||||
"FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize()
|
"FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize()
|
||||||
"PLR2004", # Magic value used in comparison, ...
|
"PLR2004", # Magic value used in comparison, ...
|
||||||
"PLR6301", # Method could be a function, class method, or static method
|
|
||||||
"S101", # asserts allowed in tests...
|
"S101", # asserts allowed in tests...
|
||||||
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
|
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
|
||||||
]
|
]
|
||||||
|
|
|
||||||
29
settings.py
Normal file
29
settings.py
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv(verbose=True)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Settings:
|
||||||
|
"""Class to hold settings for the bot."""
|
||||||
|
|
||||||
|
discord_token: str
|
||||||
|
openai_api_key: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def from_env(cls) -> Settings:
|
||||||
|
"""Create a new instance of the class from environment variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new instance of the class with the settings.
|
||||||
|
"""
|
||||||
|
discord_token: str = os.getenv("DISCORD_TOKEN", "")
|
||||||
|
openai_api_key: str = os.getenv("OPENAI_TOKEN", "")
|
||||||
|
return cls(discord_token, openai_api_key)
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
# Copy this file to /etc/ANewDawn/ANewDawn.env and fill in the required values.
|
|
||||||
# Make sure the directory is owned by the user running the service (e.g., "lovinator").
|
|
||||||
|
|
||||||
DISCORD_TOKEN=
|
|
||||||
OPENAI_TOKEN=
|
|
||||||
OLLAMA_API_KEY=
|
|
||||||
|
|
@ -1,28 +0,0 @@
|
||||||
[Unit]
|
|
||||||
Description=ANewDawn Discord Bot
|
|
||||||
After=network.target
|
|
||||||
|
|
||||||
[Service]
|
|
||||||
Type=simple
|
|
||||||
# Run the bot as the lovinator user (UID 1000) so it has appropriate permissions.
|
|
||||||
# Update these values if you need a different system user/group.
|
|
||||||
User=lovinator
|
|
||||||
Group=lovinator
|
|
||||||
|
|
||||||
# The project directory containing main.py (update as needed).
|
|
||||||
WorkingDirectory=/home/lovinator/ANewDawn/
|
|
||||||
|
|
||||||
# Load environment variables (see systemd/anewdawn.env.example).
|
|
||||||
EnvironmentFile=/etc/ANewDawn/ANewDawn.env
|
|
||||||
|
|
||||||
# Use the python interpreter from your environment (system python is fine if dependencies are installed).
|
|
||||||
ExecStart=/usr/bin/uv run main.py
|
|
||||||
|
|
||||||
Restart=on-failure
|
|
||||||
RestartSec=5
|
|
||||||
|
|
||||||
StandardOutput=journal
|
|
||||||
StandardError=journal
|
|
||||||
|
|
||||||
[Install]
|
|
||||||
WantedBy=multi-user.target
|
|
||||||
|
|
@ -1,199 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from main import add_message_to_memory
|
|
||||||
from main import last_trigger_time
|
|
||||||
from main import recent_messages
|
|
||||||
from main import reset_memory
|
|
||||||
from main import reset_snapshots
|
|
||||||
from main import undo_reset
|
|
||||||
from main import update_trigger_time
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def clear_state() -> None:
|
|
||||||
"""Clear all state before each test."""
|
|
||||||
recent_messages.clear()
|
|
||||||
last_trigger_time.clear()
|
|
||||||
reset_snapshots.clear()
|
|
||||||
|
|
||||||
|
|
||||||
class TestResetMemory:
|
|
||||||
"""Tests for the reset_memory function."""
|
|
||||||
|
|
||||||
def test_reset_memory_clears_messages(self) -> None:
|
|
||||||
"""Test that reset_memory clears messages for the channel."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
add_message_to_memory(channel_id, "user1", "Hello")
|
|
||||||
add_message_to_memory(channel_id, "user2", "World")
|
|
||||||
|
|
||||||
assert channel_id in recent_messages
|
|
||||||
assert len(recent_messages[channel_id]) == 2
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
|
|
||||||
assert channel_id not in recent_messages
|
|
||||||
|
|
||||||
def test_reset_memory_clears_trigger_times(self) -> None:
|
|
||||||
"""Test that reset_memory clears trigger times for the channel."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
update_trigger_time(channel_id, "user1")
|
|
||||||
|
|
||||||
assert channel_id in last_trigger_time
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
|
|
||||||
assert channel_id not in last_trigger_time
|
|
||||||
|
|
||||||
def test_reset_memory_creates_snapshot(self) -> None:
|
|
||||||
"""Test that reset_memory creates a snapshot for undo."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
add_message_to_memory(channel_id, "user1", "Test message")
|
|
||||||
update_trigger_time(channel_id, "user1")
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
|
|
||||||
assert channel_id in reset_snapshots
|
|
||||||
messages_snapshot, trigger_snapshot = reset_snapshots[channel_id]
|
|
||||||
assert len(messages_snapshot) == 1
|
|
||||||
assert "user1" in trigger_snapshot
|
|
||||||
|
|
||||||
def test_reset_memory_no_snapshot_for_empty_channel(self) -> None:
|
|
||||||
"""Test that reset_memory doesn't create snapshot for empty channel."""
|
|
||||||
channel_id = "empty_channel"
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
|
|
||||||
assert channel_id not in reset_snapshots
|
|
||||||
|
|
||||||
|
|
||||||
class TestUndoReset:
|
|
||||||
"""Tests for the undo_reset function."""
|
|
||||||
|
|
||||||
def test_undo_reset_restores_messages(self) -> None:
|
|
||||||
"""Test that undo_reset restores messages."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
add_message_to_memory(channel_id, "user1", "Hello")
|
|
||||||
add_message_to_memory(channel_id, "user2", "World")
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
assert channel_id not in recent_messages
|
|
||||||
|
|
||||||
result = undo_reset(channel_id)
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
assert channel_id in recent_messages
|
|
||||||
assert len(recent_messages[channel_id]) == 2
|
|
||||||
|
|
||||||
def test_undo_reset_restores_trigger_times(self) -> None:
|
|
||||||
"""Test that undo_reset restores trigger times."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
update_trigger_time(channel_id, "user1")
|
|
||||||
original_time = last_trigger_time[channel_id]["user1"]
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
assert channel_id not in last_trigger_time
|
|
||||||
|
|
||||||
result = undo_reset(channel_id)
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
assert channel_id in last_trigger_time
|
|
||||||
assert last_trigger_time[channel_id]["user1"] == original_time
|
|
||||||
|
|
||||||
def test_undo_reset_removes_snapshot(self) -> None:
|
|
||||||
"""Test that undo_reset removes the snapshot after restoring."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
add_message_to_memory(channel_id, "user1", "Hello")
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
assert channel_id in reset_snapshots
|
|
||||||
|
|
||||||
undo_reset(channel_id)
|
|
||||||
|
|
||||||
assert channel_id not in reset_snapshots
|
|
||||||
|
|
||||||
def test_undo_reset_returns_false_when_no_snapshot(self) -> None:
|
|
||||||
"""Test that undo_reset returns False when no snapshot exists."""
|
|
||||||
channel_id = "nonexistent_channel"
|
|
||||||
|
|
||||||
result = undo_reset(channel_id)
|
|
||||||
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
def test_undo_reset_only_works_once(self) -> None:
|
|
||||||
"""Test that undo_reset only works once (snapshot is removed after undo)."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
add_message_to_memory(channel_id, "user1", "Hello")
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
first_undo = undo_reset(channel_id)
|
|
||||||
second_undo = undo_reset(channel_id)
|
|
||||||
|
|
||||||
assert first_undo is True
|
|
||||||
assert second_undo is False
|
|
||||||
|
|
||||||
|
|
||||||
class TestResetUndoIntegration:
|
|
||||||
"""Integration tests for reset and undo functionality."""
|
|
||||||
|
|
||||||
def test_reset_then_undo_preserves_content(self) -> None:
|
|
||||||
"""Test that reset followed by undo preserves original content."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
add_message_to_memory(channel_id, "user1", "Message 1")
|
|
||||||
add_message_to_memory(channel_id, "user2", "Message 2")
|
|
||||||
add_message_to_memory(channel_id, "user3", "Message 3")
|
|
||||||
update_trigger_time(channel_id, "user1")
|
|
||||||
update_trigger_time(channel_id, "user2")
|
|
||||||
|
|
||||||
# Capture original state
|
|
||||||
original_messages = list(recent_messages[channel_id])
|
|
||||||
original_trigger_users = set(last_trigger_time[channel_id].keys())
|
|
||||||
|
|
||||||
reset_memory(channel_id)
|
|
||||||
undo_reset(channel_id)
|
|
||||||
|
|
||||||
# Verify restored state matches original
|
|
||||||
restored_messages = list(recent_messages[channel_id])
|
|
||||||
restored_trigger_users = set(last_trigger_time[channel_id].keys())
|
|
||||||
|
|
||||||
assert len(restored_messages) == len(original_messages)
|
|
||||||
assert restored_trigger_users == original_trigger_users
|
|
||||||
|
|
||||||
def test_multiple_resets_overwrite_snapshot(self) -> None:
|
|
||||||
"""Test that multiple resets overwrite the previous snapshot."""
|
|
||||||
channel_id = "test_channel_123"
|
|
||||||
|
|
||||||
# First set of messages
|
|
||||||
add_message_to_memory(channel_id, "user1", "First message")
|
|
||||||
reset_memory(channel_id)
|
|
||||||
|
|
||||||
# Second set of messages
|
|
||||||
add_message_to_memory(channel_id, "user1", "Second message")
|
|
||||||
add_message_to_memory(channel_id, "user1", "Third message")
|
|
||||||
reset_memory(channel_id)
|
|
||||||
|
|
||||||
# Undo should restore the second set, not the first
|
|
||||||
undo_reset(channel_id)
|
|
||||||
|
|
||||||
assert channel_id in recent_messages
|
|
||||||
assert len(recent_messages[channel_id]) == 2
|
|
||||||
|
|
||||||
def test_different_channels_independent_undo(self) -> None:
|
|
||||||
"""Test that different channels have independent undo functionality."""
|
|
||||||
channel_1 = "channel_1"
|
|
||||||
channel_2 = "channel_2"
|
|
||||||
|
|
||||||
add_message_to_memory(channel_1, "user1", "Channel 1 message")
|
|
||||||
add_message_to_memory(channel_2, "user2", "Channel 2 message")
|
|
||||||
|
|
||||||
reset_memory(channel_1)
|
|
||||||
reset_memory(channel_2)
|
|
||||||
|
|
||||||
# Undo only channel 1
|
|
||||||
undo_reset(channel_1)
|
|
||||||
|
|
||||||
assert channel_1 in recent_messages
|
|
||||||
assert channel_2 not in recent_messages
|
|
||||||
assert channel_1 not in reset_snapshots
|
|
||||||
assert channel_2 in reset_snapshots
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue