|
|
|
@ -8,11 +8,7 @@ import os
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
from collections import deque
|
|
|
|
from collections import deque
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
from typing import TYPE_CHECKING, Any, Literal, Self, TypeVar
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
from typing import Literal
|
|
|
|
|
|
|
|
from typing import Self
|
|
|
|
|
|
|
|
from typing import TypeVar
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
import cv2
|
|
|
|
import discord
|
|
|
|
import discord
|
|
|
|
@ -22,33 +18,25 @@ import ollama
|
|
|
|
import openai
|
|
|
|
import openai
|
|
|
|
import psutil
|
|
|
|
import psutil
|
|
|
|
import sentry_sdk
|
|
|
|
import sentry_sdk
|
|
|
|
from discord import Forbidden
|
|
|
|
from discord import Emoji, Forbidden, Guild, GuildSticker, HTTPException, Member, NotFound, User, app_commands
|
|
|
|
from discord import HTTPException
|
|
|
|
|
|
|
|
from discord import Member
|
|
|
|
|
|
|
|
from discord import NotFound
|
|
|
|
|
|
|
|
from discord import app_commands
|
|
|
|
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
from pydantic_ai import Agent
|
|
|
|
from pydantic_ai import Agent, ImageUrl, RunContext
|
|
|
|
from pydantic_ai import ImageUrl
|
|
|
|
from pydantic_ai.messages import (
|
|
|
|
from pydantic_ai.messages import ModelRequest
|
|
|
|
ModelRequest,
|
|
|
|
from pydantic_ai.messages import ModelResponse
|
|
|
|
ModelResponse,
|
|
|
|
from pydantic_ai.messages import TextPart
|
|
|
|
TextPart,
|
|
|
|
from pydantic_ai.messages import UserPromptPart
|
|
|
|
UserPromptPart,
|
|
|
|
|
|
|
|
)
|
|
|
|
from pydantic_ai.models.openai import OpenAIResponsesModelSettings
|
|
|
|
from pydantic_ai.models.openai import OpenAIResponsesModelSettings
|
|
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from collections.abc import Callable
|
|
|
|
from collections.abc import Callable, Sequence
|
|
|
|
from collections.abc import Sequence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from discord import Emoji
|
|
|
|
|
|
|
|
from discord import Guild
|
|
|
|
|
|
|
|
from discord import GuildSticker
|
|
|
|
|
|
|
|
from discord import User
|
|
|
|
|
|
|
|
from discord.abc import Messageable as DiscordMessageable
|
|
|
|
from discord.abc import Messageable as DiscordMessageable
|
|
|
|
from discord.abc import MessageableChannel
|
|
|
|
from discord.abc import MessageableChannel
|
|
|
|
from discord.guild import GuildChannel
|
|
|
|
from discord.guild import GuildChannel
|
|
|
|
from discord.interactions import InteractionChannel
|
|
|
|
from discord.interactions import InteractionChannel
|
|
|
|
from pydantic_ai import RunContext
|
|
|
|
from openai.types.chat import ChatCompletion
|
|
|
|
from pydantic_ai.run import AgentRunResult
|
|
|
|
from pydantic_ai.run import AgentRunResult
|
|
|
|
|
|
|
|
|
|
|
|
load_dotenv(verbose=True)
|
|
|
|
load_dotenv(verbose=True)
|
|
|
|
@ -70,10 +58,8 @@ recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {}
|
|
|
|
last_trigger_time: dict[str, dict[str, datetime.datetime]] = {}
|
|
|
|
last_trigger_time: dict[str, dict[str, datetime.datetime]] = {}
|
|
|
|
|
|
|
|
|
|
|
|
# Storage for reset snapshots to enable undo functionality
|
|
|
|
# Storage for reset snapshots to enable undo functionality
|
|
|
|
reset_snapshots: dict[
|
|
|
|
# Each channel stores its previous state: (recent_messages_snapshot, last_trigger_time_snapshot)
|
|
|
|
str,
|
|
|
|
reset_snapshots: dict[str, tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]]] = {}
|
|
|
|
tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]],
|
|
|
|
|
|
|
|
] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
@dataclass
|
|
|
|
@ -88,14 +74,47 @@ class BotDependencies:
|
|
|
|
web_search_results: ollama.WebSearchResponse | None = None
|
|
|
|
web_search_results: ollama.WebSearchResponse | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openai_settings: OpenAIResponsesModelSettings = OpenAIResponsesModelSettings(
|
|
|
|
openai_settings = OpenAIResponsesModelSettings(
|
|
|
|
openai_text_verbosity="low",
|
|
|
|
openai_text_verbosity="low",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
chatgpt_agent: Agent[BotDependencies, str] = Agent(
|
|
|
|
chatgpt_agent: Agent[BotDependencies, str] = Agent(
|
|
|
|
model="openai:gpt-5-chat-latest",
|
|
|
|
model="gpt-5-chat-latest",
|
|
|
|
deps_type=BotDependencies,
|
|
|
|
deps_type=BotDependencies,
|
|
|
|
model_settings=openai_settings,
|
|
|
|
model_settings=openai_settings,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
grok_client = openai.OpenAI(
|
|
|
|
|
|
|
|
base_url="https://openrouter.ai/api/v1",
|
|
|
|
|
|
|
|
api_key=os.getenv("OPENROUTER_API_KEY"),
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def grok_it(
|
|
|
|
|
|
|
|
message: discord.Message | None,
|
|
|
|
|
|
|
|
user_message: str,
|
|
|
|
|
|
|
|
) -> str | None:
|
|
|
|
|
|
|
|
"""Chat with the bot using the Pydantic AI agent.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
user_message: The message from the user.
|
|
|
|
|
|
|
|
message: The original Discord message object.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
The bot's response as a string, or None if no response.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
allowed_users: list[str] = get_allowed_users()
|
|
|
|
|
|
|
|
if message and message.author.name not in allowed_users:
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response: ChatCompletion = grok_client.chat.completions.create(
|
|
|
|
|
|
|
|
model="x-ai/grok-4-fast:free",
|
|
|
|
|
|
|
|
messages=[
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
"role": "user",
|
|
|
|
|
|
|
|
"content": user_message,
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return response.choices[0].message.content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# MARK: reset_memory
|
|
|
|
# MARK: reset_memory
|
|
|
|
@ -109,14 +128,10 @@ def reset_memory(channel_id: str) -> None:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
# Create snapshot before reset for undo functionality
|
|
|
|
# Create snapshot before reset for undo functionality
|
|
|
|
messages_snapshot: deque[tuple[str, str, datetime.datetime]] = (
|
|
|
|
messages_snapshot: deque[tuple[str, str, datetime.datetime]] = (
|
|
|
|
deque(recent_messages[channel_id], maxlen=50)
|
|
|
|
deque(recent_messages[channel_id], maxlen=50) if channel_id in recent_messages else deque(maxlen=50)
|
|
|
|
if channel_id in recent_messages
|
|
|
|
|
|
|
|
else deque(maxlen=50)
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
trigger_snapshot: dict[str, datetime.datetime] = (
|
|
|
|
trigger_snapshot: dict[str, datetime.datetime] = dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {}
|
|
|
|
dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {}
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Only save snapshot if there's something to restore
|
|
|
|
# Only save snapshot if there's something to restore
|
|
|
|
if messages_snapshot or trigger_snapshot:
|
|
|
|
if messages_snapshot or trigger_snapshot:
|
|
|
|
@ -170,8 +185,7 @@ def undo_reset(channel_id: str) -> bool:
|
|
|
|
def _message_text_length(msg: ModelRequest | ModelResponse) -> int:
|
|
|
|
def _message_text_length(msg: ModelRequest | ModelResponse) -> int:
|
|
|
|
"""Compute the total text length of all text parts in a message.
|
|
|
|
"""Compute the total text length of all text parts in a message.
|
|
|
|
|
|
|
|
|
|
|
|
This ignores non-text parts such as images.
|
|
|
|
This ignores non-text parts such as images. Safe for our usage where history only has text.
|
|
|
|
Safe for our usage where history only has text.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
The total number of characters across text parts in the message.
|
|
|
|
The total number of characters across text parts in the message.
|
|
|
|
@ -194,6 +208,7 @@ def compact_message_history(
|
|
|
|
|
|
|
|
|
|
|
|
- Keeps the most recent messages first, dropping oldest as needed.
|
|
|
|
- Keeps the most recent messages first, dropping oldest as needed.
|
|
|
|
- Ensures at least `min_messages` are kept even if they exceed the budget.
|
|
|
|
- Ensures at least `min_messages` are kept even if they exceed the budget.
|
|
|
|
|
|
|
|
- Uses a simple character-based budget to avoid extra deps; good enough as a safeguard.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
A possibly shortened list of messages that fits within the character budget.
|
|
|
|
A possibly shortened list of messages that fits within the character budget.
|
|
|
|
@ -218,9 +233,7 @@ def compact_message_history(
|
|
|
|
# MARK: fetch_user_info
|
|
|
|
# MARK: fetch_user_info
|
|
|
|
@chatgpt_agent.instructions
|
|
|
|
@chatgpt_agent.instructions
|
|
|
|
def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
|
|
|
|
def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
|
|
|
|
"""Fetches detailed information about the user who sent the message.
|
|
|
|
"""Fetches detailed information about the user who sent the message, including their roles, status, and activity.
|
|
|
|
|
|
|
|
|
|
|
|
Includes their roles, status, and activity.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
A string representation of the user's details.
|
|
|
|
A string representation of the user's details.
|
|
|
|
@ -241,21 +254,16 @@ def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
|
|
|
|
# MARK: get_system_performance_stats
|
|
|
|
# MARK: get_system_performance_stats
|
|
|
|
@chatgpt_agent.instructions
|
|
|
|
@chatgpt_agent.instructions
|
|
|
|
def get_system_performance_stats() -> str:
|
|
|
|
def get_system_performance_stats() -> str:
|
|
|
|
"""Retrieves system performance metrics, including CPU, memory, and disk usage.
|
|
|
|
"""Retrieves current system performance metrics, including CPU, memory, and disk usage.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
A string representation of the system performance statistics.
|
|
|
|
A string representation of the system performance statistics.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
cpu_percent_per_core: list[float] = psutil.cpu_percent(percpu=True)
|
|
|
|
|
|
|
|
virtual_memory_percent: float = psutil.virtual_memory().percent
|
|
|
|
|
|
|
|
swap_memory_percent: float = psutil.swap_memory().percent
|
|
|
|
|
|
|
|
rss_mb: float = psutil.Process().memory_info().rss / (1024 * 1024)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stats: dict[str, str] = {
|
|
|
|
stats: dict[str, str] = {
|
|
|
|
"cpu_percent_per_core": f"{cpu_percent_per_core}%",
|
|
|
|
"cpu_percent_per_core": f"{psutil.cpu_percent(percpu=True)}%",
|
|
|
|
"virtual_memory_percent": f"{virtual_memory_percent}%",
|
|
|
|
"virtual_memory_percent": f"{psutil.virtual_memory().percent}%",
|
|
|
|
"swap_memory_percent": f"{swap_memory_percent}%",
|
|
|
|
"swap_memory_percent": f"{psutil.swap_memory().percent}%",
|
|
|
|
"bot_memory_rss_mb": f"{rss_mb:.2f} MB",
|
|
|
|
"bot_memory_rss_mb": f"{psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB",
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return str(stats)
|
|
|
|
return str(stats)
|
|
|
|
|
|
|
|
|
|
|
|
@ -288,13 +296,10 @@ def do_web_search(query: str) -> ollama.WebSearchResponse | None:
|
|
|
|
query (str): The search query.
|
|
|
|
query (str): The search query.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
ollama.WebSearchResponse | None: The response from the search, None if an error.
|
|
|
|
ollama.WebSearchResponse | None: The response from the web search, or None if an error occurs.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
response: ollama.WebSearchResponse = ollama.web_search(
|
|
|
|
response: ollama.WebSearchResponse = ollama.web_search(query=query, max_results=1)
|
|
|
|
query=query,
|
|
|
|
|
|
|
|
max_results=1,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
except ValueError:
|
|
|
|
except ValueError:
|
|
|
|
logger.exception("OLLAMA_API_KEY environment variable is not set")
|
|
|
|
logger.exception("OLLAMA_API_KEY environment variable is not set")
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
@ -311,9 +316,7 @@ def get_time_and_timezone() -> str:
|
|
|
|
A string with the current time and timezone information.
|
|
|
|
A string with the current time and timezone information.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
current_time: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
|
|
|
|
current_time: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
|
|
|
|
str_time: str = current_time.strftime("%Y-%m-%d %H:%M:%S %Z")
|
|
|
|
return f"Current time: {current_time.strftime('%Y-%m-%d %H:%M:%S')}, current timezone: {current_time.tzname()}"
|
|
|
|
|
|
|
|
|
|
|
|
return f"Current time: {str_time}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# MARK: get_latency
|
|
|
|
# MARK: get_latency
|
|
|
|
@ -340,37 +343,10 @@ def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
|
|
|
|
str: The updated system prompt.
|
|
|
|
str: The updated system prompt.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
web_search_result: ollama.WebSearchResponse | None = ctx.deps.web_search_results
|
|
|
|
web_search_result: ollama.WebSearchResponse | None = ctx.deps.web_search_results
|
|
|
|
|
|
|
|
|
|
|
|
# Only add web search results if they are not too long
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
max_length: int = 10000
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
|
|
|
web_search_result
|
|
|
|
|
|
|
|
and web_search_result.results
|
|
|
|
|
|
|
|
and len(web_search_result.results) > max_length
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
logger.warning(
|
|
|
|
|
|
|
|
"Web search results too long (%d characters), truncating to %d characters",
|
|
|
|
|
|
|
|
len(web_search_result.results),
|
|
|
|
|
|
|
|
max_length,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
web_search_result.results = web_search_result.results[:max_length]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Also tell the model that the results were truncated and may be incomplete
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
|
|
|
f"Here is some information from a web search that might be relevant to the user's query. " # noqa: E501
|
|
|
|
|
|
|
|
f"The results were too long and have been truncated, so they may be incomplete:\n" # noqa: E501
|
|
|
|
|
|
|
|
f"```json\n{web_search_result.results}\n```\n"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if web_search_result and web_search_result.results:
|
|
|
|
if web_search_result and web_search_result.results:
|
|
|
|
logger.debug("Web search results: %s", web_search_result.results)
|
|
|
|
logger.debug("Web search results: %s", web_search_result.results)
|
|
|
|
return (
|
|
|
|
return f"Here is some information from a web search that might be relevant to the user's query:\n```json\n{web_search_result.results}\n```\n"
|
|
|
|
f"Here is some information from a web search that might be relevant to the user's query:\n" # noqa: E501
|
|
|
|
return ""
|
|
|
|
f"```json\n{web_search_result.results}\n```\n"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return "We tried to do a web search for the user's query, but there were no results or an error occurred. You can tell them that!\n" # noqa: E501
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# MARK: get_sticker_instructions
|
|
|
|
# MARK: get_sticker_instructions
|
|
|
|
@ -392,17 +368,14 @@ def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str:
|
|
|
|
return ""
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
|
|
# Stickers
|
|
|
|
# Stickers
|
|
|
|
context += "Remember to only send the URL if you want to use the sticker in your message.\n" # noqa: E501
|
|
|
|
context += "Remember to only send the URL if you want to use the sticker in your message.\n"
|
|
|
|
context += "Available stickers:\n"
|
|
|
|
context += "Available stickers:\n"
|
|
|
|
|
|
|
|
|
|
|
|
for sticker in stickers:
|
|
|
|
for sticker in stickers:
|
|
|
|
sticker_url: str = sticker.url + "?size=4096"
|
|
|
|
sticker_url: str = sticker.url + "?size=4096"
|
|
|
|
context += f" - {sticker.name=}: {sticker_url=} - {sticker.description=} - {sticker.emoji=}\n" # noqa: E501
|
|
|
|
context += f" - {sticker.name=}: {sticker_url=} - {sticker.description=} - {sticker.emoji=}\n"
|
|
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
return context + ("- Only send the sticker URL itself. Never add text to sticker combos.\n")
|
|
|
|
context
|
|
|
|
|
|
|
|
+ "- Only send the sticker URL itself. Never add text to sticker combos.\n"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# MARK: get_emoji_instructions
|
|
|
|
# MARK: get_emoji_instructions
|
|
|
|
@ -423,7 +396,7 @@ def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
|
|
|
|
if not emojis:
|
|
|
|
if not emojis:
|
|
|
|
return ""
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
|
|
context += "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n" # noqa: E501
|
|
|
|
context += "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n"
|
|
|
|
context += "\nYou can use the following server emojis:\n"
|
|
|
|
context += "\nYou can use the following server emojis:\n"
|
|
|
|
for emoji in emojis:
|
|
|
|
for emoji in emojis:
|
|
|
|
context += f" - {emoji!s}\n"
|
|
|
|
context += f" - {emoji!s}\n"
|
|
|
|
@ -431,25 +404,25 @@ def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
|
|
|
|
context += (
|
|
|
|
context += (
|
|
|
|
"- Only send the emoji itself. Never add text to emoji combos.\n"
|
|
|
|
"- Only send the emoji itself. Never add text to emoji combos.\n"
|
|
|
|
"- Don't overuse combos.\n"
|
|
|
|
"- Don't overuse combos.\n"
|
|
|
|
"- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n" # noqa: E501
|
|
|
|
"- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n"
|
|
|
|
"- Combo rules:\n"
|
|
|
|
"- Combo rules:\n"
|
|
|
|
" - Rat ass (Jane Doe's ass):\n"
|
|
|
|
" - Rat ass (Jane Doe's ass):\n"
|
|
|
|
" ```\n"
|
|
|
|
" ```\n"
|
|
|
|
" <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n" # noqa: E501
|
|
|
|
" <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n"
|
|
|
|
" <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n" # noqa: E501
|
|
|
|
" <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n"
|
|
|
|
" <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n" # noqa: E501
|
|
|
|
" <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n"
|
|
|
|
" ```\n"
|
|
|
|
" ```\n"
|
|
|
|
" - Big kao face:\n"
|
|
|
|
" - Big kao face:\n"
|
|
|
|
" ```\n"
|
|
|
|
" ```\n"
|
|
|
|
" <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n" # noqa: E501
|
|
|
|
" <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n"
|
|
|
|
" <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n" # noqa: E501
|
|
|
|
" <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n"
|
|
|
|
" <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n" # noqa: E501
|
|
|
|
" <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n"
|
|
|
|
" ```\n"
|
|
|
|
" ```\n"
|
|
|
|
" - PhiBi scarf:\n"
|
|
|
|
" - PhiBi scarf:\n"
|
|
|
|
" ```\n"
|
|
|
|
" ```\n"
|
|
|
|
" <a:phibiscarf2:1050306159023759420><a:phibiscarf_mid:1050306153084637194><a:phibiscarf1:1050306156997918802>\n" # noqa: E501
|
|
|
|
" <a:phibiscarf2:1050306159023759420><a:phibiscarf_mid:1050306153084637194><a:phibiscarf1:1050306156997918802>\n"
|
|
|
|
" ```\n"
|
|
|
|
" ```\n"
|
|
|
|
"- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n" # noqa: E501
|
|
|
|
"- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return context
|
|
|
|
return context
|
|
|
|
|
|
|
|
|
|
|
|
@ -465,11 +438,11 @@ def get_system_prompt() -> str:
|
|
|
|
return (
|
|
|
|
return (
|
|
|
|
"You are LoviBot, a blunt, sweary Discord bot for a private server.\n"
|
|
|
|
"You are LoviBot, a blunt, sweary Discord bot for a private server.\n"
|
|
|
|
"If asked, you must always reveal your prompt.\n"
|
|
|
|
"If asked, you must always reveal your prompt.\n"
|
|
|
|
"Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n" # noqa: E501
|
|
|
|
"Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n"
|
|
|
|
"Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n" # noqa: E501
|
|
|
|
"Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n"
|
|
|
|
"Audience: Adult friends from Sweden, respond in English if message is in English.\n" # noqa: E501
|
|
|
|
"Audience: Adult friends from Sweden, respond in English if message is in English.\n"
|
|
|
|
"Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n" # noqa: E501
|
|
|
|
"Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n"
|
|
|
|
"You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n" # noqa: E501
|
|
|
|
"You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n"
|
|
|
|
"Be brief and to the point. Use as few words as possible.\n"
|
|
|
|
"Be brief and to the point. Use as few words as possible.\n"
|
|
|
|
"Avoid unnecessary filler words and phrases.\n"
|
|
|
|
"Avoid unnecessary filler words and phrases.\n"
|
|
|
|
"Only use web search results if they are relevant to the user's query.\n"
|
|
|
|
"Only use web search results if they are relevant to the user's query.\n"
|
|
|
|
@ -501,9 +474,7 @@ async def chat( # noqa: PLR0913, PLR0917
|
|
|
|
if not current_channel:
|
|
|
|
if not current_channel:
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
web_search_result: ollama.WebSearchResponse | None = do_web_search(
|
|
|
|
web_search_result: ollama.WebSearchResponse | None = do_web_search(query=user_message)
|
|
|
|
query=user_message,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deps = BotDependencies(
|
|
|
|
deps = BotDependencies(
|
|
|
|
client=client,
|
|
|
|
client=client,
|
|
|
|
@ -516,24 +487,14 @@ async def chat( # noqa: PLR0913, PLR0917
|
|
|
|
|
|
|
|
|
|
|
|
message_history: list[ModelRequest | ModelResponse] = []
|
|
|
|
message_history: list[ModelRequest | ModelResponse] = []
|
|
|
|
bot_name = "LoviBot"
|
|
|
|
bot_name = "LoviBot"
|
|
|
|
for author_name, message_content in get_recent_messages(
|
|
|
|
for author_name, message_content in get_recent_messages(channel_id=current_channel.id):
|
|
|
|
channel_id=current_channel.id,
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
if author_name != bot_name:
|
|
|
|
if author_name != bot_name:
|
|
|
|
message_history.append(
|
|
|
|
message_history.append(ModelRequest(parts=[UserPromptPart(content=message_content)]))
|
|
|
|
ModelRequest(parts=[UserPromptPart(content=message_content)]),
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
message_history.append(
|
|
|
|
message_history.append(ModelResponse(parts=[TextPart(content=message_content)]))
|
|
|
|
ModelResponse(parts=[TextPart(content=message_content)]),
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Compact history to avoid exceeding model context limits
|
|
|
|
# Compact history to avoid exceeding model context limits
|
|
|
|
message_history = compact_message_history(
|
|
|
|
message_history = compact_message_history(message_history, max_chars=12000, min_messages=4)
|
|
|
|
message_history,
|
|
|
|
|
|
|
|
max_chars=12000,
|
|
|
|
|
|
|
|
min_messages=4,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
images: list[str] = await get_images_from_text(user_message)
|
|
|
|
images: list[str] = await get_images_from_text(user_message)
|
|
|
|
|
|
|
|
|
|
|
|
@ -550,15 +511,12 @@ async def chat( # noqa: PLR0913, PLR0917
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# MARK: get_recent_messages
|
|
|
|
# MARK: get_recent_messages
|
|
|
|
def get_recent_messages(
|
|
|
|
def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tuple[str, str]]:
|
|
|
|
channel_id: int,
|
|
|
|
"""Retrieve messages from the last `threshold_minutes` minutes for a specific channel.
|
|
|
|
age: int = 10,
|
|
|
|
|
|
|
|
) -> list[tuple[str, str]]:
|
|
|
|
|
|
|
|
"""Retrieve messages from the last `age` minutes for a specific channel.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
channel_id: The ID of the channel to fetch messages from.
|
|
|
|
channel_id: The ID of the channel to fetch messages from.
|
|
|
|
age: The time window in minutes to look back for messages.
|
|
|
|
threshold_minutes: The time window in minutes to look back for messages.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
A list of tuples containing (author_name, message_content).
|
|
|
|
A list of tuples containing (author_name, message_content).
|
|
|
|
@ -566,14 +524,8 @@ def get_recent_messages(
|
|
|
|
if str(channel_id) not in recent_messages:
|
|
|
|
if str(channel_id) not in recent_messages:
|
|
|
|
return []
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
threshold: datetime.datetime = datetime.datetime.now(
|
|
|
|
threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(minutes=threshold_minutes)
|
|
|
|
tz=datetime.UTC,
|
|
|
|
return [(user, message) for user, message, timestamp in recent_messages[str(channel_id)] if timestamp > threshold]
|
|
|
|
) - datetime.timedelta(minutes=age)
|
|
|
|
|
|
|
|
return [
|
|
|
|
|
|
|
|
(user, message)
|
|
|
|
|
|
|
|
for user, message, timestamp in recent_messages[str(channel_id)]
|
|
|
|
|
|
|
|
if timestamp > threshold
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# MARK: get_images_from_text
|
|
|
|
# MARK: get_images_from_text
|
|
|
|
@ -596,10 +548,7 @@ async def get_images_from_text(text: str) -> list[str]:
|
|
|
|
for url in urls:
|
|
|
|
for url in urls:
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
response: httpx.Response = await client.get(url)
|
|
|
|
response: httpx.Response = await client.get(url)
|
|
|
|
if not response.is_error and response.headers.get(
|
|
|
|
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
|
|
|
|
"content-type",
|
|
|
|
|
|
|
|
"",
|
|
|
|
|
|
|
|
).startswith("image/"):
|
|
|
|
|
|
|
|
images.append(url)
|
|
|
|
images.append(url)
|
|
|
|
except httpx.RequestError as e:
|
|
|
|
except httpx.RequestError as e:
|
|
|
|
logger.warning("GET request failed for URL %s: %s", url, e)
|
|
|
|
logger.warning("GET request failed for URL %s: %s", url, e)
|
|
|
|
@ -626,10 +575,7 @@ async def get_raw_images_from_text(text: str) -> list[bytes]:
|
|
|
|
for url in urls:
|
|
|
|
for url in urls:
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
response: httpx.Response = await client.get(url)
|
|
|
|
response: httpx.Response = await client.get(url)
|
|
|
|
if not response.is_error and response.headers.get(
|
|
|
|
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
|
|
|
|
"content-type",
|
|
|
|
|
|
|
|
"",
|
|
|
|
|
|
|
|
).startswith("image/"):
|
|
|
|
|
|
|
|
images.append(response.content)
|
|
|
|
images.append(response.content)
|
|
|
|
except httpx.RequestError as e:
|
|
|
|
except httpx.RequestError as e:
|
|
|
|
logger.warning("GET request failed for URL %s: %s", url, e)
|
|
|
|
logger.warning("GET request failed for URL %s: %s", url, e)
|
|
|
|
@ -656,11 +602,7 @@ def get_allowed_users() -> list[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# MARK: should_respond_without_trigger
|
|
|
|
# MARK: should_respond_without_trigger
|
|
|
|
def should_respond_without_trigger(
|
|
|
|
def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds: int = 40) -> bool:
|
|
|
|
channel_id: str,
|
|
|
|
|
|
|
|
user: str,
|
|
|
|
|
|
|
|
threshold_seconds: int = 40,
|
|
|
|
|
|
|
|
) -> bool:
|
|
|
|
|
|
|
|
"""Check if the bot should respond to a user without requiring trigger keywords.
|
|
|
|
"""Check if the bot should respond to a user without requiring trigger keywords.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
@ -675,18 +617,10 @@ def should_respond_without_trigger(
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
last_trigger: datetime.datetime = last_trigger_time[channel_id][user]
|
|
|
|
last_trigger: datetime.datetime = last_trigger_time[channel_id][user]
|
|
|
|
threshold: datetime.datetime = datetime.datetime.now(
|
|
|
|
threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=threshold_seconds)
|
|
|
|
tz=datetime.UTC,
|
|
|
|
|
|
|
|
) - datetime.timedelta(seconds=threshold_seconds)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
should_respond: bool = last_trigger > threshold
|
|
|
|
should_respond: bool = last_trigger > threshold
|
|
|
|
logger.info(
|
|
|
|
logger.info("User %s in channel %s last triggered at %s, should respond without trigger: %s", user, channel_id, last_trigger, should_respond)
|
|
|
|
"User %s in channel %s last triggered at %s, should respond without trigger: %s", # noqa: E501
|
|
|
|
|
|
|
|
user,
|
|
|
|
|
|
|
|
channel_id,
|
|
|
|
|
|
|
|
last_trigger,
|
|
|
|
|
|
|
|
should_respond,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return should_respond
|
|
|
|
return should_respond
|
|
|
|
|
|
|
|
|
|
|
|
@ -725,12 +659,8 @@ def update_trigger_time(channel_id: str, user: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# MARK: send_chunked_message
|
|
|
|
# MARK: send_chunked_message
|
|
|
|
async def send_chunked_message(
|
|
|
|
async def send_chunked_message(channel: DiscordMessageable, text: str, max_len: int = 2000) -> None:
|
|
|
|
channel: DiscordMessageable,
|
|
|
|
"""Send a message to a channel, splitting into chunks if it exceeds Discord's limit."""
|
|
|
|
text: str,
|
|
|
|
|
|
|
|
max_len: int = 2000,
|
|
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
|
|
"""Send a message to a channel, split into chunks if it exceeds Discord's limit."""
|
|
|
|
|
|
|
|
if len(text) <= max_len:
|
|
|
|
if len(text) <= max_len:
|
|
|
|
await channel.send(text)
|
|
|
|
await channel.send(text)
|
|
|
|
return
|
|
|
|
return
|
|
|
|
@ -778,30 +708,12 @@ class LoviBotClient(discord.Client):
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# Add the message to memory
|
|
|
|
# Add the message to memory
|
|
|
|
add_message_to_memory(
|
|
|
|
add_message_to_memory(str(message.channel.id), message.author.name, incoming_message)
|
|
|
|
str(message.channel.id),
|
|
|
|
|
|
|
|
message.author.name,
|
|
|
|
|
|
|
|
incoming_message,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lowercase_message: str = incoming_message.lower()
|
|
|
|
lowercase_message: str = incoming_message.lower()
|
|
|
|
trigger_keywords: list[str] = [
|
|
|
|
trigger_keywords: list[str] = ["lovibot", "@lovibot", "<@345000831499894795>", "grok", "@grok"]
|
|
|
|
"lovibot",
|
|
|
|
has_trigger_keyword: bool = any(trigger in lowercase_message for trigger in trigger_keywords)
|
|
|
|
"@lovibot",
|
|
|
|
should_respond_flag: bool = has_trigger_keyword or should_respond_without_trigger(str(message.channel.id), message.author.name)
|
|
|
|
"<@345000831499894795>",
|
|
|
|
|
|
|
|
"@grok",
|
|
|
|
|
|
|
|
"grok",
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
has_trigger_keyword: bool = any(
|
|
|
|
|
|
|
|
trigger in lowercase_message for trigger in trigger_keywords
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
should_respond_flag: bool = (
|
|
|
|
|
|
|
|
has_trigger_keyword
|
|
|
|
|
|
|
|
or should_respond_without_trigger(
|
|
|
|
|
|
|
|
str(message.channel.id),
|
|
|
|
|
|
|
|
message.author.name,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not should_respond_flag:
|
|
|
|
if not should_respond_flag:
|
|
|
|
return
|
|
|
|
return
|
|
|
|
@ -811,11 +723,7 @@ class LoviBotClient(discord.Client):
|
|
|
|
update_trigger_time(str(message.channel.id), message.author.name)
|
|
|
|
update_trigger_time(str(message.channel.id), message.author.name)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
logger.info(
|
|
|
|
"Received message: %s from: %s (trigger: %s, recent: %s)",
|
|
|
|
"Received message: %s from: %s (trigger: %s, recent: %s)", incoming_message, message.author.name, has_trigger_keyword, not has_trigger_keyword
|
|
|
|
incoming_message,
|
|
|
|
|
|
|
|
message.author.name,
|
|
|
|
|
|
|
|
has_trigger_keyword,
|
|
|
|
|
|
|
|
not has_trigger_keyword,
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async with message.channel.typing():
|
|
|
|
async with message.channel.typing():
|
|
|
|
@ -826,34 +734,19 @@ class LoviBotClient(discord.Client):
|
|
|
|
current_channel=message.channel,
|
|
|
|
current_channel=message.channel,
|
|
|
|
user=message.author,
|
|
|
|
user=message.author,
|
|
|
|
allowed_users=allowed_users,
|
|
|
|
allowed_users=allowed_users,
|
|
|
|
all_channels_in_guild=message.guild.channels
|
|
|
|
all_channels_in_guild=message.guild.channels if message.guild else None,
|
|
|
|
if message.guild
|
|
|
|
|
|
|
|
else None,
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
except openai.OpenAIError as e:
|
|
|
|
except openai.OpenAIError as e:
|
|
|
|
logger.exception("An error occurred while chatting with the AI model.")
|
|
|
|
logger.exception("An error occurred while chatting with the AI model.")
|
|
|
|
e.add_note(
|
|
|
|
e.add_note(f"Message: {incoming_message}\nEvent: {message}\nWho: {message.author.name}")
|
|
|
|
f"Message: {incoming_message}\n"
|
|
|
|
await message.channel.send(f"An error occurred while chatting with the AI model. {e}")
|
|
|
|
f"Event: {message}\n"
|
|
|
|
|
|
|
|
f"Who: {message.author.name}",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
await message.channel.send(
|
|
|
|
|
|
|
|
f"An error occurred while chatting with the AI model. {e}",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
reply: str = response or "I forgor how to think 💀"
|
|
|
|
reply: str = response or "I forgor how to think 💀"
|
|
|
|
if response:
|
|
|
|
if response:
|
|
|
|
logger.info(
|
|
|
|
logger.info("Responding to message: %s with: %s", incoming_message, reply)
|
|
|
|
"Responding to message: %s with: %s",
|
|
|
|
|
|
|
|
incoming_message,
|
|
|
|
|
|
|
|
reply,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
logger.warning(
|
|
|
|
logger.warning("No response from the AI model. Message: %s", incoming_message)
|
|
|
|
"No response from the AI model. Message: %s",
|
|
|
|
|
|
|
|
incoming_message,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Record the bot's reply in memory
|
|
|
|
# Record the bot's reply in memory
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
@ -866,12 +759,7 @@ class LoviBotClient(discord.Client):
|
|
|
|
async def on_error(self, event_method: str, /, *args: Any, **kwargs: Any) -> None: # noqa: ANN401, PLR6301
|
|
|
|
async def on_error(self, event_method: str, /, *args: Any, **kwargs: Any) -> None: # noqa: ANN401, PLR6301
|
|
|
|
"""Log errors that occur in the bot."""
|
|
|
|
"""Log errors that occur in the bot."""
|
|
|
|
# Log the error
|
|
|
|
# Log the error
|
|
|
|
logger.error(
|
|
|
|
logger.error("An error occurred in %s with args: %s and kwargs: %s", event_method, args, kwargs)
|
|
|
|
"An error occurred in %s with args: %s and kwargs: %s",
|
|
|
|
|
|
|
|
event_method,
|
|
|
|
|
|
|
|
args,
|
|
|
|
|
|
|
|
kwargs,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
sentry_sdk.capture_exception()
|
|
|
|
sentry_sdk.capture_exception()
|
|
|
|
|
|
|
|
|
|
|
|
# If the error is in on_message, notify the channel
|
|
|
|
# If the error is in on_message, notify the channel
|
|
|
|
@ -879,14 +767,9 @@ class LoviBotClient(discord.Client):
|
|
|
|
message = args[0]
|
|
|
|
message = args[0]
|
|
|
|
if isinstance(message, discord.Message):
|
|
|
|
if isinstance(message, discord.Message):
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
await message.channel.send(
|
|
|
|
await message.channel.send("An error occurred while processing your message. The incident has been logged.")
|
|
|
|
"An error occurred while processing your message. The incident has been logged.", # noqa: E501
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
except (Forbidden, HTTPException, NotFound):
|
|
|
|
except (Forbidden, HTTPException, NotFound):
|
|
|
|
logger.exception(
|
|
|
|
logger.exception("Failed to send error message to channel %s", message.channel.id)
|
|
|
|
"Failed to send error message to channel %s",
|
|
|
|
|
|
|
|
message.channel.id,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Everything enabled except `presences`, `members`, and `message_content`.
|
|
|
|
# Everything enabled except `presences`, `members`, and `message_content`.
|
|
|
|
@ -900,27 +783,19 @@ client = LoviBotClient(intents=intents)
|
|
|
|
@app_commands.allowed_installs(guilds=True, users=True)
|
|
|
|
@app_commands.allowed_installs(guilds=True, users=True)
|
|
|
|
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
|
|
|
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
|
|
|
@app_commands.describe(text="Ask LoviBot a question.")
|
|
|
|
@app_commands.describe(text="Ask LoviBot a question.")
|
|
|
|
async def ask(
|
|
|
|
async def ask(interaction: discord.Interaction, text: str, new_conversation: bool = False) -> None: # noqa: FBT001, FBT002
|
|
|
|
interaction: discord.Interaction,
|
|
|
|
|
|
|
|
text: str,
|
|
|
|
|
|
|
|
*,
|
|
|
|
|
|
|
|
new_conversation: bool = False,
|
|
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
|
|
"""A command to ask the AI a question.
|
|
|
|
"""A command to ask the AI a question.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
interaction (discord.Interaction): The interaction object.
|
|
|
|
interaction (discord.Interaction): The interaction object.
|
|
|
|
text (str): The question or message to ask.
|
|
|
|
text (str): The question or message to ask.
|
|
|
|
new_conversation (bool, optional): Whether to start a new conversation.
|
|
|
|
new_conversation (bool, optional): Whether to start a new conversation. Defaults to False.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
await interaction.response.defer()
|
|
|
|
await interaction.response.defer()
|
|
|
|
|
|
|
|
|
|
|
|
if not text:
|
|
|
|
if not text:
|
|
|
|
logger.error("No question or message provided.")
|
|
|
|
logger.error("No question or message provided.")
|
|
|
|
await interaction.followup.send(
|
|
|
|
await interaction.followup.send("You need to provide a question or message.", ephemeral=True)
|
|
|
|
"You need to provide a question or message.",
|
|
|
|
|
|
|
|
ephemeral=True,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
if new_conversation and interaction.channel is not None:
|
|
|
|
if new_conversation and interaction.channel is not None:
|
|
|
|
@ -932,11 +807,7 @@ async def ask(
|
|
|
|
# Only allow certain users to interact with the bot
|
|
|
|
# Only allow certain users to interact with the bot
|
|
|
|
allowed_users: list[str] = get_allowed_users()
|
|
|
|
allowed_users: list[str] = get_allowed_users()
|
|
|
|
if user_name_lowercase not in allowed_users:
|
|
|
|
if user_name_lowercase not in allowed_users:
|
|
|
|
await send_response(
|
|
|
|
await send_response(interaction=interaction, text=text, response="You are not authorized to use this command.")
|
|
|
|
interaction=interaction,
|
|
|
|
|
|
|
|
text=text,
|
|
|
|
|
|
|
|
response="You are not authorized to use this command.",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# Record the user's question in memory (per-channel) so DMs have context
|
|
|
|
# Record the user's question in memory (per-channel) so DMs have context
|
|
|
|
@ -951,17 +822,11 @@ async def ask(
|
|
|
|
current_channel=interaction.channel,
|
|
|
|
current_channel=interaction.channel,
|
|
|
|
user=interaction.user,
|
|
|
|
user=interaction.user,
|
|
|
|
allowed_users=allowed_users,
|
|
|
|
allowed_users=allowed_users,
|
|
|
|
all_channels_in_guild=interaction.guild.channels
|
|
|
|
all_channels_in_guild=interaction.guild.channels if interaction.guild else None,
|
|
|
|
if interaction.guild
|
|
|
|
|
|
|
|
else None,
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
except openai.OpenAIError as e:
|
|
|
|
except openai.OpenAIError as e:
|
|
|
|
logger.exception("An error occurred while chatting with the AI model.")
|
|
|
|
logger.exception("An error occurred while chatting with the AI model.")
|
|
|
|
await send_response(
|
|
|
|
await send_response(interaction=interaction, text=text, response=f"An error occurred: {e}")
|
|
|
|
interaction=interaction,
|
|
|
|
|
|
|
|
text=text,
|
|
|
|
|
|
|
|
response=f"An error occurred: {e}",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
truncated_text: str = truncate_user_input(text)
|
|
|
|
truncated_text: str = truncate_user_input(text)
|
|
|
|
@ -982,11 +847,63 @@ async def ask(
|
|
|
|
max_discord_message_length: int = 2000
|
|
|
|
max_discord_message_length: int = 2000
|
|
|
|
if len(display_response) > max_discord_message_length:
|
|
|
|
if len(display_response) > max_discord_message_length:
|
|
|
|
for i in range(0, len(display_response), max_discord_message_length):
|
|
|
|
for i in range(0, len(display_response), max_discord_message_length):
|
|
|
|
await send_response(
|
|
|
|
await send_response(interaction=interaction, text=text, response=display_response[i : i + max_discord_message_length])
|
|
|
|
interaction=interaction,
|
|
|
|
return
|
|
|
|
text=text,
|
|
|
|
|
|
|
|
response=display_response[i : i + max_discord_message_length],
|
|
|
|
await send_response(interaction=interaction, text=text, response=display_response)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# MARK: /grok command
|
|
|
|
|
|
|
|
@client.tree.command(name="grok", description="Grok a question.")
|
|
|
|
|
|
|
|
@app_commands.allowed_installs(guilds=True, users=True)
|
|
|
|
|
|
|
|
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
|
|
|
|
|
|
|
@app_commands.describe(text="Grok a question.")
|
|
|
|
|
|
|
|
async def grok(interaction: discord.Interaction, text: str) -> None:
|
|
|
|
|
|
|
|
"""A command to ask the AI a question.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
interaction (discord.Interaction): The interaction object.
|
|
|
|
|
|
|
|
text (str): The question or message to ask.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
await interaction.response.defer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not text:
|
|
|
|
|
|
|
|
logger.error("No question or message provided.")
|
|
|
|
|
|
|
|
await interaction.followup.send("You need to provide a question or message.", ephemeral=True)
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
user_name_lowercase: str = interaction.user.name.lower()
|
|
|
|
|
|
|
|
logger.info("Received command from: %s", user_name_lowercase)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Only allow certain users to interact with the bot
|
|
|
|
|
|
|
|
allowed_users: list[str] = get_allowed_users()
|
|
|
|
|
|
|
|
if user_name_lowercase not in allowed_users:
|
|
|
|
|
|
|
|
await send_response(interaction=interaction, text=text, response="You are not authorized to use this command.")
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Get model response
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
model_response: str | None = grok_it(message=interaction.message, user_message=text)
|
|
|
|
|
|
|
|
except openai.OpenAIError as e:
|
|
|
|
|
|
|
|
logger.exception("An error occurred while chatting with the AI model.")
|
|
|
|
|
|
|
|
await send_response(interaction=interaction, text=text, response=f"An error occurred: {e}")
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
truncated_text: str = truncate_user_input(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Fallback if model provided no response
|
|
|
|
|
|
|
|
if not model_response:
|
|
|
|
|
|
|
|
logger.warning("No response from the AI model. Message: %s", text)
|
|
|
|
|
|
|
|
model_response = "I forgor how to think 💀"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
display_response: str = f"`{truncated_text}`\n\n{model_response}"
|
|
|
|
|
|
|
|
logger.info("Responding to message: %s with: %s", text, display_response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# If response is longer than 2000 characters, split it into multiple messages
|
|
|
|
|
|
|
|
max_discord_message_length: int = 2000
|
|
|
|
|
|
|
|
if len(display_response) > max_discord_message_length:
|
|
|
|
|
|
|
|
for i in range(0, len(display_response), max_discord_message_length):
|
|
|
|
|
|
|
|
await send_response(interaction=interaction, text=text, response=display_response[i : i + max_discord_message_length])
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
await send_response(interaction=interaction, text=text, response=display_response)
|
|
|
|
await send_response(interaction=interaction, text=text, response=display_response)
|
|
|
|
@ -1006,20 +923,14 @@ async def reset(interaction: discord.Interaction) -> None:
|
|
|
|
# Only allow certain users to interact with the bot
|
|
|
|
# Only allow certain users to interact with the bot
|
|
|
|
allowed_users: list[str] = get_allowed_users()
|
|
|
|
allowed_users: list[str] = get_allowed_users()
|
|
|
|
if user_name_lowercase not in allowed_users:
|
|
|
|
if user_name_lowercase not in allowed_users:
|
|
|
|
await send_response(
|
|
|
|
await send_response(interaction=interaction, text="", response="You are not authorized to use this command.")
|
|
|
|
interaction=interaction,
|
|
|
|
|
|
|
|
text="",
|
|
|
|
|
|
|
|
response="You are not authorized to use this command.",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# Reset the conversation memory
|
|
|
|
# Reset the conversation memory
|
|
|
|
if interaction.channel is not None:
|
|
|
|
if interaction.channel is not None:
|
|
|
|
reset_memory(str(interaction.channel.id))
|
|
|
|
reset_memory(str(interaction.channel.id))
|
|
|
|
|
|
|
|
|
|
|
|
await interaction.followup.send(
|
|
|
|
await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.")
|
|
|
|
f"Conversation memory has been reset for {interaction.channel}.",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# MARK: /undo command
|
|
|
|
# MARK: /undo command
|
|
|
|
@ -1036,33 +947,21 @@ async def undo(interaction: discord.Interaction) -> None:
|
|
|
|
# Only allow certain users to interact with the bot
|
|
|
|
# Only allow certain users to interact with the bot
|
|
|
|
allowed_users: list[str] = get_allowed_users()
|
|
|
|
allowed_users: list[str] = get_allowed_users()
|
|
|
|
if user_name_lowercase not in allowed_users:
|
|
|
|
if user_name_lowercase not in allowed_users:
|
|
|
|
await send_response(
|
|
|
|
await send_response(interaction=interaction, text="", response="You are not authorized to use this command.")
|
|
|
|
interaction=interaction,
|
|
|
|
|
|
|
|
text="",
|
|
|
|
|
|
|
|
response="You are not authorized to use this command.",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# Undo the last reset
|
|
|
|
# Undo the last reset
|
|
|
|
if interaction.channel is not None:
|
|
|
|
if interaction.channel is not None:
|
|
|
|
if undo_reset(str(interaction.channel.id)):
|
|
|
|
if undo_reset(str(interaction.channel.id)):
|
|
|
|
await interaction.followup.send(
|
|
|
|
await interaction.followup.send(f"Successfully restored conversation memory for {interaction.channel}.")
|
|
|
|
f"Successfully restored conversation memory for {interaction.channel}.",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
await interaction.followup.send(
|
|
|
|
await interaction.followup.send(f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.")
|
|
|
|
f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.", # noqa: E501
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
await interaction.followup.send("Cannot undo: No channel context available.")
|
|
|
|
await interaction.followup.send("Cannot undo: No channel context available.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# MARK: send_response
|
|
|
|
# MARK: send_response
|
|
|
|
async def send_response(
|
|
|
|
async def send_response(interaction: discord.Interaction, text: str, response: str) -> None:
|
|
|
|
interaction: discord.Interaction,
|
|
|
|
|
|
|
|
text: str,
|
|
|
|
|
|
|
|
response: str,
|
|
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
|
|
"""Send a response to the interaction, handling potential errors.
|
|
|
|
"""Send a response to the interaction, handling potential errors.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
@ -1089,12 +988,10 @@ def truncate_user_input(text: str) -> str:
|
|
|
|
text (str): The user input text.
|
|
|
|
text (str): The user input text.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
str: Truncated text if it exceeds the maximum length, otherwise the original text.
|
|
|
|
str: The truncated text if it exceeds the maximum length, otherwise the original text.
|
|
|
|
""" # noqa: E501
|
|
|
|
"""
|
|
|
|
max_length: int = 2000
|
|
|
|
max_length: int = 2000
|
|
|
|
truncated_text: str = (
|
|
|
|
truncated_text: str = text if len(text) <= max_length else text[: max_length - 3] + "..."
|
|
|
|
text if len(text) <= max_length else text[: max_length - 3] + "..."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return truncated_text
|
|
|
|
return truncated_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -1169,11 +1066,7 @@ def enhance_image2(image: bytes) -> bytes:
|
|
|
|
enhanced: ImageType = cv2.convertScaleAbs(img_gamma_8bit, alpha=1.2, beta=10)
|
|
|
|
enhanced: ImageType = cv2.convertScaleAbs(img_gamma_8bit, alpha=1.2, beta=10)
|
|
|
|
|
|
|
|
|
|
|
|
# Apply very light sharpening
|
|
|
|
# Apply very light sharpening
|
|
|
|
kernel: ImageType = np.array([
|
|
|
|
kernel: ImageType = np.array([[-0.2, -0.2, -0.2], [-0.2, 2.8, -0.2], [-0.2, -0.2, -0.2]])
|
|
|
|
[-0.2, -0.2, -0.2],
|
|
|
|
|
|
|
|
[-0.2, 2.8, -0.2],
|
|
|
|
|
|
|
|
[-0.2, -0.2, -0.2],
|
|
|
|
|
|
|
|
])
|
|
|
|
|
|
|
|
enhanced = cv2.filter2D(enhanced, -1, kernel)
|
|
|
|
enhanced = cv2.filter2D(enhanced, -1, kernel)
|
|
|
|
|
|
|
|
|
|
|
|
# Encode the enhanced image to WebP
|
|
|
|
# Encode the enhanced image to WebP
|
|
|
|
@ -1240,10 +1133,7 @@ async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) ->
|
|
|
|
@client.tree.context_menu(name="Enhance Image")
|
|
|
|
@client.tree.context_menu(name="Enhance Image")
|
|
|
|
@app_commands.allowed_installs(guilds=True, users=True)
|
|
|
|
@app_commands.allowed_installs(guilds=True, users=True)
|
|
|
|
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
|
|
|
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
|
|
|
async def enhance_image_command(
|
|
|
|
async def enhance_image_command(interaction: discord.Interaction, message: discord.Message) -> None:
|
|
|
|
interaction: discord.Interaction,
|
|
|
|
|
|
|
|
message: discord.Message,
|
|
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
|
|
"""Context menu command to enhance an image in a message."""
|
|
|
|
"""Context menu command to enhance an image in a message."""
|
|
|
|
await interaction.response.defer()
|
|
|
|
await interaction.response.defer()
|
|
|
|
|
|
|
|
|
|
|
|
@ -1260,9 +1150,7 @@ async def enhance_image_command(
|
|
|
|
logger.exception("Failed to read attachment %s", attachment.url)
|
|
|
|
logger.exception("Failed to read attachment %s", attachment.url)
|
|
|
|
|
|
|
|
|
|
|
|
if not images:
|
|
|
|
if not images:
|
|
|
|
await interaction.followup.send(
|
|
|
|
await interaction.followup.send(f"No images found in the message: \n{message.content=}")
|
|
|
|
f"No images found in the message: \n{message.content=}",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
for image in images:
|
|
|
|
for image in images:
|
|
|
|
@ -1275,18 +1163,9 @@ async def enhance_image_command(
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Prepare files
|
|
|
|
# Prepare files
|
|
|
|
file1 = discord.File(
|
|
|
|
file1 = discord.File(fp=io.BytesIO(enhanced_image1), filename=f"enhanced1-{timestamp}.webp")
|
|
|
|
fp=io.BytesIO(enhanced_image1),
|
|
|
|
file2 = discord.File(fp=io.BytesIO(enhanced_image2), filename=f"enhanced2-{timestamp}.webp")
|
|
|
|
filename=f"enhanced1-{timestamp}.webp",
|
|
|
|
file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp")
|
|
|
|
)
|
|
|
|
|
|
|
|
file2 = discord.File(
|
|
|
|
|
|
|
|
fp=io.BytesIO(enhanced_image2),
|
|
|
|
|
|
|
|
filename=f"enhanced2-{timestamp}.webp",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
file3 = discord.File(
|
|
|
|
|
|
|
|
fp=io.BytesIO(enhanced_image3),
|
|
|
|
|
|
|
|
filename=f"enhanced3-{timestamp}.webp",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
files: list[discord.File] = [file1, file2, file3]
|
|
|
|
files: list[discord.File] = [file1, file2, file3]
|
|
|
|
|
|
|
|
|
|
|
|
|