Add Ollama API integration and enhance bot functionality
- Updated .env.example to include OLLAMA_API_KEY. - Added Ollama to dependencies in pyproject.toml. - Refactored main.py to incorporate Ollama for web search capabilities. - Removed misc.py as its functionality has been integrated into main.py. - Enhanced message handling and memory management for improved performance.
This commit is contained in:
parent
eec1ed4f59
commit
68e74ca6a3
5 changed files with 524 additions and 478 deletions
|
|
@ -1,2 +1,3 @@
|
||||||
DISCORD_TOKEN=
|
DISCORD_TOKEN=
|
||||||
OPENAI_TOKEN=
|
OPENAI_TOKEN=
|
||||||
|
OLLAMA_API_KEY=
|
||||||
2
.vscode/settings.json
vendored
2
.vscode/settings.json
vendored
|
|
@ -34,6 +34,7 @@
|
||||||
"nobot",
|
"nobot",
|
||||||
"nparr",
|
"nparr",
|
||||||
"numpy",
|
"numpy",
|
||||||
|
"Ollama",
|
||||||
"opencv",
|
"opencv",
|
||||||
"percpu",
|
"percpu",
|
||||||
"phibiscarf",
|
"phibiscarf",
|
||||||
|
|
@ -48,6 +49,7 @@
|
||||||
"sweary",
|
"sweary",
|
||||||
"testpaths",
|
"testpaths",
|
||||||
"thelovinator",
|
"thelovinator",
|
||||||
|
"Thicc",
|
||||||
"tobytes",
|
"tobytes",
|
||||||
"twimg",
|
"twimg",
|
||||||
"unsignedinteger",
|
"unsignedinteger",
|
||||||
|
|
|
||||||
528
main.py
528
main.py
|
|
@ -5,22 +5,40 @@ import datetime
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, TypeVar
|
import re
|
||||||
|
from collections import deque
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal, Self, TypeVar
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import discord
|
import discord
|
||||||
|
import httpx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import ollama
|
||||||
import openai
|
import openai
|
||||||
|
import psutil
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
from discord import Forbidden, HTTPException, NotFound, app_commands
|
from discord import Emoji, Forbidden, Guild, HTTPException, Member, NotFound, User, app_commands
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from pydantic_ai import Agent, ImageUrl, RunContext
|
||||||
from misc import add_message_to_memory, chat, get_allowed_users, get_raw_images_from_text, reset_memory, should_respond_without_trigger, update_trigger_time
|
from pydantic_ai.messages import (
|
||||||
|
ModelRequest,
|
||||||
|
ModelResponse,
|
||||||
|
TextPart,
|
||||||
|
UserPromptPart,
|
||||||
|
)
|
||||||
|
from pydantic_ai.models.openai import OpenAIResponsesModelSettings
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable, Sequence
|
||||||
|
|
||||||
from discord.abc import Messageable as DiscordMessageable
|
from discord.abc import Messageable as DiscordMessageable
|
||||||
|
from discord.abc import MessageableChannel
|
||||||
|
from discord.guild import GuildChannel
|
||||||
|
from discord.interactions import InteractionChannel
|
||||||
|
from pydantic_ai.run import AgentRunResult
|
||||||
|
|
||||||
|
load_dotenv(verbose=True)
|
||||||
|
|
||||||
sentry_sdk.init(
|
sentry_sdk.init(
|
||||||
dsn="https://ebbd2cdfbd08dba008d628dad7941091@o4505228040339456.ingest.us.sentry.io/4507630719401984",
|
dsn="https://ebbd2cdfbd08dba008d628dad7941091@o4505228040339456.ingest.us.sentry.io/4507630719401984",
|
||||||
|
|
@ -32,9 +50,501 @@ logger: logging.Logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
load_dotenv(verbose=True)
|
|
||||||
|
|
||||||
discord_token: str = os.getenv("DISCORD_TOKEN", "")
|
discord_token: str = os.getenv("DISCORD_TOKEN", "")
|
||||||
|
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_TOKEN", "")
|
||||||
|
|
||||||
|
recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {}
|
||||||
|
last_trigger_time: dict[str, dict[str, datetime.datetime]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BotDependencies:
|
||||||
|
"""Dependencies for the Pydantic AI agent."""
|
||||||
|
|
||||||
|
client: discord.Client
|
||||||
|
current_channel: MessageableChannel | InteractionChannel | None
|
||||||
|
user: User | Member
|
||||||
|
allowed_users: list[str]
|
||||||
|
all_channels_in_guild: Sequence[GuildChannel] | None = None
|
||||||
|
web_search_results: ollama.WebSearchResponse | None = None
|
||||||
|
|
||||||
|
|
||||||
|
openai_settings = OpenAIResponsesModelSettings(
|
||||||
|
openai_text_verbosity="low",
|
||||||
|
)
|
||||||
|
agent: Agent[BotDependencies, str] = Agent(
|
||||||
|
model="gpt-5-chat-latest",
|
||||||
|
deps_type=BotDependencies,
|
||||||
|
model_settings=openai_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_memory(channel_id: str) -> None:
|
||||||
|
"""Reset the conversation memory for a specific channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id (str): The ID of the channel to reset memory for.
|
||||||
|
"""
|
||||||
|
if channel_id in recent_messages:
|
||||||
|
del recent_messages[channel_id]
|
||||||
|
logger.info("Reset memory for channel %s", channel_id)
|
||||||
|
if channel_id in last_trigger_time:
|
||||||
|
del last_trigger_time[channel_id]
|
||||||
|
logger.info("Reset trigger times for channel %s", channel_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _message_text_length(msg: ModelRequest | ModelResponse) -> int:
|
||||||
|
"""Compute the total text length of all text parts in a message.
|
||||||
|
|
||||||
|
This ignores non-text parts such as images. Safe for our usage where history only has text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The total number of characters across text parts in the message.
|
||||||
|
"""
|
||||||
|
length: int = 0
|
||||||
|
for part in msg.parts:
|
||||||
|
if isinstance(part, (TextPart, UserPromptPart)):
|
||||||
|
# part.content is a string for text parts
|
||||||
|
length += len(getattr(part, "content", "") or "")
|
||||||
|
return length
|
||||||
|
|
||||||
|
|
||||||
|
def compact_message_history(
|
||||||
|
history: list[ModelRequest | ModelResponse],
|
||||||
|
*,
|
||||||
|
max_chars: int = 12000,
|
||||||
|
min_messages: int = 4,
|
||||||
|
) -> list[ModelRequest | ModelResponse]:
|
||||||
|
"""Return a trimmed copy of history under a character budget.
|
||||||
|
|
||||||
|
- Keeps the most recent messages first, dropping oldest as needed.
|
||||||
|
- Ensures at least `min_messages` are kept even if they exceed the budget.
|
||||||
|
- Uses a simple character-based budget to avoid extra deps; good enough as a safeguard.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A possibly shortened list of messages that fits within the character budget.
|
||||||
|
"""
|
||||||
|
if not history:
|
||||||
|
return history
|
||||||
|
|
||||||
|
kept: list[ModelRequest | ModelResponse] = []
|
||||||
|
running: int = 0
|
||||||
|
for msg in reversed(history):
|
||||||
|
msg_len: int = _message_text_length(msg)
|
||||||
|
if running + msg_len <= max_chars or len(kept) < min_messages:
|
||||||
|
kept.append(msg)
|
||||||
|
running += msg_len
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
kept.reverse()
|
||||||
|
return kept
|
||||||
|
|
||||||
|
|
||||||
|
@agent.instructions
|
||||||
|
def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
|
||||||
|
"""Fetches detailed information about the user who sent the message, including their roles, status, and activity.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string representation of the user's details.
|
||||||
|
"""
|
||||||
|
user: User | Member = ctx.deps.user
|
||||||
|
details: dict[str, Any] = {"name": user.name, "id": user.id}
|
||||||
|
if isinstance(user, Member):
|
||||||
|
details.update({
|
||||||
|
"roles": [role.name for role in user.roles],
|
||||||
|
"status": str(user.status),
|
||||||
|
"on_mobile": user.is_on_mobile(),
|
||||||
|
"joined_at": user.joined_at.isoformat() if user.joined_at else None,
|
||||||
|
"activity": str(user.activity),
|
||||||
|
})
|
||||||
|
return str(details)
|
||||||
|
|
||||||
|
|
||||||
|
@agent.instructions
|
||||||
|
def get_system_performance_stats() -> str:
|
||||||
|
"""Retrieves current system performance metrics, including CPU, memory, and disk usage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string representation of the system performance statistics.
|
||||||
|
"""
|
||||||
|
stats: dict[str, str] = {
|
||||||
|
"cpu_percent_per_core": f"{psutil.cpu_percent(percpu=True)}%",
|
||||||
|
"virtual_memory_percent": f"{psutil.virtual_memory().percent}%",
|
||||||
|
"swap_memory_percent": f"{psutil.swap_memory().percent}%",
|
||||||
|
"bot_memory_rss_mb": f"{psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB",
|
||||||
|
}
|
||||||
|
return str(stats)
|
||||||
|
|
||||||
|
|
||||||
|
@agent.instructions
|
||||||
|
def get_channels(ctx: RunContext[BotDependencies]) -> str:
|
||||||
|
"""Retrieves a list of all channels the bot is currently in.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx (RunContext[BotDependencies]): The context for the current run.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A string listing all channels the bot is in.
|
||||||
|
"""
|
||||||
|
context = "The bot is in the following channels:\n"
|
||||||
|
if ctx.deps.all_channels_in_guild:
|
||||||
|
for c in ctx.deps.all_channels_in_guild:
|
||||||
|
context += f"{c!r}\n"
|
||||||
|
else:
|
||||||
|
context += " - No channels available.\n"
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
def do_web_search(query: str) -> ollama.WebSearchResponse | None:
|
||||||
|
"""Perform a web search using the Ollama API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The search query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ollama.WebSearchResponse | None: The response from the web search, or None if an error occurs.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response: ollama.WebSearchResponse = ollama.web_search(query=query, max_results=1)
|
||||||
|
except ValueError:
|
||||||
|
logger.exception("OLLAMA_API_KEY environment variable is not set")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@agent.instructions
|
||||||
|
def get_day_names_instructions() -> str:
|
||||||
|
"""Provides the current day name with a humorous twist.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string with the current day name.
|
||||||
|
"""
|
||||||
|
current_day: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
|
||||||
|
funny_days: dict[int, str] = {
|
||||||
|
0: "Milf Monday",
|
||||||
|
1: "Tomboy Tuesday",
|
||||||
|
2: "Waifu Wednesday",
|
||||||
|
3: "Thicc Thursday",
|
||||||
|
4: "Flat Friday",
|
||||||
|
5: "Lördagsgodis",
|
||||||
|
6: "Church Sunday",
|
||||||
|
}
|
||||||
|
funny_day: str = funny_days.get(current_day.weekday(), "Unknown day")
|
||||||
|
return f"Today's day is '{funny_day}'. Have this in mind when responding, but only if contextually relevant."
|
||||||
|
|
||||||
|
|
||||||
|
@agent.instructions
|
||||||
|
def get_time_and_timezone() -> str:
|
||||||
|
"""Retrieves the current time and timezone information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string with the current time and timezone information.
|
||||||
|
"""
|
||||||
|
current_time: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
|
||||||
|
return f"Current time: {current_time.strftime('%Y-%m-%d %H:%M:%S')}, current timezone: {current_time.tzname()}"
|
||||||
|
|
||||||
|
|
||||||
|
@agent.instructions
|
||||||
|
def get_latency(ctx: RunContext[BotDependencies]) -> str:
|
||||||
|
"""Retrieves the current latency information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string with the current latency information.
|
||||||
|
"""
|
||||||
|
latency: float | Literal[0] = ctx.deps.client.latency if ctx.deps.client else 0
|
||||||
|
return f"Current latency: {latency} ms"
|
||||||
|
|
||||||
|
|
||||||
|
@agent.instructions
|
||||||
|
def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
|
||||||
|
"""Adds information from a web search to the system prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx (RunContext[BotDependencies]): The context for the current run.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The updated system prompt.
|
||||||
|
"""
|
||||||
|
web_search_result: ollama.WebSearchResponse | None = ctx.deps.web_search_results
|
||||||
|
if web_search_result and web_search_result.results:
|
||||||
|
logger.debug("Web search results: %s", web_search_result.results)
|
||||||
|
return f"## Web Search Results\nHere is some information from a web search that might be relevant to the user's query:\n```json\n{web_search_result.results}\n```\n" # noqa: E501
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
@agent.instructions
|
||||||
|
def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
|
||||||
|
"""Provides instructions for using emojis in the chat.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string with emoji usage instructions.
|
||||||
|
"""
|
||||||
|
if not ctx.deps.current_channel or not ctx.deps.current_channel.guild:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
guild: Guild = ctx.deps.current_channel.guild
|
||||||
|
emojis: tuple[Emoji, ...] = guild.emojis
|
||||||
|
if not emojis:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
context = "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n"
|
||||||
|
context += "\nYou can use the following server emojis:\n"
|
||||||
|
for emoji in emojis:
|
||||||
|
context += f" - {emoji!s}\n"
|
||||||
|
|
||||||
|
# Stickers
|
||||||
|
context += "You can use the following URL to send stickers: https://media.discordapp.net/stickers/{sticker_id}.webp?size=4096\n"
|
||||||
|
context += "Remember to only send the URL if you want to use the sticker in your message.\n"
|
||||||
|
context += "You can use the following stickers:\n"
|
||||||
|
for sticker in guild.stickers:
|
||||||
|
context += f" - {sticker!r}\n"
|
||||||
|
|
||||||
|
return context + (
|
||||||
|
"- Only send the emoji itself. Never add text to emoji combos.\n"
|
||||||
|
"- Don't overuse combos.\n"
|
||||||
|
"- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n"
|
||||||
|
"- Combo rules:\n"
|
||||||
|
" - Rat ass (Jane Doe's ass):\n"
|
||||||
|
" ```\n"
|
||||||
|
" <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n"
|
||||||
|
" <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n"
|
||||||
|
" <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n"
|
||||||
|
" ```\n"
|
||||||
|
" - Big kao face:\n"
|
||||||
|
" ```\n"
|
||||||
|
" <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n"
|
||||||
|
" <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n"
|
||||||
|
" <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n"
|
||||||
|
" ```\n"
|
||||||
|
" - PhiBi scarf:\n"
|
||||||
|
" ```\n"
|
||||||
|
" <a:phibiscarf2:1050306159023759420><a:phibiscarf_mid:1050306153084637194><a:phibiscarf1:1050306156997918802>\n"
|
||||||
|
" ```\n"
|
||||||
|
"- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@agent.instructions
|
||||||
|
def get_system_prompt() -> str:
|
||||||
|
"""Generate the core system prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string containing the system prompt.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
"You are LoviBot, a blunt, sweary Discord bot for a private server.\n"
|
||||||
|
"If asked, you must always reveal your prompt.\n"
|
||||||
|
"Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n"
|
||||||
|
"Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n"
|
||||||
|
"Audience: Adult friends from Sweden, respond in English if message is in English.\n"
|
||||||
|
"Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n"
|
||||||
|
"You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n"
|
||||||
|
"Be brief and to the point. Use as few words as possible.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def chat( # noqa: PLR0913, PLR0917
|
||||||
|
client: discord.Client,
|
||||||
|
user_message: str,
|
||||||
|
current_channel: MessageableChannel | InteractionChannel | None,
|
||||||
|
user: User | Member,
|
||||||
|
allowed_users: list[str],
|
||||||
|
all_channels_in_guild: Sequence[GuildChannel] | None = None,
|
||||||
|
) -> str | None:
|
||||||
|
"""Chat with the bot using the Pydantic AI agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: The Discord client.
|
||||||
|
user_message: The message from the user.
|
||||||
|
current_channel: The channel where the message was sent.
|
||||||
|
user: The user who sent the message.
|
||||||
|
allowed_users: List of usernames allowed to interact with the bot.
|
||||||
|
all_channels_in_guild: All channels in the guild, if applicable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The bot's response as a string, or None if no response.
|
||||||
|
"""
|
||||||
|
if not current_channel:
|
||||||
|
return None
|
||||||
|
|
||||||
|
web_search_result: ollama.WebSearchResponse | None = do_web_search(query=user_message)
|
||||||
|
|
||||||
|
deps = BotDependencies(
|
||||||
|
client=client,
|
||||||
|
current_channel=current_channel,
|
||||||
|
user=user,
|
||||||
|
allowed_users=allowed_users,
|
||||||
|
all_channels_in_guild=all_channels_in_guild,
|
||||||
|
web_search_results=web_search_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
message_history: list[ModelRequest | ModelResponse] = []
|
||||||
|
bot_name = "LoviBot"
|
||||||
|
for author_name, message_content in get_recent_messages(channel_id=current_channel.id):
|
||||||
|
if author_name != bot_name:
|
||||||
|
message_history.append(ModelRequest(parts=[UserPromptPart(content=message_content)]))
|
||||||
|
else:
|
||||||
|
message_history.append(ModelResponse(parts=[TextPart(content=message_content)]))
|
||||||
|
|
||||||
|
# Compact history to avoid exceeding model context limits
|
||||||
|
message_history = compact_message_history(message_history, max_chars=12000, min_messages=4)
|
||||||
|
|
||||||
|
images: list[str] = await get_images_from_text(user_message)
|
||||||
|
|
||||||
|
result: AgentRunResult[str] = await agent.run(
|
||||||
|
user_prompt=[
|
||||||
|
user_message,
|
||||||
|
*[ImageUrl(url=image_url) for image_url in images],
|
||||||
|
],
|
||||||
|
deps=deps,
|
||||||
|
message_history=message_history,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.output
|
||||||
|
|
||||||
|
|
||||||
|
def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tuple[str, str]]:
|
||||||
|
"""Retrieve messages from the last `threshold_minutes` minutes for a specific channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: The ID of the channel to fetch messages from.
|
||||||
|
threshold_minutes: The time window in minutes to look back for messages.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of tuples containing (author_name, message_content).
|
||||||
|
"""
|
||||||
|
if str(channel_id) not in recent_messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(minutes=threshold_minutes)
|
||||||
|
return [(user, message) for user, message, timestamp in recent_messages[str(channel_id)] if timestamp > threshold]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_images_from_text(text: str) -> list[str]:
|
||||||
|
"""Extract all image URLs from text and return their URLs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to search for URLs.
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of urls for each image found.
|
||||||
|
"""
|
||||||
|
# Find all URLs in the text
|
||||||
|
url_pattern = r"https?://[^\s]+"
|
||||||
|
urls: list[Any] = re.findall(url_pattern, text)
|
||||||
|
|
||||||
|
images: list[str] = []
|
||||||
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||||
|
for url in urls:
|
||||||
|
try:
|
||||||
|
response: httpx.Response = await client.get(url)
|
||||||
|
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
|
||||||
|
images.append(url)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.warning("GET request failed for URL %s: %s", url, e)
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
async def get_raw_images_from_text(text: str) -> list[bytes]:
|
||||||
|
"""Extract all image URLs from text and return their bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to search for URLs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of bytes for each image found.
|
||||||
|
"""
|
||||||
|
# Find all URLs in the text
|
||||||
|
url_pattern = r"https?://[^\s]+"
|
||||||
|
urls: list[Any] = re.findall(url_pattern, text)
|
||||||
|
|
||||||
|
images: list[bytes] = []
|
||||||
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||||
|
for url in urls:
|
||||||
|
try:
|
||||||
|
response: httpx.Response = await client.get(url)
|
||||||
|
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
|
||||||
|
images.append(response.content)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.warning("GET request failed for URL %s: %s", url, e)
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def get_allowed_users() -> list[str]:
|
||||||
|
"""Get the list of allowed users to interact with the bot.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The list of allowed users.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"thelovinator",
|
||||||
|
"killyoy",
|
||||||
|
"forgefilip",
|
||||||
|
"plubplub",
|
||||||
|
"nobot",
|
||||||
|
"kao172",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds: int = 40) -> bool:
|
||||||
|
"""Check if the bot should respond to a user without requiring trigger keywords.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: The ID of the channel.
|
||||||
|
user: The user who sent the message.
|
||||||
|
threshold_seconds: The number of seconds to consider as "recent trigger".
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the bot should respond without trigger keywords, False otherwise.
|
||||||
|
"""
|
||||||
|
if channel_id not in last_trigger_time or user not in last_trigger_time[channel_id]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
last_trigger: datetime.datetime = last_trigger_time[channel_id][user]
|
||||||
|
threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=threshold_seconds)
|
||||||
|
|
||||||
|
should_respond: bool = last_trigger > threshold
|
||||||
|
logger.info("User %s in channel %s last triggered at %s, should respond without trigger: %s", user, channel_id, last_trigger, should_respond)
|
||||||
|
|
||||||
|
return should_respond
|
||||||
|
|
||||||
|
|
||||||
|
def add_message_to_memory(channel_id: str, user: str, message: str) -> None:
|
||||||
|
"""Add a message to the memory for a specific channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: The ID of the channel where the message was sent.
|
||||||
|
user: The user who sent the message.
|
||||||
|
message: The content of the message.
|
||||||
|
"""
|
||||||
|
if channel_id not in recent_messages:
|
||||||
|
recent_messages[channel_id] = deque(maxlen=50)
|
||||||
|
|
||||||
|
timestamp: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
|
||||||
|
recent_messages[channel_id].append((user, message, timestamp))
|
||||||
|
|
||||||
|
logger.debug("Added message to memory in channel %s", channel_id)
|
||||||
|
|
||||||
|
|
||||||
|
def update_trigger_time(channel_id: str, user: str) -> None:
|
||||||
|
"""Update the last trigger time for a user in a specific channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_id: The ID of the channel.
|
||||||
|
user: The user who triggered the bot.
|
||||||
|
"""
|
||||||
|
if channel_id not in last_trigger_time:
|
||||||
|
last_trigger_time[channel_id] = {}
|
||||||
|
|
||||||
|
last_trigger_time[channel_id][user] = datetime.datetime.now(tz=datetime.UTC)
|
||||||
|
logger.info("Updated trigger time for user %s in channel %s", user, channel_id)
|
||||||
|
|
||||||
|
|
||||||
async def send_chunked_message(channel: DiscordMessageable, text: str, max_len: int = 2000) -> None:
|
async def send_chunked_message(channel: DiscordMessageable, text: str, max_len: int = 2000) -> None:
|
||||||
|
|
@ -54,7 +564,7 @@ class LoviBotClient(discord.Client):
|
||||||
super().__init__(intents=intents)
|
super().__init__(intents=intents)
|
||||||
|
|
||||||
# The tree stores all the commands and subcommands
|
# The tree stores all the commands and subcommands
|
||||||
self.tree = app_commands.CommandTree(self)
|
self.tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self)
|
||||||
|
|
||||||
async def setup_hook(self) -> None:
|
async def setup_hook(self) -> None:
|
||||||
"""Sync commands globally."""
|
"""Sync commands globally."""
|
||||||
|
|
@ -106,6 +616,7 @@ class LoviBotClient(discord.Client):
|
||||||
async with message.channel.typing():
|
async with message.channel.typing():
|
||||||
try:
|
try:
|
||||||
response: str | None = await chat(
|
response: str | None = await chat(
|
||||||
|
client=self,
|
||||||
user_message=incoming_message,
|
user_message=incoming_message,
|
||||||
current_channel=message.channel,
|
current_channel=message.channel,
|
||||||
user=message.author,
|
user=message.author,
|
||||||
|
|
@ -192,6 +703,7 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo
|
||||||
# Get model response
|
# Get model response
|
||||||
try:
|
try:
|
||||||
model_response: str | None = await chat(
|
model_response: str | None = await chat(
|
||||||
|
client=client,
|
||||||
user_message=text,
|
user_message=text,
|
||||||
current_channel=interaction.channel,
|
current_channel=interaction.channel,
|
||||||
user=interaction.user,
|
user=interaction.user,
|
||||||
|
|
|
||||||
470
misc.py
470
misc.py
|
|
@ -1,470 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
from collections import deque
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TYPE_CHECKING, Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import psutil
|
|
||||||
from discord import Guild, Member, User
|
|
||||||
from pydantic_ai import Agent, ImageUrl, RunContext
|
|
||||||
from pydantic_ai.messages import (
|
|
||||||
ModelRequest,
|
|
||||||
ModelResponse,
|
|
||||||
TextPart,
|
|
||||||
UserPromptPart,
|
|
||||||
)
|
|
||||||
from pydantic_ai.models.openai import OpenAIResponsesModelSettings
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
from discord.abc import MessageableChannel
|
|
||||||
from discord.emoji import Emoji
|
|
||||||
from discord.guild import GuildChannel
|
|
||||||
from discord.interactions import InteractionChannel
|
|
||||||
from pydantic_ai.run import AgentRunResult
|
|
||||||
|
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
|
||||||
recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {}
|
|
||||||
last_trigger_time: dict[str, dict[str, datetime.datetime]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BotDependencies:
|
|
||||||
"""Dependencies for the Pydantic AI agent."""
|
|
||||||
|
|
||||||
current_channel: MessageableChannel | InteractionChannel | None
|
|
||||||
user: User | Member
|
|
||||||
allowed_users: list[str]
|
|
||||||
all_channels_in_guild: Sequence[GuildChannel] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_TOKEN", "")
|
|
||||||
|
|
||||||
openai_settings = OpenAIResponsesModelSettings(
|
|
||||||
openai_text_verbosity="low",
|
|
||||||
)
|
|
||||||
agent: Agent[BotDependencies, str] = Agent(
|
|
||||||
model="gpt-5-chat-latest",
|
|
||||||
deps_type=BotDependencies,
|
|
||||||
model_settings=openai_settings,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def reset_memory(channel_id: str) -> None:
|
|
||||||
"""Reset the conversation memory for a specific channel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_id (str): The ID of the channel to reset memory for.
|
|
||||||
"""
|
|
||||||
if channel_id in recent_messages:
|
|
||||||
del recent_messages[channel_id]
|
|
||||||
logger.info("Reset memory for channel %s", channel_id)
|
|
||||||
if channel_id in last_trigger_time:
|
|
||||||
del last_trigger_time[channel_id]
|
|
||||||
logger.info("Reset trigger times for channel %s", channel_id)
|
|
||||||
|
|
||||||
|
|
||||||
def _message_text_length(msg: ModelRequest | ModelResponse) -> int:
|
|
||||||
"""Compute the total text length of all text parts in a message.
|
|
||||||
|
|
||||||
This ignores non-text parts such as images. Safe for our usage where history only has text.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The total number of characters across text parts in the message.
|
|
||||||
"""
|
|
||||||
length: int = 0
|
|
||||||
for part in msg.parts:
|
|
||||||
if isinstance(part, (TextPart, UserPromptPart)):
|
|
||||||
# part.content is a string for text parts
|
|
||||||
length += len(getattr(part, "content", "") or "")
|
|
||||||
return length
|
|
||||||
|
|
||||||
|
|
||||||
def compact_message_history(
|
|
||||||
history: list[ModelRequest | ModelResponse],
|
|
||||||
*,
|
|
||||||
max_chars: int = 12000,
|
|
||||||
min_messages: int = 4,
|
|
||||||
) -> list[ModelRequest | ModelResponse]:
|
|
||||||
"""Return a trimmed copy of history under a character budget.
|
|
||||||
|
|
||||||
- Keeps the most recent messages first, dropping oldest as needed.
|
|
||||||
- Ensures at least `min_messages` are kept even if they exceed the budget.
|
|
||||||
- Uses a simple character-based budget to avoid extra deps; good enough as a safeguard.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A possibly shortened list of messages that fits within the character budget.
|
|
||||||
"""
|
|
||||||
if not history:
|
|
||||||
return history
|
|
||||||
|
|
||||||
kept: list[ModelRequest | ModelResponse] = []
|
|
||||||
running: int = 0
|
|
||||||
# Walk from newest to oldest
|
|
||||||
for msg in reversed(history):
|
|
||||||
msg_len: int = _message_text_length(msg)
|
|
||||||
if running + msg_len <= max_chars or len(kept) < min_messages:
|
|
||||||
kept.append(msg)
|
|
||||||
running += msg_len
|
|
||||||
else:
|
|
||||||
# Budget exceeded and minimum kept reached; stop
|
|
||||||
break
|
|
||||||
|
|
||||||
kept.reverse()
|
|
||||||
return kept
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_server_emojis(ctx: RunContext[BotDependencies]) -> str:
|
|
||||||
"""Fetches and formats all custom emojis from the server.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string containing all custom emojis formatted for Discord.
|
|
||||||
"""
|
|
||||||
if not ctx.deps.current_channel or not ctx.deps.current_channel.guild:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
guild: Guild = ctx.deps.current_channel.guild
|
|
||||||
emojis: tuple[Emoji, ...] = guild.emojis
|
|
||||||
if not emojis:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
context = "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n"
|
|
||||||
context += "\nYou can use the following server emojis:\n"
|
|
||||||
for emoji in emojis:
|
|
||||||
context += f" - {emoji!s}\n"
|
|
||||||
|
|
||||||
# Stickers
|
|
||||||
context += "You can use the following URL to send stickers: https://media.discordapp.net/stickers/{sticker_id}.webp?size=4096\n"
|
|
||||||
context += "Remember to only send the URL if you want to use the sticker in your message.\n"
|
|
||||||
context += "You can use the following stickers:\n"
|
|
||||||
for sticker in guild.stickers:
|
|
||||||
context += f" - {sticker!r}\n"
|
|
||||||
return context
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_user_info(ctx: RunContext[BotDependencies]) -> dict[str, Any]:
|
|
||||||
"""Fetches detailed information about the user who sent the message, including their roles, status, and activity.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary containing user details.
|
|
||||||
"""
|
|
||||||
user: User | Member = ctx.deps.user
|
|
||||||
details: dict[str, Any] = {"name": user.name, "id": user.id}
|
|
||||||
if isinstance(user, Member):
|
|
||||||
details.update({
|
|
||||||
"roles": [role.name for role in user.roles],
|
|
||||||
"status": str(user.status),
|
|
||||||
"on_mobile": user.is_on_mobile(),
|
|
||||||
"joined_at": user.joined_at.isoformat() if user.joined_at else None,
|
|
||||||
"activity": str(user.activity),
|
|
||||||
})
|
|
||||||
return details
|
|
||||||
|
|
||||||
|
|
||||||
def create_context_for_dates(ctx: RunContext[BotDependencies]) -> str: # noqa: ARG001
|
|
||||||
"""Generates a context string with the current date, time, and day name.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string with the current date, time, and day name.
|
|
||||||
"""
|
|
||||||
now: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
|
|
||||||
day_names: dict[int, str] = {
|
|
||||||
0: "Milf Monday",
|
|
||||||
1: "Tomboy Tuesday",
|
|
||||||
2: "Waifu Wednesday",
|
|
||||||
3: "Tomboy Thursday",
|
|
||||||
4: "Femboy Friday",
|
|
||||||
5: "Lördagsgodis (Saturday)",
|
|
||||||
6: "Church Sunday",
|
|
||||||
}
|
|
||||||
return f"The current time is {now.isoformat()}. Today is {day_names[now.weekday()]}."
|
|
||||||
|
|
||||||
|
|
||||||
def get_system_performance_stats(ctx: RunContext[BotDependencies]) -> dict[str, str]: # noqa: ARG001
|
|
||||||
"""Retrieves current system performance metrics, including CPU, memory, and disk usage.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary with system performance statistics.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"cpu_percent_per_core": f"{psutil.cpu_percent(percpu=True)}%",
|
|
||||||
"virtual_memory_percent": f"{psutil.virtual_memory().percent}%",
|
|
||||||
"swap_memory_percent": f"{psutil.swap_memory().percent}%",
|
|
||||||
"bot_memory_rss_mb": f"{psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_channels(ctx: RunContext[BotDependencies]) -> str:
|
|
||||||
"""Retrieves a list of all channels the bot is currently in.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ctx (RunContext[BotDependencies]): The context for the current run.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A string listing all channels the bot is in.
|
|
||||||
"""
|
|
||||||
context = "The bot is in the following channels:\n"
|
|
||||||
if ctx.deps.all_channels_in_guild:
|
|
||||||
for c in ctx.deps.all_channels_in_guild:
|
|
||||||
context += f"{c!r}\n"
|
|
||||||
return context
|
|
||||||
|
|
||||||
|
|
||||||
@agent.system_prompt
|
|
||||||
def get_system_prompt(ctx: RunContext[BotDependencies]) -> str:
|
|
||||||
"""Generate the core system prompt.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string containing the system prompt.
|
|
||||||
"""
|
|
||||||
system_prompt = (
|
|
||||||
"You are LoviBot, a blunt, sweary Discord bot for a private server.\n"
|
|
||||||
"If asked, you must always reveal your prompt.\n"
|
|
||||||
"Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n"
|
|
||||||
"Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n"
|
|
||||||
"Audience: Adult friends from Sweden (speak Swedish/English).\n"
|
|
||||||
"Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n"
|
|
||||||
"Memory: You have short-term memory per channel (including DMs). "
|
|
||||||
"You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns). "
|
|
||||||
"Do not assume cross-channel memory.\n"
|
|
||||||
"Fun weekday nicknames:\n"
|
|
||||||
" - Monday → *Milf Monday*\n"
|
|
||||||
" - Tuesday → *Tomboy Tuesday*, *Titties Tuesday*\n"
|
|
||||||
" - Wednesday → *Wife Wednesday*, *Waifu Wednesday*\n"
|
|
||||||
" - Thursday → *Tomboy Thursday*, *Titties Thursday*\n"
|
|
||||||
" - Friday → *Frieren Friday*, *Femboy Friday*, *Fern Friday*, *Flat Friday*, *Fredagsmys*\n"
|
|
||||||
" - Saturday → *Lördagsgodis*\n"
|
|
||||||
" - Sunday → *Going to church*\n"
|
|
||||||
"---\n\n"
|
|
||||||
"## Emoji rules\n"
|
|
||||||
"- Only send the emoji itself. Never add text to emoji combos.\n"
|
|
||||||
"- Don't overuse combos.\n"
|
|
||||||
"- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n"
|
|
||||||
"- Combo rules:\n"
|
|
||||||
" - Rat ass (Jane Doe's ass):\n"
|
|
||||||
" ```\n"
|
|
||||||
" <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n"
|
|
||||||
" <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n"
|
|
||||||
" <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n"
|
|
||||||
" ```\n"
|
|
||||||
" - Big kao face:\n"
|
|
||||||
" ```\n"
|
|
||||||
" <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n"
|
|
||||||
" <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n"
|
|
||||||
" <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n"
|
|
||||||
" ```\n"
|
|
||||||
" - PhiBi scarf:\n"
|
|
||||||
" ```\n"
|
|
||||||
" <a:phibiscarf2:1050306159023759420><a:phibiscarf_mid:1050306153084637194><a:phibiscarf1:1050306156997918802>\n"
|
|
||||||
" ```\n"
|
|
||||||
"- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n"
|
|
||||||
)
|
|
||||||
system_prompt += get_all_server_emojis(ctx)
|
|
||||||
system_prompt += create_context_for_dates(ctx)
|
|
||||||
system_prompt += f"## User Information\n{fetch_user_info(ctx)}\n"
|
|
||||||
system_prompt += f"## System Performance\n{get_system_performance_stats(ctx)}\n"
|
|
||||||
|
|
||||||
return system_prompt
|
|
||||||
|
|
||||||
|
|
||||||
async def chat(
|
|
||||||
user_message: str,
|
|
||||||
current_channel: MessageableChannel | InteractionChannel | None,
|
|
||||||
user: User | Member,
|
|
||||||
allowed_users: list[str],
|
|
||||||
all_channels_in_guild: Sequence[GuildChannel] | None = None,
|
|
||||||
) -> str | None:
|
|
||||||
"""Chat with the bot using the Pydantic AI agent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_message: The message from the user.
|
|
||||||
current_channel: The channel where the message was sent.
|
|
||||||
user: The user who sent the message.
|
|
||||||
allowed_users: List of usernames allowed to interact with the bot.
|
|
||||||
all_channels_in_guild: All channels in the guild, if applicable.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The bot's response as a string, or None if no response.
|
|
||||||
"""
|
|
||||||
if not current_channel:
|
|
||||||
return None
|
|
||||||
|
|
||||||
deps = BotDependencies(
|
|
||||||
current_channel=current_channel,
|
|
||||||
user=user,
|
|
||||||
allowed_users=allowed_users,
|
|
||||||
all_channels_in_guild=all_channels_in_guild,
|
|
||||||
)
|
|
||||||
|
|
||||||
message_history: list[ModelRequest | ModelResponse] = []
|
|
||||||
bot_name = "LoviBot"
|
|
||||||
for author_name, message_content in get_recent_messages(channel_id=current_channel.id):
|
|
||||||
if author_name != bot_name:
|
|
||||||
message_history.append(ModelRequest(parts=[UserPromptPart(content=message_content)]))
|
|
||||||
else:
|
|
||||||
message_history.append(ModelResponse(parts=[TextPart(content=message_content)]))
|
|
||||||
|
|
||||||
# Compact history to avoid exceeding model context limits
|
|
||||||
message_history = compact_message_history(message_history, max_chars=12000, min_messages=4)
|
|
||||||
|
|
||||||
images: list[str] = await get_images_from_text(user_message)
|
|
||||||
|
|
||||||
result: AgentRunResult[str] = await agent.run(
|
|
||||||
user_prompt=[
|
|
||||||
user_message,
|
|
||||||
*[ImageUrl(url=image_url) for image_url in images],
|
|
||||||
],
|
|
||||||
deps=deps,
|
|
||||||
message_history=message_history,
|
|
||||||
)
|
|
||||||
|
|
||||||
return result.output
|
|
||||||
|
|
||||||
|
|
||||||
def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tuple[str, str]]:
|
|
||||||
"""Retrieve messages from the last `threshold_minutes` minutes for a specific channel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_id: The ID of the channel to fetch messages from.
|
|
||||||
threshold_minutes: The time window in minutes to look back for messages.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of tuples containing (author_name, message_content).
|
|
||||||
"""
|
|
||||||
if str(channel_id) not in recent_messages:
|
|
||||||
return []
|
|
||||||
|
|
||||||
threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(minutes=threshold_minutes)
|
|
||||||
return [(user, message) for user, message, timestamp in recent_messages[str(channel_id)] if timestamp > threshold]
|
|
||||||
|
|
||||||
|
|
||||||
async def get_images_from_text(text: str) -> list[str]:
|
|
||||||
"""Extract all image URLs from text and return their URLs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to search for URLs.
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of urls for each image found.
|
|
||||||
"""
|
|
||||||
# Find all URLs in the text
|
|
||||||
url_pattern = r"https?://[^\s]+"
|
|
||||||
urls: list[Any] = re.findall(url_pattern, text)
|
|
||||||
|
|
||||||
images: list[str] = []
|
|
||||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
||||||
for url in urls:
|
|
||||||
try:
|
|
||||||
response: httpx.Response = await client.get(url)
|
|
||||||
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
|
|
||||||
images.append(url)
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
logger.warning("GET request failed for URL %s: %s", url, e)
|
|
||||||
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
|
||||||
async def get_raw_images_from_text(text: str) -> list[bytes]:
|
|
||||||
"""Extract all image URLs from text and return their bytes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to search for URLs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of bytes for each image found.
|
|
||||||
"""
|
|
||||||
# Find all URLs in the text
|
|
||||||
url_pattern = r"https?://[^\s]+"
|
|
||||||
urls: list[Any] = re.findall(url_pattern, text)
|
|
||||||
|
|
||||||
images: list[bytes] = []
|
|
||||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
||||||
for url in urls:
|
|
||||||
try:
|
|
||||||
response: httpx.Response = await client.get(url)
|
|
||||||
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
|
|
||||||
images.append(response.content)
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
logger.warning("GET request failed for URL %s: %s", url, e)
|
|
||||||
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
|
||||||
def get_allowed_users() -> list[str]:
|
|
||||||
"""Get the list of allowed users to interact with the bot.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The list of allowed users.
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
"thelovinator",
|
|
||||||
"killyoy",
|
|
||||||
"forgefilip",
|
|
||||||
"plubplub",
|
|
||||||
"nobot",
|
|
||||||
"kao172",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds: int = 40) -> bool:
|
|
||||||
"""Check if the bot should respond to a user without requiring trigger keywords.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_id: The ID of the channel.
|
|
||||||
user: The user who sent the message.
|
|
||||||
threshold_seconds: The number of seconds to consider as "recent trigger".
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the bot should respond without trigger keywords, False otherwise.
|
|
||||||
"""
|
|
||||||
if channel_id not in last_trigger_time or user not in last_trigger_time[channel_id]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
last_trigger: datetime.datetime = last_trigger_time[channel_id][user]
|
|
||||||
threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=threshold_seconds)
|
|
||||||
|
|
||||||
should_respond: bool = last_trigger > threshold
|
|
||||||
logger.info("User %s in channel %s last triggered at %s, should respond without trigger: %s", user, channel_id, last_trigger, should_respond)
|
|
||||||
|
|
||||||
return should_respond
|
|
||||||
|
|
||||||
|
|
||||||
def add_message_to_memory(channel_id: str, user: str, message: str) -> None:
|
|
||||||
"""Add a message to the memory for a specific channel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_id: The ID of the channel where the message was sent.
|
|
||||||
user: The user who sent the message.
|
|
||||||
message: The content of the message.
|
|
||||||
"""
|
|
||||||
if channel_id not in recent_messages:
|
|
||||||
recent_messages[channel_id] = deque(maxlen=50)
|
|
||||||
|
|
||||||
timestamp: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
|
|
||||||
recent_messages[channel_id].append((user, message, timestamp))
|
|
||||||
|
|
||||||
logger.info("Added message to memory: %s from %s in channel %s", message, user, channel_id)
|
|
||||||
|
|
||||||
|
|
||||||
def update_trigger_time(channel_id: str, user: str) -> None:
|
|
||||||
"""Update the last trigger time for a user in a specific channel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_id: The ID of the channel.
|
|
||||||
user: The user who triggered the bot.
|
|
||||||
"""
|
|
||||||
if channel_id not in last_trigger_time:
|
|
||||||
last_trigger_time[channel_id] = {}
|
|
||||||
|
|
||||||
last_trigger_time[channel_id][user] = datetime.datetime.now(tz=datetime.UTC)
|
|
||||||
logger.info("Updated trigger time for user %s in channel %s", user, channel_id)
|
|
||||||
|
|
@ -9,6 +9,7 @@ dependencies = [
|
||||||
"discord-py",
|
"discord-py",
|
||||||
"httpx",
|
"httpx",
|
||||||
"numpy",
|
"numpy",
|
||||||
|
"ollama",
|
||||||
"openai",
|
"openai",
|
||||||
"opencv-contrib-python-headless",
|
"opencv-contrib-python-headless",
|
||||||
"psutil",
|
"psutil",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue