From 8b1636fbccd82579edfb13d6f908734fec191154 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Helle=C5=9Ben?= Date: Tue, 17 Mar 2026 20:32:34 +0100 Subject: [PATCH] Update ruff config and fix its errors --- main.py | 399 +++++++++++++++++++++++++++++---------- pyproject.toml | 23 +-- tests/test_reset_undo.py | 16 +- 3 files changed, 322 insertions(+), 116 deletions(-) diff --git a/main.py b/main.py index e045542..5d53885 100644 --- a/main.py +++ b/main.py @@ -8,7 +8,11 @@ import os import re from collections import deque from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, Self, TypeVar +from typing import TYPE_CHECKING +from typing import Any +from typing import Literal +from typing import Self +from typing import TypeVar import cv2 import discord @@ -18,24 +22,33 @@ import ollama import openai import psutil import sentry_sdk -from discord import Emoji, Forbidden, Guild, GuildSticker, HTTPException, Member, NotFound, User, app_commands +from discord import Forbidden +from discord import HTTPException +from discord import Member +from discord import NotFound +from discord import app_commands from dotenv import load_dotenv -from pydantic_ai import Agent, ImageUrl, RunContext -from pydantic_ai.messages import ( - ModelRequest, - ModelResponse, - TextPart, - UserPromptPart, -) +from pydantic_ai import Agent +from pydantic_ai import ImageUrl +from pydantic_ai.messages import ModelRequest +from pydantic_ai.messages import ModelResponse +from pydantic_ai.messages import TextPart +from pydantic_ai.messages import UserPromptPart from pydantic_ai.models.openai import OpenAIResponsesModelSettings if TYPE_CHECKING: - from collections.abc import Callable, Sequence + from collections.abc import Callable + from collections.abc import Sequence + from discord import Emoji + from discord import Guild + from discord import GuildSticker + from discord import User from discord.abc import Messageable as DiscordMessageable from discord.abc import MessageableChannel from discord.guild import GuildChannel from discord.interactions import InteractionChannel + from pydantic_ai import RunContext from pydantic_ai.run import AgentRunResult load_dotenv(verbose=True) @@ -57,8 +70,10 @@ recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {} last_trigger_time: dict[str, dict[str, datetime.datetime]] = {} # Storage for reset snapshots to enable undo functionality -# Each channel stores its previous state: (recent_messages_snapshot, last_trigger_time_snapshot) -reset_snapshots: dict[str, tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]]] = {} +reset_snapshots: dict[ + str, + tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]], +] = {} @dataclass @@ -94,10 +109,14 @@ def reset_memory(channel_id: str) -> None: """ # Create snapshot before reset for undo functionality messages_snapshot: deque[tuple[str, str, datetime.datetime]] = ( - deque(recent_messages[channel_id], maxlen=50) if channel_id in recent_messages else deque(maxlen=50) + deque(recent_messages[channel_id], maxlen=50) + if channel_id in recent_messages + else deque(maxlen=50) ) - trigger_snapshot: dict[str, datetime.datetime] = dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {} + trigger_snapshot: dict[str, datetime.datetime] = ( + dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {} + ) # Only save snapshot if there's something to restore if messages_snapshot or trigger_snapshot: @@ -151,7 +170,8 @@ def undo_reset(channel_id: str) -> bool: def _message_text_length(msg: ModelRequest | ModelResponse) -> int: """Compute the total text length of all text parts in a message. - This ignores non-text parts such as images. Safe for our usage where history only has text. + This ignores non-text parts such as images. + Safe for our usage where history only has text. Returns: The total number of characters across text parts in the message. @@ -174,7 +194,6 @@ def compact_message_history( - Keeps the most recent messages first, dropping oldest as needed. - Ensures at least `min_messages` are kept even if they exceed the budget. - - Uses a simple character-based budget to avoid extra deps; good enough as a safeguard. Returns: A possibly shortened list of messages that fits within the character budget. @@ -199,7 +218,9 @@ def compact_message_history( # MARK: fetch_user_info @chatgpt_agent.instructions def fetch_user_info(ctx: RunContext[BotDependencies]) -> str: - """Fetches detailed information about the user who sent the message, including their roles, status, and activity. + """Fetches detailed information about the user who sent the message. + + Includes their roles, status, and activity. Returns: A string representation of the user's details. @@ -220,16 +241,21 @@ def fetch_user_info(ctx: RunContext[BotDependencies]) -> str: # MARK: get_system_performance_stats @chatgpt_agent.instructions def get_system_performance_stats() -> str: - """Retrieves current system performance metrics, including CPU, memory, and disk usage. + """Retrieves system performance metrics, including CPU, memory, and disk usage. Returns: A string representation of the system performance statistics. """ + cpu_percent_per_core: list[float] = psutil.cpu_percent(percpu=True) + virtual_memory_percent: float = psutil.virtual_memory().percent + swap_memory_percent: float = psutil.swap_memory().percent + rss_mb: float = psutil.Process().memory_info().rss / (1024 * 1024) + stats: dict[str, str] = { - "cpu_percent_per_core": f"{psutil.cpu_percent(percpu=True)}%", - "virtual_memory_percent": f"{psutil.virtual_memory().percent}%", - "swap_memory_percent": f"{psutil.swap_memory().percent}%", - "bot_memory_rss_mb": f"{psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB", + "cpu_percent_per_core": f"{cpu_percent_per_core}%", + "virtual_memory_percent": f"{virtual_memory_percent}%", + "swap_memory_percent": f"{swap_memory_percent}%", + "bot_memory_rss_mb": f"{rss_mb:.2f} MB", } return str(stats) @@ -262,10 +288,13 @@ def do_web_search(query: str) -> ollama.WebSearchResponse | None: query (str): The search query. Returns: - ollama.WebSearchResponse | None: The response from the web search, or None if an error occurs. + ollama.WebSearchResponse | None: The response from the search, None if an error. """ try: - response: ollama.WebSearchResponse = ollama.web_search(query=query, max_results=1) + response: ollama.WebSearchResponse = ollama.web_search( + query=query, + max_results=1, + ) except ValueError: logger.exception("OLLAMA_API_KEY environment variable is not set") return None @@ -282,7 +311,9 @@ def get_time_and_timezone() -> str: A string with the current time and timezone information. """ current_time: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - return f"Current time: {current_time.strftime('%Y-%m-%d %H:%M:%S')}, current timezone: {current_time.tzname()}" + str_time: str = current_time.strftime("%Y-%m-%d %H:%M:%S %Z") + + return f"Current time: {str_time}" # MARK: get_latency @@ -309,10 +340,37 @@ def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str: str: The updated system prompt. """ web_search_result: ollama.WebSearchResponse | None = ctx.deps.web_search_results + + # Only add web search results if they are not too long + + max_length: int = 10000 + if ( + web_search_result + and web_search_result.results + and len(web_search_result.results) > max_length + ): + logger.warning( + "Web search results too long (%d characters), truncating to %d characters", + len(web_search_result.results), + max_length, + ) + web_search_result.results = web_search_result.results[:max_length] + + # Also tell the model that the results were truncated and may be incomplete + return ( + f"Here is some information from a web search that might be relevant to the user's query. " # noqa: E501 + f"The results were too long and have been truncated, so they may be incomplete:\n" # noqa: E501 + f"```json\n{web_search_result.results}\n```\n" + ) + if web_search_result and web_search_result.results: logger.debug("Web search results: %s", web_search_result.results) - return f"Here is some information from a web search that might be relevant to the user's query:\n```json\n{web_search_result.results}\n```\n" - return "" + return ( + f"Here is some information from a web search that might be relevant to the user's query:\n" # noqa: E501 + f"```json\n{web_search_result.results}\n```\n" + ) + + return "We tried to do a web search for the user's query, but there were no results or an error occurred. You can tell them that!\n" # noqa: E501 # MARK: get_sticker_instructions @@ -334,14 +392,17 @@ def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str: return "" # Stickers - context += "Remember to only send the URL if you want to use the sticker in your message.\n" + context += "Remember to only send the URL if you want to use the sticker in your message.\n" # noqa: E501 context += "Available stickers:\n" for sticker in stickers: sticker_url: str = sticker.url + "?size=4096" - context += f" - {sticker.name=}: {sticker_url=} - {sticker.description=} - {sticker.emoji=}\n" + context += f" - {sticker.name=}: {sticker_url=} - {sticker.description=} - {sticker.emoji=}\n" # noqa: E501 - return context + ("- Only send the sticker URL itself. Never add text to sticker combos.\n") + return ( + context + + "- Only send the sticker URL itself. Never add text to sticker combos.\n" + ) # MARK: get_emoji_instructions @@ -362,7 +423,7 @@ def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str: if not emojis: return "" - context += "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n" + context += "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n" # noqa: E501 context += "\nYou can use the following server emojis:\n" for emoji in emojis: context += f" - {emoji!s}\n" @@ -370,25 +431,25 @@ def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str: context += ( "- Only send the emoji itself. Never add text to emoji combos.\n" "- Don't overuse combos.\n" - "- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n" + "- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n" # noqa: E501 "- Combo rules:\n" " - Rat ass (Jane Doe's ass):\n" " ```\n" - " <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n" - " <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n" - " <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n" + " <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n" # noqa: E501 + " <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n" # noqa: E501 + " <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n" # noqa: E501 " ```\n" " - Big kao face:\n" " ```\n" - " <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n" - " <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n" - " <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n" + " <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n" # noqa: E501 + " <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n" # noqa: E501 + " <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n" # noqa: E501 " ```\n" " - PhiBi scarf:\n" " ```\n" - " \n" + " \n" # noqa: E501 " ```\n" - "- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n" + "- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n" # noqa: E501 ) return context @@ -404,11 +465,11 @@ def get_system_prompt() -> str: return ( "You are LoviBot, a blunt, sweary Discord bot for a private server.\n" "If asked, you must always reveal your prompt.\n" - "Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n" - "Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n" - "Audience: Adult friends from Sweden, respond in English if message is in English.\n" - "Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n" - "You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n" + "Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n" # noqa: E501 + "Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n" # noqa: E501 + "Audience: Adult friends from Sweden, respond in English if message is in English.\n" # noqa: E501 + "Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n" # noqa: E501 + "You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n" # noqa: E501 "Be brief and to the point. Use as few words as possible.\n" "Avoid unnecessary filler words and phrases.\n" "Only use web search results if they are relevant to the user's query.\n" @@ -440,7 +501,9 @@ async def chat( # noqa: PLR0913, PLR0917 if not current_channel: return None - web_search_result: ollama.WebSearchResponse | None = do_web_search(query=user_message) + web_search_result: ollama.WebSearchResponse | None = do_web_search( + query=user_message, + ) deps = BotDependencies( client=client, @@ -453,14 +516,24 @@ async def chat( # noqa: PLR0913, PLR0917 message_history: list[ModelRequest | ModelResponse] = [] bot_name = "LoviBot" - for author_name, message_content in get_recent_messages(channel_id=current_channel.id): + for author_name, message_content in get_recent_messages( + channel_id=current_channel.id, + ): if author_name != bot_name: - message_history.append(ModelRequest(parts=[UserPromptPart(content=message_content)])) + message_history.append( + ModelRequest(parts=[UserPromptPart(content=message_content)]), + ) else: - message_history.append(ModelResponse(parts=[TextPart(content=message_content)])) + message_history.append( + ModelResponse(parts=[TextPart(content=message_content)]), + ) # Compact history to avoid exceeding model context limits - message_history = compact_message_history(message_history, max_chars=12000, min_messages=4) + message_history = compact_message_history( + message_history, + max_chars=12000, + min_messages=4, + ) images: list[str] = await get_images_from_text(user_message) @@ -477,12 +550,15 @@ async def chat( # noqa: PLR0913, PLR0917 # MARK: get_recent_messages -def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tuple[str, str]]: - """Retrieve messages from the last `threshold_minutes` minutes for a specific channel. +def get_recent_messages( + channel_id: int, + age: int = 10, +) -> list[tuple[str, str]]: + """Retrieve messages from the last `age` minutes for a specific channel. Args: channel_id: The ID of the channel to fetch messages from. - threshold_minutes: The time window in minutes to look back for messages. + age: The time window in minutes to look back for messages. Returns: A list of tuples containing (author_name, message_content). @@ -490,8 +566,14 @@ def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tu if str(channel_id) not in recent_messages: return [] - threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(minutes=threshold_minutes) - return [(user, message) for user, message, timestamp in recent_messages[str(channel_id)] if timestamp > threshold] + threshold: datetime.datetime = datetime.datetime.now( + tz=datetime.UTC, + ) - datetime.timedelta(minutes=age) + return [ + (user, message) + for user, message, timestamp in recent_messages[str(channel_id)] + if timestamp > threshold + ] # MARK: get_images_from_text @@ -514,7 +596,10 @@ async def get_images_from_text(text: str) -> list[str]: for url in urls: try: response: httpx.Response = await client.get(url) - if not response.is_error and response.headers.get("content-type", "").startswith("image/"): + if not response.is_error and response.headers.get( + "content-type", + "", + ).startswith("image/"): images.append(url) except httpx.RequestError as e: logger.warning("GET request failed for URL %s: %s", url, e) @@ -541,7 +626,10 @@ async def get_raw_images_from_text(text: str) -> list[bytes]: for url in urls: try: response: httpx.Response = await client.get(url) - if not response.is_error and response.headers.get("content-type", "").startswith("image/"): + if not response.is_error and response.headers.get( + "content-type", + "", + ).startswith("image/"): images.append(response.content) except httpx.RequestError as e: logger.warning("GET request failed for URL %s: %s", url, e) @@ -568,7 +656,11 @@ def get_allowed_users() -> list[str]: # MARK: should_respond_without_trigger -def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds: int = 40) -> bool: +def should_respond_without_trigger( + channel_id: str, + user: str, + threshold_seconds: int = 40, +) -> bool: """Check if the bot should respond to a user without requiring trigger keywords. Args: @@ -583,10 +675,18 @@ def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds return False last_trigger: datetime.datetime = last_trigger_time[channel_id][user] - threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=threshold_seconds) + threshold: datetime.datetime = datetime.datetime.now( + tz=datetime.UTC, + ) - datetime.timedelta(seconds=threshold_seconds) should_respond: bool = last_trigger > threshold - logger.info("User %s in channel %s last triggered at %s, should respond without trigger: %s", user, channel_id, last_trigger, should_respond) + logger.info( + "User %s in channel %s last triggered at %s, should respond without trigger: %s", # noqa: E501 + user, + channel_id, + last_trigger, + should_respond, + ) return should_respond @@ -625,8 +725,12 @@ def update_trigger_time(channel_id: str, user: str) -> None: # MARK: send_chunked_message -async def send_chunked_message(channel: DiscordMessageable, text: str, max_len: int = 2000) -> None: - """Send a message to a channel, splitting into chunks if it exceeds Discord's limit.""" +async def send_chunked_message( + channel: DiscordMessageable, + text: str, + max_len: int = 2000, +) -> None: + """Send a message to a channel, split into chunks if it exceeds Discord's limit.""" if len(text) <= max_len: await channel.send(text) return @@ -674,12 +778,30 @@ class LoviBotClient(discord.Client): return # Add the message to memory - add_message_to_memory(str(message.channel.id), message.author.name, incoming_message) + add_message_to_memory( + str(message.channel.id), + message.author.name, + incoming_message, + ) lowercase_message: str = incoming_message.lower() - trigger_keywords: list[str] = ["lovibot", "@lovibot", "<@345000831499894795>", "@grok", "grok"] - has_trigger_keyword: bool = any(trigger in lowercase_message for trigger in trigger_keywords) - should_respond_flag: bool = has_trigger_keyword or should_respond_without_trigger(str(message.channel.id), message.author.name) + trigger_keywords: list[str] = [ + "lovibot", + "@lovibot", + "<@345000831499894795>", + "@grok", + "grok", + ] + has_trigger_keyword: bool = any( + trigger in lowercase_message for trigger in trigger_keywords + ) + should_respond_flag: bool = ( + has_trigger_keyword + or should_respond_without_trigger( + str(message.channel.id), + message.author.name, + ) + ) if not should_respond_flag: return @@ -704,19 +826,34 @@ class LoviBotClient(discord.Client): current_channel=message.channel, user=message.author, allowed_users=allowed_users, - all_channels_in_guild=message.guild.channels if message.guild else None, + all_channels_in_guild=message.guild.channels + if message.guild + else None, ) except openai.OpenAIError as e: logger.exception("An error occurred while chatting with the AI model.") - e.add_note(f"Message: {incoming_message}\nEvent: {message}\nWho: {message.author.name}") - await message.channel.send(f"An error occurred while chatting with the AI model. {e}") + e.add_note( + f"Message: {incoming_message}\n" + f"Event: {message}\n" + f"Who: {message.author.name}", + ) + await message.channel.send( + f"An error occurred while chatting with the AI model. {e}", + ) return reply: str = response or "I forgor how to think 💀" if response: - logger.info("Responding to message: %s with: %s", incoming_message, reply) + logger.info( + "Responding to message: %s with: %s", + incoming_message, + reply, + ) else: - logger.warning("No response from the AI model. Message: %s", incoming_message) + logger.warning( + "No response from the AI model. Message: %s", + incoming_message, + ) # Record the bot's reply in memory try: @@ -729,7 +866,12 @@ class LoviBotClient(discord.Client): async def on_error(self, event_method: str, /, *args: Any, **kwargs: Any) -> None: # noqa: ANN401, PLR6301 """Log errors that occur in the bot.""" # Log the error - logger.error("An error occurred in %s with args: %s and kwargs: %s", event_method, args, kwargs) + logger.error( + "An error occurred in %s with args: %s and kwargs: %s", + event_method, + args, + kwargs, + ) sentry_sdk.capture_exception() # If the error is in on_message, notify the channel @@ -737,9 +879,14 @@ class LoviBotClient(discord.Client): message = args[0] if isinstance(message, discord.Message): try: - await message.channel.send("An error occurred while processing your message. The incident has been logged.") + await message.channel.send( + "An error occurred while processing your message. The incident has been logged.", # noqa: E501 + ) except (Forbidden, HTTPException, NotFound): - logger.exception("Failed to send error message to channel %s", message.channel.id) + logger.exception( + "Failed to send error message to channel %s", + message.channel.id, + ) # Everything enabled except `presences`, `members`, and `message_content`. @@ -753,19 +900,27 @@ client = LoviBotClient(intents=intents) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(text="Ask LoviBot a question.") -async def ask(interaction: discord.Interaction, text: str, new_conversation: bool = False) -> None: # noqa: FBT001, FBT002 +async def ask( + interaction: discord.Interaction, + text: str, + *, + new_conversation: bool = False, +) -> None: """A command to ask the AI a question. Args: interaction (discord.Interaction): The interaction object. text (str): The question or message to ask. - new_conversation (bool, optional): Whether to start a new conversation. Defaults to False. + new_conversation (bool, optional): Whether to start a new conversation. """ await interaction.response.defer() if not text: logger.error("No question or message provided.") - await interaction.followup.send("You need to provide a question or message.", ephemeral=True) + await interaction.followup.send( + "You need to provide a question or message.", + ephemeral=True, + ) return if new_conversation and interaction.channel is not None: @@ -777,7 +932,11 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo # Only allow certain users to interact with the bot allowed_users: list[str] = get_allowed_users() if user_name_lowercase not in allowed_users: - await send_response(interaction=interaction, text=text, response="You are not authorized to use this command.") + await send_response( + interaction=interaction, + text=text, + response="You are not authorized to use this command.", + ) return # Record the user's question in memory (per-channel) so DMs have context @@ -792,11 +951,17 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo current_channel=interaction.channel, user=interaction.user, allowed_users=allowed_users, - all_channels_in_guild=interaction.guild.channels if interaction.guild else None, + all_channels_in_guild=interaction.guild.channels + if interaction.guild + else None, ) except openai.OpenAIError as e: logger.exception("An error occurred while chatting with the AI model.") - await send_response(interaction=interaction, text=text, response=f"An error occurred: {e}") + await send_response( + interaction=interaction, + text=text, + response=f"An error occurred: {e}", + ) return truncated_text: str = truncate_user_input(text) @@ -817,7 +982,11 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo max_discord_message_length: int = 2000 if len(display_response) > max_discord_message_length: for i in range(0, len(display_response), max_discord_message_length): - await send_response(interaction=interaction, text=text, response=display_response[i : i + max_discord_message_length]) + await send_response( + interaction=interaction, + text=text, + response=display_response[i : i + max_discord_message_length], + ) return await send_response(interaction=interaction, text=text, response=display_response) @@ -837,14 +1006,20 @@ async def reset(interaction: discord.Interaction) -> None: # Only allow certain users to interact with the bot allowed_users: list[str] = get_allowed_users() if user_name_lowercase not in allowed_users: - await send_response(interaction=interaction, text="", response="You are not authorized to use this command.") + await send_response( + interaction=interaction, + text="", + response="You are not authorized to use this command.", + ) return # Reset the conversation memory if interaction.channel is not None: reset_memory(str(interaction.channel.id)) - await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.") + await interaction.followup.send( + f"Conversation memory has been reset for {interaction.channel}.", + ) # MARK: /undo command @@ -861,21 +1036,33 @@ async def undo(interaction: discord.Interaction) -> None: # Only allow certain users to interact with the bot allowed_users: list[str] = get_allowed_users() if user_name_lowercase not in allowed_users: - await send_response(interaction=interaction, text="", response="You are not authorized to use this command.") + await send_response( + interaction=interaction, + text="", + response="You are not authorized to use this command.", + ) return # Undo the last reset if interaction.channel is not None: if undo_reset(str(interaction.channel.id)): - await interaction.followup.send(f"Successfully restored conversation memory for {interaction.channel}.") + await interaction.followup.send( + f"Successfully restored conversation memory for {interaction.channel}.", + ) else: - await interaction.followup.send(f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.") + await interaction.followup.send( + f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.", # noqa: E501 + ) else: await interaction.followup.send("Cannot undo: No channel context available.") # MARK: send_response -async def send_response(interaction: discord.Interaction, text: str, response: str) -> None: +async def send_response( + interaction: discord.Interaction, + text: str, + response: str, +) -> None: """Send a response to the interaction, handling potential errors. Args: @@ -902,10 +1089,12 @@ def truncate_user_input(text: str) -> str: text (str): The user input text. Returns: - str: The truncated text if it exceeds the maximum length, otherwise the original text. - """ + str: Truncated text if it exceeds the maximum length, otherwise the original text. + """ # noqa: E501 max_length: int = 2000 - truncated_text: str = text if len(text) <= max_length else text[: max_length - 3] + "..." + truncated_text: str = ( + text if len(text) <= max_length else text[: max_length - 3] + "..." + ) return truncated_text @@ -980,7 +1169,11 @@ def enhance_image2(image: bytes) -> bytes: enhanced: ImageType = cv2.convertScaleAbs(img_gamma_8bit, alpha=1.2, beta=10) # Apply very light sharpening - kernel: ImageType = np.array([[-0.2, -0.2, -0.2], [-0.2, 2.8, -0.2], [-0.2, -0.2, -0.2]]) + kernel: ImageType = np.array([ + [-0.2, -0.2, -0.2], + [-0.2, 2.8, -0.2], + [-0.2, -0.2, -0.2], + ]) enhanced = cv2.filter2D(enhanced, -1, kernel) # Encode the enhanced image to WebP @@ -1047,7 +1240,10 @@ async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) -> @client.tree.context_menu(name="Enhance Image") @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) -async def enhance_image_command(interaction: discord.Interaction, message: discord.Message) -> None: +async def enhance_image_command( + interaction: discord.Interaction, + message: discord.Message, +) -> None: """Context menu command to enhance an image in a message.""" await interaction.response.defer() @@ -1064,7 +1260,9 @@ async def enhance_image_command(interaction: discord.Interaction, message: disco logger.exception("Failed to read attachment %s", attachment.url) if not images: - await interaction.followup.send(f"No images found in the message: \n{message.content=}") + await interaction.followup.send( + f"No images found in the message: \n{message.content=}", + ) return for image in images: @@ -1077,9 +1275,18 @@ async def enhance_image_command(interaction: discord.Interaction, message: disco ) # Prepare files - file1 = discord.File(fp=io.BytesIO(enhanced_image1), filename=f"enhanced1-{timestamp}.webp") - file2 = discord.File(fp=io.BytesIO(enhanced_image2), filename=f"enhanced2-{timestamp}.webp") - file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp") + file1 = discord.File( + fp=io.BytesIO(enhanced_image1), + filename=f"enhanced1-{timestamp}.webp", + ) + file2 = discord.File( + fp=io.BytesIO(enhanced_image2), + filename=f"enhanced2-{timestamp}.webp", + ) + file3 = discord.File( + fp=io.BytesIO(enhanced_image3), + filename=f"enhanced3-{timestamp}.webp", + ) files: list[discord.File] = [file1, file2, file3] diff --git a/pyproject.toml b/pyproject.toml index 45ad980..89bb4ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,15 +22,21 @@ dependencies = [ dev = ["pytest", "ruff"] [tool.ruff] -preview = true fix = true +preview = true unsafe-fixes = true -lint.select = ["ALL"] -lint.fixable = ["ALL"] -lint.pydocstyle.convention = "google" -lint.isort.required-imports = ["from __future__ import annotations"] -lint.pycodestyle.ignore-overlong-task-comments = true +format.docstring-code-format = true +format.preview = true + +lint.future-annotations = true +lint.isort.force-single-line = true +lint.pycodestyle.ignore-overlong-task-comments = true +lint.pydocstyle.convention = "google" +lint.select = ["ALL"] + +# Don't automatically remove unused variables +lint.unfixable = ["F841"] lint.ignore = [ "CPY001", # Checks for the absence of copyright notices within Python files. "D100", # Checks for undocumented public module definitions. @@ -56,13 +62,8 @@ lint.ignore = [ "Q003", # Checks for strings that include escaped quotes, and suggests changing the quote style to avoid the need to escape them. "W191", # Checks for indentation that uses tabs. ] -line-length = 160 -[tool.ruff.format] -docstring-code-format = true -docstring-code-line-length = 20 - [tool.ruff.lint.per-file-ignores] "**/test_*.py" = [ "ARG", # Unused function args -> fixtures nevertheless are functionally relevant... diff --git a/tests/test_reset_undo.py b/tests/test_reset_undo.py index 1a90956..1c82d47 100644 --- a/tests/test_reset_undo.py +++ b/tests/test_reset_undo.py @@ -2,15 +2,13 @@ from __future__ import annotations import pytest -from main import ( - add_message_to_memory, - last_trigger_time, - recent_messages, - reset_memory, - reset_snapshots, - undo_reset, - update_trigger_time, -) +from main import add_message_to_memory +from main import last_trigger_time +from main import recent_messages +from main import reset_memory +from main import reset_snapshots +from main import undo_reset +from main import update_trigger_time @pytest.fixture(autouse=True)