diff --git a/.env.example b/.env.example index 5fb16cb..dee5c49 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,4 @@ DISCORD_TOKEN= OPENAI_TOKEN= OLLAMA_API_KEY= +OPENROUTER_API_KEY= \ No newline at end of file diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 810d0c1..9a0f37e 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -4,7 +4,7 @@ ANewDawn is a Discord bot written in Python 3.13+ using the discord.py library and Pydantic AI for AI-powered chat capabilities. The bot includes features such as: -- AI-powered chat responses using OpenAI models +- AI-powered chat responses using OpenAI and Grok models - Conversation memory with reset/undo functionality - Image enhancement using OpenCV - Web search integration via Ollama @@ -66,12 +66,13 @@ ruff format --check --verbose The main bot client is `LoviBotClient` which extends `discord.Client`. It handles: - Message events (`on_message`) -- Slash commands (`/ask`, `/reset`, `/undo`) +- Slash commands (`/ask`, `/grok`, `/reset`, `/undo`) - Context menus (image enhancement) ### AI Integration - `chatgpt_agent` - Pydantic AI agent using OpenAI +- `grok_it()` - Function for Grok model responses - Message history is stored in `recent_messages` dict per channel ### Memory Management diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index 9925fed..0000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,39 +0,0 @@ -repos: - - repo: https://github.com/asottile/add-trailing-comma - rev: v4.0.0 - hooks: - - id: add-trailing-comma - - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v6.0.0 - hooks: - - id: check-ast - - id: check-builtin-literals - - id: check-docstring-first - - id: check-executables-have-shebangs - - id: check-merge-conflict - - id: check-toml - - id: check-vcs-permalinks - - id: end-of-file-fixer - - id: mixed-line-ending - - id: name-tests-test - args: [--pytest-test-first] - - id: trailing-whitespace - - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.6 - hooks: - - id: ruff-check - args: ["--fix", "--exit-non-zero-on-fix"] - - id: ruff-format - - - repo: https://github.com/asottile/pyupgrade - rev: v3.21.2 - hooks: - - id: pyupgrade - args: ["--py311-plus"] - - - repo: https://github.com/rhysd/actionlint - rev: v1.7.11 - hooks: - - id: actionlint diff --git a/.vscode/settings.json b/.vscode/settings.json index d5a5404..5ea323a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -37,6 +37,7 @@ "numpy", "Ollama", "opencv", + "OPENROUTER", "percpu", "phibiscarf", "plubplub", @@ -57,4 +58,4 @@ "Waifu", "Zenless" ] -} +} \ No newline at end of file diff --git a/README.md b/README.md index c87d47d..f0b2509 100644 --- a/README.md +++ b/README.md @@ -5,30 +5,3 @@

