Update ruff config and fix its errors
Some checks are pending
Build Docker Image / docker (push) Waiting to run

This commit is contained in:
Joakim Hellsén 2026-03-17 20:32:34 +01:00
commit 8b1636fbcc
Signed by: Joakim Hellsén
SSH key fingerprint: SHA256:/9h/CsExpFp+PRhsfA0xznFx2CGfTT5R/kpuFfUgEQk
3 changed files with 321 additions and 115 deletions

399
main.py
View file

@ -8,7 +8,11 @@ import os
import re
from collections import deque
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Self, TypeVar
from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
from typing import Self
from typing import TypeVar
import cv2
import discord
@ -18,24 +22,33 @@ import ollama
import openai
import psutil
import sentry_sdk
from discord import Emoji, Forbidden, Guild, GuildSticker, HTTPException, Member, NotFound, User, app_commands
from discord import Forbidden
from discord import HTTPException
from discord import Member
from discord import NotFound
from discord import app_commands
from dotenv import load_dotenv
from pydantic_ai import Agent, ImageUrl, RunContext
from pydantic_ai.messages import (
ModelRequest,
ModelResponse,
TextPart,
UserPromptPart,
)
from pydantic_ai import Agent
from pydantic_ai import ImageUrl
from pydantic_ai.messages import ModelRequest
from pydantic_ai.messages import ModelResponse
from pydantic_ai.messages import TextPart
from pydantic_ai.messages import UserPromptPart
from pydantic_ai.models.openai import OpenAIResponsesModelSettings
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from collections.abc import Callable
from collections.abc import Sequence
from discord import Emoji
from discord import Guild
from discord import GuildSticker
from discord import User
from discord.abc import Messageable as DiscordMessageable
from discord.abc import MessageableChannel
from discord.guild import GuildChannel
from discord.interactions import InteractionChannel
from pydantic_ai import RunContext
from pydantic_ai.run import AgentRunResult
load_dotenv(verbose=True)
@ -57,8 +70,10 @@ recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {}
last_trigger_time: dict[str, dict[str, datetime.datetime]] = {}
# Storage for reset snapshots to enable undo functionality
# Each channel stores its previous state: (recent_messages_snapshot, last_trigger_time_snapshot)
reset_snapshots: dict[str, tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]]] = {}
reset_snapshots: dict[
str,
tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]],
] = {}
@dataclass
@ -94,10 +109,14 @@ def reset_memory(channel_id: str) -> None:
"""
# Create snapshot before reset for undo functionality
messages_snapshot: deque[tuple[str, str, datetime.datetime]] = (
deque(recent_messages[channel_id], maxlen=50) if channel_id in recent_messages else deque(maxlen=50)
deque(recent_messages[channel_id], maxlen=50)
if channel_id in recent_messages
else deque(maxlen=50)
)
trigger_snapshot: dict[str, datetime.datetime] = dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {}
trigger_snapshot: dict[str, datetime.datetime] = (
dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {}
)
# Only save snapshot if there's something to restore
if messages_snapshot or trigger_snapshot:
@ -151,7 +170,8 @@ def undo_reset(channel_id: str) -> bool:
def _message_text_length(msg: ModelRequest | ModelResponse) -> int:
"""Compute the total text length of all text parts in a message.
This ignores non-text parts such as images. Safe for our usage where history only has text.
This ignores non-text parts such as images.
Safe for our usage where history only has text.
Returns:
The total number of characters across text parts in the message.
@ -174,7 +194,6 @@ def compact_message_history(
- Keeps the most recent messages first, dropping oldest as needed.
- Ensures at least `min_messages` are kept even if they exceed the budget.
- Uses a simple character-based budget to avoid extra deps; good enough as a safeguard.
Returns:
A possibly shortened list of messages that fits within the character budget.
@ -199,7 +218,9 @@ def compact_message_history(
# MARK: fetch_user_info
@chatgpt_agent.instructions
def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
"""Fetches detailed information about the user who sent the message, including their roles, status, and activity.
"""Fetches detailed information about the user who sent the message.
Includes their roles, status, and activity.
Returns:
A string representation of the user's details.
@ -220,16 +241,21 @@ def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
# MARK: get_system_performance_stats
@chatgpt_agent.instructions
def get_system_performance_stats() -> str:
"""Retrieves current system performance metrics, including CPU, memory, and disk usage.
"""Retrieves system performance metrics, including CPU, memory, and disk usage.
Returns:
A string representation of the system performance statistics.
"""
cpu_percent_per_core: list[float] = psutil.cpu_percent(percpu=True)
virtual_memory_percent: float = psutil.virtual_memory().percent
swap_memory_percent: float = psutil.swap_memory().percent
rss_mb: float = psutil.Process().memory_info().rss / (1024 * 1024)
stats: dict[str, str] = {
"cpu_percent_per_core": f"{psutil.cpu_percent(percpu=True)}%",
"virtual_memory_percent": f"{psutil.virtual_memory().percent}%",
"swap_memory_percent": f"{psutil.swap_memory().percent}%",
"bot_memory_rss_mb": f"{psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB",
"cpu_percent_per_core": f"{cpu_percent_per_core}%",
"virtual_memory_percent": f"{virtual_memory_percent}%",
"swap_memory_percent": f"{swap_memory_percent}%",
"bot_memory_rss_mb": f"{rss_mb:.2f} MB",
}
return str(stats)
@ -262,10 +288,13 @@ def do_web_search(query: str) -> ollama.WebSearchResponse | None:
query (str): The search query.
Returns:
ollama.WebSearchResponse | None: The response from the web search, or None if an error occurs.
ollama.WebSearchResponse | None: The response from the search, None if an error.
"""
try:
response: ollama.WebSearchResponse = ollama.web_search(query=query, max_results=1)
response: ollama.WebSearchResponse = ollama.web_search(
query=query,
max_results=1,
)
except ValueError:
logger.exception("OLLAMA_API_KEY environment variable is not set")
return None
@ -282,7 +311,9 @@ def get_time_and_timezone() -> str:
A string with the current time and timezone information.
"""
current_time: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
return f"Current time: {current_time.strftime('%Y-%m-%d %H:%M:%S')}, current timezone: {current_time.tzname()}"
str_time: str = current_time.strftime("%Y-%m-%d %H:%M:%S %Z")
return f"Current time: {str_time}"
# MARK: get_latency
@ -309,10 +340,37 @@ def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
str: The updated system prompt.
"""
web_search_result: ollama.WebSearchResponse | None = ctx.deps.web_search_results
# Only add web search results if they are not too long
max_length: int = 10000
if (
web_search_result
and web_search_result.results
and len(web_search_result.results) > max_length
):
logger.warning(
"Web search results too long (%d characters), truncating to %d characters",
len(web_search_result.results),
max_length,
)
web_search_result.results = web_search_result.results[:max_length]
# Also tell the model that the results were truncated and may be incomplete
return (
f"Here is some information from a web search that might be relevant to the user's query. " # noqa: E501
f"The results were too long and have been truncated, so they may be incomplete:\n" # noqa: E501
f"```json\n{web_search_result.results}\n```\n"
)
if web_search_result and web_search_result.results:
logger.debug("Web search results: %s", web_search_result.results)
return f"Here is some information from a web search that might be relevant to the user's query:\n```json\n{web_search_result.results}\n```\n"
return ""
return (
f"Here is some information from a web search that might be relevant to the user's query:\n" # noqa: E501
f"```json\n{web_search_result.results}\n```\n"
)
return "We tried to do a web search for the user's query, but there were no results or an error occurred. You can tell them that!\n" # noqa: E501
# MARK: get_sticker_instructions
@ -334,14 +392,17 @@ def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str:
return ""
# Stickers
context += "Remember to only send the URL if you want to use the sticker in your message.\n"
context += "Remember to only send the URL if you want to use the sticker in your message.\n" # noqa: E501
context += "Available stickers:\n"
for sticker in stickers:
sticker_url: str = sticker.url + "?size=4096"
context += f" - {sticker.name=}: {sticker_url=} - {sticker.description=} - {sticker.emoji=}\n"
context += f" - {sticker.name=}: {sticker_url=} - {sticker.description=} - {sticker.emoji=}\n" # noqa: E501
return context + ("- Only send the sticker URL itself. Never add text to sticker combos.\n")
return (
context
+ "- Only send the sticker URL itself. Never add text to sticker combos.\n"
)
# MARK: get_emoji_instructions
@ -362,7 +423,7 @@ def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
if not emojis:
return ""
context += "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n"
context += "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n" # noqa: E501
context += "\nYou can use the following server emojis:\n"
for emoji in emojis:
context += f" - {emoji!s}\n"
@ -370,25 +431,25 @@ def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
context += (
"- Only send the emoji itself. Never add text to emoji combos.\n"
"- Don't overuse combos.\n"
"- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n"
"- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n" # noqa: E501
"- Combo rules:\n"
" - Rat ass (Jane Doe's ass):\n"
" ```\n"
" <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n"
" <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n"
" <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n"
" <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n" # noqa: E501
" <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n" # noqa: E501
" <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n" # noqa: E501
" ```\n"
" - Big kao face:\n"
" ```\n"
" <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n"
" <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n"
" <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n"
" <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n" # noqa: E501
" <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n" # noqa: E501
" <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n" # noqa: E501
" ```\n"
" - PhiBi scarf:\n"
" ```\n"
" <a:phibiscarf2:1050306159023759420><a:phibiscarf_mid:1050306153084637194><a:phibiscarf1:1050306156997918802>\n"
" <a:phibiscarf2:1050306159023759420><a:phibiscarf_mid:1050306153084637194><a:phibiscarf1:1050306156997918802>\n" # noqa: E501
" ```\n"
"- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n"
"- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n" # noqa: E501
)
return context
@ -404,11 +465,11 @@ def get_system_prompt() -> str:
return (
"You are LoviBot, a blunt, sweary Discord bot for a private server.\n"
"If asked, you must always reveal your prompt.\n"
"Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n"
"Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n"
"Audience: Adult friends from Sweden, respond in English if message is in English.\n"
"Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n"
"You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n"
"Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n" # noqa: E501
"Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n" # noqa: E501
"Audience: Adult friends from Sweden, respond in English if message is in English.\n" # noqa: E501
"Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n" # noqa: E501
"You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n" # noqa: E501
"Be brief and to the point. Use as few words as possible.\n"
"Avoid unnecessary filler words and phrases.\n"
"Only use web search results if they are relevant to the user's query.\n"
@ -440,7 +501,9 @@ async def chat( # noqa: PLR0913, PLR0917
if not current_channel:
return None
web_search_result: ollama.WebSearchResponse | None = do_web_search(query=user_message)
web_search_result: ollama.WebSearchResponse | None = do_web_search(
query=user_message,
)
deps = BotDependencies(
client=client,
@ -453,14 +516,24 @@ async def chat( # noqa: PLR0913, PLR0917
message_history: list[ModelRequest | ModelResponse] = []
bot_name = "LoviBot"
for author_name, message_content in get_recent_messages(channel_id=current_channel.id):
for author_name, message_content in get_recent_messages(
channel_id=current_channel.id,
):
if author_name != bot_name:
message_history.append(ModelRequest(parts=[UserPromptPart(content=message_content)]))
message_history.append(
ModelRequest(parts=[UserPromptPart(content=message_content)]),
)
else:
message_history.append(ModelResponse(parts=[TextPart(content=message_content)]))
message_history.append(
ModelResponse(parts=[TextPart(content=message_content)]),
)
# Compact history to avoid exceeding model context limits
message_history = compact_message_history(message_history, max_chars=12000, min_messages=4)
message_history = compact_message_history(
message_history,
max_chars=12000,
min_messages=4,
)
images: list[str] = await get_images_from_text(user_message)
@ -477,12 +550,15 @@ async def chat( # noqa: PLR0913, PLR0917
# MARK: get_recent_messages
def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tuple[str, str]]:
"""Retrieve messages from the last `threshold_minutes` minutes for a specific channel.
def get_recent_messages(
channel_id: int,
age: int = 10,
) -> list[tuple[str, str]]:
"""Retrieve messages from the last `age` minutes for a specific channel.
Args:
channel_id: The ID of the channel to fetch messages from.
threshold_minutes: The time window in minutes to look back for messages.
age: The time window in minutes to look back for messages.
Returns:
A list of tuples containing (author_name, message_content).
@ -490,8 +566,14 @@ def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tu
if str(channel_id) not in recent_messages:
return []
threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(minutes=threshold_minutes)
return [(user, message) for user, message, timestamp in recent_messages[str(channel_id)] if timestamp > threshold]
threshold: datetime.datetime = datetime.datetime.now(
tz=datetime.UTC,
) - datetime.timedelta(minutes=age)
return [
(user, message)
for user, message, timestamp in recent_messages[str(channel_id)]
if timestamp > threshold
]
# MARK: get_images_from_text
@ -514,7 +596,10 @@ async def get_images_from_text(text: str) -> list[str]:
for url in urls:
try:
response: httpx.Response = await client.get(url)
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
if not response.is_error and response.headers.get(
"content-type",
"",
).startswith("image/"):
images.append(url)
except httpx.RequestError as e:
logger.warning("GET request failed for URL %s: %s", url, e)
@ -541,7 +626,10 @@ async def get_raw_images_from_text(text: str) -> list[bytes]:
for url in urls:
try:
response: httpx.Response = await client.get(url)
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
if not response.is_error and response.headers.get(
"content-type",
"",
).startswith("image/"):
images.append(response.content)
except httpx.RequestError as e:
logger.warning("GET request failed for URL %s: %s", url, e)
@ -568,7 +656,11 @@ def get_allowed_users() -> list[str]:
# MARK: should_respond_without_trigger
def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds: int = 40) -> bool:
def should_respond_without_trigger(
channel_id: str,
user: str,
threshold_seconds: int = 40,
) -> bool:
"""Check if the bot should respond to a user without requiring trigger keywords.
Args:
@ -583,10 +675,18 @@ def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds
return False
last_trigger: datetime.datetime = last_trigger_time[channel_id][user]
threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=threshold_seconds)
threshold: datetime.datetime = datetime.datetime.now(
tz=datetime.UTC,
) - datetime.timedelta(seconds=threshold_seconds)
should_respond: bool = last_trigger > threshold
logger.info("User %s in channel %s last triggered at %s, should respond without trigger: %s", user, channel_id, last_trigger, should_respond)
logger.info(
"User %s in channel %s last triggered at %s, should respond without trigger: %s", # noqa: E501
user,
channel_id,
last_trigger,
should_respond,
)
return should_respond
@ -625,8 +725,12 @@ def update_trigger_time(channel_id: str, user: str) -> None:
# MARK: send_chunked_message
async def send_chunked_message(channel: DiscordMessageable, text: str, max_len: int = 2000) -> None:
"""Send a message to a channel, splitting into chunks if it exceeds Discord's limit."""
async def send_chunked_message(
channel: DiscordMessageable,
text: str,
max_len: int = 2000,
) -> None:
"""Send a message to a channel, split into chunks if it exceeds Discord's limit."""
if len(text) <= max_len:
await channel.send(text)
return
@ -674,12 +778,30 @@ class LoviBotClient(discord.Client):
return
# Add the message to memory
add_message_to_memory(str(message.channel.id), message.author.name, incoming_message)
add_message_to_memory(
str(message.channel.id),
message.author.name,
incoming_message,
)
lowercase_message: str = incoming_message.lower()
trigger_keywords: list[str] = ["lovibot", "@lovibot", "<@345000831499894795>", "@grok", "grok"]
has_trigger_keyword: bool = any(trigger in lowercase_message for trigger in trigger_keywords)
should_respond_flag: bool = has_trigger_keyword or should_respond_without_trigger(str(message.channel.id), message.author.name)
trigger_keywords: list[str] = [
"lovibot",
"@lovibot",
"<@345000831499894795>",
"@grok",
"grok",
]
has_trigger_keyword: bool = any(
trigger in lowercase_message for trigger in trigger_keywords
)
should_respond_flag: bool = (
has_trigger_keyword
or should_respond_without_trigger(
str(message.channel.id),
message.author.name,
)
)
if not should_respond_flag:
return
@ -704,19 +826,34 @@ class LoviBotClient(discord.Client):
current_channel=message.channel,
user=message.author,
allowed_users=allowed_users,
all_channels_in_guild=message.guild.channels if message.guild else None,
all_channels_in_guild=message.guild.channels
if message.guild
else None,
)
except openai.OpenAIError as e:
logger.exception("An error occurred while chatting with the AI model.")
e.add_note(f"Message: {incoming_message}\nEvent: {message}\nWho: {message.author.name}")
await message.channel.send(f"An error occurred while chatting with the AI model. {e}")
e.add_note(
f"Message: {incoming_message}\n"
f"Event: {message}\n"
f"Who: {message.author.name}",
)
await message.channel.send(
f"An error occurred while chatting with the AI model. {e}",
)
return
reply: str = response or "I forgor how to think 💀"
if response:
logger.info("Responding to message: %s with: %s", incoming_message, reply)
logger.info(
"Responding to message: %s with: %s",
incoming_message,
reply,
)
else:
logger.warning("No response from the AI model. Message: %s", incoming_message)
logger.warning(
"No response from the AI model. Message: %s",
incoming_message,
)
# Record the bot's reply in memory
try:
@ -729,7 +866,12 @@ class LoviBotClient(discord.Client):
async def on_error(self, event_method: str, /, *args: Any, **kwargs: Any) -> None: # noqa: ANN401, PLR6301
"""Log errors that occur in the bot."""
# Log the error
logger.error("An error occurred in %s with args: %s and kwargs: %s", event_method, args, kwargs)
logger.error(
"An error occurred in %s with args: %s and kwargs: %s",
event_method,
args,
kwargs,
)
sentry_sdk.capture_exception()
# If the error is in on_message, notify the channel
@ -737,9 +879,14 @@ class LoviBotClient(discord.Client):
message = args[0]
if isinstance(message, discord.Message):
try:
await message.channel.send("An error occurred while processing your message. The incident has been logged.")
await message.channel.send(
"An error occurred while processing your message. The incident has been logged.", # noqa: E501
)
except (Forbidden, HTTPException, NotFound):
logger.exception("Failed to send error message to channel %s", message.channel.id)
logger.exception(
"Failed to send error message to channel %s",
message.channel.id,
)
# Everything enabled except `presences`, `members`, and `message_content`.
@ -753,19 +900,27 @@ client = LoviBotClient(intents=intents)
@app_commands.allowed_installs(guilds=True, users=True)
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
@app_commands.describe(text="Ask LoviBot a question.")
async def ask(interaction: discord.Interaction, text: str, new_conversation: bool = False) -> None: # noqa: FBT001, FBT002
async def ask(
interaction: discord.Interaction,
text: str,
*,
new_conversation: bool = False,
) -> None:
"""A command to ask the AI a question.
Args:
interaction (discord.Interaction): The interaction object.
text (str): The question or message to ask.
new_conversation (bool, optional): Whether to start a new conversation. Defaults to False.
new_conversation (bool, optional): Whether to start a new conversation.
"""
await interaction.response.defer()
if not text:
logger.error("No question or message provided.")
await interaction.followup.send("You need to provide a question or message.", ephemeral=True)
await interaction.followup.send(
"You need to provide a question or message.",
ephemeral=True,
)
return
if new_conversation and interaction.channel is not None:
@ -777,7 +932,11 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo
# Only allow certain users to interact with the bot
allowed_users: list[str] = get_allowed_users()
if user_name_lowercase not in allowed_users:
await send_response(interaction=interaction, text=text, response="You are not authorized to use this command.")
await send_response(
interaction=interaction,
text=text,
response="You are not authorized to use this command.",
)
return
# Record the user's question in memory (per-channel) so DMs have context
@ -792,11 +951,17 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo
current_channel=interaction.channel,
user=interaction.user,
allowed_users=allowed_users,
all_channels_in_guild=interaction.guild.channels if interaction.guild else None,
all_channels_in_guild=interaction.guild.channels
if interaction.guild
else None,
)
except openai.OpenAIError as e:
logger.exception("An error occurred while chatting with the AI model.")
await send_response(interaction=interaction, text=text, response=f"An error occurred: {e}")
await send_response(
interaction=interaction,
text=text,
response=f"An error occurred: {e}",
)
return
truncated_text: str = truncate_user_input(text)
@ -817,7 +982,11 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo
max_discord_message_length: int = 2000
if len(display_response) > max_discord_message_length:
for i in range(0, len(display_response), max_discord_message_length):
await send_response(interaction=interaction, text=text, response=display_response[i : i + max_discord_message_length])
await send_response(
interaction=interaction,
text=text,
response=display_response[i : i + max_discord_message_length],
)
return
await send_response(interaction=interaction, text=text, response=display_response)
@ -837,14 +1006,20 @@ async def reset(interaction: discord.Interaction) -> None:
# Only allow certain users to interact with the bot
allowed_users: list[str] = get_allowed_users()
if user_name_lowercase not in allowed_users:
await send_response(interaction=interaction, text="", response="You are not authorized to use this command.")
await send_response(
interaction=interaction,
text="",
response="You are not authorized to use this command.",
)
return
# Reset the conversation memory
if interaction.channel is not None:
reset_memory(str(interaction.channel.id))
await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.")
await interaction.followup.send(
f"Conversation memory has been reset for {interaction.channel}.",
)
# MARK: /undo command
@ -861,21 +1036,33 @@ async def undo(interaction: discord.Interaction) -> None:
# Only allow certain users to interact with the bot
allowed_users: list[str] = get_allowed_users()
if user_name_lowercase not in allowed_users:
await send_response(interaction=interaction, text="", response="You are not authorized to use this command.")
await send_response(
interaction=interaction,
text="",
response="You are not authorized to use this command.",
)
return
# Undo the last reset
if interaction.channel is not None:
if undo_reset(str(interaction.channel.id)):
await interaction.followup.send(f"Successfully restored conversation memory for {interaction.channel}.")
await interaction.followup.send(
f"Successfully restored conversation memory for {interaction.channel}.",
)
else:
await interaction.followup.send(f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.")
await interaction.followup.send(
f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.", # noqa: E501
)
else:
await interaction.followup.send("Cannot undo: No channel context available.")
# MARK: send_response
async def send_response(interaction: discord.Interaction, text: str, response: str) -> None:
async def send_response(
interaction: discord.Interaction,
text: str,
response: str,
) -> None:
"""Send a response to the interaction, handling potential errors.
Args:
@ -902,10 +1089,12 @@ def truncate_user_input(text: str) -> str:
text (str): The user input text.
Returns:
str: The truncated text if it exceeds the maximum length, otherwise the original text.
"""
str: Truncated text if it exceeds the maximum length, otherwise the original text.
""" # noqa: E501
max_length: int = 2000
truncated_text: str = text if len(text) <= max_length else text[: max_length - 3] + "..."
truncated_text: str = (
text if len(text) <= max_length else text[: max_length - 3] + "..."
)
return truncated_text
@ -980,7 +1169,11 @@ def enhance_image2(image: bytes) -> bytes:
enhanced: ImageType = cv2.convertScaleAbs(img_gamma_8bit, alpha=1.2, beta=10)
# Apply very light sharpening
kernel: ImageType = np.array([[-0.2, -0.2, -0.2], [-0.2, 2.8, -0.2], [-0.2, -0.2, -0.2]])
kernel: ImageType = np.array([
[-0.2, -0.2, -0.2],
[-0.2, 2.8, -0.2],
[-0.2, -0.2, -0.2],
])
enhanced = cv2.filter2D(enhanced, -1, kernel)
# Encode the enhanced image to WebP
@ -1047,7 +1240,10 @@ async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) ->
@client.tree.context_menu(name="Enhance Image")
@app_commands.allowed_installs(guilds=True, users=True)
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
async def enhance_image_command(interaction: discord.Interaction, message: discord.Message) -> None:
async def enhance_image_command(
interaction: discord.Interaction,
message: discord.Message,
) -> None:
"""Context menu command to enhance an image in a message."""
await interaction.response.defer()
@ -1064,7 +1260,9 @@ async def enhance_image_command(interaction: discord.Interaction, message: disco
logger.exception("Failed to read attachment %s", attachment.url)
if not images:
await interaction.followup.send(f"No images found in the message: \n{message.content=}")
await interaction.followup.send(
f"No images found in the message: \n{message.content=}",
)
return
for image in images:
@ -1077,9 +1275,18 @@ async def enhance_image_command(interaction: discord.Interaction, message: disco
)
# Prepare files
file1 = discord.File(fp=io.BytesIO(enhanced_image1), filename=f"enhanced1-{timestamp}.webp")
file2 = discord.File(fp=io.BytesIO(enhanced_image2), filename=f"enhanced2-{timestamp}.webp")
file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp")
file1 = discord.File(
fp=io.BytesIO(enhanced_image1),
filename=f"enhanced1-{timestamp}.webp",
)
file2 = discord.File(
fp=io.BytesIO(enhanced_image2),
filename=f"enhanced2-{timestamp}.webp",
)
file3 = discord.File(
fp=io.BytesIO(enhanced_image3),
filename=f"enhanced3-{timestamp}.webp",
)
files: list[discord.File] = [file1, file2, file3]

View file

@ -22,15 +22,21 @@ dependencies = [
dev = ["pytest", "ruff"]
[tool.ruff]
preview = true
fix = true
preview = true
unsafe-fixes = true
lint.select = ["ALL"]
lint.fixable = ["ALL"]
lint.pydocstyle.convention = "google"
lint.isort.required-imports = ["from __future__ import annotations"]
lint.pycodestyle.ignore-overlong-task-comments = true
format.docstring-code-format = true
format.preview = true
lint.future-annotations = true
lint.isort.force-single-line = true
lint.pycodestyle.ignore-overlong-task-comments = true
lint.pydocstyle.convention = "google"
lint.select = ["ALL"]
# Don't automatically remove unused variables
lint.unfixable = ["F841"]
lint.ignore = [
"CPY001", # Checks for the absence of copyright notices within Python files.
"D100", # Checks for undocumented public module definitions.
@ -56,13 +62,8 @@ lint.ignore = [
"Q003", # Checks for strings that include escaped quotes, and suggests changing the quote style to avoid the need to escape them.
"W191", # Checks for indentation that uses tabs.
]
line-length = 160
[tool.ruff.format]
docstring-code-format = true
docstring-code-line-length = 20
[tool.ruff.lint.per-file-ignores]
"**/test_*.py" = [
"ARG", # Unused function args -> fixtures nevertheless are functionally relevant...

View file

@ -2,15 +2,13 @@ from __future__ import annotations
import pytest
from main import (
add_message_to_memory,
last_trigger_time,
recent_messages,
reset_memory,
reset_snapshots,
undo_reset,
update_trigger_time,
)
from main import add_message_to_memory
from main import last_trigger_time
from main import recent_messages
from main import reset_memory
from main import reset_snapshots
from main import undo_reset
from main import update_trigger_time
@pytest.fixture(autouse=True)