Move settings to a function and add tests
This commit is contained in:
6
.vscode/settings.json
vendored
6
.vscode/settings.json
vendored
@ -1,11 +1,17 @@
|
|||||||
{
|
{
|
||||||
"cSpell.words": [
|
"cSpell.words": [
|
||||||
|
"aiohttp",
|
||||||
"apscheduler",
|
"apscheduler",
|
||||||
"asctime",
|
"asctime",
|
||||||
|
"asyncio",
|
||||||
"audioop",
|
"audioop",
|
||||||
|
"cookiejar",
|
||||||
"dateparser",
|
"dateparser",
|
||||||
|
"delenv",
|
||||||
"docstrings",
|
"docstrings",
|
||||||
"dotenv",
|
"dotenv",
|
||||||
|
"filterwarnings",
|
||||||
|
"freezegun",
|
||||||
"hikari",
|
"hikari",
|
||||||
"isort",
|
"isort",
|
||||||
"jobstores",
|
"jobstores",
|
||||||
|
@ -12,9 +12,9 @@ from apscheduler.job import Job
|
|||||||
from discord.abc import PrivateChannel
|
from discord.abc import PrivateChannel
|
||||||
from discord_webhook import DiscordWebhook
|
from discord_webhook import DiscordWebhook
|
||||||
|
|
||||||
from discord_reminder_bot import settings
|
|
||||||
from discord_reminder_bot.misc import calculate
|
from discord_reminder_bot.misc import calculate
|
||||||
from discord_reminder_bot.parser import parse_time
|
from discord_reminder_bot.parser import parse_time
|
||||||
|
from discord_reminder_bot.settings import get_bot_token, get_scheduler, get_webhook_url
|
||||||
from discord_reminder_bot.ui import JobManagementView, create_job_embed
|
from discord_reminder_bot.ui import JobManagementView, create_job_embed
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -22,11 +22,15 @@ if TYPE_CHECKING:
|
|||||||
from discord.guild import GuildChannel
|
from discord.guild import GuildChannel
|
||||||
from discord.interactions import InteractionChannel
|
from discord.interactions import InteractionChannel
|
||||||
|
|
||||||
|
from discord_reminder_bot import settings
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
GUILD_ID = discord.Object(id=341001473661992962)
|
GUILD_ID = discord.Object(id=341001473661992962)
|
||||||
|
|
||||||
|
scheduler: settings.AsyncIOScheduler = get_scheduler()
|
||||||
|
|
||||||
|
|
||||||
class RemindBotClient(discord.Client):
|
class RemindBotClient(discord.Client):
|
||||||
"""Custom client class for the bot."""
|
"""Custom client class for the bot."""
|
||||||
@ -46,8 +50,8 @@ class RemindBotClient(discord.Client):
|
|||||||
|
|
||||||
async def setup_hook(self) -> None:
|
async def setup_hook(self) -> None:
|
||||||
"""Setup the bot."""
|
"""Setup the bot."""
|
||||||
settings.scheduler.start()
|
scheduler.start()
|
||||||
jobs: list[Job] = settings.scheduler.get_jobs()
|
jobs: list[Job] = scheduler.get_jobs()
|
||||||
if not jobs:
|
if not jobs:
|
||||||
logger.info("No jobs available.")
|
logger.info("No jobs available.")
|
||||||
return
|
return
|
||||||
@ -129,7 +133,7 @@ class RemindGroup(discord.app_commands.Group):
|
|||||||
await interaction.followup.send(content=f"Failed to parse time: {time}.", ephemeral=True)
|
await interaction.followup.send(content=f"Failed to parse time: {time}.", ephemeral=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
user_reminder: Job = settings.scheduler.add_job(
|
user_reminder: Job = scheduler.add_job(
|
||||||
func=send_to_user,
|
func=send_to_user,
|
||||||
trigger="date",
|
trigger="date",
|
||||||
run_date=parsed_time,
|
run_date=parsed_time,
|
||||||
@ -152,7 +156,7 @@ class RemindGroup(discord.app_commands.Group):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Create channel reminder job
|
# Create channel reminder job
|
||||||
channel_job: Job = settings.scheduler.add_job(
|
channel_job: Job = scheduler.add_job(
|
||||||
func=send_to_discord,
|
func=send_to_discord,
|
||||||
job_kwargs={
|
job_kwargs={
|
||||||
"channel_id": channel_id,
|
"channel_id": channel_id,
|
||||||
@ -180,13 +184,13 @@ class RemindGroup(discord.app_commands.Group):
|
|||||||
"""
|
"""
|
||||||
await interaction.response.defer()
|
await interaction.response.defer()
|
||||||
|
|
||||||
jobs: list[Job] = settings.scheduler.get_jobs()
|
jobs: list[Job] = scheduler.get_jobs()
|
||||||
if not jobs:
|
if not jobs:
|
||||||
await interaction.followup.send(content="No scheduled jobs found in the database.", ephemeral=True)
|
await interaction.followup.send(content="No scheduled jobs found in the database.", ephemeral=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
embed: discord.Embed = create_job_embed(job=jobs[0])
|
embed: discord.Embed = create_job_embed(job=jobs[0])
|
||||||
view = JobManagementView(job=jobs[0], scheduler=settings.scheduler)
|
view = JobManagementView(job=jobs[0], scheduler=scheduler)
|
||||||
|
|
||||||
await interaction.followup.send(embed=embed, view=view)
|
await interaction.followup.send(embed=embed, view=view)
|
||||||
|
|
||||||
@ -263,7 +267,7 @@ class RemindGroup(discord.app_commands.Group):
|
|||||||
# Create user DM reminder job if user is specified
|
# Create user DM reminder job if user is specified
|
||||||
dm_message: str = ""
|
dm_message: str = ""
|
||||||
if user:
|
if user:
|
||||||
user_reminder: Job = settings.scheduler.add_job(
|
user_reminder: Job = scheduler.add_job(
|
||||||
func=send_to_user,
|
func=send_to_user,
|
||||||
trigger="cron",
|
trigger="cron",
|
||||||
year=year,
|
year=year,
|
||||||
@ -295,7 +299,7 @@ class RemindGroup(discord.app_commands.Group):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Create channel reminder job
|
# Create channel reminder job
|
||||||
channel_job: Job = settings.scheduler.add_job(
|
channel_job: Job = scheduler.add_job(
|
||||||
func=send_to_discord,
|
func=send_to_discord,
|
||||||
trigger="cron",
|
trigger="cron",
|
||||||
year=year,
|
year=year,
|
||||||
@ -396,7 +400,7 @@ class RemindGroup(discord.app_commands.Group):
|
|||||||
# Create user DM reminder job if user is specified
|
# Create user DM reminder job if user is specified
|
||||||
dm_message: str = ""
|
dm_message: str = ""
|
||||||
if user:
|
if user:
|
||||||
dm_job: Job = settings.scheduler.add_job(
|
dm_job: Job = scheduler.add_job(
|
||||||
func=send_to_user,
|
func=send_to_user,
|
||||||
trigger="interval",
|
trigger="interval",
|
||||||
weeks=weeks,
|
weeks=weeks,
|
||||||
@ -424,7 +428,7 @@ class RemindGroup(discord.app_commands.Group):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create channel reminder job
|
# Create channel reminder job
|
||||||
channel_job: Job = settings.scheduler.add_job(
|
channel_job: Job = scheduler.add_job(
|
||||||
func=send_to_discord,
|
func=send_to_discord,
|
||||||
trigger="interval",
|
trigger="interval",
|
||||||
weeks=weeks,
|
weeks=weeks,
|
||||||
@ -463,7 +467,7 @@ class RemindGroup(discord.app_commands.Group):
|
|||||||
# Retrieve all jobs
|
# Retrieve all jobs
|
||||||
with tempfile.NamedTemporaryFile(mode="r+", delete=False, encoding="utf-8", suffix=".json") as temp_file:
|
with tempfile.NamedTemporaryFile(mode="r+", delete=False, encoding="utf-8", suffix=".json") as temp_file:
|
||||||
# Export jobs to a temporary file
|
# Export jobs to a temporary file
|
||||||
settings.scheduler.export_jobs(temp_file.name)
|
scheduler.export_jobs(temp_file.name)
|
||||||
|
|
||||||
# Load the exported jobs
|
# Load the exported jobs
|
||||||
temp_file.seek(0)
|
temp_file.seek(0)
|
||||||
@ -513,7 +517,7 @@ class RemindGroup(discord.app_commands.Group):
|
|||||||
|
|
||||||
# Write the data to a new file
|
# Write the data to a new file
|
||||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, encoding="utf-8", suffix=".json") as output_file:
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, encoding="utf-8", suffix=".json") as output_file:
|
||||||
file_name = f"reminders-backup-{datetime.datetime.now(tz=settings.scheduler.timezone)}.json"
|
file_name = f"reminders-backup-{datetime.datetime.now(tz=scheduler.timezone)}.json"
|
||||||
json.dump(jobs_data, output_file, indent=4)
|
json.dump(jobs_data, output_file, indent=4)
|
||||||
output_file.seek(0)
|
output_file.seek(0)
|
||||||
|
|
||||||
@ -532,7 +536,7 @@ class RemindGroup(discord.app_commands.Group):
|
|||||||
logger.info("Restoring reminders from file for %s (%s) in %s", interaction.user, interaction.user.id, interaction.channel)
|
logger.info("Restoring reminders from file for %s (%s) in %s", interaction.user, interaction.user.id, interaction.channel)
|
||||||
|
|
||||||
# Get the old jobs
|
# Get the old jobs
|
||||||
old_jobs: list[Job] = settings.scheduler.get_jobs()
|
old_jobs: list[Job] = scheduler.get_jobs()
|
||||||
|
|
||||||
# Tell to reply with the file to this message
|
# Tell to reply with the file to this message
|
||||||
await interaction.followup.send(content="Please reply to this message with the backup file.")
|
await interaction.followup.send(content="Please reply to this message with the backup file.")
|
||||||
@ -548,6 +552,10 @@ class RemindGroup(discord.app_commands.Group):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not reply.channel:
|
||||||
|
await interaction.followup.send(content="No channel found. Please try again.")
|
||||||
|
continue
|
||||||
|
|
||||||
# Fetch the message by its ID to ensure we have the latest data
|
# Fetch the message by its ID to ensure we have the latest data
|
||||||
reply = await reply.channel.fetch_message(reply.id)
|
reply = await reply.channel.fetch_message(reply.id)
|
||||||
|
|
||||||
@ -580,8 +588,8 @@ class RemindGroup(discord.app_commands.Group):
|
|||||||
|
|
||||||
with tempfile.NamedTemporaryFile(mode="w+", delete=False, encoding="utf-8", suffix=".json") as temp_import_file:
|
with tempfile.NamedTemporaryFile(mode="w+", delete=False, encoding="utf-8", suffix=".json") as temp_import_file:
|
||||||
# We can't import jobs with the same ID so remove them from the JSON
|
# We can't import jobs with the same ID so remove them from the JSON
|
||||||
jobs = [job for job in jobs_data.get("jobs", []) if not settings.scheduler.get_job(job.get("id"))]
|
jobs = [job for job in jobs_data.get("jobs", []) if not scheduler.get_job(job.get("id"))]
|
||||||
jobs_already_exist = [job.get("id") for job in jobs_data.get("jobs", []) if settings.scheduler.get_job(job.get("id"))]
|
jobs_already_exist = [job.get("id") for job in jobs_data.get("jobs", []) if scheduler.get_job(job.get("id"))]
|
||||||
jobs_data["jobs"] = jobs
|
jobs_data["jobs"] = jobs
|
||||||
for job_id in jobs_already_exist:
|
for job_id in jobs_already_exist:
|
||||||
logger.debug("Removed job: %s because it already exists.", job_id)
|
logger.debug("Removed job: %s because it already exists.", job_id)
|
||||||
@ -594,10 +602,10 @@ class RemindGroup(discord.app_commands.Group):
|
|||||||
temp_import_file.seek(0)
|
temp_import_file.seek(0)
|
||||||
|
|
||||||
# Import the jobs
|
# Import the jobs
|
||||||
settings.scheduler.import_jobs(temp_import_file.name)
|
scheduler.import_jobs(temp_import_file.name)
|
||||||
|
|
||||||
# Get the new jobs
|
# Get the new jobs
|
||||||
new_jobs: list[Job] = settings.scheduler.get_jobs()
|
new_jobs: list[Job] = scheduler.get_jobs()
|
||||||
|
|
||||||
# Get the difference
|
# Get the difference
|
||||||
added_jobs: list[Job] = [job for job in new_jobs if job not in old_jobs]
|
added_jobs: list[Job] = [job for job in new_jobs if job not in old_jobs]
|
||||||
@ -621,20 +629,26 @@ remind_group = RemindGroup()
|
|||||||
bot.tree.add_command(remind_group)
|
bot.tree.add_command(remind_group)
|
||||||
|
|
||||||
|
|
||||||
def send_webhook(
|
def send_webhook(url: str = "", message: str = "") -> None:
|
||||||
url: str = settings.webhook_url,
|
|
||||||
message: str = "discord-reminder-bot: Empty message.",
|
|
||||||
) -> None:
|
|
||||||
"""Send a webhook to Discord.
|
"""Send a webhook to Discord.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
url: Our webhook url, defaults to the one from settings.
|
url: Our webhook url, defaults to the one from settings.
|
||||||
message: The message that will be sent to Discord.
|
message: The message that will be sent to Discord.
|
||||||
"""
|
"""
|
||||||
|
if not message:
|
||||||
|
logger.error("No message provided.")
|
||||||
|
message = "No message provided."
|
||||||
|
|
||||||
if not url:
|
if not url:
|
||||||
msg = "ERROR: Tried to send a webhook but you have no webhook url configured."
|
url = get_webhook_url()
|
||||||
logger.error(msg)
|
logger.error("No webhook URL provided. Using the one from settings.")
|
||||||
webhook: DiscordWebhook = DiscordWebhook(url=settings.webhook_url, content=msg, rate_limit_retry=True)
|
webhook: DiscordWebhook = DiscordWebhook(
|
||||||
|
url=url,
|
||||||
|
username="discord-reminder-bot",
|
||||||
|
content="No webhook URL provided. Using the one from settings.",
|
||||||
|
rate_limit_retry=True,
|
||||||
|
)
|
||||||
webhook.execute()
|
webhook.execute()
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -683,4 +697,5 @@ async def send_to_user(user_id: int, guild_id: int, message: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
bot.run(settings.bot_token, root_logger=True)
|
bot_token: str = get_bot_token()
|
||||||
|
bot.run(bot_token, root_logger=True)
|
||||||
|
@ -2,21 +2,22 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||||
|
|
||||||
import dateparser
|
import dateparser
|
||||||
|
|
||||||
from discord_reminder_bot import settings
|
from discord_reminder_bot.settings import get_timezone
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def parse_time(date_to_parse: str, timezone: str | None = None) -> datetime.datetime | None:
|
def parse_time(date_to_parse: str, timezone: str | None = None, use_dotenv: bool = True) -> datetime.datetime | None: # noqa: FBT001, FBT002
|
||||||
"""Parse a date string into a datetime object.
|
"""Parse a date string into a datetime object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
date_to_parse(str): The date string to parse.
|
date_to_parse(str): The date string to parse.
|
||||||
timezone(str, optional): The timezone to use. Defaults timezone from settings.
|
timezone(str, optional): The timezone to use. Defaults timezone from settings.
|
||||||
|
use_dotenv(bool, optional): Whether to load environment variables from a .env file. Defaults to True
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
datetime.datetime: The parsed datetime object.
|
datetime.datetime: The parsed datetime object.
|
||||||
@ -28,7 +29,14 @@ def parse_time(date_to_parse: str, timezone: str | None = None) -> datetime.date
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if not timezone:
|
if not timezone:
|
||||||
timezone = settings.config_timezone
|
timezone = get_timezone(use_dotenv)
|
||||||
|
|
||||||
|
# Check if the timezone is valid
|
||||||
|
try:
|
||||||
|
tz = ZoneInfo(timezone)
|
||||||
|
except (ZoneInfoNotFoundError, ModuleNotFoundError):
|
||||||
|
logger.error("Invalid timezone provided: '%s'. Using default timezone: '%s'", timezone, get_timezone(use_dotenv)) # noqa: TRY400
|
||||||
|
tz = ZoneInfo("UTC")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
parsed_date: datetime.datetime | None = dateparser.parse(
|
parsed_date: datetime.datetime | None = dateparser.parse(
|
||||||
@ -37,7 +45,7 @@ def parse_time(date_to_parse: str, timezone: str | None = None) -> datetime.date
|
|||||||
"PREFER_DATES_FROM": "future",
|
"PREFER_DATES_FROM": "future",
|
||||||
"TIMEZONE": f"{timezone}",
|
"TIMEZONE": f"{timezone}",
|
||||||
"RETURN_AS_TIMEZONE_AWARE": True,
|
"RETURN_AS_TIMEZONE_AWARE": True,
|
||||||
"RELATIVE_BASE": datetime.datetime.now(tz=ZoneInfo(timezone)),
|
"RELATIVE_BASE": datetime.datetime.now(tz=tz),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
|
@ -7,22 +7,150 @@ from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
|
|||||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv(verbose=True)
|
|
||||||
sqlite_location: str = os.getenv("SQLITE_LOCATION", default="/jobs.sqlite")
|
|
||||||
config_timezone: str = os.getenv("TIMEZONE", default="UTC")
|
|
||||||
bot_token: str = os.getenv("BOT_TOKEN", default="")
|
|
||||||
log_level: str = os.getenv("LOG_LEVEL", default="INFO")
|
|
||||||
webhook_url: str = os.getenv("WEBHOOK_URL", default="")
|
|
||||||
|
|
||||||
if not bot_token:
|
def get_settings(use_dotenv: bool = True) -> dict[str, str | dict[str, SQLAlchemyJobStore] | dict[str, bool] | AsyncIOScheduler]: # noqa: FBT001, FBT002
|
||||||
err_msg = "Missing bot token"
|
"""Load environment variables and return the settings.
|
||||||
raise ValueError(err_msg)
|
|
||||||
|
|
||||||
# Advanced Python Scheduler
|
Args:
|
||||||
jobstores: dict[str, SQLAlchemyJobStore] = {"default": SQLAlchemyJobStore(url=f"sqlite://{sqlite_location}")}
|
use_dotenv (bool, optional): Whether to load environment variables from a .env file. Defaults to True.
|
||||||
job_defaults: dict[str, bool] = {"coalesce": True}
|
|
||||||
scheduler = AsyncIOScheduler(
|
Raises:
|
||||||
|
ValueError: If the bot token is missing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The settings.
|
||||||
|
"""
|
||||||
|
if use_dotenv:
|
||||||
|
load_dotenv(verbose=True)
|
||||||
|
|
||||||
|
sqlite_location: str = os.getenv("SQLITE_LOCATION", default="/jobs.sqlite")
|
||||||
|
config_timezone: str = os.getenv("TIMEZONE", default="UTC")
|
||||||
|
bot_token: str = os.getenv("BOT_TOKEN", default="")
|
||||||
|
log_level: str = os.getenv("LOG_LEVEL", default="INFO")
|
||||||
|
webhook_url: str = os.getenv("WEBHOOK_URL", default="")
|
||||||
|
|
||||||
|
if not bot_token:
|
||||||
|
msg = "Missing bot token. Please set the BOT_TOKEN environment variable."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
jobstores: dict[str, SQLAlchemyJobStore] = {"default": SQLAlchemyJobStore(url=f"sqlite://{sqlite_location}")}
|
||||||
|
job_defaults: dict[str, bool] = {"coalesce": True}
|
||||||
|
scheduler = AsyncIOScheduler(
|
||||||
jobstores=jobstores,
|
jobstores=jobstores,
|
||||||
timezone=pytz.timezone(config_timezone),
|
timezone=pytz.timezone(config_timezone),
|
||||||
job_defaults=job_defaults,
|
job_defaults=job_defaults,
|
||||||
)
|
)
|
||||||
|
return {
|
||||||
|
"sqlite_location": sqlite_location,
|
||||||
|
"config_timezone": config_timezone,
|
||||||
|
"bot_token": bot_token,
|
||||||
|
"log_level": log_level,
|
||||||
|
"webhook_url": webhook_url,
|
||||||
|
"jobstores": jobstores,
|
||||||
|
"job_defaults": job_defaults,
|
||||||
|
"scheduler": scheduler,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler(use_dotenv: bool = True) -> AsyncIOScheduler: # noqa: FBT001, FBT002
|
||||||
|
"""Return the scheduler instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_dotenv (bool, optional): Whether to load environment variables from a .env file. Defaults to True
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the scheduler is not an instance of AsyncIOScheduler.
|
||||||
|
KeyError: If the scheduler is missing from the settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncIOScheduler: The scheduler instance.
|
||||||
|
"""
|
||||||
|
settings: dict[str, str | dict[str, SQLAlchemyJobStore] | dict[str, bool] | AsyncIOScheduler] = get_settings(use_dotenv)
|
||||||
|
|
||||||
|
if scheduler := settings.get("scheduler"):
|
||||||
|
if not isinstance(scheduler, AsyncIOScheduler):
|
||||||
|
msg = "The scheduler is not an instance of AsyncIOScheduler."
|
||||||
|
raise TypeError(msg)
|
||||||
|
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
msg = "The scheduler is missing from the settings."
|
||||||
|
raise KeyError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def get_bot_token(use_dotenv: bool = True) -> str: # noqa: FBT001, FBT002
|
||||||
|
"""Return the bot token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_dotenv (bool, optional): Whether to load environment variables from a .env file. Defaults to True
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the bot token is not a string.
|
||||||
|
KeyError: If the bot token is missing from the settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The bot token.
|
||||||
|
"""
|
||||||
|
settings: dict[str, str | dict[str, SQLAlchemyJobStore] | dict[str, bool] | AsyncIOScheduler] = get_settings(use_dotenv)
|
||||||
|
|
||||||
|
if bot_token := settings.get("bot_token"):
|
||||||
|
if not isinstance(bot_token, str):
|
||||||
|
msg = "The bot token is not a string."
|
||||||
|
raise TypeError(msg)
|
||||||
|
|
||||||
|
return bot_token
|
||||||
|
|
||||||
|
msg = "The bot token is missing from the settings."
|
||||||
|
raise KeyError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def get_webhook_url(use_dotenv: bool = True) -> str: # noqa: FBT001, FBT002
|
||||||
|
"""Return the webhook URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_dotenv (bool, optional): Whether to load environment variables from a .env file. Defaults to True
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the webhook URL is not a string.
|
||||||
|
KeyError: If the webhook URL is missing from the settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The webhook URL.
|
||||||
|
"""
|
||||||
|
settings: dict[str, str | dict[str, SQLAlchemyJobStore] | dict[str, bool] | AsyncIOScheduler] = get_settings(use_dotenv)
|
||||||
|
|
||||||
|
if webhook_url := settings.get("webhook_url"):
|
||||||
|
if not isinstance(webhook_url, str):
|
||||||
|
msg = "The webhook URL is not a string."
|
||||||
|
raise TypeError(msg)
|
||||||
|
|
||||||
|
return webhook_url
|
||||||
|
|
||||||
|
msg = "The webhook URL is missing from the settings."
|
||||||
|
raise KeyError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def get_timezone(use_dotenv: bool = True) -> str: # noqa: FBT001, FBT002
|
||||||
|
"""Return the timezone.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_dotenv (bool, optional): Whether to load environment variables from a .env file. Defaults to True
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the timezone is not a string.
|
||||||
|
KeyError: If the timezone is missing from the settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The timezone.
|
||||||
|
"""
|
||||||
|
settings: dict[str, str | dict[str, SQLAlchemyJobStore] | dict[str, bool] | AsyncIOScheduler] = get_settings(use_dotenv)
|
||||||
|
|
||||||
|
if config_timezone := settings.get("config_timezone"):
|
||||||
|
if not isinstance(config_timezone, str):
|
||||||
|
msg = "The timezone is not a string."
|
||||||
|
raise TypeError(msg)
|
||||||
|
|
||||||
|
return config_timezone
|
||||||
|
|
||||||
|
msg = "The timezone is missing from the settings."
|
||||||
|
raise KeyError(msg)
|
||||||
|
@ -10,7 +10,6 @@ from apscheduler.triggers.cron import CronTrigger
|
|||||||
from apscheduler.triggers.interval import IntervalTrigger
|
from apscheduler.triggers.interval import IntervalTrigger
|
||||||
from discord.ui import Button, Select
|
from discord.ui import Button, Select
|
||||||
|
|
||||||
from discord_reminder_bot import settings
|
|
||||||
from discord_reminder_bot.misc import DateTrigger, calc_time, calculate
|
from discord_reminder_bot.misc import DateTrigger, calc_time, calculate
|
||||||
from discord_reminder_bot.parser import parse_time
|
from discord_reminder_bot.parser import parse_time
|
||||||
|
|
||||||
@ -20,8 +19,6 @@ if TYPE_CHECKING:
|
|||||||
from apscheduler.job import Job
|
from apscheduler.job import Job
|
||||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
|
|
||||||
from discord_reminder_bot import settings
|
|
||||||
|
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -41,7 +38,7 @@ class ModifyJobModal(discord.ui.Modal, title="Modify Job"):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.job: Job = job
|
self.job: Job = job
|
||||||
self.scheduler: settings.AsyncIOScheduler = scheduler
|
self.scheduler: AsyncIOScheduler = scheduler
|
||||||
|
|
||||||
# Use "Name" as label if the message is too long, otherwise use the old message
|
# Use "Name" as label if the message is too long, otherwise use the old message
|
||||||
job_name_label: str = f"Name ({self.job.kwargs.get('message', 'X' * 46)})"
|
job_name_label: str = f"Name ({self.job.kwargs.get('message', 'X' * 46)})"
|
||||||
@ -167,7 +164,7 @@ class JobSelector(Select):
|
|||||||
Args:
|
Args:
|
||||||
scheduler: The scheduler to get the jobs from.
|
scheduler: The scheduler to get the jobs from.
|
||||||
"""
|
"""
|
||||||
self.scheduler: settings.AsyncIOScheduler = scheduler
|
self.scheduler: AsyncIOScheduler = scheduler
|
||||||
options: list[discord.SelectOption] = []
|
options: list[discord.SelectOption] = []
|
||||||
jobs: list[Job] = scheduler.get_jobs()
|
jobs: list[Job] = scheduler.get_jobs()
|
||||||
|
|
||||||
@ -217,7 +214,7 @@ class JobManagementView(discord.ui.View):
|
|||||||
"""
|
"""
|
||||||
super().__init__(timeout=None)
|
super().__init__(timeout=None)
|
||||||
self.job: Job = job
|
self.job: Job = job
|
||||||
self.scheduler: settings.AsyncIOScheduler = scheduler
|
self.scheduler: AsyncIOScheduler = scheduler
|
||||||
self.add_item(JobSelector(scheduler))
|
self.add_item(JobSelector(scheduler))
|
||||||
self.update_buttons()
|
self.update_buttons()
|
||||||
|
|
||||||
|
@ -9,17 +9,26 @@ description = "Discord bot that allows you to set date, cron and interval remind
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"apscheduler<4.0.0",
|
# The Discord bot library, and legacy-cgi is because Python 3.13 removed cgi module
|
||||||
"dateparser",
|
"discord-py[speed]>=2.4.0,<3.0.0", # https://github.com/Rapptz/discord.py
|
||||||
"discord-py",
|
"legacy-cgi>=2.6.2,<3.0.0; python_version >= '3.13'", # https://github.com/jackrosenthal/legacy-cgi
|
||||||
"discord-webhook",
|
|
||||||
"legacy-cgi",
|
# For parsing dates and times in /remind commands
|
||||||
"python-dotenv",
|
"dateparser>=1.0.0", # https://github.com/scrapinghub/dateparser
|
||||||
"sqlalchemy",
|
|
||||||
|
# For sending webhook messages to Discord
|
||||||
|
"discord-webhook>=1.3.1,<2.0.0", # https://github.com/lovvskillz/python-discord-webhook
|
||||||
|
|
||||||
|
# For scheduling reminders, sqlalchemy is needed for storing reminders in a database
|
||||||
|
"apscheduler>=3.11.0,<4.0.0", # https://github.com/agronholm/apscheduler
|
||||||
|
"sqlalchemy>=2.0.37,<3.0.0", # https://github.com/sqlalchemy/sqlalchemy
|
||||||
|
|
||||||
|
# For loading environment variables from a .env file
|
||||||
|
"python-dotenv>=1.0.1,<2.0.0", # https://github.com/theskumar/python-dotenv
|
||||||
]
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = ["pytest", "ruff", "pre-commit"]
|
dev = ["pytest", "ruff", "pre-commit", "pytest-asyncio", "freezegun"]
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "discord-reminder-bot"
|
name = "discord-reminder-bot"
|
||||||
@ -33,16 +42,37 @@ bot = "discord_reminder_bot.main:start"
|
|||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.13"
|
python = "^3.13"
|
||||||
apscheduler = "<4.0.0"
|
|
||||||
dateparser = "*"
|
# https://github.com/agronholm/apscheduler
|
||||||
discord-py = {git = "https://github.com/Rapptz/discord.py"}
|
# https://github.com/sqlalchemy/sqlalchemy
|
||||||
python-dotenv = "*"
|
# For scheduling reminders, sqlalchemy is needed for storing reminders in a database
|
||||||
sqlalchemy = "*"
|
sqlalchemy = {version = ">=2.0.37,<3.0.0"}
|
||||||
|
apscheduler = {version = ">=3.11.0,<4.0.0"}
|
||||||
|
|
||||||
|
# https://github.com/scrapinghub/dateparser
|
||||||
|
# For parsing dates and times in /remind commands
|
||||||
|
dateparser = {version = ">=1.0.0"}
|
||||||
|
|
||||||
|
# https://github.com/Rapptz/discord.py
|
||||||
|
# https://github.com/jackrosenthal/legacy-cgi
|
||||||
|
# The Discord bot library, and legacy-cgi is because Python 3.13 removed cgi module
|
||||||
|
discord-py = {version = ">=2.4.0,<3.0.0", extras = ["speed"]}
|
||||||
|
legacy-cgi = {version = ">=2.6.2,<3.0.0", markers = "python_version >= '3.13'"}
|
||||||
|
|
||||||
|
# https://github.com/lovvskillz/python-discord-webhook
|
||||||
|
# For sending webhook messages to Discord
|
||||||
|
discord-webhook = {version = ">=1.3.1,<2.0.0"}
|
||||||
|
|
||||||
|
# https://github.com/theskumar/python-dotenv
|
||||||
|
# For loading environment variables from a .env file
|
||||||
|
python-dotenv = {version = ">=1.0.1,<2.0.0"}
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
pytest = "*"
|
pytest = "*"
|
||||||
pre-commit = "*"
|
pre-commit = "*"
|
||||||
ruff = "*"
|
ruff = "*"
|
||||||
|
pytest-asyncio = "*"
|
||||||
|
freezegun = "*"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core>=1.0.0"]
|
requires = ["poetry-core>=1.0.0"]
|
||||||
@ -89,7 +119,7 @@ docstring-code-format = true
|
|||||||
docstring-code-line-length = 20
|
docstring-code-line-length = 20
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"**/*_test.py" = [
|
"**/test_*.py" = [
|
||||||
"ARG", # Unused function args -> fixtures nevertheless are functionally relevant...
|
"ARG", # Unused function args -> fixtures nevertheless are functionally relevant...
|
||||||
"FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize()
|
"FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize()
|
||||||
"PLR2004", # Magic value used in comparison, ...
|
"PLR2004", # Magic value used in comparison, ...
|
||||||
@ -102,6 +132,7 @@ log_cli = true
|
|||||||
log_cli_level = "INFO"
|
log_cli_level = "INFO"
|
||||||
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
|
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
|
||||||
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
|
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
|
||||||
|
filterwarnings = ["ignore::DeprecationWarning:aiohttp.cookiejar"]
|
||||||
|
|
||||||
[tool.uv.sources]
|
# [tool.uv.sources]
|
||||||
discord-py = {git = "https://github.com/Rapptz/discord.py"}
|
# discord-py = {git = "https://github.com/Rapptz/discord.py"}
|
||||||
|
77
tests/test_misc.py
Normal file
77
tests/test_misc.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
|
from apscheduler.triggers.date import DateTrigger
|
||||||
|
|
||||||
|
from discord_reminder_bot.misc import calc_time, calculate, get_human_time
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from apscheduler.job import Job
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_time() -> None:
|
||||||
|
"""Test the calc_time function with various datetime inputs."""
|
||||||
|
test_datetime: datetime = datetime(2023, 10, 1, 12, 0, 0, tzinfo=UTC)
|
||||||
|
expected_timestamp: str = f"<t:{int(test_datetime.timestamp())}:R>"
|
||||||
|
assert calc_time(test_datetime) == expected_timestamp
|
||||||
|
|
||||||
|
now: datetime = datetime.now(tz=UTC)
|
||||||
|
expected_timestamp_now: str = f"<t:{int(now.timestamp())}:R>"
|
||||||
|
assert calc_time(now) == expected_timestamp_now
|
||||||
|
|
||||||
|
past_datetime: datetime = datetime(2000, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||||
|
expected_timestamp_past: str = f"<t:{int(past_datetime.timestamp())}:R>"
|
||||||
|
assert calc_time(past_datetime) == expected_timestamp_past
|
||||||
|
|
||||||
|
future_datetime: datetime = datetime(2100, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||||
|
expected_timestamp_future: str = f"<t:{int(future_datetime.timestamp())}:R>"
|
||||||
|
assert calc_time(future_datetime) == expected_timestamp_future
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_human_time() -> None:
|
||||||
|
"""Test the get_human_time function with various timedelta inputs."""
|
||||||
|
test_timedelta = timedelta(days=1, hours=2, minutes=3, seconds=4)
|
||||||
|
expected_output = "1d2h3m4s"
|
||||||
|
assert get_human_time(test_timedelta) == expected_output
|
||||||
|
|
||||||
|
test_timedelta = timedelta(hours=5, minutes=6, seconds=7)
|
||||||
|
expected_output = "5h6m7s"
|
||||||
|
assert get_human_time(test_timedelta) == expected_output
|
||||||
|
|
||||||
|
test_timedelta = timedelta(minutes=8, seconds=9)
|
||||||
|
expected_output = "8m9s"
|
||||||
|
assert get_human_time(test_timedelta) == expected_output
|
||||||
|
|
||||||
|
test_timedelta = timedelta(seconds=10)
|
||||||
|
expected_output = "10s"
|
||||||
|
assert get_human_time(test_timedelta) == expected_output
|
||||||
|
|
||||||
|
test_timedelta = timedelta(days=0, hours=0, minutes=0, seconds=0)
|
||||||
|
expected_output = ""
|
||||||
|
assert get_human_time(test_timedelta) == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate() -> None:
|
||||||
|
"""Test the calculate function with various job inputs."""
|
||||||
|
scheduler = BackgroundScheduler()
|
||||||
|
scheduler.start()
|
||||||
|
|
||||||
|
# Create a job with a DateTrigger
|
||||||
|
run_date = datetime(2270, 10, 1, 12, 0, 0, tzinfo=UTC)
|
||||||
|
job: Job = scheduler.add_job(lambda: None, trigger=DateTrigger(run_date=run_date), id="test_job", name="Test Job")
|
||||||
|
|
||||||
|
expected_output = "<t:9490737600:R>"
|
||||||
|
assert calculate(job) == expected_output
|
||||||
|
|
||||||
|
# Modify the job to have a next_run_time
|
||||||
|
job.modify(next_run_time=run_date)
|
||||||
|
assert calculate(job) == expected_output
|
||||||
|
|
||||||
|
# Paused job should still return the same output
|
||||||
|
job.pause()
|
||||||
|
assert calculate(job) == expected_output
|
||||||
|
|
||||||
|
scheduler.shutdown()
|
53
tests/test_parser.py
Normal file
53
tests/test_parser.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
from freezegun import freeze_time
|
||||||
|
|
||||||
|
from discord_reminder_bot import settings
|
||||||
|
from discord_reminder_bot.parser import parse_time
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_time_valid_date() -> None:
|
||||||
|
"""Test the `parse_time` function with a valid date string."""
|
||||||
|
date_to_parse = "tomorrow at 5pm"
|
||||||
|
timezone = "UTC"
|
||||||
|
result: datetime.datetime | None = parse_time(date_to_parse, timezone, use_dotenv=False)
|
||||||
|
assert result is not None
|
||||||
|
assert result.tzinfo == ZoneInfo(timezone)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_time_no_date() -> None:
|
||||||
|
"""Test the `parse_time` function with no date string."""
|
||||||
|
date_to_parse: str = ""
|
||||||
|
timezone = "UTC"
|
||||||
|
result: datetime.datetime | None = parse_time(date_to_parse, timezone, use_dotenv=False)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_time_no_timezone() -> None:
|
||||||
|
"""Test the `parse_time` function with no timezone."""
|
||||||
|
date_to_parse = "tomorrow at 5pm"
|
||||||
|
result: datetime.datetime | None = parse_time(date_to_parse, use_dotenv=False)
|
||||||
|
assert result is not None
|
||||||
|
assert result.tzinfo == ZoneInfo(settings.get_timezone(use_dotenv=False))
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_time_invalid_date() -> None:
|
||||||
|
"""Test the `parse_time` function with an invalid date string."""
|
||||||
|
date_to_parse = "invalid date"
|
||||||
|
timezone = "UTC"
|
||||||
|
result: datetime.datetime | None = parse_time(date_to_parse, timezone, use_dotenv=False)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01 12:00:00")
|
||||||
|
def test_parse_time_invalid_timezone() -> None:
|
||||||
|
"""Test the `parse_time` function with an invalid timezone."""
|
||||||
|
date_to_parse = "tomorrow at 5pm"
|
||||||
|
timezone = "Invalid/Timezone"
|
||||||
|
result: datetime.datetime | None = parse_time(date_to_parse, timezone, use_dotenv=False)
|
||||||
|
assert result is not None
|
||||||
|
assert result.tzinfo == ZoneInfo("UTC")
|
||||||
|
assert result == datetime.datetime(2023, 1, 2, 17, 0, tzinfo=ZoneInfo("UTC"))
|
53
tests/test_settings.py
Normal file
53
tests/test_settings.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
|
||||||
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
|
|
||||||
|
from discord_reminder_bot.settings import get_settings
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_settings(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""Test get_settings function with environment variables."""
|
||||||
|
monkeypatch.setenv("SQLITE_LOCATION", "/test_jobs.sqlite")
|
||||||
|
monkeypatch.setenv("TIMEZONE", "UTC")
|
||||||
|
monkeypatch.setenv("BOT_TOKEN", "test_token")
|
||||||
|
monkeypatch.setenv("LOG_LEVEL", "DEBUG")
|
||||||
|
monkeypatch.setenv("WEBHOOK_URL", "http://test_webhook_url")
|
||||||
|
|
||||||
|
settings: dict[str, str | dict[str, SQLAlchemyJobStore] | dict[str, bool] | AsyncIOScheduler] = get_settings(use_dotenv=False)
|
||||||
|
|
||||||
|
assert settings["sqlite_location"] == "/test_jobs.sqlite"
|
||||||
|
assert settings["config_timezone"] == "UTC"
|
||||||
|
assert settings["bot_token"] == "test_token" # noqa: S105
|
||||||
|
assert settings["log_level"] == "DEBUG"
|
||||||
|
assert settings["webhook_url"] == "http://test_webhook_url"
|
||||||
|
assert isinstance(settings["jobstores"]["default"], SQLAlchemyJobStore) # type: ignore # noqa: PGH003
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_settings_missing_bot_token(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""Test get_settings function with missing bot token."""
|
||||||
|
monkeypatch.delenv("BOT_TOKEN", raising=False)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Missing bot token"):
|
||||||
|
get_settings(use_dotenv=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_settings_default_values(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""Test get_settings function with default values."""
|
||||||
|
monkeypatch.delenv("SQLITE_LOCATION", raising=False)
|
||||||
|
monkeypatch.delenv("TIMEZONE", raising=False)
|
||||||
|
monkeypatch.delenv("BOT_TOKEN", raising=False)
|
||||||
|
monkeypatch.delenv("LOG_LEVEL", raising=False)
|
||||||
|
monkeypatch.delenv("WEBHOOK_URL", raising=False)
|
||||||
|
monkeypatch.setenv("BOT_TOKEN", "default_token")
|
||||||
|
|
||||||
|
settings: dict[str, str | dict[str, SQLAlchemyJobStore] | dict[str, bool] | AsyncIOScheduler] = get_settings(use_dotenv=False)
|
||||||
|
|
||||||
|
assert settings["sqlite_location"] == "/jobs.sqlite"
|
||||||
|
assert settings["config_timezone"] == "UTC"
|
||||||
|
assert settings["bot_token"] == "default_token" # noqa: S105
|
||||||
|
assert settings["log_level"] == "INFO"
|
||||||
|
assert not settings["webhook_url"]
|
||||||
|
assert isinstance(settings["jobstores"]["default"], SQLAlchemyJobStore) # type: ignore # noqa: PGH003
|
||||||
|
assert isinstance(settings["scheduler"], AsyncIOScheduler)
|
70
tests/test_ui.py
Normal file
70
tests/test_ui.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from apscheduler.triggers.interval import IntervalTrigger
|
||||||
|
|
||||||
|
from discord_reminder_bot.ui import create_job_embed
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateJobEmbed(unittest.TestCase):
|
||||||
|
"""Test the `create_job_embed` function in the `discord_reminder_bot.ui` module."""
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
"""Set up the mock job for testing."""
|
||||||
|
self.job = Mock()
|
||||||
|
self.job.id = "12345"
|
||||||
|
self.job.kwargs = {"channel_id": 67890, "message": "Test message", "author_id": 54321}
|
||||||
|
self.job.next_run_time = None
|
||||||
|
self.job.trigger = Mock(spec=IntervalTrigger)
|
||||||
|
self.job.trigger.interval = "1 day"
|
||||||
|
|
||||||
|
def test_create_job_embed_with_next_run_time(self) -> None:
|
||||||
|
"""Test the `create_job_embed` function to ensure it correctly creates a Discord embed for a job with the next run time."""
|
||||||
|
self.job.next_run_time = Mock()
|
||||||
|
self.job.next_run_time.strftime.return_value = "2023-10-10 10:00:00"
|
||||||
|
|
||||||
|
embed: discord.Embed = create_job_embed(self.job)
|
||||||
|
|
||||||
|
assert isinstance(embed, discord.Embed)
|
||||||
|
assert embed.title == "Test message"
|
||||||
|
assert embed.description is not None
|
||||||
|
assert "ID: 12345" in embed.description
|
||||||
|
assert "Next run: 2023-10-10 10:00:00" in embed.description
|
||||||
|
assert "Interval: 1 day" in embed.description
|
||||||
|
assert "Channel: <#67890>" in embed.description
|
||||||
|
assert "Author: <@54321>" in embed.description
|
||||||
|
|
||||||
|
def test_create_job_embed_without_next_run_time(self) -> None:
|
||||||
|
"""Test the `create_job_embed` function to ensure it correctly creates a Discord embed for a job without the next run time."""
|
||||||
|
embed: discord.Embed = create_job_embed(self.job)
|
||||||
|
|
||||||
|
assert isinstance(embed, discord.Embed)
|
||||||
|
assert embed.title == "Test message"
|
||||||
|
assert embed.description is not None
|
||||||
|
assert "ID: 12345" in embed.description
|
||||||
|
assert "Paused" in embed.description
|
||||||
|
assert "Interval: 1 day" in embed.description
|
||||||
|
assert "Channel: <#67890>" in embed.description
|
||||||
|
assert "Author: <@54321>" in embed.description
|
||||||
|
|
||||||
|
def test_create_job_embed_with_long_message(self) -> None:
|
||||||
|
"""Test the `create_job_embed` function to ensure it correctly truncates long messages."""
|
||||||
|
self.job.kwargs["message"] = "A" * 300
|
||||||
|
|
||||||
|
embed: discord.Embed = create_job_embed(self.job)
|
||||||
|
|
||||||
|
assert isinstance(embed, discord.Embed)
|
||||||
|
assert embed.title == "A" * 256 + "..."
|
||||||
|
assert embed.description is not None
|
||||||
|
assert "ID: 12345" in embed.description
|
||||||
|
assert "Paused" in embed.description
|
||||||
|
assert "Interval: 1 day" in embed.description
|
||||||
|
assert "Channel: <#67890>" in embed.description
|
||||||
|
assert "Author: <@54321>" in embed.description
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Reference in New Issue
Block a user