diff --git a/main.py b/main.py index cba205c..5bf6ffa 100644 --- a/main.py +++ b/main.py @@ -15,7 +15,7 @@ import sentry_sdk from discord import app_commands from openai import OpenAI -from misc import add_message_to_memory, chat, get_allowed_users +from misc import add_message_to_memory, chat, get_allowed_users, should_respond_without_trigger, update_trigger_time from settings import Settings sentry_sdk.init( @@ -79,8 +79,17 @@ class LoviBotClient(discord.Client): lowercase_message: str = incoming_message.lower() if incoming_message else "" trigger_keywords: list[str] = ["lovibot", "@lovibot", "<@345000831499894795>", "grok", "@grok"] - if any(trigger in lowercase_message for trigger in trigger_keywords): - logger.info("Received message: %s from: %s", incoming_message, message.author.name) + has_trigger_keyword: bool = any(trigger in lowercase_message for trigger in trigger_keywords) + should_respond: bool = has_trigger_keyword or should_respond_without_trigger(str(message.channel.id), message.author.name) + + if should_respond: + # Update trigger time if they used a trigger keyword + if has_trigger_keyword: + update_trigger_time(str(message.channel.id), message.author.name) + + logger.info( + "Received message: %s from: %s (trigger: %s, recent: %s)", incoming_message, message.author.name, has_trigger_keyword, not has_trigger_keyword + ) async with message.channel.typing(): try: diff --git a/misc.py b/misc.py index b85e2ee..9ab2395 100644 --- a/misc.py +++ b/misc.py @@ -23,6 +23,9 @@ logger: logging.Logger = logging.getLogger(__name__) # A dictionary to store recent messages per channel with a maximum length per channel recent_messages: dict[str, deque[tuple[str, str, datetime.datetime]]] = {} +# A dictionary to track the last time each user triggered the bot in each channel +last_trigger_time: dict[str, dict[str, datetime.datetime]] = {} + def get_allowed_users() -> list[str]: """Get the list of allowed users to interact with the bot. @@ -74,6 +77,43 @@ def get_recent_messages(channel_id: int, threshold_minutes: int = 10) -> list[tu return [(user, message) for user, message, timestamp in recent_messages[str(channel_id)] if timestamp > threshold] +def update_trigger_time(channel_id: str, user: str) -> None: + """Update the last trigger time for a user in a specific channel. + + Args: + channel_id: The ID of the channel. + user: The user who triggered the bot. + """ + if channel_id not in last_trigger_time: + last_trigger_time[channel_id] = {} + + last_trigger_time[channel_id][user] = datetime.datetime.now(tz=datetime.UTC) + logger.info("Updated trigger time for user %s in channel %s", user, channel_id) + + +def should_respond_without_trigger(channel_id: str, user: str, threshold_seconds: int = 40) -> bool: + """Check if the bot should respond to a user without requiring trigger keywords. + + Args: + channel_id: The ID of the channel. + user: The user who sent the message. + threshold_seconds: The number of seconds to consider as "recent trigger". + + Returns: + True if the bot should respond without trigger keywords, False otherwise. + """ + if channel_id not in last_trigger_time: + return False + + last_trigger: datetime.datetime = last_trigger_time[channel_id][user] + threshold: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=threshold_seconds) + + should_respond: bool = last_trigger > threshold + logger.info("User %s in channel %s last triggered at %s, should respond without trigger: %s", user, channel_id, last_trigger, should_respond) + + return should_respond + + def extra_context(current_channel: MessageableChannel | InteractionChannel | None, user: User | Member) -> str: """Add extra context to the chat prompt.