Move settings to a function and add tests
This commit is contained in:
@ -12,9 +12,9 @@ from apscheduler.job import Job
|
||||
from discord.abc import PrivateChannel
|
||||
from discord_webhook import DiscordWebhook
|
||||
|
||||
from discord_reminder_bot import settings
|
||||
from discord_reminder_bot.misc import calculate
|
||||
from discord_reminder_bot.parser import parse_time
|
||||
from discord_reminder_bot.settings import get_bot_token, get_scheduler, get_webhook_url
|
||||
from discord_reminder_bot.ui import JobManagementView, create_job_embed
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -22,11 +22,15 @@ if TYPE_CHECKING:
|
||||
from discord.guild import GuildChannel
|
||||
from discord.interactions import InteractionChannel
|
||||
|
||||
from discord_reminder_bot import settings
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
GUILD_ID = discord.Object(id=341001473661992962)
|
||||
|
||||
scheduler: settings.AsyncIOScheduler = get_scheduler()
|
||||
|
||||
|
||||
class RemindBotClient(discord.Client):
|
||||
"""Custom client class for the bot."""
|
||||
@ -46,8 +50,8 @@ class RemindBotClient(discord.Client):
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
"""Setup the bot."""
|
||||
settings.scheduler.start()
|
||||
jobs: list[Job] = settings.scheduler.get_jobs()
|
||||
scheduler.start()
|
||||
jobs: list[Job] = scheduler.get_jobs()
|
||||
if not jobs:
|
||||
logger.info("No jobs available.")
|
||||
return
|
||||
@ -129,7 +133,7 @@ class RemindGroup(discord.app_commands.Group):
|
||||
await interaction.followup.send(content=f"Failed to parse time: {time}.", ephemeral=True)
|
||||
return
|
||||
|
||||
user_reminder: Job = settings.scheduler.add_job(
|
||||
user_reminder: Job = scheduler.add_job(
|
||||
func=send_to_user,
|
||||
trigger="date",
|
||||
run_date=parsed_time,
|
||||
@ -152,7 +156,7 @@ class RemindGroup(discord.app_commands.Group):
|
||||
return
|
||||
|
||||
# Create channel reminder job
|
||||
channel_job: Job = settings.scheduler.add_job(
|
||||
channel_job: Job = scheduler.add_job(
|
||||
func=send_to_discord,
|
||||
job_kwargs={
|
||||
"channel_id": channel_id,
|
||||
@ -180,13 +184,13 @@ class RemindGroup(discord.app_commands.Group):
|
||||
"""
|
||||
await interaction.response.defer()
|
||||
|
||||
jobs: list[Job] = settings.scheduler.get_jobs()
|
||||
jobs: list[Job] = scheduler.get_jobs()
|
||||
if not jobs:
|
||||
await interaction.followup.send(content="No scheduled jobs found in the database.", ephemeral=True)
|
||||
return
|
||||
|
||||
embed: discord.Embed = create_job_embed(job=jobs[0])
|
||||
view = JobManagementView(job=jobs[0], scheduler=settings.scheduler)
|
||||
view = JobManagementView(job=jobs[0], scheduler=scheduler)
|
||||
|
||||
await interaction.followup.send(embed=embed, view=view)
|
||||
|
||||
@ -263,7 +267,7 @@ class RemindGroup(discord.app_commands.Group):
|
||||
# Create user DM reminder job if user is specified
|
||||
dm_message: str = ""
|
||||
if user:
|
||||
user_reminder: Job = settings.scheduler.add_job(
|
||||
user_reminder: Job = scheduler.add_job(
|
||||
func=send_to_user,
|
||||
trigger="cron",
|
||||
year=year,
|
||||
@ -295,7 +299,7 @@ class RemindGroup(discord.app_commands.Group):
|
||||
return
|
||||
|
||||
# Create channel reminder job
|
||||
channel_job: Job = settings.scheduler.add_job(
|
||||
channel_job: Job = scheduler.add_job(
|
||||
func=send_to_discord,
|
||||
trigger="cron",
|
||||
year=year,
|
||||
@ -396,7 +400,7 @@ class RemindGroup(discord.app_commands.Group):
|
||||
# Create user DM reminder job if user is specified
|
||||
dm_message: str = ""
|
||||
if user:
|
||||
dm_job: Job = settings.scheduler.add_job(
|
||||
dm_job: Job = scheduler.add_job(
|
||||
func=send_to_user,
|
||||
trigger="interval",
|
||||
weeks=weeks,
|
||||
@ -424,7 +428,7 @@ class RemindGroup(discord.app_commands.Group):
|
||||
)
|
||||
|
||||
# Create channel reminder job
|
||||
channel_job: Job = settings.scheduler.add_job(
|
||||
channel_job: Job = scheduler.add_job(
|
||||
func=send_to_discord,
|
||||
trigger="interval",
|
||||
weeks=weeks,
|
||||
@ -463,7 +467,7 @@ class RemindGroup(discord.app_commands.Group):
|
||||
# Retrieve all jobs
|
||||
with tempfile.NamedTemporaryFile(mode="r+", delete=False, encoding="utf-8", suffix=".json") as temp_file:
|
||||
# Export jobs to a temporary file
|
||||
settings.scheduler.export_jobs(temp_file.name)
|
||||
scheduler.export_jobs(temp_file.name)
|
||||
|
||||
# Load the exported jobs
|
||||
temp_file.seek(0)
|
||||
@ -513,7 +517,7 @@ class RemindGroup(discord.app_commands.Group):
|
||||
|
||||
# Write the data to a new file
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, encoding="utf-8", suffix=".json") as output_file:
|
||||
file_name = f"reminders-backup-{datetime.datetime.now(tz=settings.scheduler.timezone)}.json"
|
||||
file_name = f"reminders-backup-{datetime.datetime.now(tz=scheduler.timezone)}.json"
|
||||
json.dump(jobs_data, output_file, indent=4)
|
||||
output_file.seek(0)
|
||||
|
||||
@ -532,7 +536,7 @@ class RemindGroup(discord.app_commands.Group):
|
||||
logger.info("Restoring reminders from file for %s (%s) in %s", interaction.user, interaction.user.id, interaction.channel)
|
||||
|
||||
# Get the old jobs
|
||||
old_jobs: list[Job] = settings.scheduler.get_jobs()
|
||||
old_jobs: list[Job] = scheduler.get_jobs()
|
||||
|
||||
# Tell to reply with the file to this message
|
||||
await interaction.followup.send(content="Please reply to this message with the backup file.")
|
||||
@ -548,6 +552,10 @@ class RemindGroup(discord.app_commands.Group):
|
||||
)
|
||||
return
|
||||
|
||||
if not reply.channel:
|
||||
await interaction.followup.send(content="No channel found. Please try again.")
|
||||
continue
|
||||
|
||||
# Fetch the message by its ID to ensure we have the latest data
|
||||
reply = await reply.channel.fetch_message(reply.id)
|
||||
|
||||
@ -580,8 +588,8 @@ class RemindGroup(discord.app_commands.Group):
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w+", delete=False, encoding="utf-8", suffix=".json") as temp_import_file:
|
||||
# We can't import jobs with the same ID so remove them from the JSON
|
||||
jobs = [job for job in jobs_data.get("jobs", []) if not settings.scheduler.get_job(job.get("id"))]
|
||||
jobs_already_exist = [job.get("id") for job in jobs_data.get("jobs", []) if settings.scheduler.get_job(job.get("id"))]
|
||||
jobs = [job for job in jobs_data.get("jobs", []) if not scheduler.get_job(job.get("id"))]
|
||||
jobs_already_exist = [job.get("id") for job in jobs_data.get("jobs", []) if scheduler.get_job(job.get("id"))]
|
||||
jobs_data["jobs"] = jobs
|
||||
for job_id in jobs_already_exist:
|
||||
logger.debug("Removed job: %s because it already exists.", job_id)
|
||||
@ -594,10 +602,10 @@ class RemindGroup(discord.app_commands.Group):
|
||||
temp_import_file.seek(0)
|
||||
|
||||
# Import the jobs
|
||||
settings.scheduler.import_jobs(temp_import_file.name)
|
||||
scheduler.import_jobs(temp_import_file.name)
|
||||
|
||||
# Get the new jobs
|
||||
new_jobs: list[Job] = settings.scheduler.get_jobs()
|
||||
new_jobs: list[Job] = scheduler.get_jobs()
|
||||
|
||||
# Get the difference
|
||||
added_jobs: list[Job] = [job for job in new_jobs if job not in old_jobs]
|
||||
@ -621,20 +629,26 @@ remind_group = RemindGroup()
|
||||
bot.tree.add_command(remind_group)
|
||||
|
||||
|
||||
def send_webhook(
|
||||
url: str = settings.webhook_url,
|
||||
message: str = "discord-reminder-bot: Empty message.",
|
||||
) -> None:
|
||||
def send_webhook(url: str = "", message: str = "") -> None:
|
||||
"""Send a webhook to Discord.
|
||||
|
||||
Args:
|
||||
url: Our webhook url, defaults to the one from settings.
|
||||
message: The message that will be sent to Discord.
|
||||
"""
|
||||
if not message:
|
||||
logger.error("No message provided.")
|
||||
message = "No message provided."
|
||||
|
||||
if not url:
|
||||
msg = "ERROR: Tried to send a webhook but you have no webhook url configured."
|
||||
logger.error(msg)
|
||||
webhook: DiscordWebhook = DiscordWebhook(url=settings.webhook_url, content=msg, rate_limit_retry=True)
|
||||
url = get_webhook_url()
|
||||
logger.error("No webhook URL provided. Using the one from settings.")
|
||||
webhook: DiscordWebhook = DiscordWebhook(
|
||||
url=url,
|
||||
username="discord-reminder-bot",
|
||||
content="No webhook URL provided. Using the one from settings.",
|
||||
rate_limit_retry=True,
|
||||
)
|
||||
webhook.execute()
|
||||
return
|
||||
|
||||
@ -683,4 +697,5 @@ async def send_to_user(user_id: int, guild_id: int, message: str) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bot.run(settings.bot_token, root_logger=True)
|
||||
bot_token: str = get_bot_token()
|
||||
bot.run(bot_token, root_logger=True)
|
||||
|
@ -2,21 +2,22 @@ from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
from zoneinfo import ZoneInfo
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
import dateparser
|
||||
|
||||
from discord_reminder_bot import settings
|
||||
from discord_reminder_bot.settings import get_timezone
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_time(date_to_parse: str, timezone: str | None = None) -> datetime.datetime | None:
|
||||
def parse_time(date_to_parse: str, timezone: str | None = None, use_dotenv: bool = True) -> datetime.datetime | None: # noqa: FBT001, FBT002
|
||||
"""Parse a date string into a datetime object.
|
||||
|
||||
Args:
|
||||
date_to_parse(str): The date string to parse.
|
||||
timezone(str, optional): The timezone to use. Defaults timezone from settings.
|
||||
use_dotenv(bool, optional): Whether to load environment variables from a .env file. Defaults to True
|
||||
|
||||
Returns:
|
||||
datetime.datetime: The parsed datetime object.
|
||||
@ -28,7 +29,14 @@ def parse_time(date_to_parse: str, timezone: str | None = None) -> datetime.date
|
||||
return None
|
||||
|
||||
if not timezone:
|
||||
timezone = settings.config_timezone
|
||||
timezone = get_timezone(use_dotenv)
|
||||
|
||||
# Check if the timezone is valid
|
||||
try:
|
||||
tz = ZoneInfo(timezone)
|
||||
except (ZoneInfoNotFoundError, ModuleNotFoundError):
|
||||
logger.error("Invalid timezone provided: '%s'. Using default timezone: '%s'", timezone, get_timezone(use_dotenv)) # noqa: TRY400
|
||||
tz = ZoneInfo("UTC")
|
||||
|
||||
try:
|
||||
parsed_date: datetime.datetime | None = dateparser.parse(
|
||||
@ -37,7 +45,7 @@ def parse_time(date_to_parse: str, timezone: str | None = None) -> datetime.date
|
||||
"PREFER_DATES_FROM": "future",
|
||||
"TIMEZONE": f"{timezone}",
|
||||
"RETURN_AS_TIMEZONE_AWARE": True,
|
||||
"RELATIVE_BASE": datetime.datetime.now(tz=ZoneInfo(timezone)),
|
||||
"RELATIVE_BASE": datetime.datetime.now(tz=tz),
|
||||
},
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
|
@ -7,22 +7,150 @@ from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(verbose=True)
|
||||
sqlite_location: str = os.getenv("SQLITE_LOCATION", default="/jobs.sqlite")
|
||||
config_timezone: str = os.getenv("TIMEZONE", default="UTC")
|
||||
bot_token: str = os.getenv("BOT_TOKEN", default="")
|
||||
log_level: str = os.getenv("LOG_LEVEL", default="INFO")
|
||||
webhook_url: str = os.getenv("WEBHOOK_URL", default="")
|
||||
|
||||
if not bot_token:
|
||||
err_msg = "Missing bot token"
|
||||
raise ValueError(err_msg)
|
||||
def get_settings(use_dotenv: bool = True) -> dict[str, str | dict[str, SQLAlchemyJobStore] | dict[str, bool] | AsyncIOScheduler]: # noqa: FBT001, FBT002
|
||||
"""Load environment variables and return the settings.
|
||||
|
||||
# Advanced Python Scheduler
|
||||
jobstores: dict[str, SQLAlchemyJobStore] = {"default": SQLAlchemyJobStore(url=f"sqlite://{sqlite_location}")}
|
||||
job_defaults: dict[str, bool] = {"coalesce": True}
|
||||
scheduler = AsyncIOScheduler(
|
||||
jobstores=jobstores,
|
||||
timezone=pytz.timezone(config_timezone),
|
||||
job_defaults=job_defaults,
|
||||
)
|
||||
Args:
|
||||
use_dotenv (bool, optional): Whether to load environment variables from a .env file. Defaults to True.
|
||||
|
||||
Raises:
|
||||
ValueError: If the bot token is missing.
|
||||
|
||||
Returns:
|
||||
dict: The settings.
|
||||
"""
|
||||
if use_dotenv:
|
||||
load_dotenv(verbose=True)
|
||||
|
||||
sqlite_location: str = os.getenv("SQLITE_LOCATION", default="/jobs.sqlite")
|
||||
config_timezone: str = os.getenv("TIMEZONE", default="UTC")
|
||||
bot_token: str = os.getenv("BOT_TOKEN", default="")
|
||||
log_level: str = os.getenv("LOG_LEVEL", default="INFO")
|
||||
webhook_url: str = os.getenv("WEBHOOK_URL", default="")
|
||||
|
||||
if not bot_token:
|
||||
msg = "Missing bot token. Please set the BOT_TOKEN environment variable."
|
||||
raise ValueError(msg)
|
||||
|
||||
jobstores: dict[str, SQLAlchemyJobStore] = {"default": SQLAlchemyJobStore(url=f"sqlite://{sqlite_location}")}
|
||||
job_defaults: dict[str, bool] = {"coalesce": True}
|
||||
scheduler = AsyncIOScheduler(
|
||||
jobstores=jobstores,
|
||||
timezone=pytz.timezone(config_timezone),
|
||||
job_defaults=job_defaults,
|
||||
)
|
||||
return {
|
||||
"sqlite_location": sqlite_location,
|
||||
"config_timezone": config_timezone,
|
||||
"bot_token": bot_token,
|
||||
"log_level": log_level,
|
||||
"webhook_url": webhook_url,
|
||||
"jobstores": jobstores,
|
||||
"job_defaults": job_defaults,
|
||||
"scheduler": scheduler,
|
||||
}
|
||||
|
||||
|
||||
def get_scheduler(use_dotenv: bool = True) -> AsyncIOScheduler: # noqa: FBT001, FBT002
|
||||
"""Return the scheduler instance.
|
||||
|
||||
Args:
|
||||
use_dotenv (bool, optional): Whether to load environment variables from a .env file. Defaults to True
|
||||
|
||||
Raises:
|
||||
TypeError: If the scheduler is not an instance of AsyncIOScheduler.
|
||||
KeyError: If the scheduler is missing from the settings.
|
||||
|
||||
Returns:
|
||||
AsyncIOScheduler: The scheduler instance.
|
||||
"""
|
||||
settings: dict[str, str | dict[str, SQLAlchemyJobStore] | dict[str, bool] | AsyncIOScheduler] = get_settings(use_dotenv)
|
||||
|
||||
if scheduler := settings.get("scheduler"):
|
||||
if not isinstance(scheduler, AsyncIOScheduler):
|
||||
msg = "The scheduler is not an instance of AsyncIOScheduler."
|
||||
raise TypeError(msg)
|
||||
|
||||
return scheduler
|
||||
|
||||
msg = "The scheduler is missing from the settings."
|
||||
raise KeyError(msg)
|
||||
|
||||
|
||||
def get_bot_token(use_dotenv: bool = True) -> str: # noqa: FBT001, FBT002
|
||||
"""Return the bot token.
|
||||
|
||||
Args:
|
||||
use_dotenv (bool, optional): Whether to load environment variables from a .env file. Defaults to True
|
||||
|
||||
Raises:
|
||||
TypeError: If the bot token is not a string.
|
||||
KeyError: If the bot token is missing from the settings.
|
||||
|
||||
Returns:
|
||||
str: The bot token.
|
||||
"""
|
||||
settings: dict[str, str | dict[str, SQLAlchemyJobStore] | dict[str, bool] | AsyncIOScheduler] = get_settings(use_dotenv)
|
||||
|
||||
if bot_token := settings.get("bot_token"):
|
||||
if not isinstance(bot_token, str):
|
||||
msg = "The bot token is not a string."
|
||||
raise TypeError(msg)
|
||||
|
||||
return bot_token
|
||||
|
||||
msg = "The bot token is missing from the settings."
|
||||
raise KeyError(msg)
|
||||
|
||||
|
||||
def get_webhook_url(use_dotenv: bool = True) -> str: # noqa: FBT001, FBT002
|
||||
"""Return the webhook URL.
|
||||
|
||||
Args:
|
||||
use_dotenv (bool, optional): Whether to load environment variables from a .env file. Defaults to True
|
||||
|
||||
Raises:
|
||||
TypeError: If the webhook URL is not a string.
|
||||
KeyError: If the webhook URL is missing from the settings.
|
||||
|
||||
Returns:
|
||||
str: The webhook URL.
|
||||
"""
|
||||
settings: dict[str, str | dict[str, SQLAlchemyJobStore] | dict[str, bool] | AsyncIOScheduler] = get_settings(use_dotenv)
|
||||
|
||||
if webhook_url := settings.get("webhook_url"):
|
||||
if not isinstance(webhook_url, str):
|
||||
msg = "The webhook URL is not a string."
|
||||
raise TypeError(msg)
|
||||
|
||||
return webhook_url
|
||||
|
||||
msg = "The webhook URL is missing from the settings."
|
||||
raise KeyError(msg)
|
||||
|
||||
|
||||
def get_timezone(use_dotenv: bool = True) -> str: # noqa: FBT001, FBT002
|
||||
"""Return the timezone.
|
||||
|
||||
Args:
|
||||
use_dotenv (bool, optional): Whether to load environment variables from a .env file. Defaults to True
|
||||
|
||||
Raises:
|
||||
TypeError: If the timezone is not a string.
|
||||
KeyError: If the timezone is missing from the settings.
|
||||
|
||||
Returns:
|
||||
str: The timezone.
|
||||
"""
|
||||
settings: dict[str, str | dict[str, SQLAlchemyJobStore] | dict[str, bool] | AsyncIOScheduler] = get_settings(use_dotenv)
|
||||
|
||||
if config_timezone := settings.get("config_timezone"):
|
||||
if not isinstance(config_timezone, str):
|
||||
msg = "The timezone is not a string."
|
||||
raise TypeError(msg)
|
||||
|
||||
return config_timezone
|
||||
|
||||
msg = "The timezone is missing from the settings."
|
||||
raise KeyError(msg)
|
||||
|
@ -10,7 +10,6 @@ from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from discord.ui import Button, Select
|
||||
|
||||
from discord_reminder_bot import settings
|
||||
from discord_reminder_bot.misc import DateTrigger, calc_time, calculate
|
||||
from discord_reminder_bot.parser import parse_time
|
||||
|
||||
@ -20,8 +19,6 @@ if TYPE_CHECKING:
|
||||
from apscheduler.job import Job
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
|
||||
from discord_reminder_bot import settings
|
||||
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
@ -41,7 +38,7 @@ class ModifyJobModal(discord.ui.Modal, title="Modify Job"):
|
||||
"""
|
||||
super().__init__()
|
||||
self.job: Job = job
|
||||
self.scheduler: settings.AsyncIOScheduler = scheduler
|
||||
self.scheduler: AsyncIOScheduler = scheduler
|
||||
|
||||
# Use "Name" as label if the message is too long, otherwise use the old message
|
||||
job_name_label: str = f"Name ({self.job.kwargs.get('message', 'X' * 46)})"
|
||||
@ -167,7 +164,7 @@ class JobSelector(Select):
|
||||
Args:
|
||||
scheduler: The scheduler to get the jobs from.
|
||||
"""
|
||||
self.scheduler: settings.AsyncIOScheduler = scheduler
|
||||
self.scheduler: AsyncIOScheduler = scheduler
|
||||
options: list[discord.SelectOption] = []
|
||||
jobs: list[Job] = scheduler.get_jobs()
|
||||
|
||||
@ -217,7 +214,7 @@ class JobManagementView(discord.ui.View):
|
||||
"""
|
||||
super().__init__(timeout=None)
|
||||
self.job: Job = job
|
||||
self.scheduler: settings.AsyncIOScheduler = scheduler
|
||||
self.scheduler: AsyncIOScheduler = scheduler
|
||||
self.add_item(JobSelector(scheduler))
|
||||
self.update_buttons()
|
||||
|
||||
|
Reference in New Issue
Block a user