Add MARK comments for better code organization and readability

This commit is contained in:
Joakim Hellsén 2025-09-26 17:43:26 +02:00
commit ddf9e636f4

30
main.py
View file

@ -79,6 +79,7 @@ agent: Agent[BotDependencies, str] = Agent(
)
# MARK: reset_memory
def reset_memory(channel_id: str) -> None:
"""Reset the conversation memory for a specific channel.
@ -141,6 +142,7 @@ def compact_message_history(
return kept
# MARK: fetch_user_info
@agent.instructions
def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
"""Fetches detailed information about the user who sent the message, including their roles, status, and activity.
@ -161,6 +163,7 @@ def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
return str(details)
# MARK: get_system_performance_stats
@agent.instructions
def get_system_performance_stats() -> str:
"""Retrieves current system performance metrics, including CPU, memory, and disk usage.
@ -177,6 +180,7 @@ def get_system_performance_stats() -> str:
return str(stats)
# MARK: get_channels
@agent.instructions
def get_channels(ctx: RunContext[BotDependencies]) -> str:
"""Retrieves a list of all channels the bot is currently in.
@ -196,6 +200,7 @@ def get_channels(ctx: RunContext[BotDependencies]) -> str:
return context
# MARK: do_web_search
def do_web_search(query: str) -> ollama.WebSearchResponse | None:
"""Perform a web search using the Ollama API.
@ -214,6 +219,7 @@ def do_web_search(query: str) -> ollama.WebSearchResponse | None:
return response
# MARK: get_time_and_timezone
@agent.instructions
def get_time_and_timezone() -> str:
"""Retrieves the current time and timezone information.
@ -225,6 +231,7 @@ def get_time_and_timezone() -> str:
return f"Current time: {current_time.strftime('%Y-%m-%d %H:%M:%S')}, current timezone: {current_time.tzname()}"
# MARK: get_latency
@agent.instructions
def get_latency(ctx: RunContext[BotDependencies]) -> str:
"""Retrieves the current latency information.
@ -236,6 +243,7 @@ def get_latency(ctx: RunContext[BotDependencies]) -> str:
return f"Current latency: {latency} seconds"
# MARK: added_information_from_web_search
@agent.instructions
def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
"""Adds information from a web search to the system prompt.
@ -253,6 +261,7 @@ def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
return ""
# MARK: get_sticker_instructions
@agent.instructions
def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str:
"""Provides instructions for using stickers in the chat.
@ -281,6 +290,7 @@ def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str:
return context + ("- Only send the sticker URL itself. Never add text to sticker combos.\n")
# MARK: get_emoji_instructions
@agent.instructions
def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
"""Provides instructions for using emojis in the chat.
@ -329,6 +339,7 @@ def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
return context
# MARK: get_system_prompt
@agent.instructions
def get_system_prompt() -> str:
"""Generate the core system prompt.
@ -353,6 +364,7 @@ def get_system_prompt() -> str:
)
# MARK: chat
async def chat( # noqa: PLR0913, PLR0917
client: discord.Client,
user_message: str,
@ -413,6 +425,7 @@ async def chat( # noqa: PLR0913, PLR0917
return result.output
# MARK: get_recent_messages
def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tuple[str, str]]:
"""Retrieve messages from the last `threshold_minutes` minutes for a specific channel.
@ -430,6 +443,7 @@ def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tu
return [(user, message) for user, message, timestamp in recent_messages[str(channel_id)] if timestamp > threshold]
# MARK: get_images_from_text
async def get_images_from_text(text: str) -> list[str]:
"""Extract all image URLs from text and return their URLs.
@ -457,6 +471,7 @@ async def get_images_from_text(text: str) -> list[str]:
return images
# MARK: get_raw_images_from_text
async def get_raw_images_from_text(text: str) -> list[bytes]:
"""Extract all image URLs from text and return their bytes.
@ -483,6 +498,7 @@ async def get_raw_images_from_text(text: str) -> list[bytes]:
return images
# MARK: get_allowed_users
def get_allowed_users() -> list[str]:
"""Get the list of allowed users to interact with the bot.
@ -499,6 +515,7 @@ def get_allowed_users() -> list[str]:
]
# MARK: should_respond_without_trigger
def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds: int = 40) -> bool:
"""Check if the bot should respond to a user without requiring trigger keywords.
@ -522,6 +539,7 @@ def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds
return should_respond
# MARK: add_message_to_memory
def add_message_to_memory(channel_id: str, user: str, message: str) -> None:
"""Add a message to the memory for a specific channel.
@ -539,6 +557,7 @@ def add_message_to_memory(channel_id: str, user: str, message: str) -> None:
logger.debug("Added message to memory in channel %s", channel_id)
# MARK: update_trigger_time
def update_trigger_time(channel_id: str, user: str) -> None:
"""Update the last trigger time for a user in a specific channel.
@ -553,6 +572,7 @@ def update_trigger_time(channel_id: str, user: str) -> None:
logger.info("Updated trigger time for user %s in channel %s", user, channel_id)
# MARK: send_chunked_message
async def send_chunked_message(channel: DiscordMessageable, text: str, max_len: int = 2000) -> None:
"""Send a message to a channel, splitting into chunks if it exceeds Discord's limit."""
if len(text) <= max_len:
@ -562,6 +582,7 @@ async def send_chunked_message(channel: DiscordMessageable, text: str, max_len:
await channel.send(text[i : i + max_len])
# MARK: LoviBotClient
class LoviBotClient(discord.Client):
"""The main bot client."""
@ -671,6 +692,7 @@ intents.message_content = True
client = LoviBotClient(intents=intents)
# MARK: /ask command
@client.tree.command(name="ask", description="Ask LoviBot a question.")
@app_commands.allowed_installs(guilds=True, users=True)
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
@ -745,6 +767,7 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo
await send_response(interaction=interaction, text=text, response=display_response)
# MARK: /reset command
@client.tree.command(name="reset", description="Reset the conversation memory.")
@app_commands.allowed_installs(guilds=True, users=True)
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
@ -768,6 +791,7 @@ async def reset(interaction: discord.Interaction) -> None:
await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.")
# MARK: send_response
async def send_response(interaction: discord.Interaction, text: str, response: str) -> None:
"""Send a response to the interaction, handling potential errors.
@ -787,6 +811,7 @@ async def send_response(interaction: discord.Interaction, text: str, response: s
await interaction.followup.send(f"Failed to send message: {e}")
# MARK: truncate_user_input
def truncate_user_input(text: str) -> str:
"""Truncate user input if it exceeds the maximum length.
@ -804,6 +829,7 @@ def truncate_user_input(text: str) -> str:
type ImageType = np.ndarray[Any, np.dtype[np.integer[Any] | np.floating[Any]]] | cv2.Mat
# MARK: enhance_image1
def enhance_image1(image: bytes) -> bytes:
"""Enhance an image using OpenCV histogram equalization with denoising.
@ -840,6 +866,7 @@ def enhance_image1(image: bytes) -> bytes:
return enhanced_webp.tobytes()
# MARK: enhance_image2
def enhance_image2(image: bytes) -> bytes:
"""Enhance an image using gamma correction, contrast enhancement, and denoising.
@ -879,6 +906,7 @@ def enhance_image2(image: bytes) -> bytes:
return enhanced_webp.tobytes()
# MARK: enhance_image3
def enhance_image3(image: bytes) -> bytes:
"""Enhance an image using HSV color space manipulation with denoising.
@ -917,6 +945,7 @@ def enhance_image3(image: bytes) -> bytes:
T = TypeVar("T")
# MARK: run_in_thread
async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) -> T: # noqa: ANN401
"""Run a blocking function in a separate thread.
@ -931,6 +960,7 @@ async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) ->
return await asyncio.to_thread(func, *args, **kwargs)
# MARK: enhance_image_command
@client.tree.context_menu(name="Enhance Image")
@app_commands.allowed_installs(guilds=True, users=True)
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)