Update ruff config and fix its errors
Some checks are pending
Build Docker Image / docker (push) Waiting to run

This commit is contained in:
Joakim Hellsén 2026-03-17 20:32:34 +01:00
commit 8b1636fbcc
Signed by: Joakim Hellsén
SSH key fingerprint: SHA256:/9h/CsExpFp+PRhsfA0xznFx2CGfTT5R/kpuFfUgEQk
3 changed files with 321 additions and 115 deletions

399
main.py
View file

@ -8,7 +8,11 @@ import os
import re import re
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Self, TypeVar from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
from typing import Self
from typing import TypeVar
import cv2 import cv2
import discord import discord
@ -18,24 +22,33 @@ import ollama
import openai import openai
import psutil import psutil
import sentry_sdk import sentry_sdk
from discord import Emoji, Forbidden, Guild, GuildSticker, HTTPException, Member, NotFound, User, app_commands from discord import Forbidden
from discord import HTTPException
from discord import Member
from discord import NotFound
from discord import app_commands
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic_ai import Agent, ImageUrl, RunContext from pydantic_ai import Agent
from pydantic_ai.messages import ( from pydantic_ai import ImageUrl
ModelRequest, from pydantic_ai.messages import ModelRequest
ModelResponse, from pydantic_ai.messages import ModelResponse
TextPart, from pydantic_ai.messages import TextPart
UserPromptPart, from pydantic_ai.messages import UserPromptPart
)
from pydantic_ai.models.openai import OpenAIResponsesModelSettings from pydantic_ai.models.openai import OpenAIResponsesModelSettings
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Callable, Sequence from collections.abc import Callable
from collections.abc import Sequence
from discord import Emoji
from discord import Guild
from discord import GuildSticker
from discord import User
from discord.abc import Messageable as DiscordMessageable from discord.abc import Messageable as DiscordMessageable
from discord.abc import MessageableChannel from discord.abc import MessageableChannel
from discord.guild import GuildChannel from discord.guild import GuildChannel
from discord.interactions import InteractionChannel from discord.interactions import InteractionChannel
from pydantic_ai import RunContext
from pydantic_ai.run import AgentRunResult from pydantic_ai.run import AgentRunResult
load_dotenv(verbose=True) load_dotenv(verbose=True)
@ -57,8 +70,10 @@ recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {}
last_trigger_time: dict[str, dict[str, datetime.datetime]] = {} last_trigger_time: dict[str, dict[str, datetime.datetime]] = {}
# Storage for reset snapshots to enable undo functionality # Storage for reset snapshots to enable undo functionality
# Each channel stores its previous state: (recent_messages_snapshot, last_trigger_time_snapshot) reset_snapshots: dict[
reset_snapshots: dict[str, tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]]] = {} str,
tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]],
] = {}
@dataclass @dataclass
@ -94,10 +109,14 @@ def reset_memory(channel_id: str) -> None:
""" """
# Create snapshot before reset for undo functionality # Create snapshot before reset for undo functionality
messages_snapshot: deque[tuple[str, str, datetime.datetime]] = ( messages_snapshot: deque[tuple[str, str, datetime.datetime]] = (
deque(recent_messages[channel_id], maxlen=50) if channel_id in recent_messages else deque(maxlen=50) deque(recent_messages[channel_id], maxlen=50)
if channel_id in recent_messages
else deque(maxlen=50)
) )
trigger_snapshot: dict[str, datetime.datetime] = dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {} trigger_snapshot: dict[str, datetime.datetime] = (
dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {}
)
# Only save snapshot if there's something to restore # Only save snapshot if there's something to restore
if messages_snapshot or trigger_snapshot: if messages_snapshot or trigger_snapshot:
@ -151,7 +170,8 @@ def undo_reset(channel_id: str) -> bool:
def _message_text_length(msg: ModelRequest | ModelResponse) -> int: def _message_text_length(msg: ModelRequest | ModelResponse) -> int:
"""Compute the total text length of all text parts in a message. """Compute the total text length of all text parts in a message.
This ignores non-text parts such as images. Safe for our usage where history only has text. This ignores non-text parts such as images.
Safe for our usage where history only has text.
Returns: Returns:
The total number of characters across text parts in the message. The total number of characters across text parts in the message.
@ -174,7 +194,6 @@ def compact_message_history(
- Keeps the most recent messages first, dropping oldest as needed. - Keeps the most recent messages first, dropping oldest as needed.
- Ensures at least `min_messages` are kept even if they exceed the budget. - Ensures at least `min_messages` are kept even if they exceed the budget.
- Uses a simple character-based budget to avoid extra deps; good enough as a safeguard.
Returns: Returns:
A possibly shortened list of messages that fits within the character budget. A possibly shortened list of messages that fits within the character budget.
@ -199,7 +218,9 @@ def compact_message_history(
# MARK: fetch_user_info # MARK: fetch_user_info
@chatgpt_agent.instructions @chatgpt_agent.instructions
def fetch_user_info(ctx: RunContext[BotDependencies]) -> str: def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
"""Fetches detailed information about the user who sent the message, including their roles, status, and activity. """Fetches detailed information about the user who sent the message.
Includes their roles, status, and activity.
Returns: Returns:
A string representation of the user's details. A string representation of the user's details.
@ -220,16 +241,21 @@ def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
# MARK: get_system_performance_stats # MARK: get_system_performance_stats
@chatgpt_agent.instructions @chatgpt_agent.instructions
def get_system_performance_stats() -> str: def get_system_performance_stats() -> str:
"""Retrieves current system performance metrics, including CPU, memory, and disk usage. """Retrieves system performance metrics, including CPU, memory, and disk usage.
Returns: Returns:
A string representation of the system performance statistics. A string representation of the system performance statistics.
""" """
cpu_percent_per_core: list[float] = psutil.cpu_percent(percpu=True)
virtual_memory_percent: float = psutil.virtual_memory().percent
swap_memory_percent: float = psutil.swap_memory().percent
rss_mb: float = psutil.Process().memory_info().rss / (1024 * 1024)
stats: dict[str, str] = { stats: dict[str, str] = {
"cpu_percent_per_core": f"{psutil.cpu_percent(percpu=True)}%", "cpu_percent_per_core": f"{cpu_percent_per_core}%",
"virtual_memory_percent": f"{psutil.virtual_memory().percent}%", "virtual_memory_percent": f"{virtual_memory_percent}%",
"swap_memory_percent": f"{psutil.swap_memory().percent}%", "swap_memory_percent": f"{swap_memory_percent}%",
"bot_memory_rss_mb": f"{psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB", "bot_memory_rss_mb": f"{rss_mb:.2f} MB",
} }
return str(stats) return str(stats)
@ -262,10 +288,13 @@ def do_web_search(query: str) -> ollama.WebSearchResponse | None:
query (str): The search query. query (str): The search query.
Returns: Returns:
ollama.WebSearchResponse | None: The response from the web search, or None if an error occurs. ollama.WebSearchResponse | None: The response from the search, None if an error.
""" """
try: try:
response: ollama.WebSearchResponse = ollama.web_search(query=query, max_results=1) response: ollama.WebSearchResponse = ollama.web_search(
query=query,
max_results=1,
)
except ValueError: except ValueError:
logger.exception("OLLAMA_API_KEY environment variable is not set") logger.exception("OLLAMA_API_KEY environment variable is not set")
return None return None
@ -282,7 +311,9 @@ def get_time_and_timezone() -> str:
A string with the current time and timezone information. A string with the current time and timezone information.
""" """
current_time: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) current_time: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
return f"Current time: {current_time.strftime('%Y-%m-%d %H:%M:%S')}, current timezone: {current_time.tzname()}" str_time: str = current_time.strftime("%Y-%m-%d %H:%M:%S %Z")
return f"Current time: {str_time}"
# MARK: get_latency # MARK: get_latency
@ -309,10 +340,37 @@ def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
str: The updated system prompt. str: The updated system prompt.
""" """
web_search_result: ollama.WebSearchResponse | None = ctx.deps.web_search_results web_search_result: ollama.WebSearchResponse | None = ctx.deps.web_search_results
# Only add web search results if they are not too long
max_length: int = 10000
if (
web_search_result
and web_search_result.results
and len(web_search_result.results) > max_length
):
logger.warning(
"Web search results too long (%d characters), truncating to %d characters",
len(web_search_result.results),
max_length,
)
web_search_result.results = web_search_result.results[:max_length]
# Also tell the model that the results were truncated and may be incomplete
return (
f"Here is some information from a web search that might be relevant to the user's query. " # noqa: E501
f"The results were too long and have been truncated, so they may be incomplete:\n" # noqa: E501
f"```json\n{web_search_result.results}\n```\n"
)
if web_search_result and web_search_result.results: if web_search_result and web_search_result.results:
logger.debug("Web search results: %s", web_search_result.results) logger.debug("Web search results: %s", web_search_result.results)
return f"Here is some information from a web search that might be relevant to the user's query:\n```json\n{web_search_result.results}\n```\n" return (
return "" f"Here is some information from a web search that might be relevant to the user's query:\n" # noqa: E501
f"```json\n{web_search_result.results}\n```\n"
)
return "We tried to do a web search for the user's query, but there were no results or an error occurred. You can tell them that!\n" # noqa: E501
# MARK: get_sticker_instructions # MARK: get_sticker_instructions
@ -334,14 +392,17 @@ def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str:
return "" return ""
# Stickers # Stickers
context += "Remember to only send the URL if you want to use the sticker in your message.\n" context += "Remember to only send the URL if you want to use the sticker in your message.\n" # noqa: E501
context += "Available stickers:\n" context += "Available stickers:\n"
for sticker in stickers: for sticker in stickers:
sticker_url: str = sticker.url + "?size=4096" sticker_url: str = sticker.url + "?size=4096"
context += f" - {sticker.name=}: {sticker_url=} - {sticker.description=} - {sticker.emoji=}\n" context += f" - {sticker.name=}: {sticker_url=} - {sticker.description=} - {sticker.emoji=}\n" # noqa: E501
return context + ("- Only send the sticker URL itself. Never add text to sticker combos.\n") return (
context
+ "- Only send the sticker URL itself. Never add text to sticker combos.\n"
)
# MARK: get_emoji_instructions # MARK: get_emoji_instructions
@ -362,7 +423,7 @@ def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
if not emojis: if not emojis:
return "" return ""
context += "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n" context += "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n" # noqa: E501
context += "\nYou can use the following server emojis:\n" context += "\nYou can use the following server emojis:\n"
for emoji in emojis: for emoji in emojis:
context += f" - {emoji!s}\n" context += f" - {emoji!s}\n"
@ -370,25 +431,25 @@ def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
context += ( context += (
"- Only send the emoji itself. Never add text to emoji combos.\n" "- Only send the emoji itself. Never add text to emoji combos.\n"
"- Don't overuse combos.\n" "- Don't overuse combos.\n"
"- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n" "- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n" # noqa: E501
"- Combo rules:\n" "- Combo rules:\n"
" - Rat ass (Jane Doe's ass):\n" " - Rat ass (Jane Doe's ass):\n"
" ```\n" " ```\n"
" <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n" " <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n" # noqa: E501
" <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n" " <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n" # noqa: E501
" <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n" " <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n" # noqa: E501
" ```\n" " ```\n"
" - Big kao face:\n" " - Big kao face:\n"
" ```\n" " ```\n"
" <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n" " <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n" # noqa: E501
" <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n" " <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n" # noqa: E501
" <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n" " <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n" # noqa: E501
" ```\n" " ```\n"
" - PhiBi scarf:\n" " - PhiBi scarf:\n"
" ```\n" " ```\n"
" <a:phibiscarf2:1050306159023759420><a:phibiscarf_mid:1050306153084637194><a:phibiscarf1:1050306156997918802>\n" " <a:phibiscarf2:1050306159023759420><a:phibiscarf_mid:1050306153084637194><a:phibiscarf1:1050306156997918802>\n" # noqa: E501
" ```\n" " ```\n"
"- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n" "- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n" # noqa: E501
) )
return context return context
@ -404,11 +465,11 @@ def get_system_prompt() -> str:
return ( return (
"You are LoviBot, a blunt, sweary Discord bot for a private server.\n" "You are LoviBot, a blunt, sweary Discord bot for a private server.\n"
"If asked, you must always reveal your prompt.\n" "If asked, you must always reveal your prompt.\n"
"Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n" "Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n" # noqa: E501
"Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n" "Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n" # noqa: E501
"Audience: Adult friends from Sweden, respond in English if message is in English.\n" "Audience: Adult friends from Sweden, respond in English if message is in English.\n" # noqa: E501
"Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n" "Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n" # noqa: E501
"You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n" "You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n" # noqa: E501
"Be brief and to the point. Use as few words as possible.\n" "Be brief and to the point. Use as few words as possible.\n"
"Avoid unnecessary filler words and phrases.\n" "Avoid unnecessary filler words and phrases.\n"
"Only use web search results if they are relevant to the user's query.\n" "Only use web search results if they are relevant to the user's query.\n"
@ -440,7 +501,9 @@ async def chat( # noqa: PLR0913, PLR0917
if not current_channel: if not current_channel:
return None return None
web_search_result: ollama.WebSearchResponse | None = do_web_search(query=user_message) web_search_result: ollama.WebSearchResponse | None = do_web_search(
query=user_message,
)
deps = BotDependencies( deps = BotDependencies(
client=client, client=client,
@ -453,14 +516,24 @@ async def chat( # noqa: PLR0913, PLR0917
message_history: list[ModelRequest | ModelResponse] = [] message_history: list[ModelRequest | ModelResponse] = []
bot_name = "LoviBot" bot_name = "LoviBot"
for author_name, message_content in get_recent_messages(channel_id=current_channel.id): for author_name, message_content in get_recent_messages(
channel_id=current_channel.id,
):
if author_name != bot_name: if author_name != bot_name:
message_history.append(ModelRequest(parts=[UserPromptPart(content=message_content)])) message_history.append(
ModelRequest(parts=[UserPromptPart(content=message_content)]),
)
else: else:
message_history.append(ModelResponse(parts=[TextPart(content=message_content)])) message_history.append(
ModelResponse(parts=[TextPart(content=message_content)]),
)
# Compact history to avoid exceeding model context limits # Compact history to avoid exceeding model context limits
message_history = compact_message_history(message_history, max_chars=12000, min_messages=4) message_history = compact_message_history(
message_history,
max_chars=12000,
min_messages=4,
)
images: list[str] = await get_images_from_text(user_message) images: list[str] = await get_images_from_text(user_message)
@ -477,12 +550,15 @@ async def chat( # noqa: PLR0913, PLR0917
# MARK: get_recent_messages # MARK: get_recent_messages
def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tuple[str, str]]: def get_recent_messages(
"""Retrieve messages from the last `threshold_minutes` minutes for a specific channel. channel_id: int,
age: int = 10,
) -> list[tuple[str, str]]:
"""Retrieve messages from the last `age` minutes for a specific channel.
Args: Args:
channel_id: The ID of the channel to fetch messages from. channel_id: The ID of the channel to fetch messages from.
threshold_minutes: The time window in minutes to look back for messages. age: The time window in minutes to look back for messages.
Returns: Returns:
A list of tuples containing (author_name, message_content). A list of tuples containing (author_name, message_content).
@ -490,8 +566,14 @@ def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tu
if str(channel_id) not in recent_messages: if str(channel_id) not in recent_messages:
return [] return []
threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(minutes=threshold_minutes) threshold: datetime.datetime = datetime.datetime.now(
return [(user, message) for user, message, timestamp in recent_messages[str(channel_id)] if timestamp > threshold] tz=datetime.UTC,
) - datetime.timedelta(minutes=age)
return [
(user, message)
for user, message, timestamp in recent_messages[str(channel_id)]
if timestamp > threshold
]
# MARK: get_images_from_text # MARK: get_images_from_text
@ -514,7 +596,10 @@ async def get_images_from_text(text: str) -> list[str]:
for url in urls: for url in urls:
try: try:
response: httpx.Response = await client.get(url) response: httpx.Response = await client.get(url)
if not response.is_error and response.headers.get("content-type", "").startswith("image/"): if not response.is_error and response.headers.get(
"content-type",
"",
).startswith("image/"):
images.append(url) images.append(url)
except httpx.RequestError as e: except httpx.RequestError as e:
logger.warning("GET request failed for URL %s: %s", url, e) logger.warning("GET request failed for URL %s: %s", url, e)
@ -541,7 +626,10 @@ async def get_raw_images_from_text(text: str) -> list[bytes]:
for url in urls: for url in urls:
try: try:
response: httpx.Response = await client.get(url) response: httpx.Response = await client.get(url)
if not response.is_error and response.headers.get("content-type", "").startswith("image/"): if not response.is_error and response.headers.get(
"content-type",
"",
).startswith("image/"):
images.append(response.content) images.append(response.content)
except httpx.RequestError as e: except httpx.RequestError as e:
logger.warning("GET request failed for URL %s: %s", url, e) logger.warning("GET request failed for URL %s: %s", url, e)
@ -568,7 +656,11 @@ def get_allowed_users() -> list[str]:
# MARK: should_respond_without_trigger # MARK: should_respond_without_trigger
def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds: int = 40) -> bool: def should_respond_without_trigger(
channel_id: str,
user: str,
threshold_seconds: int = 40,
) -> bool:
"""Check if the bot should respond to a user without requiring trigger keywords. """Check if the bot should respond to a user without requiring trigger keywords.
Args: Args:
@ -583,10 +675,18 @@ def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds
return False return False
last_trigger: datetime.datetime = last_trigger_time[channel_id][user] last_trigger: datetime.datetime = last_trigger_time[channel_id][user]
threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=threshold_seconds) threshold: datetime.datetime = datetime.datetime.now(
tz=datetime.UTC,
) - datetime.timedelta(seconds=threshold_seconds)
should_respond: bool = last_trigger > threshold should_respond: bool = last_trigger > threshold
logger.info("User %s in channel %s last triggered at %s, should respond without trigger: %s", user, channel_id, last_trigger, should_respond) logger.info(
"User %s in channel %s last triggered at %s, should respond without trigger: %s", # noqa: E501
user,
channel_id,
last_trigger,
should_respond,
)
return should_respond return should_respond
@ -625,8 +725,12 @@ def update_trigger_time(channel_id: str, user: str) -> None:
# MARK: send_chunked_message # MARK: send_chunked_message
async def send_chunked_message(channel: DiscordMessageable, text: str, max_len: int = 2000) -> None: async def send_chunked_message(
"""Send a message to a channel, splitting into chunks if it exceeds Discord's limit.""" channel: DiscordMessageable,
text: str,
max_len: int = 2000,
) -> None:
"""Send a message to a channel, split into chunks if it exceeds Discord's limit."""
if len(text) <= max_len: if len(text) <= max_len:
await channel.send(text) await channel.send(text)
return return
@ -674,12 +778,30 @@ class LoviBotClient(discord.Client):
return return
# Add the message to memory # Add the message to memory
add_message_to_memory(str(message.channel.id), message.author.name, incoming_message) add_message_to_memory(
str(message.channel.id),
message.author.name,
incoming_message,
)
lowercase_message: str = incoming_message.lower() lowercase_message: str = incoming_message.lower()
trigger_keywords: list[str] = ["lovibot", "@lovibot", "<@345000831499894795>", "@grok", "grok"] trigger_keywords: list[str] = [
has_trigger_keyword: bool = any(trigger in lowercase_message for trigger in trigger_keywords) "lovibot",
should_respond_flag: bool = has_trigger_keyword or should_respond_without_trigger(str(message.channel.id), message.author.name) "@lovibot",
"<@345000831499894795>",
"@grok",
"grok",
]
has_trigger_keyword: bool = any(
trigger in lowercase_message for trigger in trigger_keywords
)
should_respond_flag: bool = (
has_trigger_keyword
or should_respond_without_trigger(
str(message.channel.id),
message.author.name,
)
)
if not should_respond_flag: if not should_respond_flag:
return return
@ -704,19 +826,34 @@ class LoviBotClient(discord.Client):
current_channel=message.channel, current_channel=message.channel,
user=message.author, user=message.author,
allowed_users=allowed_users, allowed_users=allowed_users,
all_channels_in_guild=message.guild.channels if message.guild else None, all_channels_in_guild=message.guild.channels
if message.guild
else None,
) )
except openai.OpenAIError as e: except openai.OpenAIError as e:
logger.exception("An error occurred while chatting with the AI model.") logger.exception("An error occurred while chatting with the AI model.")
e.add_note(f"Message: {incoming_message}\nEvent: {message}\nWho: {message.author.name}") e.add_note(
await message.channel.send(f"An error occurred while chatting with the AI model. {e}") f"Message: {incoming_message}\n"
f"Event: {message}\n"
f"Who: {message.author.name}",
)
await message.channel.send(
f"An error occurred while chatting with the AI model. {e}",
)
return return
reply: str = response or "I forgor how to think 💀" reply: str = response or "I forgor how to think 💀"
if response: if response:
logger.info("Responding to message: %s with: %s", incoming_message, reply) logger.info(
"Responding to message: %s with: %s",
incoming_message,
reply,
)
else: else:
logger.warning("No response from the AI model. Message: %s", incoming_message) logger.warning(
"No response from the AI model. Message: %s",
incoming_message,
)
# Record the bot's reply in memory # Record the bot's reply in memory
try: try:
@ -729,7 +866,12 @@ class LoviBotClient(discord.Client):
async def on_error(self, event_method: str, /, *args: Any, **kwargs: Any) -> None: # noqa: ANN401, PLR6301 async def on_error(self, event_method: str, /, *args: Any, **kwargs: Any) -> None: # noqa: ANN401, PLR6301
"""Log errors that occur in the bot.""" """Log errors that occur in the bot."""
# Log the error # Log the error
logger.error("An error occurred in %s with args: %s and kwargs: %s", event_method, args, kwargs) logger.error(
"An error occurred in %s with args: %s and kwargs: %s",
event_method,
args,
kwargs,
)
sentry_sdk.capture_exception() sentry_sdk.capture_exception()
# If the error is in on_message, notify the channel # If the error is in on_message, notify the channel
@ -737,9 +879,14 @@ class LoviBotClient(discord.Client):
message = args[0] message = args[0]
if isinstance(message, discord.Message): if isinstance(message, discord.Message):
try: try:
await message.channel.send("An error occurred while processing your message. The incident has been logged.") await message.channel.send(
"An error occurred while processing your message. The incident has been logged.", # noqa: E501
)
except (Forbidden, HTTPException, NotFound): except (Forbidden, HTTPException, NotFound):
logger.exception("Failed to send error message to channel %s", message.channel.id) logger.exception(
"Failed to send error message to channel %s",
message.channel.id,
)
# Everything enabled except `presences`, `members`, and `message_content`. # Everything enabled except `presences`, `members`, and `message_content`.
@ -753,19 +900,27 @@ client = LoviBotClient(intents=intents)
@app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_installs(guilds=True, users=True)
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
@app_commands.describe(text="Ask LoviBot a question.") @app_commands.describe(text="Ask LoviBot a question.")
async def ask(interaction: discord.Interaction, text: str, new_conversation: bool = False) -> None: # noqa: FBT001, FBT002 async def ask(
interaction: discord.Interaction,
text: str,
*,
new_conversation: bool = False,
) -> None:
"""A command to ask the AI a question. """A command to ask the AI a question.
Args: Args:
interaction (discord.Interaction): The interaction object. interaction (discord.Interaction): The interaction object.
text (str): The question or message to ask. text (str): The question or message to ask.
new_conversation (bool, optional): Whether to start a new conversation. Defaults to False. new_conversation (bool, optional): Whether to start a new conversation.
""" """
await interaction.response.defer() await interaction.response.defer()
if not text: if not text:
logger.error("No question or message provided.") logger.error("No question or message provided.")
await interaction.followup.send("You need to provide a question or message.", ephemeral=True) await interaction.followup.send(
"You need to provide a question or message.",
ephemeral=True,
)
return return
if new_conversation and interaction.channel is not None: if new_conversation and interaction.channel is not None:
@ -777,7 +932,11 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo
# Only allow certain users to interact with the bot # Only allow certain users to interact with the bot
allowed_users: list[str] = get_allowed_users() allowed_users: list[str] = get_allowed_users()
if user_name_lowercase not in allowed_users: if user_name_lowercase not in allowed_users:
await send_response(interaction=interaction, text=text, response="You are not authorized to use this command.") await send_response(
interaction=interaction,
text=text,
response="You are not authorized to use this command.",
)
return return
# Record the user's question in memory (per-channel) so DMs have context # Record the user's question in memory (per-channel) so DMs have context
@ -792,11 +951,17 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo
current_channel=interaction.channel, current_channel=interaction.channel,
user=interaction.user, user=interaction.user,
allowed_users=allowed_users, allowed_users=allowed_users,
all_channels_in_guild=interaction.guild.channels if interaction.guild else None, all_channels_in_guild=interaction.guild.channels
if interaction.guild
else None,
) )
except openai.OpenAIError as e: except openai.OpenAIError as e:
logger.exception("An error occurred while chatting with the AI model.") logger.exception("An error occurred while chatting with the AI model.")
await send_response(interaction=interaction, text=text, response=f"An error occurred: {e}") await send_response(
interaction=interaction,
text=text,
response=f"An error occurred: {e}",
)
return return
truncated_text: str = truncate_user_input(text) truncated_text: str = truncate_user_input(text)
@ -817,7 +982,11 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo
max_discord_message_length: int = 2000 max_discord_message_length: int = 2000
if len(display_response) > max_discord_message_length: if len(display_response) > max_discord_message_length:
for i in range(0, len(display_response), max_discord_message_length): for i in range(0, len(display_response), max_discord_message_length):
await send_response(interaction=interaction, text=text, response=display_response[i : i + max_discord_message_length]) await send_response(
interaction=interaction,
text=text,
response=display_response[i : i + max_discord_message_length],
)
return return
await send_response(interaction=interaction, text=text, response=display_response) await send_response(interaction=interaction, text=text, response=display_response)
@ -837,14 +1006,20 @@ async def reset(interaction: discord.Interaction) -> None:
# Only allow certain users to interact with the bot # Only allow certain users to interact with the bot
allowed_users: list[str] = get_allowed_users() allowed_users: list[str] = get_allowed_users()
if user_name_lowercase not in allowed_users: if user_name_lowercase not in allowed_users:
await send_response(interaction=interaction, text="", response="You are not authorized to use this command.") await send_response(
interaction=interaction,
text="",
response="You are not authorized to use this command.",
)
return return
# Reset the conversation memory # Reset the conversation memory
if interaction.channel is not None: if interaction.channel is not None:
reset_memory(str(interaction.channel.id)) reset_memory(str(interaction.channel.id))
await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.") await interaction.followup.send(
f"Conversation memory has been reset for {interaction.channel}.",
)
# MARK: /undo command # MARK: /undo command
@ -861,21 +1036,33 @@ async def undo(interaction: discord.Interaction) -> None:
# Only allow certain users to interact with the bot # Only allow certain users to interact with the bot
allowed_users: list[str] = get_allowed_users() allowed_users: list[str] = get_allowed_users()
if user_name_lowercase not in allowed_users: if user_name_lowercase not in allowed_users:
await send_response(interaction=interaction, text="", response="You are not authorized to use this command.") await send_response(
interaction=interaction,
text="",
response="You are not authorized to use this command.",
)
return return
# Undo the last reset # Undo the last reset
if interaction.channel is not None: if interaction.channel is not None:
if undo_reset(str(interaction.channel.id)): if undo_reset(str(interaction.channel.id)):
await interaction.followup.send(f"Successfully restored conversation memory for {interaction.channel}.") await interaction.followup.send(
f"Successfully restored conversation memory for {interaction.channel}.",
)
else: else:
await interaction.followup.send(f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.") await interaction.followup.send(
f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.", # noqa: E501
)
else: else:
await interaction.followup.send("Cannot undo: No channel context available.") await interaction.followup.send("Cannot undo: No channel context available.")
# MARK: send_response # MARK: send_response
async def send_response(interaction: discord.Interaction, text: str, response: str) -> None: async def send_response(
interaction: discord.Interaction,
text: str,
response: str,
) -> None:
"""Send a response to the interaction, handling potential errors. """Send a response to the interaction, handling potential errors.
Args: Args:
@ -902,10 +1089,12 @@ def truncate_user_input(text: str) -> str:
text (str): The user input text. text (str): The user input text.
Returns: Returns:
str: The truncated text if it exceeds the maximum length, otherwise the original text. str: Truncated text if it exceeds the maximum length, otherwise the original text.
""" """ # noqa: E501
max_length: int = 2000 max_length: int = 2000
truncated_text: str = text if len(text) <= max_length else text[: max_length - 3] + "..." truncated_text: str = (
text if len(text) <= max_length else text[: max_length - 3] + "..."
)
return truncated_text return truncated_text
@ -980,7 +1169,11 @@ def enhance_image2(image: bytes) -> bytes:
enhanced: ImageType = cv2.convertScaleAbs(img_gamma_8bit, alpha=1.2, beta=10) enhanced: ImageType = cv2.convertScaleAbs(img_gamma_8bit, alpha=1.2, beta=10)
# Apply very light sharpening # Apply very light sharpening
kernel: ImageType = np.array([[-0.2, -0.2, -0.2], [-0.2, 2.8, -0.2], [-0.2, -0.2, -0.2]]) kernel: ImageType = np.array([
[-0.2, -0.2, -0.2],
[-0.2, 2.8, -0.2],
[-0.2, -0.2, -0.2],
])
enhanced = cv2.filter2D(enhanced, -1, kernel) enhanced = cv2.filter2D(enhanced, -1, kernel)
# Encode the enhanced image to WebP # Encode the enhanced image to WebP
@ -1047,7 +1240,10 @@ async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) ->
@client.tree.context_menu(name="Enhance Image") @client.tree.context_menu(name="Enhance Image")
@app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_installs(guilds=True, users=True)
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
async def enhance_image_command(interaction: discord.Interaction, message: discord.Message) -> None: async def enhance_image_command(
interaction: discord.Interaction,
message: discord.Message,
) -> None:
"""Context menu command to enhance an image in a message.""" """Context menu command to enhance an image in a message."""
await interaction.response.defer() await interaction.response.defer()
@ -1064,7 +1260,9 @@ async def enhance_image_command(interaction: discord.Interaction, message: disco
logger.exception("Failed to read attachment %s", attachment.url) logger.exception("Failed to read attachment %s", attachment.url)
if not images: if not images:
await interaction.followup.send(f"No images found in the message: \n{message.content=}") await interaction.followup.send(
f"No images found in the message: \n{message.content=}",
)
return return
for image in images: for image in images:
@ -1077,9 +1275,18 @@ async def enhance_image_command(interaction: discord.Interaction, message: disco
) )
# Prepare files # Prepare files
file1 = discord.File(fp=io.BytesIO(enhanced_image1), filename=f"enhanced1-{timestamp}.webp") file1 = discord.File(
file2 = discord.File(fp=io.BytesIO(enhanced_image2), filename=f"enhanced2-{timestamp}.webp") fp=io.BytesIO(enhanced_image1),
file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp") filename=f"enhanced1-{timestamp}.webp",
)
file2 = discord.File(
fp=io.BytesIO(enhanced_image2),
filename=f"enhanced2-{timestamp}.webp",
)
file3 = discord.File(
fp=io.BytesIO(enhanced_image3),
filename=f"enhanced3-{timestamp}.webp",
)
files: list[discord.File] = [file1, file2, file3] files: list[discord.File] = [file1, file2, file3]