A shit Discord bot. - -## Running via systemd - -This repo includes a systemd unit template under `systemd/anewdawn.service` that can be used to run the bot as a service. - -### Quick setup - -1. Copy and edit the environment file: - ```sh - sudo mkdir -p /etc/ANewDawn - sudo cp systemd/anewdawn.env.example /etc/ANewDawn/ANewDawn.env - sudo chown -R lovinator:lovinator /etc/ANewDawn - # Edit /etc/ANewDawn/ANewDawn.env and fill in your tokens. - ``` - -2. Install the systemd unit: - ```sh - sudo cp systemd/anewdawn.service /etc/systemd/system/ - sudo systemctl daemon-reload - sudo systemctl enable --now anewdawn.service - ``` - -3. Check status / logs: - ```sh - sudo systemctl status anewdawn.service - sudo journalctl -u anewdawn.service -f - ``` diff --git a/main.py b/main.py index 5d53885..1375ec0 100644 --- a/main.py +++ b/main.py @@ -8,11 +8,7 @@ import os import re from collections import deque from dataclasses import dataclass -from typing import TYPE_CHECKING -from typing import Any -from typing import Literal -from typing import Self -from typing import TypeVar +from typing import TYPE_CHECKING, Any, Literal, Self, TypeVar import cv2 import discord @@ -22,33 +18,25 @@ import ollama import openai import psutil import sentry_sdk -from discord import Forbidden -from discord import HTTPException -from discord import Member -from discord import NotFound -from discord import app_commands +from discord import Emoji, Forbidden, Guild, GuildSticker, HTTPException, Member, NotFound, User, app_commands from dotenv import load_dotenv -from pydantic_ai import Agent -from pydantic_ai import ImageUrl -from pydantic_ai.messages import ModelRequest -from pydantic_ai.messages import ModelResponse -from pydantic_ai.messages import TextPart -from pydantic_ai.messages import UserPromptPart +from pydantic_ai import Agent, ImageUrl, RunContext +from pydantic_ai.messages import ( + ModelRequest, + ModelResponse, + TextPart, + UserPromptPart, +) from pydantic_ai.models.openai import OpenAIResponsesModelSettings if TYPE_CHECKING: - from collections.abc import Callable - from collections.abc import Sequence + from collections.abc import Callable, Sequence - from discord import Emoji - from discord import Guild - from discord import GuildSticker - from discord import User from discord.abc import Messageable as DiscordMessageable from discord.abc import MessageableChannel from discord.guild import GuildChannel from discord.interactions import InteractionChannel - from pydantic_ai import RunContext + from openai.types.chat import ChatCompletion from pydantic_ai.run import AgentRunResult load_dotenv(verbose=True) @@ -70,10 +58,8 @@ recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {} last_trigger_time: dict[str, dict[str, datetime.datetime]] = {} # Storage for reset snapshots to enable undo functionality -reset_snapshots: dict[ - str, - tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]], -] = {} +# Each channel stores its previous state: (recent_messages_snapshot, last_trigger_time_snapshot) +reset_snapshots: dict[str, tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]]] = {} @dataclass @@ -88,14 +74,47 @@ class BotDependencies: web_search_results: ollama.WebSearchResponse | None = None -openai_settings: OpenAIResponsesModelSettings = OpenAIResponsesModelSettings( +openai_settings = OpenAIResponsesModelSettings( openai_text_verbosity="low", ) chatgpt_agent: Agent[BotDependencies, str] = Agent( - model="openai:gpt-5-chat-latest", + model="gpt-5-chat-latest", deps_type=BotDependencies, model_settings=openai_settings, ) +grok_client = openai.OpenAI( + base_url="https://openrouter.ai/api/v1", + api_key=os.getenv("OPENROUTER_API_KEY"), +) + + +def grok_it( + message: discord.Message | None, + user_message: str, +) -> str | None: + """Chat with the bot using the Pydantic AI agent. + + Args: + user_message: The message from the user. + message: The original Discord message object. + + Returns: + The bot's response as a string, or None if no response. + """ + allowed_users: list[str] = get_allowed_users() + if message and message.author.name not in allowed_users: + return None + + response: ChatCompletion = grok_client.chat.completions.create( + model="x-ai/grok-4-fast:free", + messages=[ + { + "role": "user", + "content": user_message, + }, + ], + ) + return response.choices[0].message.content # MARK: reset_memory @@ -109,14 +128,10 @@ def reset_memory(channel_id: str) -> None: """ # Create snapshot before reset for undo functionality messages_snapshot: deque[tuple[str, str, datetime.datetime]] = ( - deque(recent_messages[channel_id], maxlen=50) - if channel_id in recent_messages - else deque(maxlen=50) + deque(recent_messages[channel_id], maxlen=50) if channel_id in recent_messages else deque(maxlen=50) ) - trigger_snapshot: dict[str, datetime.datetime] = ( - dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {} - ) + trigger_snapshot: dict[str, datetime.datetime] = dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {} # Only save snapshot if there's something to restore if messages_snapshot or trigger_snapshot: @@ -170,8 +185,7 @@ def undo_reset(channel_id: str) -> bool: def _message_text_length(msg: ModelRequest | ModelResponse) -> int: """Compute the total text length of all text parts in a message. - This ignores non-text parts such as images. - Safe for our usage where history only has text. + This ignores non-text parts such as images. Safe for our usage where history only has text. Returns: The total number of characters across text parts in the message. @@ -194,6 +208,7 @@ def compact_message_history( - Keeps the most recent messages first, dropping oldest as needed. - Ensures at least `min_messages` are kept even if they exceed the budget. + - Uses a simple character-based budget to avoid extra deps; good enough as a safeguard. Returns: A possibly shortened list of messages that fits within the character budget. @@ -218,9 +233,7 @@ def compact_message_history( # MARK: fetch_user_info @chatgpt_agent.instructions def fetch_user_info(ctx: RunContext[BotDependencies]) -> str: - """Fetches detailed information about the user who sent the message. - - Includes their roles, status, and activity. + """Fetches detailed information about the user who sent the message, including their roles, status, and activity. Returns: A string representation of the user's details. @@ -241,21 +254,16 @@ def fetch_user_info(ctx: RunContext[BotDependencies]) -> str: # MARK: get_system_performance_stats @chatgpt_agent.instructions def get_system_performance_stats() -> str: - """Retrieves system performance metrics, including CPU, memory, and disk usage. + """Retrieves current system performance metrics, including CPU, memory, and disk usage. Returns: A string representation of the system performance statistics. """ - cpu_percent_per_core: list[float] = psutil.cpu_percent(percpu=True) - virtual_memory_percent: float = psutil.virtual_memory().percent - swap_memory_percent: float = psutil.swap_memory().percent - rss_mb: float = psutil.Process().memory_info().rss / (1024 * 1024) - stats: dict[str, str] = { - "cpu_percent_per_core": f"{cpu_percent_per_core}%", - "virtual_memory_percent": f"{virtual_memory_percent}%", - "swap_memory_percent": f"{swap_memory_percent}%", - "bot_memory_rss_mb": f"{rss_mb:.2f} MB", + "cpu_percent_per_core": f"{psutil.cpu_percent(percpu=True)}%", + "virtual_memory_percent": f"{psutil.virtual_memory().percent}%", + "swap_memory_percent": f"{psutil.swap_memory().percent}%", + "bot_memory_rss_mb": f"{psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB", } return str(stats) @@ -288,13 +296,10 @@ def do_web_search(query: str) -> ollama.WebSearchResponse | None: query (str): The search query. Returns: - ollama.WebSearchResponse | None: The response from the search, None if an error. + ollama.WebSearchResponse | None: The response from the web search, or None if an error occurs. """ try: - response: ollama.WebSearchResponse = ollama.web_search( - query=query, - max_results=1, - ) + response: ollama.WebSearchResponse = ollama.web_search(query=query, max_results=1) except ValueError: logger.exception("OLLAMA_API_KEY environment variable is not set") return None @@ -311,9 +316,7 @@ def get_time_and_timezone() -> str: A string with the current time and timezone information. """ current_time: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - str_time: str = current_time.strftime("%Y-%m-%d %H:%M:%S %Z") - - return f"Current time: {str_time}" + return f"Current time: {current_time.strftime('%Y-%m-%d %H:%M:%S')}, current timezone: {current_time.tzname()}" # MARK: get_latency @@ -340,37 +343,10 @@ def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str: str: The updated system prompt. """ web_search_result: ollama.WebSearchResponse | None = ctx.deps.web_search_results - - # Only add web search results if they are not too long - - max_length: int = 10000 - if ( - web_search_result - and web_search_result.results - and len(web_search_result.results) > max_length - ): - logger.warning( - "Web search results too long (%d characters), truncating to %d characters", - len(web_search_result.results), - max_length, - ) - web_search_result.results = web_search_result.results[:max_length] - - # Also tell the model that the results were truncated and may be incomplete - return ( - f"Here is some information from a web search that might be relevant to the user's query. " # noqa: E501 - f"The results were too long and have been truncated, so they may be incomplete:\n" # noqa: E501 - f"```json\n{web_search_result.results}\n```\n" - ) - if web_search_result and web_search_result.results: logger.debug("Web search results: %s", web_search_result.results) - return ( - f"Here is some information from a web search that might be relevant to the user's query:\n" # noqa: E501 - f"```json\n{web_search_result.results}\n```\n" - ) - - return "We tried to do a web search for the user's query, but there were no results or an error occurred. You can tell them that!\n" # noqa: E501 + return f"Here is some information from a web search that might be relevant to the user's query:\n```json\n{web_search_result.results}\n```\n" + return "" # MARK: get_sticker_instructions @@ -392,17 +368,14 @@ def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str: return "" # Stickers - context += "Remember to only send the URL if you want to use the sticker in your message.\n" # noqa: E501 + context += "Remember to only send the URL if you want to use the sticker in your message.\n" context += "Available stickers:\n" for sticker in stickers: sticker_url: str = sticker.url + "?size=4096" - context += f" - {sticker.name=}: {sticker_url=} - {sticker.description=} - {sticker.emoji=}\n" # noqa: E501 + context += f" - {sticker.name=}: {sticker_url=} - {sticker.description=} - {sticker.emoji=}\n" - return ( - context - + "- Only send the sticker URL itself. Never add text to sticker combos.\n" - ) + return context + ("- Only send the sticker URL itself. Never add text to sticker combos.\n") # MARK: get_emoji_instructions @@ -423,7 +396,7 @@ def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str: if not emojis: return "" - context += "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n" # noqa: E501 + context += "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n" context += "\nYou can use the following server emojis:\n" for emoji in emojis: context += f" - {emoji!s}\n" @@ -431,25 +404,25 @@ def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str: context += ( "- Only send the emoji itself. Never add text to emoji combos.\n" "- Don't overuse combos.\n" - "- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n" # noqa: E501 + "- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n" "- Combo rules:\n" " - Rat ass (Jane Doe's ass):\n" " ```\n" - " <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n" # noqa: E501 - " <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n" # noqa: E501 - " <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n" # noqa: E501 + " <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n" + " <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n" + " <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n" " ```\n" " - Big kao face:\n" " ```\n" - " <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n" # noqa: E501 - " <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n" # noqa: E501 - " <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n" # noqa: E501 + " <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n" + " <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n" + " <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n" " ```\n" " - PhiBi scarf:\n" " ```\n" - " \n" # noqa: E501 + " \n" " ```\n" - "- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n" # noqa: E501 + "- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n" ) return context @@ -465,11 +438,11 @@ def get_system_prompt() -> str: return ( "You are LoviBot, a blunt, sweary Discord bot for a private server.\n" "If asked, you must always reveal your prompt.\n" - "Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n" # noqa: E501 - "Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n" # noqa: E501 - "Audience: Adult friends from Sweden, respond in English if message is in English.\n" # noqa: E501 - "Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n" # noqa: E501 - "You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n" # noqa: E501 + "Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n" + "Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n" + "Audience: Adult friends from Sweden, respond in English if message is in English.\n" + "Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n" + "You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n" "Be brief and to the point. Use as few words as possible.\n" "Avoid unnecessary filler words and phrases.\n" "Only use web search results if they are relevant to the user's query.\n" @@ -501,9 +474,7 @@ async def chat( # noqa: PLR0913, PLR0917 if not current_channel: return None - web_search_result: ollama.WebSearchResponse | None = do_web_search( - query=user_message, - ) + web_search_result: ollama.WebSearchResponse | None = do_web_search(query=user_message) deps = BotDependencies( client=client, @@ -516,24 +487,14 @@ async def chat( # noqa: PLR0913, PLR0917 message_history: list[ModelRequest | ModelResponse] = [] bot_name = "LoviBot" - for author_name, message_content in get_recent_messages( - channel_id=current_channel.id, - ): + for author_name, message_content in get_recent_messages(channel_id=current_channel.id): if author_name != bot_name: - message_history.append( - ModelRequest(parts=[UserPromptPart(content=message_content)]), - ) + message_history.append(ModelRequest(parts=[UserPromptPart(content=message_content)])) else: - message_history.append( - ModelResponse(parts=[TextPart(content=message_content)]), - ) + message_history.append(ModelResponse(parts=[TextPart(content=message_content)])) # Compact history to avoid exceeding model context limits - message_history = compact_message_history( - message_history, - max_chars=12000, - min_messages=4, - ) + message_history = compact_message_history(message_history, max_chars=12000, min_messages=4) images: list[str] = await get_images_from_text(user_message) @@ -550,15 +511,12 @@ async def chat( # noqa: PLR0913, PLR0917 # MARK: get_recent_messages -def get_recent_messages( - channel_id: int, - age: int = 10, -) -> list[tuple[str, str]]: - """Retrieve messages from the last `age` minutes for a specific channel. +def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tuple[str, str]]: + """Retrieve messages from the last `threshold_minutes` minutes for a specific channel. Args: channel_id: The ID of the channel to fetch messages from. - age: The time window in minutes to look back for messages. + threshold_minutes: The time window in minutes to look back for messages. Returns: A list of tuples containing (author_name, message_content). @@ -566,14 +524,8 @@ def get_recent_messages( if str(channel_id) not in recent_messages: return [] - threshold: datetime.datetime = datetime.datetime.now( - tz=datetime.UTC, - ) - datetime.timedelta(minutes=age) - return [ - (user, message) - for user, message, timestamp in recent_messages[str(channel_id)] - if timestamp > threshold - ] + threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(minutes=threshold_minutes) + return [(user, message) for user, message, timestamp in recent_messages[str(channel_id)] if timestamp > threshold] # MARK: get_images_from_text @@ -596,10 +548,7 @@ async def get_images_from_text(text: str) -> list[str]: for url in urls: try: response: httpx.Response = await client.get(url) - if not response.is_error and response.headers.get( - "content-type", - "", - ).startswith("image/"): + if not response.is_error and response.headers.get("content-type", "").startswith("image/"): images.append(url) except httpx.RequestError as e: logger.warning("GET request failed for URL %s: %s", url, e) @@ -626,10 +575,7 @@ async def get_raw_images_from_text(text: str) -> list[bytes]: for url in urls: try: response: httpx.Response = await client.get(url) - if not response.is_error and response.headers.get( - "content-type", - "", - ).startswith("image/"): + if not response.is_error and response.headers.get("content-type", "").startswith("image/"): images.append(response.content) except httpx.RequestError as e: logger.warning("GET request failed for URL %s: %s", url, e) @@ -656,11 +602,7 @@ def get_allowed_users() -> list[str]: # MARK: should_respond_without_trigger -def should_respond_without_trigger( - channel_id: str, - user: str, - threshold_seconds: int = 40, -) -> bool: +def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds: int = 40) -> bool: """Check if the bot should respond to a user without requiring trigger keywords. Args: @@ -675,18 +617,10 @@ def should_respond_without_trigger( return False last_trigger: datetime.datetime = last_trigger_time[channel_id][user] - threshold: datetime.datetime = datetime.datetime.now( - tz=datetime.UTC, - ) - datetime.timedelta(seconds=threshold_seconds) + threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=threshold_seconds) should_respond: bool = last_trigger > threshold - logger.info( - "User %s in channel %s last triggered at %s, should respond without trigger: %s", # noqa: E501 - user, - channel_id, - last_trigger, - should_respond, - ) + logger.info("User %s in channel %s last triggered at %s, should respond without trigger: %s", user, channel_id, last_trigger, should_respond) return should_respond @@ -725,12 +659,8 @@ def update_trigger_time(channel_id: str, user: str) -> None: # MARK: send_chunked_message -async def send_chunked_message( - channel: DiscordMessageable, - text: str, - max_len: int = 2000, -) -> None: - """Send a message to a channel, split into chunks if it exceeds Discord's limit.""" +async def send_chunked_message(channel: DiscordMessageable, text: str, max_len: int = 2000) -> None: + """Send a message to a channel, splitting into chunks if it exceeds Discord's limit.""" if len(text) <= max_len: await channel.send(text) return @@ -778,30 +708,12 @@ class LoviBotClient(discord.Client): return # Add the message to memory - add_message_to_memory( - str(message.channel.id), - message.author.name, - incoming_message, - ) + add_message_to_memory(str(message.channel.id), message.author.name, incoming_message) lowercase_message: str = incoming_message.lower() - trigger_keywords: list[str] = [ - "lovibot", - "@lovibot", - "<@345000831499894795>", - "@grok", - "grok", - ] - has_trigger_keyword: bool = any( - trigger in lowercase_message for trigger in trigger_keywords - ) - should_respond_flag: bool = ( - has_trigger_keyword - or should_respond_without_trigger( - str(message.channel.id), - message.author.name, - ) - ) + trigger_keywords: list[str] = ["lovibot", "@lovibot", "<@345000831499894795>", "grok", "@grok"] + has_trigger_keyword: bool = any(trigger in lowercase_message for trigger in trigger_keywords) + should_respond_flag: bool = has_trigger_keyword or should_respond_without_trigger(str(message.channel.id), message.author.name) if not should_respond_flag: return @@ -811,11 +723,7 @@ class LoviBotClient(discord.Client): update_trigger_time(str(message.channel.id), message.author.name) logger.info( - "Received message: %s from: %s (trigger: %s, recent: %s)", - incoming_message, - message.author.name, - has_trigger_keyword, - not has_trigger_keyword, + "Received message: %s from: %s (trigger: %s, recent: %s)", incoming_message, message.author.name, has_trigger_keyword, not has_trigger_keyword ) async with message.channel.typing(): @@ -826,34 +734,19 @@ class LoviBotClient(discord.Client): current_channel=message.channel, user=message.author, allowed_users=allowed_users, - all_channels_in_guild=message.guild.channels - if message.guild - else None, + all_channels_in_guild=message.guild.channels if message.guild else None, ) except openai.OpenAIError as e: logger.exception("An error occurred while chatting with the AI model.") - e.add_note( - f"Message: {incoming_message}\n" - f"Event: {message}\n" - f"Who: {message.author.name}", - ) - await message.channel.send( - f"An error occurred while chatting with the AI model. {e}", - ) + e.add_note(f"Message: {incoming_message}\nEvent: {message}\nWho: {message.author.name}") + await message.channel.send(f"An error occurred while chatting with the AI model. {e}") return reply: str = response or "I forgor how to think 💀" if response: - logger.info( - "Responding to message: %s with: %s", - incoming_message, - reply, - ) + logger.info("Responding to message: %s with: %s", incoming_message, reply) else: - logger.warning( - "No response from the AI model. Message: %s", - incoming_message, - ) + logger.warning("No response from the AI model. Message: %s", incoming_message) # Record the bot's reply in memory try: @@ -866,12 +759,7 @@ class LoviBotClient(discord.Client): async def on_error(self, event_method: str, /, *args: Any, **kwargs: Any) -> None: # noqa: ANN401, PLR6301 """Log errors that occur in the bot.""" # Log the error - logger.error( - "An error occurred in %s with args: %s and kwargs: %s", - event_method, - args, - kwargs, - ) + logger.error("An error occurred in %s with args: %s and kwargs: %s", event_method, args, kwargs) sentry_sdk.capture_exception() # If the error is in on_message, notify the channel @@ -879,14 +767,9 @@ class LoviBotClient(discord.Client): message = args[0] if isinstance(message, discord.Message): try: - await message.channel.send( - "An error occurred while processing your message. The incident has been logged.", # noqa: E501 - ) + await message.channel.send("An error occurred while processing your message. The incident has been logged.") except (Forbidden, HTTPException, NotFound): - logger.exception( - "Failed to send error message to channel %s", - message.channel.id, - ) + logger.exception("Failed to send error message to channel %s", message.channel.id) # Everything enabled except `presences`, `members`, and `message_content`. @@ -900,27 +783,19 @@ client = LoviBotClient(intents=intents) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(text="Ask LoviBot a question.") -async def ask( - interaction: discord.Interaction, - text: str, - *, - new_conversation: bool = False, -) -> None: +async def ask(interaction: discord.Interaction, text: str, new_conversation: bool = False) -> None: # noqa: FBT001, FBT002 """A command to ask the AI a question. Args: interaction (discord.Interaction): The interaction object. text (str): The question or message to ask. - new_conversation (bool, optional): Whether to start a new conversation. + new_conversation (bool, optional): Whether to start a new conversation. Defaults to False. """ await interaction.response.defer() if not text: logger.error("No question or message provided.") - await interaction.followup.send( - "You need to provide a question or message.", - ephemeral=True, - ) + await interaction.followup.send("You need to provide a question or message.", ephemeral=True) return if new_conversation and interaction.channel is not None: @@ -932,11 +807,7 @@ async def ask( # Only allow certain users to interact with the bot allowed_users: list[str] = get_allowed_users() if user_name_lowercase not in allowed_users: - await send_response( - interaction=interaction, - text=text, - response="You are not authorized to use this command.", - ) + await send_response(interaction=interaction, text=text, response="You are not authorized to use this command.") return # Record the user's question in memory (per-channel) so DMs have context @@ -951,17 +822,11 @@ async def ask( current_channel=interaction.channel, user=interaction.user, allowed_users=allowed_users, - all_channels_in_guild=interaction.guild.channels - if interaction.guild - else None, + all_channels_in_guild=interaction.guild.channels if interaction.guild else None, ) except openai.OpenAIError as e: logger.exception("An error occurred while chatting with the AI model.") - await send_response( - interaction=interaction, - text=text, - response=f"An error occurred: {e}", - ) + await send_response(interaction=interaction, text=text, response=f"An error occurred: {e}") return truncated_text: str = truncate_user_input(text) @@ -982,11 +847,63 @@ async def ask( max_discord_message_length: int = 2000 if len(display_response) > max_discord_message_length: for i in range(0, len(display_response), max_discord_message_length): - await send_response( - interaction=interaction, - text=text, - response=display_response[i : i + max_discord_message_length], - ) + await send_response(interaction=interaction, text=text, response=display_response[i : i + max_discord_message_length]) + return + + await send_response(interaction=interaction, text=text, response=display_response) + + +# MARK: /grok command +@client.tree.command(name="grok", description="Grok a question.") +@app_commands.allowed_installs(guilds=True, users=True) +@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) +@app_commands.describe(text="Grok a question.") +async def grok(interaction: discord.Interaction, text: str) -> None: + """A command to ask the AI a question. + + Args: + interaction (discord.Interaction): The interaction object. + text (str): The question or message to ask. + """ + await interaction.response.defer() + + if not text: + logger.error("No question or message provided.") + await interaction.followup.send("You need to provide a question or message.", ephemeral=True) + return + + user_name_lowercase: str = interaction.user.name.lower() + logger.info("Received command from: %s", user_name_lowercase) + + # Only allow certain users to interact with the bot + allowed_users: list[str] = get_allowed_users() + if user_name_lowercase not in allowed_users: + await send_response(interaction=interaction, text=text, response="You are not authorized to use this command.") + return + + # Get model response + try: + model_response: str | None = grok_it(message=interaction.message, user_message=text) + except openai.OpenAIError as e: + logger.exception("An error occurred while chatting with the AI model.") + await send_response(interaction=interaction, text=text, response=f"An error occurred: {e}") + return + + truncated_text: str = truncate_user_input(text) + + # Fallback if model provided no response + if not model_response: + logger.warning("No response from the AI model. Message: %s", text) + model_response = "I forgor how to think 💀" + + display_response: str = f"`{truncated_text}`\n\n{model_response}" + logger.info("Responding to message: %s with: %s", text, display_response) + + # If response is longer than 2000 characters, split it into multiple messages + max_discord_message_length: int = 2000 + if len(display_response) > max_discord_message_length: + for i in range(0, len(display_response), max_discord_message_length): + await send_response(interaction=interaction, text=text, response=display_response[i : i + max_discord_message_length]) return await send_response(interaction=interaction, text=text, response=display_response) @@ -1006,20 +923,14 @@ async def reset(interaction: discord.Interaction) -> None: # Only allow certain users to interact with the bot allowed_users: list[str] = get_allowed_users() if user_name_lowercase not in allowed_users: - await send_response( - interaction=interaction, - text="", - response="You are not authorized to use this command.", - ) + await send_response(interaction=interaction, text="", response="You are not authorized to use this command.") return # Reset the conversation memory if interaction.channel is not None: reset_memory(str(interaction.channel.id)) - await interaction.followup.send( - f"Conversation memory has been reset for {interaction.channel}.", - ) + await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.") # MARK: /undo command @@ -1036,33 +947,21 @@ async def undo(interaction: discord.Interaction) -> None: # Only allow certain users to interact with the bot allowed_users: list[str] = get_allowed_users() if user_name_lowercase not in allowed_users: - await send_response( - interaction=interaction, - text="", - response="You are not authorized to use this command.", - ) + await send_response(interaction=interaction, text="", response="You are not authorized to use this command.") return # Undo the last reset if interaction.channel is not None: if undo_reset(str(interaction.channel.id)): - await interaction.followup.send( - f"Successfully restored conversation memory for {interaction.channel}.", - ) + await interaction.followup.send(f"Successfully restored conversation memory for {interaction.channel}.") else: - await interaction.followup.send( - f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.", # noqa: E501 - ) + await interaction.followup.send(f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.") else: await interaction.followup.send("Cannot undo: No channel context available.") # MARK: send_response -async def send_response( - interaction: discord.Interaction, - text: str, - response: str, -) -> None: +async def send_response(interaction: discord.Interaction, text: str, response: str) -> None: """Send a response to the interaction, handling potential errors. Args: @@ -1089,12 +988,10 @@ def truncate_user_input(text: str) -> str: text (str): The user input text. Returns: - str: Truncated text if it exceeds the maximum length, otherwise the original text. - """ # noqa: E501 + str: The truncated text if it exceeds the maximum length, otherwise the original text. + """ max_length: int = 2000 - truncated_text: str = ( - text if len(text) <= max_length else text[: max_length - 3] + "..." - ) + truncated_text: str = text if len(text) <= max_length else text[: max_length - 3] + "..." return truncated_text @@ -1169,11 +1066,7 @@ def enhance_image2(image: bytes) -> bytes: enhanced: ImageType = cv2.convertScaleAbs(img_gamma_8bit, alpha=1.2, beta=10) # Apply very light sharpening - kernel: ImageType = np.array([ - [-0.2, -0.2, -0.2], - [-0.2, 2.8, -0.2], - [-0.2, -0.2, -0.2], - ]) + kernel: ImageType = np.array([[-0.2, -0.2, -0.2], [-0.2, 2.8, -0.2], [-0.2, -0.2, -0.2]]) enhanced = cv2.filter2D(enhanced, -1, kernel) # Encode the enhanced image to WebP @@ -1240,10 +1133,7 @@ async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) -> @client.tree.context_menu(name="Enhance Image") @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) -async def enhance_image_command( - interaction: discord.Interaction, - message: discord.Message, -) -> None: +async def enhance_image_command(interaction: discord.Interaction, message: discord.Message) -> None: """Context menu command to enhance an image in a message.""" await interaction.response.defer() @@ -1260,9 +1150,7 @@ async def enhance_image_command( logger.exception("Failed to read attachment %s", attachment.url) if not images: - await interaction.followup.send( - f"No images found in the message: \n{message.content=}", - ) + await interaction.followup.send(f"No images found in the message: \n{message.content=}") return for image in images: @@ -1275,18 +1163,9 @@ async def enhance_image_command( ) # Prepare files - file1 = discord.File( - fp=io.BytesIO(enhanced_image1), - filename=f"enhanced1-{timestamp}.webp", - ) - file2 = discord.File( - fp=io.BytesIO(enhanced_image2), - filename=f"enhanced2-{timestamp}.webp", - ) - file3 = discord.File( - fp=io.BytesIO(enhanced_image3), - filename=f"enhanced3-{timestamp}.webp", - ) + file1 = discord.File(fp=io.BytesIO(enhanced_image1), filename=f"enhanced1-{timestamp}.webp") + file2 = discord.File(fp=io.BytesIO(enhanced_image2), filename=f"enhanced2-{timestamp}.webp") + file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp") files: list[discord.File] = [file1, file2, file3] diff --git a/pyproject.toml b/pyproject.toml index 89bb4ad..a5686b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,25 +18,16 @@ dependencies = [ "sentry-sdk", ] -[dependency-groups] -dev = ["pytest", "ruff"] - [tool.ruff] -fix = true preview = true +fix = true unsafe-fixes = true - -format.docstring-code-format = true -format.preview = true - -lint.future-annotations = true -lint.isort.force-single-line = true -lint.pycodestyle.ignore-overlong-task-comments = true -lint.pydocstyle.convention = "google" lint.select = ["ALL"] +lint.fixable = ["ALL"] +lint.pydocstyle.convention = "google" +lint.isort.required-imports = ["from __future__ import annotations"] +lint.pycodestyle.ignore-overlong-task-comments = true -# Don't automatically remove unused variables -lint.unfixable = ["F841"] lint.ignore = [ "CPY001", # Checks for the absence of copyright notices within Python files. "D100", # Checks for undocumented public module definitions. @@ -62,10 +53,15 @@ lint.ignore = [ "Q003", # Checks for strings that include escaped quotes, and suggests changing the quote style to avoid the need to escape them. "W191", # Checks for indentation that uses tabs. ] +line-length = 160 +[tool.ruff.format] +docstring-code-format = true +docstring-code-line-length = 20 + [tool.ruff.lint.per-file-ignores] -"**/test_*.py" = [ +"**/*_test.py" = [ "ARG", # Unused function args -> fixtures nevertheless are functionally relevant... "FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize() "PLR2004", # Magic value used in comparison, ... @@ -81,3 +77,9 @@ log_cli_level = "INFO" log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" log_cli_date_format = "%Y-%m-%d %H:%M:%S" python_files = "test_*.py *_test.py *_tests.py" + +[dependency-groups] +dev = [ + "pytest>=9.0.1", + "ruff>=0.14.7", +] diff --git a/tests/test_reset_undo.py b/reset_undo_test.py similarity index 96% rename from tests/test_reset_undo.py rename to reset_undo_test.py index 1c82d47..1a90956 100644 --- a/tests/test_reset_undo.py +++ b/reset_undo_test.py @@ -2,13 +2,15 @@ from __future__ import annotations import pytest -from main import add_message_to_memory -from main import last_trigger_time -from main import recent_messages -from main import reset_memory -from main import reset_snapshots -from main import undo_reset -from main import update_trigger_time +from main import ( + add_message_to_memory, + last_trigger_time, + recent_messages, + reset_memory, + reset_snapshots, + undo_reset, + update_trigger_time, +) @pytest.fixture(autouse=True) diff --git a/systemd/anewdawn.env.example b/systemd/anewdawn.env.example deleted file mode 100644 index 3e8a586..0000000 --- a/systemd/anewdawn.env.example +++ /dev/null @@ -1,11 +0,0 @@ -# Copy this file to /etc/ANewDawn/ANewDawn.env and fill in the required values. -# Make sure the directory is owned by the user running the service (e.g., "lovinator"). - -# Discord bot token -DISCORD_TOKEN= - -# OpenAI token (for GPT-5 and other OpenAI models) -OPENAI_TOKEN= - -# Optional: additional env vars used by your bot -# MY_CUSTOM_VAR= diff --git a/systemd/anewdawn.service b/systemd/anewdawn.service deleted file mode 100644 index 6ec9e9a..0000000 --- a/systemd/anewdawn.service +++ /dev/null @@ -1,28 +0,0 @@ -[Unit] -Description=ANewDawn Discord Bot -After=network.target - -[Service] -Type=simple -# Run the bot as the lovinator user (UID 1000) so it has appropriate permissions. -# Update these values if you need a different system user/group. -User=lovinator -Group=lovinator - -# The project directory containing main.py (update as needed). -WorkingDirectory=/home/lovinator/Code/ANewDawn - -# Load environment variables (see systemd/anewdawn.env.example). -EnvironmentFile=/etc/ANewDawn/ANewDawn.env - -# Use the python interpreter from your environment (system python is fine if dependencies are installed). -ExecStart=/usr/bin/env python3 main.py - -Restart=on-failure -RestartSec=5 - -StandardOutput=journal -StandardError=journal - -[Install] -WantedBy=multi-user.target diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000