Add MARK comments for better code organization and readability

This commit is contained in:
Joakim Hellsén 2025-09-26 17:43:26 +02:00
commit ddf9e636f4

30
main.py
View file

@ -79,6 +79,7 @@ agent: Agent[BotDependencies, str] = Agent(
) )
# MARK: reset_memory
def reset_memory(channel_id: str) -> None: def reset_memory(channel_id: str) -> None:
"""Reset the conversation memory for a specific channel. """Reset the conversation memory for a specific channel.
@ -141,6 +142,7 @@ def compact_message_history(
return kept return kept
# MARK: fetch_user_info
@agent.instructions @agent.instructions
def fetch_user_info(ctx: RunContext[BotDependencies]) -> str: def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
"""Fetches detailed information about the user who sent the message, including their roles, status, and activity. """Fetches detailed information about the user who sent the message, including their roles, status, and activity.
@ -161,6 +163,7 @@ def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
return str(details) return str(details)
# MARK: get_system_performance_stats
@agent.instructions @agent.instructions
def get_system_performance_stats() -> str: def get_system_performance_stats() -> str:
"""Retrieves current system performance metrics, including CPU, memory, and disk usage. """Retrieves current system performance metrics, including CPU, memory, and disk usage.
@ -177,6 +180,7 @@ def get_system_performance_stats() -> str:
return str(stats) return str(stats)
# MARK: get_channels
@agent.instructions @agent.instructions
def get_channels(ctx: RunContext[BotDependencies]) -> str: def get_channels(ctx: RunContext[BotDependencies]) -> str:
"""Retrieves a list of all channels the bot is currently in. """Retrieves a list of all channels the bot is currently in.
@ -196,6 +200,7 @@ def get_channels(ctx: RunContext[BotDependencies]) -> str:
return context return context
# MARK: do_web_search
def do_web_search(query: str) -> ollama.WebSearchResponse | None: def do_web_search(query: str) -> ollama.WebSearchResponse | None:
"""Perform a web search using the Ollama API. """Perform a web search using the Ollama API.
@ -214,6 +219,7 @@ def do_web_search(query: str) -> ollama.WebSearchResponse | None:
return response return response
# MARK: get_time_and_timezone
@agent.instructions @agent.instructions
def get_time_and_timezone() -> str: def get_time_and_timezone() -> str:
"""Retrieves the current time and timezone information. """Retrieves the current time and timezone information.
@ -225,6 +231,7 @@ def get_time_and_timezone() -> str:
return f"Current time: {current_time.strftime('%Y-%m-%d %H:%M:%S')}, current timezone: {current_time.tzname()}" return f"Current time: {current_time.strftime('%Y-%m-%d %H:%M:%S')}, current timezone: {current_time.tzname()}"
# MARK: get_latency
@agent.instructions @agent.instructions
def get_latency(ctx: RunContext[BotDependencies]) -> str: def get_latency(ctx: RunContext[BotDependencies]) -> str:
"""Retrieves the current latency information. """Retrieves the current latency information.
@ -236,6 +243,7 @@ def get_latency(ctx: RunContext[BotDependencies]) -> str:
return f"Current latency: {latency} seconds" return f"Current latency: {latency} seconds"
# MARK: added_information_from_web_search
@agent.instructions @agent.instructions
def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str: def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
"""Adds information from a web search to the system prompt. """Adds information from a web search to the system prompt.
@ -253,6 +261,7 @@ def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
return "" return ""
# MARK: get_sticker_instructions
@agent.instructions @agent.instructions
def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str: def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str:
"""Provides instructions for using stickers in the chat. """Provides instructions for using stickers in the chat.
@ -281,6 +290,7 @@ def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str:
return context + ("- Only send the sticker URL itself. Never add text to sticker combos.\n") return context + ("- Only send the sticker URL itself. Never add text to sticker combos.\n")
# MARK: get_emoji_instructions
@agent.instructions @agent.instructions
def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str: def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
"""Provides instructions for using emojis in the chat. """Provides instructions for using emojis in the chat.
@ -329,6 +339,7 @@ def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
return context return context
# MARK: get_system_prompt
@agent.instructions @agent.instructions
def get_system_prompt() -> str: def get_system_prompt() -> str:
"""Generate the core system prompt. """Generate the core system prompt.
@ -353,6 +364,7 @@ def get_system_prompt() -> str:
) )
# MARK: chat
async def chat( # noqa: PLR0913, PLR0917 async def chat( # noqa: PLR0913, PLR0917
client: discord.Client, client: discord.Client,
user_message: str, user_message: str,
@ -413,6 +425,7 @@ async def chat( # noqa: PLR0913, PLR0917
return result.output return result.output
# MARK: get_recent_messages
def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tuple[str, str]]: def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tuple[str, str]]:
"""Retrieve messages from the last `threshold_minutes` minutes for a specific channel. """Retrieve messages from the last `threshold_minutes` minutes for a specific channel.
@ -430,6 +443,7 @@ def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tu
return [(user, message) for user, message, timestamp in recent_messages[str(channel_id)] if timestamp > threshold] return [(user, message) for user, message, timestamp in recent_messages[str(channel_id)] if timestamp > threshold]
# MARK: get_images_from_text
async def get_images_from_text(text: str) -> list[str]: async def get_images_from_text(text: str) -> list[str]:
"""Extract all image URLs from text and return their URLs. """Extract all image URLs from text and return their URLs.
@ -457,6 +471,7 @@ async def get_images_from_text(text: str) -> list[str]:
return images return images
# MARK: get_raw_images_from_text
async def get_raw_images_from_text(text: str) -> list[bytes]: async def get_raw_images_from_text(text: str) -> list[bytes]:
"""Extract all image URLs from text and return their bytes. """Extract all image URLs from text and return their bytes.
@ -483,6 +498,7 @@ async def get_raw_images_from_text(text: str) -> list[bytes]:
return images return images
# MARK: get_allowed_users
def get_allowed_users() -> list[str]: def get_allowed_users() -> list[str]:
"""Get the list of allowed users to interact with the bot. """Get the list of allowed users to interact with the bot.
@ -499,6 +515,7 @@ def get_allowed_users() -> list[str]:
] ]
# MARK: should_respond_without_trigger
def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds: int = 40) -> bool: def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds: int = 40) -> bool:
"""Check if the bot should respond to a user without requiring trigger keywords. """Check if the bot should respond to a user without requiring trigger keywords.
@ -522,6 +539,7 @@ def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds
return should_respond return should_respond
# MARK: add_message_to_memory
def add_message_to_memory(channel_id: str, user: str, message: str) -> None: def add_message_to_memory(channel_id: str, user: str, message: str) -> None:
"""Add a message to the memory for a specific channel. """Add a message to the memory for a specific channel.
@ -539,6 +557,7 @@ def add_message_to_memory(channel_id: str, user: str, message: str) -> None:
logger.debug("Added message to memory in channel %s", channel_id) logger.debug("Added message to memory in channel %s", channel_id)
# MARK: update_trigger_time
def update_trigger_time(channel_id: str, user: str) -> None: def update_trigger_time(channel_id: str, user: str) -> None:
"""Update the last trigger time for a user in a specific channel. """Update the last trigger time for a user in a specific channel.
@ -553,6 +572,7 @@ def update_trigger_time(channel_id: str, user: str) -> None:
logger.info("Updated trigger time for user %s in channel %s", user, channel_id) logger.info("Updated trigger time for user %s in channel %s", user, channel_id)
# MARK: send_chunked_message
async def send_chunked_message(channel: DiscordMessageable, text: str, max_len: int = 2000) -> None: async def send_chunked_message(channel: DiscordMessageable, text: str, max_len: int = 2000) -> None:
"""Send a message to a channel, splitting into chunks if it exceeds Discord's limit.""" """Send a message to a channel, splitting into chunks if it exceeds Discord's limit."""
if len(text) <= max_len: if len(text) <= max_len:
@ -562,6 +582,7 @@ async def send_chunked_message(channel: DiscordMessageable, text: str, max_len:
await channel.send(text[i : i + max_len]) await channel.send(text[i : i + max_len])
# MARK: LoviBotClient
class LoviBotClient(discord.Client): class LoviBotClient(discord.Client):
"""The main bot client.""" """The main bot client."""
@ -671,6 +692,7 @@ intents.message_content = True
client = LoviBotClient(intents=intents) client = LoviBotClient(intents=intents)
# MARK: /ask command
@client.tree.command(name="ask", description="Ask LoviBot a question.") @client.tree.command(name="ask", description="Ask LoviBot a question.")
@app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_installs(guilds=True, users=True)
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
@ -745,6 +767,7 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo
await send_response(interaction=interaction, text=text, response=display_response) await send_response(interaction=interaction, text=text, response=display_response)
# MARK: /reset command
@client.tree.command(name="reset", description="Reset the conversation memory.") @client.tree.command(name="reset", description="Reset the conversation memory.")
@app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_installs(guilds=True, users=True)
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
@ -768,6 +791,7 @@ async def reset(interaction: discord.Interaction) -> None:
await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.") await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.")
# MARK: send_response
async def send_response(interaction: discord.Interaction, text: str, response: str) -> None: async def send_response(interaction: discord.Interaction, text: str, response: str) -> None:
"""Send a response to the interaction, handling potential errors. """Send a response to the interaction, handling potential errors.
@ -787,6 +811,7 @@ async def send_response(interaction: discord.Interaction, text: str, response: s
await interaction.followup.send(f"Failed to send message: {e}") await interaction.followup.send(f"Failed to send message: {e}")
# MARK: truncate_user_input
def truncate_user_input(text: str) -> str: def truncate_user_input(text: str) -> str:
"""Truncate user input if it exceeds the maximum length. """Truncate user input if it exceeds the maximum length.
@ -804,6 +829,7 @@ def truncate_user_input(text: str) -> str:
type ImageType = np.ndarray[Any, np.dtype[np.integer[Any] | np.floating[Any]]] | cv2.Mat type ImageType = np.ndarray[Any, np.dtype[np.integer[Any] | np.floating[Any]]] | cv2.Mat
# MARK: enhance_image1
def enhance_image1(image: bytes) -> bytes: def enhance_image1(image: bytes) -> bytes:
"""Enhance an image using OpenCV histogram equalization with denoising. """Enhance an image using OpenCV histogram equalization with denoising.
@ -840,6 +866,7 @@ def enhance_image1(image: bytes) -> bytes:
return enhanced_webp.tobytes() return enhanced_webp.tobytes()
# MARK: enhance_image2
def enhance_image2(image: bytes) -> bytes: def enhance_image2(image: bytes) -> bytes:
"""Enhance an image using gamma correction, contrast enhancement, and denoising. """Enhance an image using gamma correction, contrast enhancement, and denoising.
@ -879,6 +906,7 @@ def enhance_image2(image: bytes) -> bytes:
return enhanced_webp.tobytes() return enhanced_webp.tobytes()
# MARK: enhance_image3
def enhance_image3(image: bytes) -> bytes: def enhance_image3(image: bytes) -> bytes:
"""Enhance an image using HSV color space manipulation with denoising. """Enhance an image using HSV color space manipulation with denoising.
@ -917,6 +945,7 @@ def enhance_image3(image: bytes) -> bytes:
T = TypeVar("T") T = TypeVar("T")
# MARK: run_in_thread
async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) -> T: # noqa: ANN401 async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) -> T: # noqa: ANN401
"""Run a blocking function in a separate thread. """Run a blocking function in a separate thread.
@ -931,6 +960,7 @@ async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) ->
return await asyncio.to_thread(func, *args, **kwargs) return await asyncio.to_thread(func, *args, **kwargs)
# MARK: enhance_image_command
@client.tree.context_menu(name="Enhance Image") @client.tree.context_menu(name="Enhance Image")
@app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_installs(guilds=True, users=True)
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)