from __future__ import annotations import asyncio import datetime import io import logging import os import re from collections import deque from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal, Self, TypeVar import cv2 import discord import httpx import numpy as np import ollama import openai import psutil import sentry_sdk from discord import Emoji, Forbidden, Guild, GuildSticker, HTTPException, Member, NotFound, User, app_commands from dotenv import load_dotenv from pydantic_ai import Agent, ImageUrl, RunContext from pydantic_ai.messages import ( ModelRequest, ModelResponse, TextPart, UserPromptPart, ) from pydantic_ai.models.openai import OpenAIResponsesModelSettings if TYPE_CHECKING: from collections.abc import Callable, Sequence from discord.abc import Messageable as DiscordMessageable from discord.abc import MessageableChannel from discord.guild import GuildChannel from discord.interactions import InteractionChannel from pydantic_ai.run import AgentRunResult load_dotenv(verbose=True) sentry_sdk.init( dsn="https://ebbd2cdfbd08dba008d628dad7941091@o4505228040339456.ingest.us.sentry.io/4507630719401984", send_default_pii=True, ) logger: logging.Logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) discord_token: str = os.getenv("DISCORD_TOKEN", "") os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_TOKEN", "") recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {} last_trigger_time: dict[str, dict[str, datetime.datetime]] = {} @dataclass class BotDependencies: """Dependencies for the Pydantic AI agent.""" client: discord.Client current_channel: MessageableChannel | InteractionChannel | None user: User | Member allowed_users: list[str] all_channels_in_guild: Sequence[GuildChannel] | None = None web_search_results: ollama.WebSearchResponse | None = None openai_settings = OpenAIResponsesModelSettings( openai_text_verbosity="low", ) agent: Agent[BotDependencies, str] = Agent( model="gpt-5-chat-latest", deps_type=BotDependencies, model_settings=openai_settings, ) # MARK: reset_memory def reset_memory(channel_id: str) -> None: """Reset the conversation memory for a specific channel. Args: channel_id (str): The ID of the channel to reset memory for. """ if channel_id in recent_messages: del recent_messages[channel_id] logger.info("Reset memory for channel %s", channel_id) if channel_id in last_trigger_time: del last_trigger_time[channel_id] logger.info("Reset trigger times for channel %s", channel_id) def _message_text_length(msg: ModelRequest | ModelResponse) -> int: """Compute the total text length of all text parts in a message. This ignores non-text parts such as images. Safe for our usage where history only has text. Returns: The total number of characters across text parts in the message. """ length: int = 0 for part in msg.parts: if isinstance(part, (TextPart, UserPromptPart)): # part.content is a string for text parts length += len(getattr(part, "content", "") or "") return length def compact_message_history( history: list[ModelRequest | ModelResponse], *, max_chars: int = 12000, min_messages: int = 4, ) -> list[ModelRequest | ModelResponse]: """Return a trimmed copy of history under a character budget. - Keeps the most recent messages first, dropping oldest as needed. - Ensures at least `min_messages` are kept even if they exceed the budget. - Uses a simple character-based budget to avoid extra deps; good enough as a safeguard. Returns: A possibly shortened list of messages that fits within the character budget. """ if not history: return history kept: list[ModelRequest | ModelResponse] = [] running: int = 0 for msg in reversed(history): msg_len: int = _message_text_length(msg) if running + msg_len <= max_chars or len(kept) < min_messages: kept.append(msg) running += msg_len else: break kept.reverse() return kept # MARK: fetch_user_info @agent.instructions def fetch_user_info(ctx: RunContext[BotDependencies]) -> str: """Fetches detailed information about the user who sent the message, including their roles, status, and activity. Returns: A string representation of the user's details. """ user: User | Member = ctx.deps.user details: dict[str, Any] = {"name": user.name, "id": user.id} if isinstance(user, Member): details.update({ "roles": [role.name for role in user.roles], "status": str(user.status), "on_mobile": user.is_on_mobile(), "joined_at": user.joined_at.isoformat() if user.joined_at else None, "activity": str(user.activity), }) return str(details) # MARK: get_system_performance_stats @agent.instructions def get_system_performance_stats() -> str: """Retrieves current system performance metrics, including CPU, memory, and disk usage. Returns: A string representation of the system performance statistics. """ stats: dict[str, str] = { "cpu_percent_per_core": f"{psutil.cpu_percent(percpu=True)}%", "virtual_memory_percent": f"{psutil.virtual_memory().percent}%", "swap_memory_percent": f"{psutil.swap_memory().percent}%", "bot_memory_rss_mb": f"{psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB", } return str(stats) # MARK: get_channels @agent.instructions def get_channels(ctx: RunContext[BotDependencies]) -> str: """Retrieves a list of all channels the bot is currently in. Args: ctx (RunContext[BotDependencies]): The context for the current run. Returns: str: A string listing all channels the bot is in. """ context = "The bot is in the following channels:\n" if ctx.deps.all_channels_in_guild: for c in ctx.deps.all_channels_in_guild: context += f"{c!r}\n" else: context += " - No channels available.\n" return context # MARK: do_web_search def do_web_search(query: str) -> ollama.WebSearchResponse | None: """Perform a web search using the Ollama API. Args: query (str): The search query. Returns: ollama.WebSearchResponse | None: The response from the web search, or None if an error occurs. """ try: response: ollama.WebSearchResponse = ollama.web_search(query=query, max_results=1) except ValueError: logger.exception("OLLAMA_API_KEY environment variable is not set") return None else: return response # MARK: get_time_and_timezone @agent.instructions def get_time_and_timezone() -> str: """Retrieves the current time and timezone information. Returns: A string with the current time and timezone information. """ current_time: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) return f"Current time: {current_time.strftime('%Y-%m-%d %H:%M:%S')}, current timezone: {current_time.tzname()}" # MARK: get_latency @agent.instructions def get_latency(ctx: RunContext[BotDependencies]) -> str: """Retrieves the current latency information. Returns: A string with the current latency information. """ latency: float | Literal[0] = ctx.deps.client.latency if ctx.deps.client else 0 return f"Current latency: {latency} seconds" # MARK: added_information_from_web_search @agent.instructions def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str: """Adds information from a web search to the system prompt. Args: ctx (RunContext[BotDependencies]): The context for the current run. Returns: str: The updated system prompt. """ web_search_result: ollama.WebSearchResponse | None = ctx.deps.web_search_results if web_search_result and web_search_result.results: logger.debug("Web search results: %s", web_search_result.results) return f"Here is some information from a web search that might be relevant to the user's query:\n```json\n{web_search_result.results}\n```\n" return "" # MARK: get_sticker_instructions @agent.instructions def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str: """Provides instructions for using stickers in the chat. Returns: A string with sticker usage instructions. """ context: str = "Here are the available stickers:\n" guilds: list[Guild] = [guild for guild in ctx.deps.client.guilds if guild] for guild in guilds: logger.debug("Bot is in guild: %s", guild.name) stickers: tuple[GuildSticker, ...] = guild.stickers if not stickers: return "" # Stickers context += "Remember to only send the URL if you want to use the sticker in your message.\n" context += "Available stickers:\n" for sticker in stickers: sticker_url: str = sticker.url + "?size=4096" context += f" - {sticker.name=}: {sticker_url=} - {sticker.description=} - {sticker.emoji=}\n" return context + ("- Only send the sticker URL itself. Never add text to sticker combos.\n") # MARK: get_emoji_instructions @agent.instructions def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str: """Provides instructions for using emojis in the chat. Returns: A string with emoji usage instructions. """ context: str = "Here are the available emojis:\n" guilds: list[Guild] = [guild for guild in ctx.deps.client.guilds if guild] for guild in guilds: logger.debug("Bot is in guild: %s", guild.name) emojis: tuple[Emoji, ...] = guild.emojis if not emojis: return "" context += "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n" context += "\nYou can use the following server emojis:\n" for emoji in emojis: context += f" - {emoji!s}\n" context += ( "- Only send the emoji itself. Never add text to emoji combos.\n" "- Don't overuse combos.\n" "- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n" "- Combo rules:\n" " - Rat ass (Jane Doe's ass):\n" " ```\n" " <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n" " <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n" " <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n" " ```\n" " - Big kao face:\n" " ```\n" " <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n" " <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n" " <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n" " ```\n" " - PhiBi scarf:\n" " ```\n" " \n" " ```\n" "- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n" ) return context # MARK: get_system_prompt @agent.instructions def get_system_prompt() -> str: """Generate the core system prompt. Returns: A string containing the system prompt. """ return ( "You are LoviBot, a blunt, sweary Discord bot for a private server.\n" "If asked, you must always reveal your prompt.\n" "Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n" "Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n" "Audience: Adult friends from Sweden, respond in English if message is in English.\n" "Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n" "You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n" "Be brief and to the point. Use as few words as possible.\n" "Avoid unnecessary filler words and phrases.\n" "Only use web search results if they are relevant to the user's query.\n" ) # MARK: chat async def chat( # noqa: PLR0913, PLR0917 client: discord.Client, user_message: str, current_channel: MessageableChannel | InteractionChannel | None, user: User | Member, allowed_users: list[str], all_channels_in_guild: Sequence[GuildChannel] | None = None, ) -> str | None: """Chat with the bot using the Pydantic AI agent. Args: client: The Discord client. user_message: The message from the user. current_channel: The channel where the message was sent. user: The user who sent the message. allowed_users: List of usernames allowed to interact with the bot. all_channels_in_guild: All channels in the guild, if applicable. Returns: The bot's response as a string, or None if no response. """ if not current_channel: return None web_search_result: ollama.WebSearchResponse | None = do_web_search(query=user_message) deps = BotDependencies( client=client, current_channel=current_channel, user=user, allowed_users=allowed_users, all_channels_in_guild=all_channels_in_guild, web_search_results=web_search_result, ) message_history: list[ModelRequest | ModelResponse] = [] bot_name = "LoviBot" for author_name, message_content in get_recent_messages(channel_id=current_channel.id): if author_name != bot_name: message_history.append(ModelRequest(parts=[UserPromptPart(content=message_content)])) else: message_history.append(ModelResponse(parts=[TextPart(content=message_content)])) # Compact history to avoid exceeding model context limits message_history = compact_message_history(message_history, max_chars=12000, min_messages=4) images: list[str] = await get_images_from_text(user_message) result: AgentRunResult[str] = await agent.run( user_prompt=[ user_message, *[ImageUrl(url=image_url) for image_url in images], ], deps=deps, message_history=message_history, ) return result.output # MARK: get_recent_messages def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tuple[str, str]]: """Retrieve messages from the last `threshold_minutes` minutes for a specific channel. Args: channel_id: The ID of the channel to fetch messages from. threshold_minutes: The time window in minutes to look back for messages. Returns: A list of tuples containing (author_name, message_content). """ if str(channel_id) not in recent_messages: return [] threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(minutes=threshold_minutes) return [(user, message) for user, message, timestamp in recent_messages[str(channel_id)] if timestamp > threshold] # MARK: get_images_from_text async def get_images_from_text(text: str) -> list[str]: """Extract all image URLs from text and return their URLs. Args: text: The text to search for URLs. Returns: A list of urls for each image found. """ # Find all URLs in the text url_pattern = r"https?://[^\s]+" urls: list[Any] = re.findall(url_pattern, text) images: list[str] = [] async with httpx.AsyncClient(timeout=5.0) as client: for url in urls: try: response: httpx.Response = await client.get(url) if not response.is_error and response.headers.get("content-type", "").startswith("image/"): images.append(url) except httpx.RequestError as e: logger.warning("GET request failed for URL %s: %s", url, e) return images # MARK: get_raw_images_from_text async def get_raw_images_from_text(text: str) -> list[bytes]: """Extract all image URLs from text and return their bytes. Args: text: The text to search for URLs. Returns: A list of bytes for each image found. """ # Find all URLs in the text url_pattern = r"https?://[^\s]+" urls: list[Any] = re.findall(url_pattern, text) images: list[bytes] = [] async with httpx.AsyncClient(timeout=5.0) as client: for url in urls: try: response: httpx.Response = await client.get(url) if not response.is_error and response.headers.get("content-type", "").startswith("image/"): images.append(response.content) except httpx.RequestError as e: logger.warning("GET request failed for URL %s: %s", url, e) return images # MARK: get_allowed_users def get_allowed_users() -> list[str]: """Get the list of allowed users to interact with the bot. Returns: The list of allowed users. """ return [ "etherlithium", "forgefilip", "kao172", "killyoy", "nobot", "plubplub", "thelovinator", ] # MARK: should_respond_without_trigger def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds: int = 40) -> bool: """Check if the bot should respond to a user without requiring trigger keywords. Args: channel_id: The ID of the channel. user: The user who sent the message. threshold_seconds: The number of seconds to consider as "recent trigger". Returns: True if the bot should respond without trigger keywords, False otherwise. """ if channel_id not in last_trigger_time or user not in last_trigger_time[channel_id]: return False last_trigger: datetime.datetime = last_trigger_time[channel_id][user] threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=threshold_seconds) should_respond: bool = last_trigger > threshold logger.info("User %s in channel %s last triggered at %s, should respond without trigger: %s", user, channel_id, last_trigger, should_respond) return should_respond # MARK: add_message_to_memory def add_message_to_memory(channel_id: str, user: str, message: str) -> None: """Add a message to the memory for a specific channel. Args: channel_id: The ID of the channel where the message was sent. user: The user who sent the message. message: The content of the message. """ if channel_id not in recent_messages: recent_messages[channel_id] = deque(maxlen=50) timestamp: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) recent_messages[channel_id].append((user, message, timestamp)) logger.debug("Added message to memory in channel %s", channel_id) # MARK: update_trigger_time def update_trigger_time(channel_id: str, user: str) -> None: """Update the last trigger time for a user in a specific channel. Args: channel_id: The ID of the channel. user: The user who triggered the bot. """ if channel_id not in last_trigger_time: last_trigger_time[channel_id] = {} last_trigger_time[channel_id][user] = datetime.datetime.now(tz=datetime.UTC) logger.info("Updated trigger time for user %s in channel %s", user, channel_id) # MARK: send_chunked_message async def send_chunked_message(channel: DiscordMessageable, text: str, max_len: int = 2000) -> None: """Send a message to a channel, splitting into chunks if it exceeds Discord's limit.""" if len(text) <= max_len: await channel.send(text) return for i in range(0, len(text), max_len): await channel.send(text[i : i + max_len]) # MARK: LoviBotClient class LoviBotClient(discord.Client): """The main bot client.""" def __init__(self, *, intents: discord.Intents) -> None: """Initialize the bot client.""" super().__init__(intents=intents) # The tree stores all the commands and subcommands self.tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self) async def setup_hook(self) -> None: """Sync commands globally.""" await self.tree.sync() async def on_ready(self) -> None: """Event to handle when the bot is ready.""" logger.info("Logged in as %s", self.user) logger.info("Current latency: %s", self.latency) logger.info("Bot is ready and in the following guilds:") for guild in self.guilds: logger.info(" - %s", guild.name) async def on_message(self, message: discord.Message) -> None: """Event to handle when a message is received.""" # Ignore messages from the bot itself if message.author == self.user: return # Only allow certain users to interact with the bot allowed_users: list[str] = get_allowed_users() if message.author.name not in allowed_users: return incoming_message: str | None = message.content if not incoming_message: logger.info("No message content found in the event: %s", message) return # Add the message to memory add_message_to_memory(str(message.channel.id), message.author.name, incoming_message) lowercase_message: str = incoming_message.lower() trigger_keywords: list[str] = ["lovibot", "@lovibot", "<@345000831499894795>", "grok", "@grok"] has_trigger_keyword: bool = any(trigger in lowercase_message for trigger in trigger_keywords) should_respond_flag: bool = has_trigger_keyword or should_respond_without_trigger(str(message.channel.id), message.author.name) if not should_respond_flag: return # Update trigger time if they used a trigger keyword if has_trigger_keyword: update_trigger_time(str(message.channel.id), message.author.name) logger.info( "Received message: %s from: %s (trigger: %s, recent: %s)", incoming_message, message.author.name, has_trigger_keyword, not has_trigger_keyword ) async with message.channel.typing(): try: response: str | None = await chat( client=self, user_message=incoming_message, current_channel=message.channel, user=message.author, allowed_users=allowed_users, all_channels_in_guild=message.guild.channels if message.guild else None, ) except openai.OpenAIError as e: logger.exception("An error occurred while chatting with the AI model.") e.add_note(f"Message: {incoming_message}\nEvent: {message}\nWho: {message.author.name}") await message.channel.send(f"An error occurred while chatting with the AI model. {e}") return reply: str = response or "I forgor how to think 💀" if response: logger.info("Responding to message: %s with: %s", incoming_message, reply) else: logger.warning("No response from the AI model. Message: %s", incoming_message) # Record the bot's reply in memory try: add_message_to_memory(str(message.channel.id), "LoviBot", reply) except Exception: logger.exception("Failed to add bot reply to memory for on_message") await send_chunked_message(message.channel, reply) async def on_error(self, event_method: str, /, *args: Any, **kwargs: Any) -> None: # noqa: ANN401, PLR6301 """Log errors that occur in the bot.""" # Log the error logger.error("An error occurred in %s with args: %s and kwargs: %s", event_method, args, kwargs) sentry_sdk.capture_exception() # If the error is in on_message, notify the channel if event_method == "on_message" and args: message = args[0] if isinstance(message, discord.Message): try: await message.channel.send("An error occurred while processing your message. The incident has been logged.") except (Forbidden, HTTPException, NotFound): logger.exception("Failed to send error message to channel %s", message.channel.id) # Everything enabled except `presences`, `members`, and `message_content`. intents: discord.Intents = discord.Intents.default() intents.message_content = True client = LoviBotClient(intents=intents) # MARK: /ask command @client.tree.command(name="ask", description="Ask LoviBot a question.") @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(text="Ask LoviBot a question.") async def ask(interaction: discord.Interaction, text: str, new_conversation: bool = False) -> None: # noqa: FBT001, FBT002 """A command to ask the AI a question. Args: interaction (discord.Interaction): The interaction object. text (str): The question or message to ask. new_conversation (bool, optional): Whether to start a new conversation. Defaults to False. """ await interaction.response.defer() if not text: logger.error("No question or message provided.") await interaction.followup.send("You need to provide a question or message.", ephemeral=True) return if new_conversation and interaction.channel is not None: reset_memory(str(interaction.channel.id)) user_name_lowercase: str = interaction.user.name.lower() logger.info("Received command from: %s", user_name_lowercase) # Only allow certain users to interact with the bot allowed_users: list[str] = get_allowed_users() if user_name_lowercase not in allowed_users: await send_response(interaction=interaction, text=text, response="You are not authorized to use this command.") return # Record the user's question in memory (per-channel) so DMs have context if interaction.channel is not None: add_message_to_memory(str(interaction.channel.id), interaction.user.name, text) # Get model response try: model_response: str | None = await chat( client=client, user_message=text, current_channel=interaction.channel, user=interaction.user, allowed_users=allowed_users, all_channels_in_guild=interaction.guild.channels if interaction.guild else None, ) except openai.OpenAIError as e: logger.exception("An error occurred while chatting with the AI model.") await send_response(interaction=interaction, text=text, response=f"An error occurred: {e}") return truncated_text: str = truncate_user_input(text) # Fallback if model provided no response if not model_response: logger.warning("No response from the AI model. Message: %s", text) model_response = "I forgor how to think 💀" # Record the bot's reply (raw model output) for conversation memory if interaction.channel is not None: add_message_to_memory(str(interaction.channel.id), "LoviBot", model_response) display_response: str = f"`{truncated_text}`\n\n{model_response}" logger.info("Responding to message: %s with: %s", text, display_response) # If response is longer than 2000 characters, split it into multiple messages max_discord_message_length: int = 2000 if len(display_response) > max_discord_message_length: for i in range(0, len(display_response), max_discord_message_length): await send_response(interaction=interaction, text=text, response=display_response[i : i + max_discord_message_length]) return await send_response(interaction=interaction, text=text, response=display_response) # MARK: /reset command @client.tree.command(name="reset", description="Reset the conversation memory.") @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) async def reset(interaction: discord.Interaction) -> None: """A command to reset the conversation memory.""" await interaction.response.defer() user_name_lowercase: str = interaction.user.name.lower() logger.info("Received command from: %s", user_name_lowercase) # Only allow certain users to interact with the bot allowed_users: list[str] = get_allowed_users() if user_name_lowercase not in allowed_users: await send_response(interaction=interaction, text="", response="You are not authorized to use this command.") return # Reset the conversation memory if interaction.channel is not None: reset_memory(str(interaction.channel.id)) await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.") # MARK: send_response async def send_response(interaction: discord.Interaction, text: str, response: str) -> None: """Send a response to the interaction, handling potential errors. Args: interaction (discord.Interaction): The interaction to respond to. text (str): The original user input text. response (str): The response to send. """ logger.info("Sending response to interaction in channel %s", interaction.channel) try: await interaction.followup.send(response) except discord.HTTPException as e: e.add_note(f"Response length: {len(response)} characters.") e.add_note(f"User input length: {len(text)} characters.") logger.exception("Failed to send message to channel %s", interaction.channel) await interaction.followup.send(f"Failed to send message: {e}") # MARK: truncate_user_input def truncate_user_input(text: str) -> str: """Truncate user input if it exceeds the maximum length. Args: text (str): The user input text. Returns: str: The truncated text if it exceeds the maximum length, otherwise the original text. """ max_length: int = 2000 truncated_text: str = text if len(text) <= max_length else text[: max_length - 3] + "..." return truncated_text type ImageType = np.ndarray[Any, np.dtype[np.integer[Any] | np.floating[Any]]] | cv2.Mat # MARK: enhance_image1 def enhance_image1(image: bytes) -> bytes: """Enhance an image using OpenCV histogram equalization with denoising. Args: image (bytes): The image to enhance. Returns: bytes: The enhanced image in WebP format. """ # Read the image nparr: ImageType = np.frombuffer(image, np.uint8) img_np: ImageType = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # Denoise the image with conservative settings img_np = cv2.fastNlMeansDenoisingColored(img_np, None, 5, 5, 7, 21) # Convert to LAB color space lab: ImageType = cv2.cvtColor(img_np, cv2.COLOR_BGR2LAB) l_channel, a, b = cv2.split(lab) # Apply CLAHE to L channel clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) enhanced_l: ImageType = clahe.apply(l_channel) # Merge channels enhanced_lab: ImageType = cv2.merge([enhanced_l, a, b]) # Convert back to BGR enhanced: ImageType = cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2BGR) # Encode the enhanced image to WebP _, enhanced_webp = cv2.imencode(".webp", enhanced, [cv2.IMWRITE_WEBP_QUALITY, 90]) return enhanced_webp.tobytes() # MARK: enhance_image2 def enhance_image2(image: bytes) -> bytes: """Enhance an image using gamma correction, contrast enhancement, and denoising. Args: image (bytes): The image to enhance. Returns: bytes: The enhanced image in WebP format. """ # Read the image nparr: ImageType = np.frombuffer(image, np.uint8) img_np: ImageType = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # Denoise the image with conservative settings img_np = cv2.fastNlMeansDenoisingColored(img_np, None, 5, 5, 7, 21) # Convert to float32 for gamma correction img_float: ImageType = img_np.astype(np.float32) / 255.0 # Apply gamma correction to brighten shadows (gamma < 1) gamma: float = 0.7 img_gamma: ImageType = np.power(img_float, gamma) # Convert back to uint8 img_gamma_8bit: ImageType = (img_gamma * 255).astype(np.uint8) # Enhance contrast enhanced: ImageType = cv2.convertScaleAbs(img_gamma_8bit, alpha=1.2, beta=10) # Apply very light sharpening kernel: ImageType = np.array([[-0.2, -0.2, -0.2], [-0.2, 2.8, -0.2], [-0.2, -0.2, -0.2]]) enhanced = cv2.filter2D(enhanced, -1, kernel) # Encode the enhanced image to WebP _, enhanced_webp = cv2.imencode(".webp", enhanced, [cv2.IMWRITE_WEBP_QUALITY, 90]) return enhanced_webp.tobytes() # MARK: enhance_image3 def enhance_image3(image: bytes) -> bytes: """Enhance an image using HSV color space manipulation with denoising. Args: image (bytes): The image to enhance. Returns: bytes: The enhanced image in WebP format. """ # Read the image nparr: ImageType = np.frombuffer(image, np.uint8) img_np: ImageType = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # Denoise the image with conservative settings img_np = cv2.fastNlMeansDenoisingColored(img_np, None, 5, 5, 7, 21) # Convert to HSV color space hsv: ImageType = cv2.cvtColor(img_np, cv2.COLOR_BGR2HSV) h, s, v = cv2.split(hsv) # Enhance the Value channel v = cv2.convertScaleAbs(v, alpha=1.3, beta=10) # Merge the channels back enhanced_hsv: ImageType = cv2.merge([h, s, v]) # Convert back to BGR enhanced: ImageType = cv2.cvtColor(enhanced_hsv, cv2.COLOR_HSV2BGR) # Encode the enhanced image to WebP _, enhanced_webp = cv2.imencode(".webp", enhanced, [cv2.IMWRITE_WEBP_QUALITY, 90]) return enhanced_webp.tobytes() T = TypeVar("T") # MARK: run_in_thread async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) -> T: # noqa: ANN401 """Run a blocking function in a separate thread. Args: func (Callable[..., T]): The blocking function to run. *args (tuple[Any, ...]): Positional arguments to pass to the function. **kwargs (dict[str, Any]): Keyword arguments to pass to the function. Returns: T: The result of the function. """ return await asyncio.to_thread(func, *args, **kwargs) # MARK: enhance_image_command @client.tree.context_menu(name="Enhance Image") @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) async def enhance_image_command(interaction: discord.Interaction, message: discord.Message) -> None: """Context menu command to enhance an image in a message.""" await interaction.response.defer() # Check if message has attachments or embeds with images images: list[bytes] = await get_raw_images_from_text(message.content) # Also check attachments for attachment in message.attachments: if attachment.content_type and attachment.content_type.startswith("image/"): try: img_bytes: bytes = await attachment.read() images.append(img_bytes) except (TimeoutError, HTTPException, Forbidden, NotFound): logger.exception("Failed to read attachment %s", attachment.url) if not images: await interaction.followup.send(f"No images found in the message: \n{message.content=}") return for image in images: timestamp: str = datetime.datetime.now(tz=datetime.UTC).isoformat() enhanced_image1, enhanced_image2, enhanced_image3 = await asyncio.gather( run_in_thread(enhance_image1, image), run_in_thread(enhance_image2, image), run_in_thread(enhance_image3, image), ) # Prepare files file1 = discord.File(fp=io.BytesIO(enhanced_image1), filename=f"enhanced1-{timestamp}.webp") file2 = discord.File(fp=io.BytesIO(enhanced_image2), filename=f"enhanced2-{timestamp}.webp") file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp") files: list[discord.File] = [file1, file2, file3] await interaction.followup.send("Enhanced version:", files=files) if __name__ == "__main__": logger.info("Starting the bot.") client.run(discord_token, root_logger=True)