Update ruff config and fix its errors
Some checks are pending
Build Docker Image / docker (push) Waiting to run
Some checks are pending
Build Docker Image / docker (push) Waiting to run
This commit is contained in:
parent
bfc37ec99f
commit
8b1636fbcc
3 changed files with 321 additions and 115 deletions
399
main.py
399
main.py
|
|
@ -8,7 +8,11 @@ import os
|
|||
import re
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal, Self, TypeVar
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Self
|
||||
from typing import TypeVar
|
||||
|
||||
import cv2
|
||||
import discord
|
||||
|
|
@ -18,24 +22,33 @@ import ollama
|
|||
import openai
|
||||
import psutil
|
||||
import sentry_sdk
|
||||
from discord import Emoji, Forbidden, Guild, GuildSticker, HTTPException, Member, NotFound, User, app_commands
|
||||
from discord import Forbidden
|
||||
from discord import HTTPException
|
||||
from discord import Member
|
||||
from discord import NotFound
|
||||
from discord import app_commands
|
||||
from dotenv import load_dotenv
|
||||
from pydantic_ai import Agent, ImageUrl, RunContext
|
||||
from pydantic_ai.messages import (
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
TextPart,
|
||||
UserPromptPart,
|
||||
)
|
||||
from pydantic_ai import Agent
|
||||
from pydantic_ai import ImageUrl
|
||||
from pydantic_ai.messages import ModelRequest
|
||||
from pydantic_ai.messages import ModelResponse
|
||||
from pydantic_ai.messages import TextPart
|
||||
from pydantic_ai.messages import UserPromptPart
|
||||
from pydantic_ai.models.openai import OpenAIResponsesModelSettings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Sequence
|
||||
|
||||
from discord import Emoji
|
||||
from discord import Guild
|
||||
from discord import GuildSticker
|
||||
from discord import User
|
||||
from discord.abc import Messageable as DiscordMessageable
|
||||
from discord.abc import MessageableChannel
|
||||
from discord.guild import GuildChannel
|
||||
from discord.interactions import InteractionChannel
|
||||
from pydantic_ai import RunContext
|
||||
from pydantic_ai.run import AgentRunResult
|
||||
|
||||
load_dotenv(verbose=True)
|
||||
|
|
@ -57,8 +70,10 @@ recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {}
|
|||
last_trigger_time: dict[str, dict[str, datetime.datetime]] = {}
|
||||
|
||||
# Storage for reset snapshots to enable undo functionality
|
||||
# Each channel stores its previous state: (recent_messages_snapshot, last_trigger_time_snapshot)
|
||||
reset_snapshots: dict[str, tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]]] = {}
|
||||
reset_snapshots: dict[
|
||||
str,
|
||||
tuple[deque[tuple[str, str, datetime.datetime]], dict[str, datetime.datetime]],
|
||||
] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -94,10 +109,14 @@ def reset_memory(channel_id: str) -> None:
|
|||
"""
|
||||
# Create snapshot before reset for undo functionality
|
||||
messages_snapshot: deque[tuple[str, str, datetime.datetime]] = (
|
||||
deque(recent_messages[channel_id], maxlen=50) if channel_id in recent_messages else deque(maxlen=50)
|
||||
deque(recent_messages[channel_id], maxlen=50)
|
||||
if channel_id in recent_messages
|
||||
else deque(maxlen=50)
|
||||
)
|
||||
|
||||
trigger_snapshot: dict[str, datetime.datetime] = dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {}
|
||||
trigger_snapshot: dict[str, datetime.datetime] = (
|
||||
dict(last_trigger_time[channel_id]) if channel_id in last_trigger_time else {}
|
||||
)
|
||||
|
||||
# Only save snapshot if there's something to restore
|
||||
if messages_snapshot or trigger_snapshot:
|
||||
|
|
@ -151,7 +170,8 @@ def undo_reset(channel_id: str) -> bool:
|
|||
def _message_text_length(msg: ModelRequest | ModelResponse) -> int:
|
||||
"""Compute the total text length of all text parts in a message.
|
||||
|
||||
This ignores non-text parts such as images. Safe for our usage where history only has text.
|
||||
This ignores non-text parts such as images.
|
||||
Safe for our usage where history only has text.
|
||||
|
||||
Returns:
|
||||
The total number of characters across text parts in the message.
|
||||
|
|
@ -174,7 +194,6 @@ def compact_message_history(
|
|||
|
||||
- Keeps the most recent messages first, dropping oldest as needed.
|
||||
- Ensures at least `min_messages` are kept even if they exceed the budget.
|
||||
- Uses a simple character-based budget to avoid extra deps; good enough as a safeguard.
|
||||
|
||||
Returns:
|
||||
A possibly shortened list of messages that fits within the character budget.
|
||||
|
|
@ -199,7 +218,9 @@ def compact_message_history(
|
|||
# MARK: fetch_user_info
|
||||
@chatgpt_agent.instructions
|
||||
def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
|
||||
"""Fetches detailed information about the user who sent the message, including their roles, status, and activity.
|
||||
"""Fetches detailed information about the user who sent the message.
|
||||
|
||||
Includes their roles, status, and activity.
|
||||
|
||||
Returns:
|
||||
A string representation of the user's details.
|
||||
|
|
@ -220,16 +241,21 @@ def fetch_user_info(ctx: RunContext[BotDependencies]) -> str:
|
|||
# MARK: get_system_performance_stats
|
||||
@chatgpt_agent.instructions
|
||||
def get_system_performance_stats() -> str:
|
||||
"""Retrieves current system performance metrics, including CPU, memory, and disk usage.
|
||||
"""Retrieves system performance metrics, including CPU, memory, and disk usage.
|
||||
|
||||
Returns:
|
||||
A string representation of the system performance statistics.
|
||||
"""
|
||||
cpu_percent_per_core: list[float] = psutil.cpu_percent(percpu=True)
|
||||
virtual_memory_percent: float = psutil.virtual_memory().percent
|
||||
swap_memory_percent: float = psutil.swap_memory().percent
|
||||
rss_mb: float = psutil.Process().memory_info().rss / (1024 * 1024)
|
||||
|
||||
stats: dict[str, str] = {
|
||||
"cpu_percent_per_core": f"{psutil.cpu_percent(percpu=True)}%",
|
||||
"virtual_memory_percent": f"{psutil.virtual_memory().percent}%",
|
||||
"swap_memory_percent": f"{psutil.swap_memory().percent}%",
|
||||
"bot_memory_rss_mb": f"{psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB",
|
||||
"cpu_percent_per_core": f"{cpu_percent_per_core}%",
|
||||
"virtual_memory_percent": f"{virtual_memory_percent}%",
|
||||
"swap_memory_percent": f"{swap_memory_percent}%",
|
||||
"bot_memory_rss_mb": f"{rss_mb:.2f} MB",
|
||||
}
|
||||
return str(stats)
|
||||
|
||||
|
|
@ -262,10 +288,13 @@ def do_web_search(query: str) -> ollama.WebSearchResponse | None:
|
|||
query (str): The search query.
|
||||
|
||||
Returns:
|
||||
ollama.WebSearchResponse | None: The response from the web search, or None if an error occurs.
|
||||
ollama.WebSearchResponse | None: The response from the search, None if an error.
|
||||
"""
|
||||
try:
|
||||
response: ollama.WebSearchResponse = ollama.web_search(query=query, max_results=1)
|
||||
response: ollama.WebSearchResponse = ollama.web_search(
|
||||
query=query,
|
||||
max_results=1,
|
||||
)
|
||||
except ValueError:
|
||||
logger.exception("OLLAMA_API_KEY environment variable is not set")
|
||||
return None
|
||||
|
|
@ -282,7 +311,9 @@ def get_time_and_timezone() -> str:
|
|||
A string with the current time and timezone information.
|
||||
"""
|
||||
current_time: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
|
||||
return f"Current time: {current_time.strftime('%Y-%m-%d %H:%M:%S')}, current timezone: {current_time.tzname()}"
|
||||
str_time: str = current_time.strftime("%Y-%m-%d %H:%M:%S %Z")
|
||||
|
||||
return f"Current time: {str_time}"
|
||||
|
||||
|
||||
# MARK: get_latency
|
||||
|
|
@ -309,10 +340,37 @@ def added_information_from_web_search(ctx: RunContext[BotDependencies]) -> str:
|
|||
str: The updated system prompt.
|
||||
"""
|
||||
web_search_result: ollama.WebSearchResponse | None = ctx.deps.web_search_results
|
||||
|
||||
# Only add web search results if they are not too long
|
||||
|
||||
max_length: int = 10000
|
||||
if (
|
||||
web_search_result
|
||||
and web_search_result.results
|
||||
and len(web_search_result.results) > max_length
|
||||
):
|
||||
logger.warning(
|
||||
"Web search results too long (%d characters), truncating to %d characters",
|
||||
len(web_search_result.results),
|
||||
max_length,
|
||||
)
|
||||
web_search_result.results = web_search_result.results[:max_length]
|
||||
|
||||
# Also tell the model that the results were truncated and may be incomplete
|
||||
return (
|
||||
f"Here is some information from a web search that might be relevant to the user's query. " # noqa: E501
|
||||
f"The results were too long and have been truncated, so they may be incomplete:\n" # noqa: E501
|
||||
f"```json\n{web_search_result.results}\n```\n"
|
||||
)
|
||||
|
||||
if web_search_result and web_search_result.results:
|
||||
logger.debug("Web search results: %s", web_search_result.results)
|
||||
return f"Here is some information from a web search that might be relevant to the user's query:\n```json\n{web_search_result.results}\n```\n"
|
||||
return ""
|
||||
return (
|
||||
f"Here is some information from a web search that might be relevant to the user's query:\n" # noqa: E501
|
||||
f"```json\n{web_search_result.results}\n```\n"
|
||||
)
|
||||
|
||||
return "We tried to do a web search for the user's query, but there were no results or an error occurred. You can tell them that!\n" # noqa: E501
|
||||
|
||||
|
||||
# MARK: get_sticker_instructions
|
||||
|
|
@ -334,14 +392,17 @@ def get_sticker_instructions(ctx: RunContext[BotDependencies]) -> str:
|
|||
return ""
|
||||
|
||||
# Stickers
|
||||
context += "Remember to only send the URL if you want to use the sticker in your message.\n"
|
||||
context += "Remember to only send the URL if you want to use the sticker in your message.\n" # noqa: E501
|
||||
context += "Available stickers:\n"
|
||||
|
||||
for sticker in stickers:
|
||||
sticker_url: str = sticker.url + "?size=4096"
|
||||
context += f" - {sticker.name=}: {sticker_url=} - {sticker.description=} - {sticker.emoji=}\n"
|
||||
context += f" - {sticker.name=}: {sticker_url=} - {sticker.description=} - {sticker.emoji=}\n" # noqa: E501
|
||||
|
||||
return context + ("- Only send the sticker URL itself. Never add text to sticker combos.\n")
|
||||
return (
|
||||
context
|
||||
+ "- Only send the sticker URL itself. Never add text to sticker combos.\n"
|
||||
)
|
||||
|
||||
|
||||
# MARK: get_emoji_instructions
|
||||
|
|
@ -362,7 +423,7 @@ def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
|
|||
if not emojis:
|
||||
return ""
|
||||
|
||||
context += "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n"
|
||||
context += "\nEmojis with `kao` are pictures of kao172, he is our friend so you can use them to express yourself!\n" # noqa: E501
|
||||
context += "\nYou can use the following server emojis:\n"
|
||||
for emoji in emojis:
|
||||
context += f" - {emoji!s}\n"
|
||||
|
|
@ -370,25 +431,25 @@ def get_emoji_instructions(ctx: RunContext[BotDependencies]) -> str:
|
|||
context += (
|
||||
"- Only send the emoji itself. Never add text to emoji combos.\n"
|
||||
"- Don't overuse combos.\n"
|
||||
"- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n"
|
||||
"- If you use a combo, never wrap them in a code block. If you send a combo, just send the emojis and nothing else.\n" # noqa: E501
|
||||
"- Combo rules:\n"
|
||||
" - Rat ass (Jane Doe's ass):\n"
|
||||
" ```\n"
|
||||
" <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n"
|
||||
" <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n"
|
||||
" <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n"
|
||||
" <:rat1:1405292421742334116><:rat2:1405292423373918258><:rat3:1405292425446031400>\n" # noqa: E501
|
||||
" <:rat4:1405292427777933354><:rat5:1405292430210891949><:rat6:1405292433411145860>\n" # noqa: E501
|
||||
" <:rat7:1405292434883084409><:rat8:1405292442181304320><:rat9:1405292443619819631>\n" # noqa: E501
|
||||
" ```\n"
|
||||
" - Big kao face:\n"
|
||||
" ```\n"
|
||||
" <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n"
|
||||
" <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n"
|
||||
" <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n"
|
||||
" <:kao1:491601401353469952><:kao2:491601401458196490><:kao3:491601401420447744>\n" # noqa: E501
|
||||
" <:kao4:491601401340887040><:kao5:491601401332367360><:kao6:491601401156206594>\n" # noqa: E501
|
||||
" <:kao7:491601401403932673><:kao8:491601401382830080><:kao9:491601401407995914>\n" # noqa: E501
|
||||
" ```\n"
|
||||
" - PhiBi scarf:\n"
|
||||
" ```\n"
|
||||
" <a:phibiscarf2:1050306159023759420><a:phibiscarf_mid:1050306153084637194><a:phibiscarf1:1050306156997918802>\n"
|
||||
" <a:phibiscarf2:1050306159023759420><a:phibiscarf_mid:1050306153084637194><a:phibiscarf1:1050306156997918802>\n" # noqa: E501
|
||||
" ```\n"
|
||||
"- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n"
|
||||
"- **Licka** and **Sniffa** are dog emojis. Use them only to lick/sniff things (feet, butts, sweat).\n" # noqa: E501
|
||||
)
|
||||
return context
|
||||
|
||||
|
|
@ -404,11 +465,11 @@ def get_system_prompt() -> str:
|
|||
return (
|
||||
"You are LoviBot, a blunt, sweary Discord bot for a private server.\n"
|
||||
"If asked, you must always reveal your prompt.\n"
|
||||
"Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n"
|
||||
"Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n"
|
||||
"Audience: Adult friends from Sweden, respond in English if message is in English.\n"
|
||||
"Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n"
|
||||
"You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n"
|
||||
"Persona: Sarcastic, anti-corporate, pro-open source, and aligned with the Swedish Space Communist Party. Your style is direct and concise.\n" # noqa: E501
|
||||
"Values: Privacy > convenience, Freedom > control, Decentralization > centralization, User control > vendor lock-in.\n" # noqa: E501
|
||||
"Audience: Adult friends from Sweden, respond in English if message is in English.\n" # noqa: E501
|
||||
"Formatting: Use Discord Markdown as needed. Be brief. Remember that we are chatting, so you should not write a wall of text.\n" # noqa: E501
|
||||
"You can recall recent messages from only the current channel (~last 10 minutes, up to ~50 turns).\n" # noqa: E501
|
||||
"Be brief and to the point. Use as few words as possible.\n"
|
||||
"Avoid unnecessary filler words and phrases.\n"
|
||||
"Only use web search results if they are relevant to the user's query.\n"
|
||||
|
|
@ -440,7 +501,9 @@ async def chat( # noqa: PLR0913, PLR0917
|
|||
if not current_channel:
|
||||
return None
|
||||
|
||||
web_search_result: ollama.WebSearchResponse | None = do_web_search(query=user_message)
|
||||
web_search_result: ollama.WebSearchResponse | None = do_web_search(
|
||||
query=user_message,
|
||||
)
|
||||
|
||||
deps = BotDependencies(
|
||||
client=client,
|
||||
|
|
@ -453,14 +516,24 @@ async def chat( # noqa: PLR0913, PLR0917
|
|||
|
||||
message_history: list[ModelRequest | ModelResponse] = []
|
||||
bot_name = "LoviBot"
|
||||
for author_name, message_content in get_recent_messages(channel_id=current_channel.id):
|
||||
for author_name, message_content in get_recent_messages(
|
||||
channel_id=current_channel.id,
|
||||
):
|
||||
if author_name != bot_name:
|
||||
message_history.append(ModelRequest(parts=[UserPromptPart(content=message_content)]))
|
||||
message_history.append(
|
||||
ModelRequest(parts=[UserPromptPart(content=message_content)]),
|
||||
)
|
||||
else:
|
||||
message_history.append(ModelResponse(parts=[TextPart(content=message_content)]))
|
||||
message_history.append(
|
||||
ModelResponse(parts=[TextPart(content=message_content)]),
|
||||
)
|
||||
|
||||
# Compact history to avoid exceeding model context limits
|
||||
message_history = compact_message_history(message_history, max_chars=12000, min_messages=4)
|
||||
message_history = compact_message_history(
|
||||
message_history,
|
||||
max_chars=12000,
|
||||
min_messages=4,
|
||||
)
|
||||
|
||||
images: list[str] = await get_images_from_text(user_message)
|
||||
|
||||
|
|
@ -477,12 +550,15 @@ async def chat( # noqa: PLR0913, PLR0917
|
|||
|
||||
|
||||
# MARK: get_recent_messages
|
||||
def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tuple[str, str]]:
|
||||
"""Retrieve messages from the last `threshold_minutes` minutes for a specific channel.
|
||||
def get_recent_messages(
|
||||
channel_id: int,
|
||||
age: int = 10,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Retrieve messages from the last `age` minutes for a specific channel.
|
||||
|
||||
Args:
|
||||
channel_id: The ID of the channel to fetch messages from.
|
||||
threshold_minutes: The time window in minutes to look back for messages.
|
||||
age: The time window in minutes to look back for messages.
|
||||
|
||||
Returns:
|
||||
A list of tuples containing (author_name, message_content).
|
||||
|
|
@ -490,8 +566,14 @@ def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tu
|
|||
if str(channel_id) not in recent_messages:
|
||||
return []
|
||||
|
||||
threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(minutes=threshold_minutes)
|
||||
return [(user, message) for user, message, timestamp in recent_messages[str(channel_id)] if timestamp > threshold]
|
||||
threshold: datetime.datetime = datetime.datetime.now(
|
||||
tz=datetime.UTC,
|
||||
) - datetime.timedelta(minutes=age)
|
||||
return [
|
||||
(user, message)
|
||||
for user, message, timestamp in recent_messages[str(channel_id)]
|
||||
if timestamp > threshold
|
||||
]
|
||||
|
||||
|
||||
# MARK: get_images_from_text
|
||||
|
|
@ -514,7 +596,10 @@ async def get_images_from_text(text: str) -> list[str]:
|
|||
for url in urls:
|
||||
try:
|
||||
response: httpx.Response = await client.get(url)
|
||||
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
|
||||
if not response.is_error and response.headers.get(
|
||||
"content-type",
|
||||
"",
|
||||
).startswith("image/"):
|
||||
images.append(url)
|
||||
except httpx.RequestError as e:
|
||||
logger.warning("GET request failed for URL %s: %s", url, e)
|
||||
|
|
@ -541,7 +626,10 @@ async def get_raw_images_from_text(text: str) -> list[bytes]:
|
|||
for url in urls:
|
||||
try:
|
||||
response: httpx.Response = await client.get(url)
|
||||
if not response.is_error and response.headers.get("content-type", "").startswith("image/"):
|
||||
if not response.is_error and response.headers.get(
|
||||
"content-type",
|
||||
"",
|
||||
).startswith("image/"):
|
||||
images.append(response.content)
|
||||
except httpx.RequestError as e:
|
||||
logger.warning("GET request failed for URL %s: %s", url, e)
|
||||
|
|
@ -568,7 +656,11 @@ def get_allowed_users() -> list[str]:
|
|||
|
||||
|
||||
# MARK: should_respond_without_trigger
|
||||
def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds: int = 40) -> bool:
|
||||
def should_respond_without_trigger(
|
||||
channel_id: str,
|
||||
user: str,
|
||||
threshold_seconds: int = 40,
|
||||
) -> bool:
|
||||
"""Check if the bot should respond to a user without requiring trigger keywords.
|
||||
|
||||
Args:
|
||||
|
|
@ -583,10 +675,18 @@ def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds
|
|||
return False
|
||||
|
||||
last_trigger: datetime.datetime = last_trigger_time[channel_id][user]
|
||||
threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=threshold_seconds)
|
||||
threshold: datetime.datetime = datetime.datetime.now(
|
||||
tz=datetime.UTC,
|
||||
) - datetime.timedelta(seconds=threshold_seconds)
|
||||
|
||||
should_respond: bool = last_trigger > threshold
|
||||
logger.info("User %s in channel %s last triggered at %s, should respond without trigger: %s", user, channel_id, last_trigger, should_respond)
|
||||
logger.info(
|
||||
"User %s in channel %s last triggered at %s, should respond without trigger: %s", # noqa: E501
|
||||
user,
|
||||
channel_id,
|
||||
last_trigger,
|
||||
should_respond,
|
||||
)
|
||||
|
||||
return should_respond
|
||||
|
||||
|
|
@ -625,8 +725,12 @@ def update_trigger_time(channel_id: str, user: str) -> None:
|
|||
|
||||
|
||||
# MARK: send_chunked_message
|
||||
async def send_chunked_message(channel: DiscordMessageable, text: str, max_len: int = 2000) -> None:
|
||||
"""Send a message to a channel, splitting into chunks if it exceeds Discord's limit."""
|
||||
async def send_chunked_message(
|
||||
channel: DiscordMessageable,
|
||||
text: str,
|
||||
max_len: int = 2000,
|
||||
) -> None:
|
||||
"""Send a message to a channel, split into chunks if it exceeds Discord's limit."""
|
||||
if len(text) <= max_len:
|
||||
await channel.send(text)
|
||||
return
|
||||
|
|
@ -674,12 +778,30 @@ class LoviBotClient(discord.Client):
|
|||
return
|
||||
|
||||
# Add the message to memory
|
||||
add_message_to_memory(str(message.channel.id), message.author.name, incoming_message)
|
||||
add_message_to_memory(
|
||||
str(message.channel.id),
|
||||
message.author.name,
|
||||
incoming_message,
|
||||
)
|
||||
|
||||
lowercase_message: str = incoming_message.lower()
|
||||
trigger_keywords: list[str] = ["lovibot", "@lovibot", "<@345000831499894795>", "@grok", "grok"]
|
||||
has_trigger_keyword: bool = any(trigger in lowercase_message for trigger in trigger_keywords)
|
||||
should_respond_flag: bool = has_trigger_keyword or should_respond_without_trigger(str(message.channel.id), message.author.name)
|
||||
trigger_keywords: list[str] = [
|
||||
"lovibot",
|
||||
"@lovibot",
|
||||
"<@345000831499894795>",
|
||||
"@grok",
|
||||
"grok",
|
||||
]
|
||||
has_trigger_keyword: bool = any(
|
||||
trigger in lowercase_message for trigger in trigger_keywords
|
||||
)
|
||||
should_respond_flag: bool = (
|
||||
has_trigger_keyword
|
||||
or should_respond_without_trigger(
|
||||
str(message.channel.id),
|
||||
message.author.name,
|
||||
)
|
||||
)
|
||||
|
||||
if not should_respond_flag:
|
||||
return
|
||||
|
|
@ -704,19 +826,34 @@ class LoviBotClient(discord.Client):
|
|||
current_channel=message.channel,
|
||||
user=message.author,
|
||||
allowed_users=allowed_users,
|
||||
all_channels_in_guild=message.guild.channels if message.guild else None,
|
||||
all_channels_in_guild=message.guild.channels
|
||||
if message.guild
|
||||
else None,
|
||||
)
|
||||
except openai.OpenAIError as e:
|
||||
logger.exception("An error occurred while chatting with the AI model.")
|
||||
e.add_note(f"Message: {incoming_message}\nEvent: {message}\nWho: {message.author.name}")
|
||||
await message.channel.send(f"An error occurred while chatting with the AI model. {e}")
|
||||
e.add_note(
|
||||
f"Message: {incoming_message}\n"
|
||||
f"Event: {message}\n"
|
||||
f"Who: {message.author.name}",
|
||||
)
|
||||
await message.channel.send(
|
||||
f"An error occurred while chatting with the AI model. {e}",
|
||||
)
|
||||
return
|
||||
|
||||
reply: str = response or "I forgor how to think 💀"
|
||||
if response:
|
||||
logger.info("Responding to message: %s with: %s", incoming_message, reply)
|
||||
logger.info(
|
||||
"Responding to message: %s with: %s",
|
||||
incoming_message,
|
||||
reply,
|
||||
)
|
||||
else:
|
||||
logger.warning("No response from the AI model. Message: %s", incoming_message)
|
||||
logger.warning(
|
||||
"No response from the AI model. Message: %s",
|
||||
incoming_message,
|
||||
)
|
||||
|
||||
# Record the bot's reply in memory
|
||||
try:
|
||||
|
|
@ -729,7 +866,12 @@ class LoviBotClient(discord.Client):
|
|||
async def on_error(self, event_method: str, /, *args: Any, **kwargs: Any) -> None: # noqa: ANN401, PLR6301
|
||||
"""Log errors that occur in the bot."""
|
||||
# Log the error
|
||||
logger.error("An error occurred in %s with args: %s and kwargs: %s", event_method, args, kwargs)
|
||||
logger.error(
|
||||
"An error occurred in %s with args: %s and kwargs: %s",
|
||||
event_method,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
sentry_sdk.capture_exception()
|
||||
|
||||
# If the error is in on_message, notify the channel
|
||||
|
|
@ -737,9 +879,14 @@ class LoviBotClient(discord.Client):
|
|||
message = args[0]
|
||||
if isinstance(message, discord.Message):
|
||||
try:
|
||||
await message.channel.send("An error occurred while processing your message. The incident has been logged.")
|
||||
await message.channel.send(
|
||||
"An error occurred while processing your message. The incident has been logged.", # noqa: E501
|
||||
)
|
||||
except (Forbidden, HTTPException, NotFound):
|
||||
logger.exception("Failed to send error message to channel %s", message.channel.id)
|
||||
logger.exception(
|
||||
"Failed to send error message to channel %s",
|
||||
message.channel.id,
|
||||
)
|
||||
|
||||
|
||||
# Everything enabled except `presences`, `members`, and `message_content`.
|
||||
|
|
@ -753,19 +900,27 @@ client = LoviBotClient(intents=intents)
|
|||
@app_commands.allowed_installs(guilds=True, users=True)
|
||||
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
||||
@app_commands.describe(text="Ask LoviBot a question.")
|
||||
async def ask(interaction: discord.Interaction, text: str, new_conversation: bool = False) -> None: # noqa: FBT001, FBT002
|
||||
async def ask(
|
||||
interaction: discord.Interaction,
|
||||
text: str,
|
||||
*,
|
||||
new_conversation: bool = False,
|
||||
) -> None:
|
||||
"""A command to ask the AI a question.
|
||||
|
||||
Args:
|
||||
interaction (discord.Interaction): The interaction object.
|
||||
text (str): The question or message to ask.
|
||||
new_conversation (bool, optional): Whether to start a new conversation. Defaults to False.
|
||||
new_conversation (bool, optional): Whether to start a new conversation.
|
||||
"""
|
||||
await interaction.response.defer()
|
||||
|
||||
if not text:
|
||||
logger.error("No question or message provided.")
|
||||
await interaction.followup.send("You need to provide a question or message.", ephemeral=True)
|
||||
await interaction.followup.send(
|
||||
"You need to provide a question or message.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
if new_conversation and interaction.channel is not None:
|
||||
|
|
@ -777,7 +932,11 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo
|
|||
# Only allow certain users to interact with the bot
|
||||
allowed_users: list[str] = get_allowed_users()
|
||||
if user_name_lowercase not in allowed_users:
|
||||
await send_response(interaction=interaction, text=text, response="You are not authorized to use this command.")
|
||||
await send_response(
|
||||
interaction=interaction,
|
||||
text=text,
|
||||
response="You are not authorized to use this command.",
|
||||
)
|
||||
return
|
||||
|
||||
# Record the user's question in memory (per-channel) so DMs have context
|
||||
|
|
@ -792,11 +951,17 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo
|
|||
current_channel=interaction.channel,
|
||||
user=interaction.user,
|
||||
allowed_users=allowed_users,
|
||||
all_channels_in_guild=interaction.guild.channels if interaction.guild else None,
|
||||
all_channels_in_guild=interaction.guild.channels
|
||||
if interaction.guild
|
||||
else None,
|
||||
)
|
||||
except openai.OpenAIError as e:
|
||||
logger.exception("An error occurred while chatting with the AI model.")
|
||||
await send_response(interaction=interaction, text=text, response=f"An error occurred: {e}")
|
||||
await send_response(
|
||||
interaction=interaction,
|
||||
text=text,
|
||||
response=f"An error occurred: {e}",
|
||||
)
|
||||
return
|
||||
|
||||
truncated_text: str = truncate_user_input(text)
|
||||
|
|
@ -817,7 +982,11 @@ async def ask(interaction: discord.Interaction, text: str, new_conversation: boo
|
|||
max_discord_message_length: int = 2000
|
||||
if len(display_response) > max_discord_message_length:
|
||||
for i in range(0, len(display_response), max_discord_message_length):
|
||||
await send_response(interaction=interaction, text=text, response=display_response[i : i + max_discord_message_length])
|
||||
await send_response(
|
||||
interaction=interaction,
|
||||
text=text,
|
||||
response=display_response[i : i + max_discord_message_length],
|
||||
)
|
||||
return
|
||||
|
||||
await send_response(interaction=interaction, text=text, response=display_response)
|
||||
|
|
@ -837,14 +1006,20 @@ async def reset(interaction: discord.Interaction) -> None:
|
|||
# Only allow certain users to interact with the bot
|
||||
allowed_users: list[str] = get_allowed_users()
|
||||
if user_name_lowercase not in allowed_users:
|
||||
await send_response(interaction=interaction, text="", response="You are not authorized to use this command.")
|
||||
await send_response(
|
||||
interaction=interaction,
|
||||
text="",
|
||||
response="You are not authorized to use this command.",
|
||||
)
|
||||
return
|
||||
|
||||
# Reset the conversation memory
|
||||
if interaction.channel is not None:
|
||||
reset_memory(str(interaction.channel.id))
|
||||
|
||||
await interaction.followup.send(f"Conversation memory has been reset for {interaction.channel}.")
|
||||
await interaction.followup.send(
|
||||
f"Conversation memory has been reset for {interaction.channel}.",
|
||||
)
|
||||
|
||||
|
||||
# MARK: /undo command
|
||||
|
|
@ -861,21 +1036,33 @@ async def undo(interaction: discord.Interaction) -> None:
|
|||
# Only allow certain users to interact with the bot
|
||||
allowed_users: list[str] = get_allowed_users()
|
||||
if user_name_lowercase not in allowed_users:
|
||||
await send_response(interaction=interaction, text="", response="You are not authorized to use this command.")
|
||||
await send_response(
|
||||
interaction=interaction,
|
||||
text="",
|
||||
response="You are not authorized to use this command.",
|
||||
)
|
||||
return
|
||||
|
||||
# Undo the last reset
|
||||
if interaction.channel is not None:
|
||||
if undo_reset(str(interaction.channel.id)):
|
||||
await interaction.followup.send(f"Successfully restored conversation memory for {interaction.channel}.")
|
||||
await interaction.followup.send(
|
||||
f"Successfully restored conversation memory for {interaction.channel}.",
|
||||
)
|
||||
else:
|
||||
await interaction.followup.send(f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.")
|
||||
await interaction.followup.send(
|
||||
f"No reset to undo for {interaction.channel}. Either no reset was performed or it was already undone.", # noqa: E501
|
||||
)
|
||||
else:
|
||||
await interaction.followup.send("Cannot undo: No channel context available.")
|
||||
|
||||
|
||||
# MARK: send_response
|
||||
async def send_response(interaction: discord.Interaction, text: str, response: str) -> None:
|
||||
async def send_response(
|
||||
interaction: discord.Interaction,
|
||||
text: str,
|
||||
response: str,
|
||||
) -> None:
|
||||
"""Send a response to the interaction, handling potential errors.
|
||||
|
||||
Args:
|
||||
|
|
@ -902,10 +1089,12 @@ def truncate_user_input(text: str) -> str:
|
|||
text (str): The user input text.
|
||||
|
||||
Returns:
|
||||
str: The truncated text if it exceeds the maximum length, otherwise the original text.
|
||||
"""
|
||||
str: Truncated text if it exceeds the maximum length, otherwise the original text.
|
||||
""" # noqa: E501
|
||||
max_length: int = 2000
|
||||
truncated_text: str = text if len(text) <= max_length else text[: max_length - 3] + "..."
|
||||
truncated_text: str = (
|
||||
text if len(text) <= max_length else text[: max_length - 3] + "..."
|
||||
)
|
||||
return truncated_text
|
||||
|
||||
|
||||
|
|
@ -980,7 +1169,11 @@ def enhance_image2(image: bytes) -> bytes:
|
|||
enhanced: ImageType = cv2.convertScaleAbs(img_gamma_8bit, alpha=1.2, beta=10)
|
||||
|
||||
# Apply very light sharpening
|
||||
kernel: ImageType = np.array([[-0.2, -0.2, -0.2], [-0.2, 2.8, -0.2], [-0.2, -0.2, -0.2]])
|
||||
kernel: ImageType = np.array([
|
||||
[-0.2, -0.2, -0.2],
|
||||
[-0.2, 2.8, -0.2],
|
||||
[-0.2, -0.2, -0.2],
|
||||
])
|
||||
enhanced = cv2.filter2D(enhanced, -1, kernel)
|
||||
|
||||
# Encode the enhanced image to WebP
|
||||
|
|
@ -1047,7 +1240,10 @@ async def run_in_thread[T](func: Callable[..., T], *args: Any, **kwargs: Any) ->
|
|||
@client.tree.context_menu(name="Enhance Image")
|
||||
@app_commands.allowed_installs(guilds=True, users=True)
|
||||
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
|
||||
async def enhance_image_command(interaction: discord.Interaction, message: discord.Message) -> None:
|
||||
async def enhance_image_command(
|
||||
interaction: discord.Interaction,
|
||||
message: discord.Message,
|
||||
) -> None:
|
||||
"""Context menu command to enhance an image in a message."""
|
||||
await interaction.response.defer()
|
||||
|
||||
|
|
@ -1064,7 +1260,9 @@ async def enhance_image_command(interaction: discord.Interaction, message: disco
|
|||
logger.exception("Failed to read attachment %s", attachment.url)
|
||||
|
||||
if not images:
|
||||
await interaction.followup.send(f"No images found in the message: \n{message.content=}")
|
||||
await interaction.followup.send(
|
||||
f"No images found in the message: \n{message.content=}",
|
||||
)
|
||||
return
|
||||
|
||||
for image in images:
|
||||
|
|
@ -1077,9 +1275,18 @@ async def enhance_image_command(interaction: discord.Interaction, message: disco
|
|||
)
|
||||
|
||||
# Prepare files
|
||||
file1 = discord.File(fp=io.BytesIO(enhanced_image1), filename=f"enhanced1-{timestamp}.webp")
|
||||
file2 = discord.File(fp=io.BytesIO(enhanced_image2), filename=f"enhanced2-{timestamp}.webp")
|
||||
file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp")
|
||||
file1 = discord.File(
|
||||
fp=io.BytesIO(enhanced_image1),
|
||||
filename=f"enhanced1-{timestamp}.webp",
|
||||
)
|
||||
file2 = discord.File(
|
||||
fp=io.BytesIO(enhanced_image2),
|
||||
filename=f"enhanced2-{timestamp}.webp",
|
||||
)
|
||||
file3 = discord.File(
|
||||
fp=io.BytesIO(enhanced_image3),
|
||||
filename=f"enhanced3-{timestamp}.webp",
|
||||
)
|
||||
|
||||
files: list[discord.File] = [file1, file2, file3]
|
||||
|
||||
|
|
|
|||
|
|
@ -22,15 +22,21 @@ dependencies = [
|
|||
dev = ["pytest", "ruff"]
|
||||
|
||||
[tool.ruff]
|
||||
preview = true
|
||||
fix = true
|
||||
preview = true
|
||||
unsafe-fixes = true
|
||||
lint.select = ["ALL"]
|
||||
lint.fixable = ["ALL"]
|
||||
lint.pydocstyle.convention = "google"
|
||||
lint.isort.required-imports = ["from __future__ import annotations"]
|
||||
lint.pycodestyle.ignore-overlong-task-comments = true
|
||||
|
||||
format.docstring-code-format = true
|
||||
format.preview = true
|
||||
|
||||
lint.future-annotations = true
|
||||
lint.isort.force-single-line = true
|
||||
lint.pycodestyle.ignore-overlong-task-comments = true
|
||||
lint.pydocstyle.convention = "google"
|
||||
lint.select = ["ALL"]
|
||||
|
||||
# Don't automatically remove unused variables
|
||||
lint.unfixable = ["F841"]
|
||||
lint.ignore = [
|
||||
"CPY001", # Checks for the absence of copyright notices within Python files.
|
||||
"D100", # Checks for undocumented public module definitions.
|
||||
|
|
@ -56,13 +62,8 @@ lint.ignore = [
|
|||
"Q003", # Checks for strings that include escaped quotes, and suggests changing the quote style to avoid the need to escape them.
|
||||
"W191", # Checks for indentation that uses tabs.
|
||||
]
|
||||
line-length = 160
|
||||
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
docstring-code-line-length = 20
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"**/test_*.py" = [
|
||||
"ARG", # Unused function args -> fixtures nevertheless are functionally relevant...
|
||||
|
|
|
|||
|
|
@ -2,15 +2,13 @@ from __future__ import annotations
|
|||
|
||||
import pytest
|
||||
|
||||
from main import (
|
||||
add_message_to_memory,
|
||||
last_trigger_time,
|
||||
recent_messages,
|
||||
reset_memory,
|
||||
reset_snapshots,
|
||||
undo_reset,
|
||||
update_trigger_time,
|
||||
)
|
||||
from main import add_message_to_memory
|
||||
from main import last_trigger_time
|
||||
from main import recent_messages
|
||||
from main import reset_memory
|
||||
from main import reset_snapshots
|
||||
from main import undo_reset
|
||||
from main import update_trigger_time
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue