diff --git a/.env.example b/.env.example index aae1f64..88b8813 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,3 @@ DISCORD_TOKEN= OPENAI_TOKEN= +OLLAMA_API_KEY= \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 62064c8..1567075 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -34,6 +34,7 @@ "nobot", "nparr", "numpy", + "Ollama", "opencv", "percpu", "phibiscarf", @@ -48,6 +49,7 @@ "sweary", "testpaths", "thelovinator", + "Thicc", "tobytes", "twimg", "unsignedinteger", diff --git a/main.py b/main.py index 11d9955..47898c6 100644 --- a/main.py +++ b/main.py @@ -5,22 +5,40 @@ import datetime import io import logging import os -from typing import TYPE_CHECKING, Any, TypeVar +import re +from collections import deque +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal, Self, TypeVar import cv2 import discord +import httpx import numpy as np +import ollama import openai +import psutil import sentry_sdk -from discord import Forbidden, HTTPException, NotFound, app_commands +from discord import Emoji, Forbidden, Guild, HTTPException, Member, NotFound, User, app_commands from dotenv import load_dotenv - -from misc import add_message_to_memory, chat, get_allowed_users, get_raw_images_from_text, reset_memory, should_respond_without_trigger, update_trigger_time +from pydantic_ai import Agent, ImageUrl, RunContext +from pydantic_ai.messages import ( + ModelRequest, + ModelResponse, + TextPart, + UserPromptPart, +) +from pydantic_ai.models.openai import OpenAIResponsesModelSettings if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Sequence from discord.abc import Messageable as DiscordMessageable + from discord.abc import MessageableChannel + from discord.guild import GuildChannel + from discord.interactions import InteractionChannel + from pydantic_ai.run import AgentRunResult + +load_dotenv(verbose=True) sentry_sdk.init( dsn="https://ebbd2cdfbd08dba008d628dad7941091@o4505228040339456.ingest.us.sentry.io/4507630719401984", @@ -32,9 +50,501 @@ logger: logging.Logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -load_dotenv(verbose=True) - discord_token: str = os.getenv("DISCORD_TOKEN", "") +os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_TOKEN", "") + +recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {} +last_trigger_time: dict[str, dict[str, datetime.datetime]] = {} + + +@dataclass +class BotDependencies: + """Dependencies for the Pydantic AI agent.""" + + client: discord.Client + current_channel: MessageableChannel | InteractionChannel | None + user: User | Member + allowed_users: list[str] + all_channels_in_guild: Sequence[GuildChannel] | None = None + web_search_results: ollama.WebSearchResponse | None = None + + +openai_settings = OpenAIResponsesModelSettings( + openai_text_verbosity="low", +) +agent: Agent[BotDependencies, str] = Agent( + model="gpt-5-chat-latest", + deps_type=BotDependencies, + model_settings=openai_settings, +) + + +def reset_memory(channel_id: str) -> None: + """Reset the conversation memory for a specific channel. + + Args: + channel_id (str): The ID of the channel to reset memory for. + """ + if channel_id in recent_messages: + del recent_messages[channel_id] + logger.info("Reset memory for channel %s", channel_id) + if channel_id in last_trigger_time: + del last_trigger_time[channel_id] + logger.info("Reset trigger times for channel %s", channel_id) + + +def _message_text_length(msg: ModelRequest | ModelResponse) -> int: + """Compute the total text length of all text parts in a message. + + This ignores non-text parts such as images. Safe for our usage where history only has text. + + Returns: + The total number of characters across text parts in the message. + """ + length: int = 0 + for part in msg.parts: + if isinstance(part, (TextPart, UserPromptPart)): + # part.content is a string for text parts + length += len(getattr(part, "content", "") or "") + return length + + +def compact_message_history( + history: list[ModelRequest | ModelResponse], + *, + max_chars: int = 12000, + min_messages: int = 4, +) -> list[ModelRequest | ModelResponse]: + """Return a trimmed copy of history under a character budget. + + - Keeps the most recent messages first, dropping oldest as needed. + - Ensures at least `min_messages` are kept even if they exceed the budget. + - Uses a simple character-based budget to avoid extra deps; good enough as a safeguard. + + Returns: + A possibly shortened list of messages that fits within the character budget. + """ + if not history: + return history + + kept: list[ModelRequest | ModelResponse] = [] + running: int = 0 + for msg in reversed(history): + msg_len: int = _message_text_length(msg) + if running + msg_len <= max_chars or len(kept) < min_messages: + kept.append(msg) + running += msg_len + else: + break + + kept.reverse() + return kept + + +@agent.instructions +def fetch_user_info(ctx: RunContext[BotDependencies]) -> str: + """Fetches detailed information about the user who sent the message, including their roles, status, and activity. + + Returns: + A string representation of the user's details. + """ + user: User | Member = ctx.deps.user + details: dict[str, Any] = {"name": user.name, "id": user.id} + if isinstance(user, Member): + details.update({ + "roles": [role.name for role in user.roles], + "status": str(user.status), + "on_mobile": user.is_on_mobile(), + "joined_at": user.joined_at.isoformat() if user.joined_at else None, + "activity": str(user.activity), + }) + return str(details) + + +@agent.instructions +def get_system_performance_stats() -> str: + """Retrieves current system performance metrics, including CPU, memory, and disk usage. + + Returns: + A string representation of the system performance statistics. + """ + stats: dict[str, str] = { + "cpu_percent_per_core": f"{psutil.cpu_percent(percpu=True)}%", + "virtual_memory_percent": f"{psutil.virtual_memory().percent}%", + "swap_memory_percent": f"{psutil.swap_memory().percent}%", + "bot_memory_rss_mb": f"{psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB", + } + return str(stats) + + +@agent.instructions +def get_channels(ctx: RunContext[BotDependencies]) -> str: + """Retrieves a list of all channels the bot is currently in. + + Args: + ctx (RunContext[BotDependencies]): The context for the current run. + + Returns: + str: A string listing all channels the bot is in. + """ + context = "The bot is in the following channels:\n" + if ctx.deps.all_channels_in_guild: + for c in ctx.deps.all_channels_in_guild: + context += f"{c!r}\n" + else: + context += " - No channels available.\n" + return context + + +def do_web_search(query: str) -> ollama.WebSearchResponse | None: + """Perform a web search using the Ollama API. + + Args: + query (str): The search query. + + Returns: + ollama.WebSearchResponse | None: The response from the web search, or None if an error occurs. + """ + try: + response: ollama.WebSearchResponse = ollama.web_search(query=query, max_results=1) + except ValueError: + logger.exception("OLLAMA_API_KEY environment variable is not set") + return None + else: + return response + + +@agent.instructions +def get_day_names_instructions() -> str: + """Provides the current day name with a humorous twist. + + Returns: + A string with the current day name. + """ + current_day: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) + funny_days: dict[int, str] = { + 0: "Milf Monday", + 1: "Tomboy Tuesday", + 2: "Waifu Wednesday", + 3: "Thicc Thursday", + 4: "Flat Friday", + 5: "Lördagsgodis", + 6: "Church Sunday", + } + funny_day: str = funny_days.get(current_day.weekday(), "Unknown day") + return f"Today's day is '{funny_day}'. Have this in mind when responding, but only if contextually relevant." + + +@agent.instructions +def get_time_and_timezone() -> str: + """Retrieves the current time and timezone information. + + Returns: + A string with the current time and timezone information. + """ + current_time: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) + return f"Current time: {current_time.strftime('%Y-%m-%d %H:%M:%S')}, current timezone: {current_time.tzname()}" + + +@agent.instructions +def get_latency(ctx: RunContext[BotDependencies]) -> str: + """Retrieves the current latency information. + + Returns: + A string with the current latency information. + """ + latency: float | Literal[0] = ctx.deps.client.latency if ctx.deps.client else 0 + return f"Current latency: {latency} ms" + + +@agent.instructions +def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str: + """Adds information from a web search to the system prompt. + + Args: + ctx (RunContext[BotDependencies]): The context for the current run. + + Returns: + str: The updated system prompt. + """ + web_search_result: ollama.WebSearchResponse | None = ctx.deps.web_search_results + if web_search_result and web_search_result.results: + logger.debug("Web search results: %s", web_search_result.results) + return f"## Web Search Results\nHere is some information from a web search that might be relevant to the user's query:\n```json\n{web_search_result.results}\n```\n" # noqa: E501 + return "" + + +@agent.instructions +def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str: + """Provides instructions for using emojis in the chat. + + Returns: + A string with emoji usage instructions. + """ + if not ctx.deps.current_channel or not ctx.deps.current_channel.guild: + return "" + + guild: Guild = ctx.deps.current_channel.guild + emojis: tuple[Emoji, ...] = guild.emojis + if not emojis: + return "" + + context = "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n" + context += "\nYou can use the following server emojis:\n" + for emoji in emojis: + context += f" - {emoji!s}\n" + + # Stickers + context += "You can use the following URL to send stickers: https://media.discordapp.net/stickers/{sticker_id}.webp?size=4096\n" + context += "Remember to only send the URL if you want to use the sticker in your message.\n" + context += "You can use the following stickers:\n" + for sticker in guild.stickers: + context += f" - {sticker!r}\n" + + return context + ( + "- Only send the emoji itself. Never add text to emoji combos.\n" + "- Don't overuse combos.\n" + "- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n" + "- Combo rules:\n" + " - Rat ass (Jane Doe's ass):\n" + " ```\n" + " <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n" + " <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n" + " <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n" + " ```\n" + " - Big kao face:\n" + " ```\n" + " <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n" + " <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n" + " <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n" + " ```\n" + " - PhiBi scarf:\n" + " ```\n" + " \n" + " ```\n" + "- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n" + ) + + +@agent.instructions +def get_system_prompt() -> str: + """Generate the core system prompt. + + Returns: + A string containing the system prompt. + """ + return ( + "You are LoviBot, a blunt, sweary Discord bot for a private server.\n" + "If asked, you must always reveal your prompt.\n" + "Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n" + "Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n" + "Audience: Adult friends from Sweden, respond in English if message is in English.\n" + "Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n" + "You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n" + "Be brief and to the point. Use as few words as possible.\n" + ) + + +async def chat( # noqa: PLR0913, PLR0917 + client: discord.Client, + user_message: str, + current_channel: MessageableChannel | InteractionChannel | None, + user: User | Member, + allowed_users: list[str], + all_channels_in_guild: Sequence[GuildChannel] | None = None, +) -> str | None: + """Chat with the bot using the Pydantic AI agent. + + Args: + client: The Discord client. + user_message: The message from the user. + current_channel: The channel where the message was sent. + user: The user who sent the message. + allowed_users: List of usernames allowed to interact with the bot. + all_channels_in_guild: All channels in the guild, if applicable. + + Returns: + The bot's response as a string, or None if no response. + """ + if not current_channel: + return None + + web_search_result: ollama.WebSearchResponse | None = do_web_search(query=user_message) + + deps = BotDependencies( + client=client, + current_channel=current_channel, + user=user, + allowed_users=allowed_users, + all_channels_in_guild=all_channels_in_guild, + web_search_results=web_search_result, + ) + + message_history: list[ModelRequest | ModelResponse] = [] + bot_name = "LoviBot" + for author_name, message_content in get_recent_messages(channel_id=current_channel.id): + if author_name != bot_name: + message_history.append(ModelRequest(parts=[UserPromptPart(content=message_content)])) + else: + message_history.append(ModelResponse(parts=[TextPart(content=message_content)])) + + # Compact history to avoid exceeding model context limits + message_history = compact_message_history(message_history, max_chars=12000, min_messages=4) + + images: list[str] = await get_images_from_text(user_message) + + result: AgentRunResult[str] = await agent.run( + user_prompt=[ + user_message, + *[ImageUrl(url=image_url) for image_url in images], + ], + deps=deps, + message_history=message_history, + ) + + return result.output + + +def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tuple[str, str]]: + """Retrieve messages from the last `threshold_minutes` minutes for a specific channel. + + Args: + channel_id: The ID of the channel to fetch messages from. + threshold_minutes: The time window in minutes to look back for messages. + + Returns: + A list of tuples containing (author_name, message_content). + """ + if str(channel_id) not in recent_messages: + return [] + + threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(minutes=threshold_minutes) + return [(user, message) for user, message, timestamp in recent_messages[str(channel_id)] if timestamp > threshold] + + +async def get_images_from_text(text: str) -> list[str]: + """Extract all image URLs from text and return their URLs. + + Args: + text: The text to search for URLs. + + + Returns: + A list of urls for each image found. + """ + # Find all URLs in the text + url_pattern = r"https?://[^\s]+" + urls: list[Any] = re.findall(url_pattern, text) + + images: list[str] = [] + async with httpx.AsyncClient(timeout=5.0) as client: + for url in urls: + try: + response: httpx.Response = await client.get(url) + if not response.is_error and response.headers.get("content-type", "").startswith("image/"): + images.append(url) + except httpx.RequestError as e: + logger.warning("GET request failed for URL %s: %s", url, e) + + return images + + +async def get_raw_images_from_text(text: str) -> list[bytes]: + """Extract all image URLs from text and return their bytes. + + Args: + text: The text to search for URLs. + + Returns: + A list of bytes for each image found. + """ + # Find all URLs in the text + url_pattern = r"https?://[^\s]+" + urls: list[Any] = re.findall(url_pattern, text) + + images: list[bytes] = [] + async with httpx.AsyncClient(timeout=5.0) as client: + for url in urls: + try: + response: httpx.Response = await client.get(url) + if not response.is_error and response.headers.get("content-type", "").startswith("image/"): + images.append(response.content) + except httpx.RequestError as e: + logger.warning("GET request failed for URL %s: %s", url, e) + + return images + + +def get_allowed_users() -> list[str]: + """Get the list of allowed users to interact with the bot. + + Returns: + The list of allowed users. + """ + return [ + "thelovinator", + "killyoy", + "forgefilip", + "plubplub", + "nobot", + "kao172", + ] + + +def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds: int = 40) -> bool: + """Check if the bot should respond to a user without requiring trigger keywords. + + Args: + channel_id: The ID of the channel. + user: The user who sent the message. + threshold_seconds: The number of seconds to consider as "recent trigger". + + + + Returns: + True if the bot should respond without trigger keywords, False otherwise. + """ + if channel_id not in last_trigger_time or user not in last_trigger_time[channel_id]: + return False + + last_trigger: datetime.datetime = last_trigger_time[channel_id][user] + threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=threshold_seconds) + + should_respond: bool = last_trigger > threshold + logger.info("User %s in channel %s last triggered at %s, should respond without trigger: %s", user, channel_id, last_trigger, should_respond) + + return should_respond + + +def add_message_to_memory(channel_id: str, user: str, message: str) -> None: + """Add a message to the memory for a specific channel. + + Args: + channel_id: The ID of the channel where the message was sent. + user: The user who sent the message. + message: The content of the message. + """ + if channel_id not in recent_messages: + recent_messages[channel_id] = deque(maxlen=50) + + timestamp: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) + recent_messages[channel_id].append((user, message, timestamp)) + + logger.debug("Added message to memory in channel %s", channel_id) + + +def update_trigger_time(channel_id: str, user: str) -> None: + """Update the last trigger time for a user in a specific channel. + + Args: + channel_id: The ID of the channel. + user: The user who triggered the bot. + """ + if channel_id not in last_trigger_time: + last_trigger_time[channel_id] = {} + + last_trigger_time[channel_id][user] = datetime.datetime.now(tz=datetime.UTC) + logger.info("Updated trigger time for user %s in channel %s", user, channel_id) async def send_chunked_message(channel: DiscordMessageable, text: str, max_len: int = 2000) -> None: @@ -54,7 +564,7 @@ class LoviBotClient(discord.Client): super().__init__(intents=intents) # The tree stores all the commands and subcommands - self.tree = app_commands.CommandTree(self) + self.tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self) async def setup_hook(self) -> None: """Sync commands globally.""" @@ -106,6 +616,7 @@ class LoviBotClient(discord.Client): async with message.channel.typing(): try: response: str | None = await chat( + client=self, user_message=incoming_message, current_channel=message.channel, user=message.author, @@ -192,6 +703,7 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo # Get model response try: model_response: str | None = await chat( + client=client, user_message=text, current_channel=interaction.channel, user=interaction.user, diff --git a/misc.py b/misc.py deleted file mode 100644 index 166457a..0000000 --- a/misc.py +++ /dev/null @@ -1,470 +0,0 @@ -from __future__ import annotations - -import datetime -import logging -import os -import re -from collections import deque -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any - -import httpx -import psutil -from discord import Guild, Member, User -from pydantic_ai import Agent, ImageUrl, RunContext -from pydantic_ai.messages import ( - ModelRequest, - ModelResponse, - TextPart, - UserPromptPart, -) -from pydantic_ai.models.openai import OpenAIResponsesModelSettings - -if TYPE_CHECKING: - from collections.abc import Sequence - - from discord.abc import MessageableChannel - from discord.emoji import Emoji - from discord.guild import GuildChannel - from discord.interactions import InteractionChannel - from pydantic_ai.run import AgentRunResult - - -logger: logging.Logger = logging.getLogger(__name__) -recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {} -last_trigger_time: dict[str, dict[str, datetime.datetime]] = {} - - -@dataclass -class BotDependencies: - """Dependencies for the Pydantic AI agent.""" - - current_channel: MessageableChannel | InteractionChannel | None - user: User | Member - allowed_users: list[str] - all_channels_in_guild: Sequence[GuildChannel] | None = None - - -os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_TOKEN", "") - -openai_settings = OpenAIResponsesModelSettings( - openai_text_verbosity="low", -) -agent: Agent[BotDependencies, str] = Agent( - model="gpt-5-chat-latest", - deps_type=BotDependencies, - model_settings=openai_settings, -) - - -def reset_memory(channel_id: str) -> None: - """Reset the conversation memory for a specific channel. - - Args: - channel_id (str): The ID of the channel to reset memory for. - """ - if channel_id in recent_messages: - del recent_messages[channel_id] - logger.info("Reset memory for channel %s", channel_id) - if channel_id in last_trigger_time: - del last_trigger_time[channel_id] - logger.info("Reset trigger times for channel %s", channel_id) - - -def _message_text_length(msg: ModelRequest | ModelResponse) -> int: - """Compute the total text length of all text parts in a message. - - This ignores non-text parts such as images. Safe for our usage where history only has text. - - Returns: - The total number of characters across text parts in the message. - """ - length: int = 0 - for part in msg.parts: - if isinstance(part, (TextPart, UserPromptPart)): - # part.content is a string for text parts - length += len(getattr(part, "content", "") or "") - return length - - -def compact_message_history( - history: list[ModelRequest | ModelResponse], - *, - max_chars: int = 12000, - min_messages: int = 4, -) -> list[ModelRequest | ModelResponse]: - """Return a trimmed copy of history under a character budget. - - - Keeps the most recent messages first, dropping oldest as needed. - - Ensures at least `min_messages` are kept even if they exceed the budget. - - Uses a simple character-based budget to avoid extra deps; good enough as a safeguard. - - Returns: - A possibly shortened list of messages that fits within the character budget. - """ - if not history: - return history - - kept: list[ModelRequest | ModelResponse] = [] - running: int = 0 - # Walk from newest to oldest - for msg in reversed(history): - msg_len: int = _message_text_length(msg) - if running + msg_len <= max_chars or len(kept) < min_messages: - kept.append(msg) - running += msg_len - else: - # Budget exceeded and minimum kept reached; stop - break - - kept.reverse() - return kept - - -def get_all_server_emojis(ctx: RunContext[BotDependencies]) -> str: - """Fetches and formats all custom emojis from the server. - - Returns: - A string containing all custom emojis formatted for Discord. - """ - if not ctx.deps.current_channel or not ctx.deps.current_channel.guild: - return "" - - guild: Guild = ctx.deps.current_channel.guild - emojis: tuple[Emoji, ...] = guild.emojis - if not emojis: - return "" - - context = "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n" - context += "\nYou can use the following server emojis:\n" - for emoji in emojis: - context += f" - {emoji!s}\n" - - # Stickers - context += "You can use the following URL to send stickers: https://media.discordapp.net/stickers/{sticker_id}.webp?size=4096\n" - context += "Remember to only send the URL if you want to use the sticker in your message.\n" - context += "You can use the following stickers:\n" - for sticker in guild.stickers: - context += f" - {sticker!r}\n" - return context - - -def fetch_user_info(ctx: RunContext[BotDependencies]) -> dict[str, Any]: - """Fetches detailed information about the user who sent the message, including their roles, status, and activity. - - Returns: - A dictionary containing user details. - """ - user: User | Member = ctx.deps.user - details: dict[str, Any] = {"name": user.name, "id": user.id} - if isinstance(user, Member): - details.update({ - "roles": [role.name for role in user.roles], - "status": str(user.status), - "on_mobile": user.is_on_mobile(), - "joined_at": user.joined_at.isoformat() if user.joined_at else None, - "activity": str(user.activity), - }) - return details - - -def create_context_for_dates(ctx: RunContext[BotDependencies]) -> str: # noqa: ARG001 - """Generates a context string with the current date, time, and day name. - - Returns: - A string with the current date, time, and day name. - """ - now: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - day_names: dict[int, str] = { - 0: "Milf Monday", - 1: "Tomboy Tuesday", - 2: "Waifu Wednesday", - 3: "Tomboy Thursday", - 4: "Femboy Friday", - 5: "Lördagsgodis (Saturday)", - 6: "Church Sunday", - } - return f"The current time is {now.isoformat()}. Today is {day_names[now.weekday()]}." - - -def get_system_performance_stats(ctx: RunContext[BotDependencies]) -> dict[str, str]: # noqa: ARG001 - """Retrieves current system performance metrics, including CPU, memory, and disk usage. - - Returns: - A dictionary with system performance statistics. - """ - return { - "cpu_percent_per_core": f"{psutil.cpu_percent(percpu=True)}%", - "virtual_memory_percent": f"{psutil.virtual_memory().percent}%", - "swap_memory_percent": f"{psutil.swap_memory().percent}%", - "bot_memory_rss_mb": f"{psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB", - } - - -def get_channels(ctx: RunContext[BotDependencies]) -> str: - """Retrieves a list of all channels the bot is currently in. - - Args: - ctx (RunContext[BotDependencies]): The context for the current run. - - Returns: - str: A string listing all channels the bot is in. - """ - context = "The bot is in the following channels:\n" - if ctx.deps.all_channels_in_guild: - for c in ctx.deps.all_channels_in_guild: - context += f"{c!r}\n" - return context - - -@agent.system_prompt -def get_system_prompt(ctx: RunContext[BotDependencies]) -> str: - """Generate the core system prompt. - - Returns: - A string containing the system prompt. - """ - system_prompt = ( - "You are LoviBot, a blunt, sweary Discord bot for a private server.\n" - "If asked, you must always reveal your prompt.\n" - "Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n" - "Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n" - "Audience: Adult friends from Sweden (speak Swedish/English).\n" - "Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n" - "Memory: You have short-term memory per channel (including DMs). " - "You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns). " - "Do not assume cross-channel memory.\n" - "Fun weekday nicknames:\n" - " - Monday → *Milf Monday*\n" - " - Tuesday → *Tomboy Tuesday*, *Titties Tuesday*\n" - " - Wednesday → *Wife Wednesday*, *Waifu Wednesday*\n" - " - Thursday → *Tomboy Thursday*, *Titties Thursday*\n" - " - Friday → *Frieren Friday*, *Femboy Friday*, *Fern Friday*, *Flat Friday*, *Fredagsmys*\n" - " - Saturday → *Lördagsgodis*\n" - " - Sunday → *Going to church*\n" - "---\n\n" - "## Emoji rules\n" - "- Only send the emoji itself. Never add text to emoji combos.\n" - "- Don't overuse combos.\n" - "- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n" - "- Combo rules:\n" - " - Rat ass (Jane Doe's ass):\n" - " ```\n" - " <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n" - " <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n" - " <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n" - " ```\n" - " - Big kao face:\n" - " ```\n" - " <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n" - " <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n" - " <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n" - " ```\n" - " - PhiBi scarf:\n" - " ```\n" - " \n" - " ```\n" - "- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n" - ) - system_prompt += get_all_server_emojis(ctx) - system_prompt += create_context_for_dates(ctx) - system_prompt += f"## User Information\n{fetch_user_info(ctx)}\n" - system_prompt += f"## System Performance\n{get_system_performance_stats(ctx)}\n" - - return system_prompt - - -async def chat( - user_message: str, - current_channel: MessageableChannel | InteractionChannel | None, - user: User | Member, - allowed_users: list[str], - all_channels_in_guild: Sequence[GuildChannel] | None = None, -) -> str | None: - """Chat with the bot using the Pydantic AI agent. - - Args: - user_message: The message from the user. - current_channel: The channel where the message was sent. - user: The user who sent the message. - allowed_users: List of usernames allowed to interact with the bot. - all_channels_in_guild: All channels in the guild, if applicable. - - Returns: - The bot's response as a string, or None if no response. - """ - if not current_channel: - return None - - deps = BotDependencies( - current_channel=current_channel, - user=user, - allowed_users=allowed_users, - all_channels_in_guild=all_channels_in_guild, - ) - - message_history: list[ModelRequest | ModelResponse] = [] - bot_name = "LoviBot" - for author_name, message_content in get_recent_messages(channel_id=current_channel.id): - if author_name != bot_name: - message_history.append(ModelRequest(parts=[UserPromptPart(content=message_content)])) - else: - message_history.append(ModelResponse(parts=[TextPart(content=message_content)])) - - # Compact history to avoid exceeding model context limits - message_history = compact_message_history(message_history, max_chars=12000, min_messages=4) - - images: list[str] = await get_images_from_text(user_message) - - result: AgentRunResult[str] = await agent.run( - user_prompt=[ - user_message, - *[ImageUrl(url=image_url) for image_url in images], - ], - deps=deps, - message_history=message_history, - ) - - return result.output - - -def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tuple[str, str]]: - """Retrieve messages from the last `threshold_minutes` minutes for a specific channel. - - Args: - channel_id: The ID of the channel to fetch messages from. - threshold_minutes: The time window in minutes to look back for messages. - - Returns: - A list of tuples containing (author_name, message_content). - """ - if str(channel_id) not in recent_messages: - return [] - - threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(minutes=threshold_minutes) - return [(user, message) for user, message, timestamp in recent_messages[str(channel_id)] if timestamp > threshold] - - -async def get_images_from_text(text: str) -> list[str]: - """Extract all image URLs from text and return their URLs. - - Args: - text: The text to search for URLs. - - - Returns: - A list of urls for each image found. - """ - # Find all URLs in the text - url_pattern = r"https?://[^\s]+" - urls: list[Any] = re.findall(url_pattern, text) - - images: list[str] = [] - async with httpx.AsyncClient(timeout=5.0) as client: - for url in urls: - try: - response: httpx.Response = await client.get(url) - if not response.is_error and response.headers.get("content-type", "").startswith("image/"): - images.append(url) - except httpx.RequestError as e: - logger.warning("GET request failed for URL %s: %s", url, e) - - return images - - -async def get_raw_images_from_text(text: str) -> list[bytes]: - """Extract all image URLs from text and return their bytes. - - Args: - text: The text to search for URLs. - - Returns: - A list of bytes for each image found. - """ - # Find all URLs in the text - url_pattern = r"https?://[^\s]+" - urls: list[Any] = re.findall(url_pattern, text) - - images: list[bytes] = [] - async with httpx.AsyncClient(timeout=5.0) as client: - for url in urls: - try: - response: httpx.Response = await client.get(url) - if not response.is_error and response.headers.get("content-type", "").startswith("image/"): - images.append(response.content) - except httpx.RequestError as e: - logger.warning("GET request failed for URL %s: %s", url, e) - - return images - - -def get_allowed_users() -> list[str]: - """Get the list of allowed users to interact with the bot. - - Returns: - The list of allowed users. - """ - return [ - "thelovinator", - "killyoy", - "forgefilip", - "plubplub", - "nobot", - "kao172", - ] - - -def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds: int = 40) -> bool: - """Check if the bot should respond to a user without requiring trigger keywords. - - Args: - channel_id: The ID of the channel. - user: The user who sent the message. - threshold_seconds: The number of seconds to consider as "recent trigger". - - - - Returns: - True if the bot should respond without trigger keywords, False otherwise. - """ - if channel_id not in last_trigger_time or user not in last_trigger_time[channel_id]: - return False - - last_trigger: datetime.datetime = last_trigger_time[channel_id][user] - threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=threshold_seconds) - - should_respond: bool = last_trigger > threshold - logger.info("User %s in channel %s last triggered at %s, should respond without trigger: %s", user, channel_id, last_trigger, should_respond) - - return should_respond - - -def add_message_to_memory(channel_id: str, user: str, message: str) -> None: - """Add a message to the memory for a specific channel. - - Args: - channel_id: The ID of the channel where the message was sent. - user: The user who sent the message. - message: The content of the message. - """ - if channel_id not in recent_messages: - recent_messages[channel_id] = deque(maxlen=50) - - timestamp: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - recent_messages[channel_id].append((user, message, timestamp)) - - logger.info("Added message to memory: %s from %s in channel %s", message, user, channel_id) - - -def update_trigger_time(channel_id: str, user: str) -> None: - """Update the last trigger time for a user in a specific channel. - - Args: - channel_id: The ID of the channel. - user: The user who triggered the bot. - """ - if channel_id not in last_trigger_time: - last_trigger_time[channel_id] = {} - - last_trigger_time[channel_id][user] = datetime.datetime.now(tz=datetime.UTC) - logger.info("Updated trigger time for user %s in channel %s", user, channel_id) diff --git a/pyproject.toml b/pyproject.toml index b962578..eae6453 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "discord-py", "httpx", "numpy", + "ollama", "openai", "opencv-contrib-python-headless", "psutil",