Add helper functions and modal for modifying APScheduler jobs
This commit is contained in:
160
discord_reminder_bot/helpers.py
Normal file
160
discord_reminder_bot/helpers.py
Normal file
@ -0,0 +1,160 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import dateparser
|
||||
from apscheduler.job import Job
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.date import DateTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from loguru import logger
|
||||
|
||||
from interactions.api.models.misc import Snowflake
|
||||
|
||||
|
||||
def calculate(job: Job) -> str:
|
||||
"""Calculate the time left for a job.
|
||||
|
||||
Args:
|
||||
job: The job to calculate the time for.
|
||||
|
||||
Returns:
|
||||
str: The time left for the job or "Paused" if the job is paused or has no next run time.
|
||||
"""
|
||||
trigger_time = None
|
||||
if isinstance(job.trigger, DateTrigger | IntervalTrigger):
|
||||
trigger_time = job.next_run_time or None
|
||||
|
||||
elif isinstance(job.trigger, CronTrigger):
|
||||
if not job.next_run_time:
|
||||
logger.debug(f"No next run time found for '{job.id}', probably paused? {job.__getstate__()}")
|
||||
return "Paused"
|
||||
|
||||
trigger_time = job.trigger.get_next_fire_time(None, datetime.datetime.now(tz=job._scheduler.timezone)) # noqa: SLF001
|
||||
|
||||
logger.debug(f"{type(job.trigger)=}, {trigger_time=}")
|
||||
|
||||
if not trigger_time:
|
||||
logger.debug("No trigger time found")
|
||||
return "Paused"
|
||||
|
||||
return f"<t:{int(trigger_time.timestamp())}:R>"
|
||||
|
||||
|
||||
def get_human_readable_time(job: Job) -> str:
|
||||
"""Get the human-readable time for a job.
|
||||
|
||||
Args:
|
||||
job: The job to get the time for.
|
||||
|
||||
Returns:
|
||||
str: The human-readable time.
|
||||
"""
|
||||
trigger_time = None
|
||||
if isinstance(job.trigger, DateTrigger | IntervalTrigger):
|
||||
trigger_time = job.next_run_time or None
|
||||
|
||||
elif isinstance(job.trigger, CronTrigger):
|
||||
if not job.next_run_time:
|
||||
logger.debug(f"No next run time found for '{job.id}', probably paused? {job.__getstate__()}")
|
||||
return "Paused"
|
||||
|
||||
trigger_time = job.trigger.get_next_fire_time(None, datetime.datetime.now(tz=job._scheduler.timezone)) # noqa: SLF001
|
||||
|
||||
if not trigger_time:
|
||||
logger.debug("No trigger time found")
|
||||
return "Paused"
|
||||
|
||||
return trigger_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def parse_time(date_to_parse: str | None, timezone: str | None = os.getenv("TIMEZONE")) -> datetime.datetime | None:
|
||||
"""Parse a date string into a datetime object.
|
||||
|
||||
Args:
|
||||
date_to_parse(str): The date string to parse.
|
||||
timezone(str, optional): The timezone to use. Defaults to the TIMEZONE environment variable.
|
||||
|
||||
Returns:
|
||||
datetime.datetime: The parsed datetime object.
|
||||
"""
|
||||
if not date_to_parse:
|
||||
logger.error("No date provided to parse.")
|
||||
return None
|
||||
|
||||
if not timezone:
|
||||
logger.error("No timezone provided to parse date.")
|
||||
return None
|
||||
|
||||
logger.info(f"Parsing date: '{date_to_parse}' with timezone: '{timezone}'")
|
||||
|
||||
try:
|
||||
parsed_date: datetime.datetime | None = dateparser.parse(
|
||||
date_string=date_to_parse,
|
||||
settings={
|
||||
"PREFER_DATES_FROM": "future",
|
||||
"TIMEZONE": f"{timezone}",
|
||||
"RETURN_AS_TIMEZONE_AWARE": True,
|
||||
"RELATIVE_BASE": datetime.datetime.now(tz=ZoneInfo(str(timezone))),
|
||||
},
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Failed to parse date: '{date_to_parse}' with timezone: '{timezone}'. Error: {e}")
|
||||
return None
|
||||
|
||||
logger.debug(f"Parsed date: {parsed_date} from '{date_to_parse}'")
|
||||
|
||||
return parsed_date
|
||||
|
||||
|
||||
def generate_state(state: dict[str, Any], job: Job) -> str:
|
||||
"""Format the __getstate__ dictionary for Discord markdown.
|
||||
|
||||
Args:
|
||||
state (dict): The __getstate__ dictionary.
|
||||
job (Job): The APScheduler job.
|
||||
|
||||
Returns:
|
||||
str: The formatted string.
|
||||
"""
|
||||
if not state:
|
||||
logger.error(f"No state found for {job.id}")
|
||||
return "No state found.\n"
|
||||
|
||||
for key, value in state.items():
|
||||
if isinstance(value, IntervalTrigger):
|
||||
state[key] = "IntervalTrigger"
|
||||
elif isinstance(value, DateTrigger):
|
||||
state[key] = "DateTrigger"
|
||||
elif isinstance(value, Job):
|
||||
state[key] = "Job"
|
||||
elif isinstance(value, Snowflake):
|
||||
state[key] = str(value)
|
||||
|
||||
try:
|
||||
msg: str = json.dumps(state, indent=4, default=str)
|
||||
except TypeError as e:
|
||||
e.add_note("This is likely due to a non-serializable object in the state. Please check the state for any non-serializable objects.")
|
||||
e.add_note(f"{state=}")
|
||||
logger.error(f"Failed to serialize state: {e}")
|
||||
return "Failed to serialize state."
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
def generate_markdown_state(state: dict[str, Any], job: Job) -> str:
|
||||
"""Format the __getstate__ dictionary for Discord markdown.
|
||||
|
||||
Args:
|
||||
state (dict): The __getstate__ dictionary.
|
||||
job (Job): The APScheduler job.
|
||||
|
||||
Returns:
|
||||
str: The formatted string.
|
||||
"""
|
||||
msg: str = generate_state(state=state, job=job)
|
||||
return "```json\n" + msg + "\n```"
|
@ -3,237 +3,36 @@ from __future__ import annotations
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
import dateparser
|
||||
import discord
|
||||
import pytz
|
||||
import sentry_sdk
|
||||
from apscheduler import events
|
||||
from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_MISSED, JobExecutionEvent
|
||||
from apscheduler.job import Job
|
||||
from apscheduler.jobstores.base import JobLookupError
|
||||
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.date import DateTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from discord.abc import PrivateChannel
|
||||
from discord.utils import escape_markdown
|
||||
from discord_webhook import DiscordWebhook
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from interactions.api.models.misc import Snowflake
|
||||
from discord_reminder_bot.helpers import calculate, generate_markdown_state, generate_state, get_human_readable_time, parse_time
|
||||
from discord_reminder_bot.modals import ReminderModifyModal
|
||||
from discord_reminder_bot.settings import scheduler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from types import CoroutineType
|
||||
|
||||
from apscheduler.job import Job
|
||||
from discord.guild import GuildChannel
|
||||
from discord.interactions import InteractionChannel
|
||||
from requests import Response
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
default_sentry_dsn: str = "https://c4c61a52838be9b5042144420fba5aaa@o4505228040339456.ingest.us.sentry.io/4508707268984832"
|
||||
sentry_sdk.init(
|
||||
dsn=os.getenv("SENTRY_DSN", default_sentry_dsn),
|
||||
environment=platform.node() or "Unknown",
|
||||
traces_sample_rate=1.0,
|
||||
send_default_pii=True,
|
||||
)
|
||||
|
||||
|
||||
def generate_state(state: dict[str, Any], job: Job) -> str:
|
||||
"""Format the __getstate__ dictionary for Discord markdown.
|
||||
|
||||
Args:
|
||||
state (dict): The __getstate__ dictionary.
|
||||
job (Job): The APScheduler job.
|
||||
|
||||
Returns:
|
||||
str: The formatted string.
|
||||
"""
|
||||
if not state:
|
||||
logger.error(f"No state found for {job.id}")
|
||||
return "No state found.\n"
|
||||
|
||||
for key, value in state.items():
|
||||
if isinstance(value, IntervalTrigger):
|
||||
state[key] = "IntervalTrigger"
|
||||
elif isinstance(value, DateTrigger):
|
||||
state[key] = "DateTrigger"
|
||||
elif isinstance(value, Job):
|
||||
state[key] = "Job"
|
||||
elif isinstance(value, Snowflake):
|
||||
state[key] = str(value)
|
||||
|
||||
try:
|
||||
msg: str = json.dumps(state, indent=4, default=str)
|
||||
except TypeError as e:
|
||||
e.add_note("This is likely due to a non-serializable object in the state. Please check the state for any non-serializable objects.")
|
||||
e.add_note(f"{state=}")
|
||||
logger.error(f"Failed to serialize state: {e}")
|
||||
return "Failed to serialize state."
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
def generate_markdown_state(state: dict[str, Any], job: Job) -> str:
|
||||
"""Format the __getstate__ dictionary for Discord markdown.
|
||||
|
||||
Args:
|
||||
state (dict): The __getstate__ dictionary.
|
||||
job (Job): The APScheduler job.
|
||||
|
||||
Returns:
|
||||
str: The formatted string.
|
||||
"""
|
||||
msg: str = generate_state(state=state, job=job)
|
||||
return "```json\n" + msg + "\n```"
|
||||
|
||||
|
||||
def parse_time(date_to_parse: str | None, timezone: str | None = os.getenv("TIMEZONE")) -> datetime.datetime | None:
|
||||
"""Parse a date string into a datetime object.
|
||||
|
||||
Args:
|
||||
date_to_parse(str): The date string to parse.
|
||||
timezone(str, optional): The timezone to use. Defaults to the TIMEZONE environment variable.
|
||||
|
||||
Returns:
|
||||
datetime.datetime: The parsed datetime object.
|
||||
"""
|
||||
if not date_to_parse:
|
||||
logger.error("No date provided to parse.")
|
||||
return None
|
||||
|
||||
if not timezone:
|
||||
logger.error("No timezone provided to parse date.")
|
||||
return None
|
||||
|
||||
logger.info(f"Parsing date: '{date_to_parse}' with timezone: '{timezone}'")
|
||||
|
||||
try:
|
||||
parsed_date: datetime.datetime | None = dateparser.parse(
|
||||
date_string=date_to_parse,
|
||||
settings={
|
||||
"PREFER_DATES_FROM": "future",
|
||||
"TIMEZONE": f"{timezone}",
|
||||
"RETURN_AS_TIMEZONE_AWARE": True,
|
||||
"RELATIVE_BASE": datetime.datetime.now(tz=ZoneInfo(str(timezone))),
|
||||
},
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Failed to parse date: '{date_to_parse}' with timezone: '{timezone}'. Error: {e}")
|
||||
return None
|
||||
|
||||
logger.debug(f"Parsed date: {parsed_date} from '{date_to_parse}'")
|
||||
|
||||
return parsed_date
|
||||
|
||||
|
||||
def calculate(job: Job) -> str:
|
||||
"""Calculate the time left for a job.
|
||||
|
||||
Args:
|
||||
job: The job to calculate the time for.
|
||||
|
||||
Returns:
|
||||
str: The time left for the job or "Paused" if the job is paused or has no next run time.
|
||||
"""
|
||||
trigger_time = None
|
||||
if isinstance(job.trigger, DateTrigger | IntervalTrigger):
|
||||
trigger_time = job.next_run_time or None
|
||||
|
||||
elif isinstance(job.trigger, CronTrigger):
|
||||
if not job.next_run_time:
|
||||
logger.debug(f"No next run time found for '{job.id}', probably paused? {job.__getstate__()}")
|
||||
return "Paused"
|
||||
|
||||
trigger_time = job.trigger.get_next_fire_time(None, datetime.datetime.now(tz=job._scheduler.timezone)) # noqa: SLF001
|
||||
|
||||
logger.debug(f"{type(job.trigger)=}, {trigger_time=}")
|
||||
|
||||
if not trigger_time:
|
||||
logger.debug("No trigger time found")
|
||||
return "Paused"
|
||||
|
||||
return f"<t:{int(trigger_time.timestamp())}:R>"
|
||||
|
||||
|
||||
def get_human_readable_time(job: Job) -> str:
|
||||
"""Get the human-readable time for a job.
|
||||
|
||||
Args:
|
||||
job: The job to get the time for.
|
||||
|
||||
Returns:
|
||||
str: The human-readable time.
|
||||
"""
|
||||
trigger_time = None
|
||||
if isinstance(job.trigger, DateTrigger | IntervalTrigger):
|
||||
trigger_time = job.next_run_time or None
|
||||
|
||||
elif isinstance(job.trigger, CronTrigger):
|
||||
if not job.next_run_time:
|
||||
logger.debug(f"No next run time found for '{job.id}', probably paused? {job.__getstate__()}")
|
||||
return "Paused"
|
||||
|
||||
trigger_time = job.trigger.get_next_fire_time(None, datetime.datetime.now(tz=job._scheduler.timezone)) # noqa: SLF001
|
||||
|
||||
if not trigger_time:
|
||||
logger.debug("No trigger time found")
|
||||
return "Paused"
|
||||
|
||||
return trigger_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def get_scheduler() -> AsyncIOScheduler:
|
||||
"""Return the scheduler instance.
|
||||
|
||||
Uses the SQLITE_LOCATION environment variable for the SQLite database location.
|
||||
|
||||
Raises:
|
||||
ValueError: If the timezone is missing or invalid.
|
||||
|
||||
Returns:
|
||||
AsyncIOScheduler: The scheduler instance.
|
||||
"""
|
||||
config_timezone: str | None = os.getenv("TIMEZONE")
|
||||
if not config_timezone:
|
||||
msg = "Missing timezone. Please set the TIMEZONE environment variable."
|
||||
raise ValueError(msg)
|
||||
|
||||
# Test if the timezone is valid
|
||||
try:
|
||||
ZoneInfo(config_timezone)
|
||||
except (ZoneInfoNotFoundError, ModuleNotFoundError) as e:
|
||||
msg: str = f"Invalid timezone: {config_timezone}. Error: {e}"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
logger.info(f"Using timezone: {config_timezone}. If this is incorrect, please set the TIMEZONE environment variable.")
|
||||
|
||||
sqlite_location: str = os.getenv("SQLITE_LOCATION", default="/jobs.sqlite")
|
||||
logger.info(f"Using SQLite database at: {sqlite_location}")
|
||||
|
||||
jobstores: dict[str, SQLAlchemyJobStore] = {"default": SQLAlchemyJobStore(url=f"sqlite://{sqlite_location}")}
|
||||
job_defaults: dict[str, bool] = {"coalesce": True}
|
||||
timezone = pytz.timezone(config_timezone)
|
||||
return AsyncIOScheduler(jobstores=jobstores, timezone=timezone, job_defaults=job_defaults)
|
||||
|
||||
|
||||
scheduler: AsyncIOScheduler = get_scheduler()
|
||||
|
||||
|
||||
def my_listener(event: JobExecutionEvent) -> None:
|
||||
"""Listener for job events.
|
||||
|
||||
@ -499,127 +298,12 @@ class ReminderListView(discord.ui.View):
|
||||
interaction (discord.Interaction): The interaction that triggered this modification.
|
||||
job_id (str): The ID of the job to modify.
|
||||
"""
|
||||
|
||||
class ReminderModifyModal(discord.ui.Modal, title="Modify reminder"):
|
||||
"""Modal for modifying a APScheduler job."""
|
||||
|
||||
def __init__(self, job: Job) -> None:
|
||||
"""Initialize the modal for modifying a reminder.
|
||||
|
||||
Args:
|
||||
job (Job): The APScheduler job to modify.
|
||||
"""
|
||||
super().__init__(title="Modify Reminder")
|
||||
self.job = job
|
||||
self.job_id = job.id
|
||||
|
||||
self.message_input = discord.ui.TextInput(
|
||||
label="Reminder message",
|
||||
default=job.kwargs.get("message", ""),
|
||||
placeholder="What do you want to be reminded of?",
|
||||
max_length=200,
|
||||
)
|
||||
self.time_input = discord.ui.TextInput(
|
||||
label="New time",
|
||||
placeholder="e.g. tomorrow at 3 PM",
|
||||
required=True,
|
||||
)
|
||||
|
||||
self.add_item(self.message_input)
|
||||
self.add_item(self.time_input)
|
||||
|
||||
async def on_submit(self, interaction: discord.Interaction) -> None:
|
||||
"""Called when the modal is submitted.
|
||||
|
||||
Args:
|
||||
interaction (discord.Interaction): The Discord interaction where this modal was triggered from.
|
||||
"""
|
||||
old_message: str = self.job.kwargs.get("message", "")
|
||||
old_time: datetime.datetime = self.job.next_run_time
|
||||
old_time_countdown: str = calculate(self.job)
|
||||
|
||||
new_message: str = self.message_input.value
|
||||
new_time_str: str = self.time_input.value
|
||||
|
||||
parsed_time: datetime.datetime | None = parse_time(new_time_str)
|
||||
if not parsed_time:
|
||||
await interaction.response.send_message(f"Invalid time format: `{new_time_str}`", ephemeral=True)
|
||||
return
|
||||
|
||||
job_to_modify: Job | None = scheduler.get_job(self.job_id)
|
||||
if not job_to_modify:
|
||||
await interaction.response.send_message(
|
||||
f"Failed to get job.\n{new_message=}\n{new_time_str=}\n{parsed_time=}",
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
# Defer now that we've validated the input to avoid timeout
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
|
||||
job = scheduler.modify_job(self.job_id)
|
||||
msg: str = f"Modified job `{escape_markdown(self.job_id)}`:\n"
|
||||
changes_made = False
|
||||
|
||||
if parsed_time != old_time:
|
||||
logger.info(f"Rescheduling job {self.job_id}")
|
||||
rescheduled_job = scheduler.reschedule_job(self.job_id, trigger="date", run_date=parsed_time)
|
||||
|
||||
logger.debug(f"Rescheduled job {self.job_id} from {old_time} to {parsed_time}")
|
||||
|
||||
msg += (
|
||||
f"Old time: `{old_time.strftime('%Y-%m-%d %H:%M:%S')}` (In {old_time_countdown})\n"
|
||||
f"New time: `{parsed_time.strftime('%Y-%m-%d %H:%M:%S')}`. (In {calculate(rescheduled_job)})\n"
|
||||
)
|
||||
changes_made = True
|
||||
|
||||
if new_message != old_message:
|
||||
old_kwargs = job.kwargs.copy()
|
||||
scheduler.modify_job(
|
||||
self.job_id,
|
||||
kwargs={
|
||||
**old_kwargs,
|
||||
"message": new_message,
|
||||
},
|
||||
)
|
||||
|
||||
logger.debug(f"Modified job {self.job_id} with new message: {new_message}")
|
||||
logger.debug(f"Old kwargs: {old_kwargs}, New kwargs: {job.kwargs}")
|
||||
|
||||
msg += f"Old message: `{escape_markdown(old_message)}`\n"
|
||||
msg += f"New message: `{escape_markdown(new_message)}`.\n"
|
||||
changes_made = True
|
||||
|
||||
if changes_made:
|
||||
await interaction.followup.send(content=msg)
|
||||
else:
|
||||
await interaction.followup.send(content=f"No changes made to job `{escape_markdown(self.job_id)}`.", ephemeral=True)
|
||||
|
||||
async def on_error(self, interaction: discord.Interaction, error: Exception) -> None:
|
||||
"""A callback that is called when on_submit fails with an error.
|
||||
|
||||
Args:
|
||||
interaction (discord.Interaction): The Discord interaction where this modal was triggered from.
|
||||
error (Exception): The raised exception.
|
||||
"""
|
||||
# Check if the interaction has already been responded to
|
||||
if not interaction.response.is_done():
|
||||
await interaction.response.send_message("Oops! Something went wrong.", ephemeral=True)
|
||||
else:
|
||||
try:
|
||||
await interaction.followup.send("Oops! Something went wrong.", ephemeral=True)
|
||||
except discord.HTTPException:
|
||||
logger.warning("Failed to send error message via followup")
|
||||
|
||||
logger.exception(f"Error in ReminderModifyModal: {error}")
|
||||
traceback.print_exception(type(error), error, error.__traceback__)
|
||||
|
||||
job: Job | None = scheduler.get_job(job_id)
|
||||
if not job:
|
||||
await interaction.response.send_message(f"Failed to get job for '{job_id}'", ephemeral=True)
|
||||
return
|
||||
|
||||
await interaction.response.send_modal(ReminderModifyModal(job)) # pyright: ignore[reportArgumentType]
|
||||
await interaction.response.send_modal(ReminderModifyModal(job))
|
||||
|
||||
async def handle_pause_unpause(self, interaction: discord.Interaction, job_id: str) -> None:
|
||||
"""Handle pausing or unpausing a reminder job.
|
||||
|
386
discord_reminder_bot/modals.py
Normal file
386
discord_reminder_bot/modals.py
Normal file
@ -0,0 +1,386 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import discord
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.date import DateTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from discord.utils import escape_markdown
|
||||
from loguru import logger
|
||||
|
||||
from discord_reminder_bot.helpers import calculate, parse_time
|
||||
from discord_reminder_bot.settings import scheduler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
from apscheduler.job import Job
|
||||
|
||||
|
||||
class ReminderModifyModal(discord.ui.Modal, title="Modify reminder"):
|
||||
"""Modal for modifying a APScheduler job."""
|
||||
|
||||
def __init__(self, job: Job) -> None:
|
||||
"""Initialize the modal for modifying a reminder.
|
||||
|
||||
Args:
|
||||
job (Job): The APScheduler job to modify.
|
||||
"""
|
||||
super().__init__(title="Modify Reminder")
|
||||
self.job = job
|
||||
self.job_id = job.id
|
||||
self.trigger_type = self._get_trigger_type(job.trigger)
|
||||
|
||||
self.message_input = discord.ui.TextInput(
|
||||
label="Reminder message",
|
||||
default=job.kwargs.get("message", ""),
|
||||
placeholder="What do you want to be reminded of?",
|
||||
max_length=200,
|
||||
)
|
||||
|
||||
# Different input fields based on trigger type
|
||||
if self.trigger_type == "date":
|
||||
self.time_input = discord.ui.TextInput(
|
||||
label="New time",
|
||||
placeholder="e.g. tomorrow at 3 PM",
|
||||
required=True,
|
||||
)
|
||||
elif self.trigger_type == "interval":
|
||||
interval_text = self._format_interval_from_trigger(job.trigger)
|
||||
self.time_input = discord.ui.TextInput(
|
||||
label="New interval",
|
||||
placeholder="e.g. 1d 2h 30m (days, hours, minutes)",
|
||||
required=True,
|
||||
default=interval_text,
|
||||
)
|
||||
elif self.trigger_type == "cron":
|
||||
cron_text = self._format_cron_from_trigger(job.trigger)
|
||||
self.time_input = discord.ui.TextInput(
|
||||
label="New cron expression",
|
||||
placeholder="e.g. 0 9 * * 1-5 (min hour day month day_of_week)",
|
||||
required=True,
|
||||
default=cron_text,
|
||||
)
|
||||
else:
|
||||
# Fallback to date input for unknown trigger types
|
||||
self.time_input = discord.ui.TextInput(
|
||||
label="New time",
|
||||
placeholder="e.g. tomorrow at 3 PM",
|
||||
required=True,
|
||||
)
|
||||
|
||||
self.add_item(self.message_input)
|
||||
self.add_item(self.time_input)
|
||||
|
||||
def _get_trigger_type(self, trigger: DateTrigger | IntervalTrigger | CronTrigger) -> str:
|
||||
"""Determine the type of trigger.
|
||||
|
||||
Args:
|
||||
trigger: The APScheduler trigger.
|
||||
|
||||
Returns:
|
||||
str: The type of trigger ("date", "interval", "cron", or "unknown").
|
||||
"""
|
||||
if isinstance(trigger, DateTrigger):
|
||||
return "date"
|
||||
if isinstance(trigger, IntervalTrigger):
|
||||
return "interval"
|
||||
if isinstance(trigger, CronTrigger):
|
||||
return "cron"
|
||||
return "unknown"
|
||||
|
||||
def _format_interval_from_trigger(self, trigger: IntervalTrigger) -> str:
|
||||
"""Format an interval trigger into a human-readable string.
|
||||
|
||||
Args:
|
||||
trigger (IntervalTrigger): The interval trigger.
|
||||
|
||||
Returns:
|
||||
str: Formatted interval string (e.g., "1d 2h 30m").
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Get interval values from the trigger.__getstate__() dictionary
|
||||
trigger_state = trigger.__getstate__()
|
||||
|
||||
if trigger_state.get("weeks", 0):
|
||||
parts.append(f"{trigger_state['weeks']}w")
|
||||
if trigger_state.get("days", 0):
|
||||
parts.append(f"{trigger_state['days']}d")
|
||||
if trigger_state.get("hours", 0):
|
||||
parts.append(f"{trigger_state['hours']}h")
|
||||
if trigger_state.get("minutes", 0):
|
||||
parts.append(f"{trigger_state['minutes']}m")
|
||||
|
||||
seconds = trigger_state.get("seconds", 0)
|
||||
if seconds and seconds % 60 != 0: # Only show seconds if not even minutes
|
||||
parts.append(f"{seconds % 60}s")
|
||||
|
||||
return " ".join(parts) if parts else "0m"
|
||||
|
||||
def _format_cron_from_trigger(self, trigger: CronTrigger) -> str:
|
||||
"""Format a cron trigger into a string representation.
|
||||
|
||||
Args:
|
||||
trigger (CronTrigger): The cron trigger.
|
||||
|
||||
Returns:
|
||||
str: Formatted cron string.
|
||||
"""
|
||||
fields = []
|
||||
|
||||
# Get the fields in standard cron order
|
||||
for field in ["second", "minute", "hour", "day", "month", "day_of_week", "year"]:
|
||||
if hasattr(trigger, field) and getattr(trigger, field) is not None:
|
||||
expr = getattr(trigger, field).expression
|
||||
fields.append(expr if expr != "*" else "*")
|
||||
|
||||
# Return only the standard 5 cron fields by default
|
||||
return " ".join(fields[:5])
|
||||
|
||||
def _parse_interval_string(self, interval_str: str) -> dict[str, int]:
|
||||
"""Parse an interval string into component parts.
|
||||
|
||||
Args:
|
||||
interval_str (str): String like "1w 2d 3h 4m 5s"
|
||||
|
||||
Returns:
|
||||
dict[str, int]: Dictionary with interval components.
|
||||
"""
|
||||
interval_dict = {"weeks": 0, "days": 0, "hours": 0, "minutes": 0, "seconds": 0}
|
||||
|
||||
# Define regex patterns for each time unit
|
||||
patterns = {r"(\d+)w": "weeks", r"(\d+)d": "days", r"(\d+)h": "hours", r"(\d+)m": "minutes", r"(\d+)s": "seconds"}
|
||||
|
||||
# Extract values for each unit
|
||||
for pattern, key in patterns.items():
|
||||
match = re.search(pattern, interval_str)
|
||||
if match:
|
||||
interval_dict[key] = int(match.group(1))
|
||||
|
||||
# Ensure at least 30 seconds total interval
|
||||
total_seconds = (
|
||||
interval_dict["weeks"] * 604800
|
||||
+ interval_dict["days"] * 86400
|
||||
+ interval_dict["hours"] * 3600
|
||||
+ interval_dict["minutes"] * 60
|
||||
+ interval_dict["seconds"]
|
||||
)
|
||||
|
||||
if total_seconds < 30:
|
||||
interval_dict["seconds"] = 30
|
||||
|
||||
return interval_dict
|
||||
|
||||
def _parse_cron_string(self, cron_str: str) -> dict[str, str]:
|
||||
"""Parse a cron string into its components.
|
||||
|
||||
Args:
|
||||
cron_str (str): Cron string like "0 9 * * 1-5"
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Dictionary with cron components.
|
||||
"""
|
||||
parts = cron_str.strip().split()
|
||||
cron_dict = {}
|
||||
|
||||
# Map position to field name
|
||||
field_names = ["second", "minute", "hour", "day", "month", "day_of_week", "year"]
|
||||
|
||||
# Handle standard 5-part cron string (minute hour day month day_of_week)
|
||||
if len(parts) == 5:
|
||||
# Add a 0 for seconds (as APScheduler expects it)
|
||||
parts.insert(0, "0")
|
||||
|
||||
# Map parts to field names
|
||||
for i, part in enumerate(parts):
|
||||
if i < len(field_names) and part != "*": # Only add non-default values
|
||||
cron_dict[field_names[i]] = part
|
||||
|
||||
return cron_dict
|
||||
|
||||
def _process_date_trigger(self, new_time_str: str, old_time: datetime.datetime | None) -> tuple[bool, str, Job | None]:
|
||||
"""Process date trigger modification.
|
||||
|
||||
Args:
|
||||
new_time_str (str): The new time string to parse.
|
||||
old_time (datetime.datetime | None): The old scheduled time.
|
||||
|
||||
Returns:
|
||||
tuple[bool, str, Job | None]: Success flag, error message, and rescheduled job.
|
||||
"""
|
||||
parsed_time: datetime.datetime | None = parse_time(new_time_str)
|
||||
if not parsed_time:
|
||||
return False, f"Invalid time format: `{new_time_str}`", None
|
||||
|
||||
if old_time and parsed_time == old_time:
|
||||
return True, "", None # No change needed
|
||||
|
||||
logger.info(f"Rescheduling date-based job {self.job_id}")
|
||||
try:
|
||||
rescheduled_job = scheduler.reschedule_job(self.job_id, trigger="date", run_date=parsed_time)
|
||||
except (ValueError, TypeError, AttributeError) as e:
|
||||
logger.exception(f"Failed to reschedule date-based job: {e}")
|
||||
return False, f"Failed to reschedule job: {e}", None
|
||||
else:
|
||||
return True, "", rescheduled_job
|
||||
|
||||
def _process_interval_trigger(self, new_time_str: str) -> tuple[bool, str, Job | None]:
|
||||
"""Process interval trigger modification.
|
||||
|
||||
Args:
|
||||
new_time_str (str): The new interval string to parse.
|
||||
|
||||
Returns:
|
||||
tuple[bool, str, Job | None]: Success flag, error message, and rescheduled job.
|
||||
"""
|
||||
try:
|
||||
interval_dict = self._parse_interval_string(new_time_str)
|
||||
logger.info(f"Rescheduling interval job {self.job_id} with {interval_dict}")
|
||||
rescheduled_job = scheduler.reschedule_job(self.job_id, trigger="interval", **interval_dict)
|
||||
except (ValueError, TypeError, AttributeError) as e:
|
||||
error_msg = f"Invalid interval format: `{new_time_str}`"
|
||||
logger.exception(f"Failed to parse interval: {e}")
|
||||
return False, error_msg, None
|
||||
else:
|
||||
return True, "", rescheduled_job
|
||||
|
||||
def _process_cron_trigger(self, new_time_str: str) -> tuple[bool, str, Job | None]:
|
||||
"""Process cron trigger modification.
|
||||
|
||||
Args:
|
||||
new_time_str (str): The new cron string to parse.
|
||||
|
||||
Returns:
|
||||
tuple[bool, str, Job | None]: Success flag, error message, and rescheduled job.
|
||||
"""
|
||||
try:
|
||||
cron_dict = self._parse_cron_string(new_time_str)
|
||||
logger.info(f"Rescheduling cron job {self.job_id} with {cron_dict}")
|
||||
rescheduled_job = scheduler.reschedule_job(self.job_id, trigger="cron", **cron_dict)
|
||||
except (ValueError, TypeError, AttributeError) as e:
|
||||
error_msg = f"Invalid cron format: `{new_time_str}`"
|
||||
logger.exception(f"Failed to parse cron: {e}")
|
||||
return False, error_msg, None
|
||||
else:
|
||||
return True, "", rescheduled_job
|
||||
|
||||
async def _update_message(self, old_message: str, new_message: str) -> bool:
|
||||
"""Update the message of a job.
|
||||
|
||||
Args:
|
||||
old_message (str): The old message.
|
||||
new_message (str): The new message.
|
||||
|
||||
Returns:
|
||||
bool: Whether the message was changed.
|
||||
"""
|
||||
if new_message == old_message:
|
||||
return False
|
||||
|
||||
job = scheduler.get_job(self.job_id)
|
||||
if not job:
|
||||
return False
|
||||
|
||||
old_kwargs = job.kwargs.copy()
|
||||
scheduler.modify_job(
|
||||
self.job_id,
|
||||
kwargs={
|
||||
**old_kwargs,
|
||||
"message": new_message,
|
||||
},
|
||||
)
|
||||
|
||||
logger.debug(f"Modified job {self.job_id} with new message: {new_message}")
|
||||
logger.debug(f"Old kwargs: {old_kwargs}, New kwargs: {job.kwargs}")
|
||||
return True
|
||||
|
||||
async def on_submit(self, interaction: discord.Interaction) -> None:
|
||||
"""Called when the modal is submitted.
|
||||
|
||||
Args:
|
||||
interaction (discord.Interaction): The Discord interaction where this modal was triggered from.
|
||||
"""
|
||||
old_message: str = self.job.kwargs.get("message", "")
|
||||
old_time: datetime.datetime | None = self.job.next_run_time
|
||||
old_time_countdown: str = calculate(self.job)
|
||||
|
||||
new_message: str = self.message_input.value
|
||||
new_time_str: str = self.time_input.value
|
||||
|
||||
# Get the job to modify
|
||||
job_to_modify: Job | None = scheduler.get_job(self.job_id)
|
||||
if not job_to_modify:
|
||||
await interaction.response.send_message(
|
||||
f"Failed to get job.\n{new_message=}\n{new_time_str=}",
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
# Defer early for long operations
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
|
||||
# Process time/schedule changes based on trigger type
|
||||
success, error_msg, rescheduled_job = False, "", None
|
||||
|
||||
if self.trigger_type == "date":
|
||||
success, error_msg, rescheduled_job = self._process_date_trigger(new_time_str, old_time)
|
||||
elif self.trigger_type == "interval":
|
||||
success, error_msg, rescheduled_job = self._process_interval_trigger(new_time_str)
|
||||
elif self.trigger_type == "cron":
|
||||
success, error_msg, rescheduled_job = self._process_cron_trigger(new_time_str)
|
||||
|
||||
# If time input is invalid, send error message
|
||||
if not success and error_msg:
|
||||
await interaction.followup.send(error_msg, ephemeral=True)
|
||||
return
|
||||
|
||||
# Update the message if changed
|
||||
msg: str = f"Modified job `{escape_markdown(self.job_id)}`:\n"
|
||||
changes_made = False
|
||||
|
||||
# Add schedule change info to message
|
||||
if rescheduled_job:
|
||||
if old_time:
|
||||
msg += (
|
||||
f"Old time: `{old_time.strftime('%Y-%m-%d %H:%M:%S')}` (In {old_time_countdown})\n"
|
||||
f"New time: Next run in {calculate(rescheduled_job)}\n"
|
||||
)
|
||||
else:
|
||||
msg += f"Job unpaused. Next run in {calculate(rescheduled_job)}\n"
|
||||
changes_made = True
|
||||
|
||||
# Update message if changed
|
||||
message_changed = await self._update_message(old_message, new_message)
|
||||
if message_changed:
|
||||
msg += f"Old message: `{escape_markdown(old_message)}`\n"
|
||||
msg += f"New message: `{escape_markdown(new_message)}`.\n"
|
||||
changes_made = True
|
||||
|
||||
# Send confirmation message
|
||||
if changes_made:
|
||||
await interaction.followup.send(content=msg)
|
||||
else:
|
||||
await interaction.followup.send(content=f"No changes made to job `{escape_markdown(self.job_id)}`.", ephemeral=True)
|
||||
|
||||
async def on_error(self, interaction: discord.Interaction, error: Exception) -> None:
|
||||
"""A callback that is called when on_submit fails with an error.
|
||||
|
||||
Args:
|
||||
interaction (discord.Interaction): The Discord interaction where this modal was triggered from.
|
||||
error (Exception): The raised exception.
|
||||
"""
|
||||
# Check if the interaction has already been responded to
|
||||
if not interaction.response.is_done():
|
||||
await interaction.response.send_message("Oops! Something went wrong.", ephemeral=True)
|
||||
else:
|
||||
try:
|
||||
await interaction.followup.send("Oops! Something went wrong.", ephemeral=True)
|
||||
except discord.HTTPException:
|
||||
logger.warning("Failed to send error message via followup")
|
||||
|
||||
logger.exception(f"Error in ReminderModifyModal: {error}")
|
||||
traceback.print_exception(type(error), error, error.__traceback__)
|
59
discord_reminder_bot/settings.py
Normal file
59
discord_reminder_bot/settings.py
Normal file
@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import platform
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
import pytz
|
||||
import sentry_sdk
|
||||
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
load_dotenv(verbose=True)
|
||||
|
||||
default_sentry_dsn: str = "https://c4c61a52838be9b5042144420fba5aaa@o4505228040339456.ingest.us.sentry.io/4508707268984832"
|
||||
sentry_sdk.init(
|
||||
dsn=os.getenv("SENTRY_DSN", default_sentry_dsn),
|
||||
environment=platform.node() or "Unknown",
|
||||
traces_sample_rate=1.0,
|
||||
send_default_pii=True,
|
||||
)
|
||||
|
||||
|
||||
def get_scheduler() -> AsyncIOScheduler:
|
||||
"""Return the scheduler instance.
|
||||
|
||||
Uses the SQLITE_LOCATION environment variable for the SQLite database location.
|
||||
|
||||
Raises:
|
||||
ValueError: If the timezone is missing or invalid.
|
||||
|
||||
Returns:
|
||||
AsyncIOScheduler: The scheduler instance.
|
||||
"""
|
||||
config_timezone: str | None = os.getenv("TIMEZONE")
|
||||
if not config_timezone:
|
||||
msg = "Missing timezone. Please set the TIMEZONE environment variable."
|
||||
raise ValueError(msg)
|
||||
|
||||
# Test if the timezone is valid
|
||||
try:
|
||||
ZoneInfo(config_timezone)
|
||||
except (ZoneInfoNotFoundError, ModuleNotFoundError) as e:
|
||||
msg: str = f"Invalid timezone: {config_timezone}. Error: {e}"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
logger.info(f"Using timezone: {config_timezone}. If this is incorrect, please set the TIMEZONE environment variable.")
|
||||
|
||||
sqlite_location: str = os.getenv("SQLITE_LOCATION", default="/jobs.sqlite")
|
||||
logger.info(f"Using SQLite database at: {sqlite_location}")
|
||||
|
||||
jobstores: dict[str, SQLAlchemyJobStore] = {"default": SQLAlchemyJobStore(url=f"sqlite://{sqlite_location}")}
|
||||
job_defaults: dict[str, bool] = {"coalesce": True}
|
||||
timezone = pytz.timezone(config_timezone)
|
||||
return AsyncIOScheduler(jobstores=jobstores, timezone=timezone, job_defaults=job_defaults)
|
||||
|
||||
|
||||
scheduler: AsyncIOScheduler = get_scheduler()
|
@ -12,7 +12,7 @@ from apscheduler.triggers.date import DateTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
|
||||
from discord_reminder_bot import main
|
||||
from discord_reminder_bot.main import calculate, parse_time
|
||||
from discord_reminder_bot.helpers import calculate, parse_time
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from apscheduler.job import Job
|
||||
|
12
uv.lock
generated
12
uv.lock
generated
@ -59,14 +59,14 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "aiosignal"
|
||||
version = "1.3.2"
|
||||
version = "1.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "frozenlist" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ba/b5/6d55e80f6d8a08ce22b982eafa278d823b541c925f11ee774b0b9c43473d/aiosignal-1.3.2.tar.gz", hash = "sha256:a8c255c66fafb1e499c9351d0bf32ff2d8a0321595ebac3b93713656d2436f54", size = 19424, upload-time = "2024-12-13T17:10:40.86Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/61/62/06741b579156360248d1ec624842ad0edf697050bbaf7c3e46394e106ad1/aiosignal-1.4.0.tar.gz", hash = "sha256:f47eecd9468083c2029cc99945502cb7708b082c232f9aca65da147157b251c7", size = 25007, upload-time = "2025-07-03T22:54:43.528Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/6a/bc7e17a3e87a2985d3e8f4da4cd0f481060eb78fb08596c42be62c90a4d9/aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5", size = 7597, upload-time = "2024-12-13T17:10:38.469Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -690,11 +690,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.14.0"
|
||||
version = "4.14.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d1/bc/51647cd02527e87d05cb083ccc402f93e441606ff1f01739a62c8ad09ba5/typing_extensions-4.14.0.tar.gz", hash = "sha256:8676b788e32f02ab42d9e7c61324048ae4c6d844a399eebace3d4979d75ceef4", size = 107423, upload-time = "2025-06-02T14:52:11.399Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/98/5a/da40306b885cc8c09109dc2e1abd358d5684b1425678151cdaed4731c822/typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36", size = 107673, upload-time = "2025-07-04T13:28:34.16Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/69/e0/552843e0d356fbb5256d21449fa957fa4eff3bbc135a74a691ee70c7c5da/typing_extensions-4.14.0-py3-none-any.whl", hash = "sha256:a1514509136dd0b477638fc68d6a91497af5076466ad0fa6c338e44e359944af", size = 43839, upload-time = "2025-06-02T14:52:10.026Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/00/d631e67a838026495268c2f6884f3711a15a9a2a96cd244fdaea53b823fb/typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76", size = 43906, upload-time = "2025-07-04T13:28:32.743Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
Reference in New Issue
Block a user