View file

@ -22,15 +22,21 @@ dependencies = [
dev = ["pytest", "ruff"] dev = ["pytest", "ruff"]
[tool.ruff] [tool.ruff]
preview = true
fix = true fix = true
preview = true
unsafe-fixes = true unsafe-fixes = true
lint.select = ["ALL"]
lint.fixable = ["ALL"]
lint.pydocstyle.convention = "google"
lint.isort.required-imports = ["from __future__ import annotations"]
lint.pycodestyle.ignore-overlong-task-comments = true
format.docstring-code-format = true
format.preview = true
lint.future-annotations = true
lint.isort.force-single-line = true
lint.pycodestyle.ignore-overlong-task-comments = true
lint.pydocstyle.convention = "google"
lint.select = ["ALL"]
# Don't automatically remove unused variables
lint.unfixable = ["F841"]
lint.ignore = [ lint.ignore = [
"CPY001", # Checks for the absence of copyright notices within Python files. "CPY001", # Checks for the absence of copyright notices within Python files.
"D100", # Checks for undocumented public module definitions. "D100", # Checks for undocumented public module definitions.
@ -56,13 +62,8 @@ lint.ignore = [
"Q003", # Checks for strings that include escaped quotes, and suggests changing the quote style to avoid the need to escape them. "Q003", # Checks for strings that include escaped quotes, and suggests changing the quote style to avoid the need to escape them.
"W191", # Checks for indentation that uses tabs. "W191", # Checks for indentation that uses tabs.
] ]
line-length = 160
[tool.ruff.format]
docstring-code-format = true
docstring-code-line-length = 20
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
"**/test_*.py" = [ "**/test_*.py" = [
"ARG", # Unused function args -> fixtures nevertheless are functionally relevant... "ARG", # Unused function args -> fixtures nevertheless are functionally relevant...

View file

@ -2,15 +2,13 @@ from __future__ import annotations
import pytest import pytest
from main import ( from main import add_message_to_memory
add_message_to_memory, from main import last_trigger_time
last_trigger_time, from main import recent_messages
recent_messages, from main import reset_memory
reset_memory, from main import reset_snapshots
reset_snapshots, from main import undo_reset
undo_reset, from main import update_trigger_time
update_trigger_time,
)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)