Refactor create_pages.py and write tests

This commit is contained in:
2023-03-20 02:18:04 +01:00
parent 2df11adb16
commit 40569cb91c
3 changed files with 499 additions and 139 deletions

View File

@ -1,8 +1,22 @@
from collections.abc import Generator
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING, Literal
import interactions import interactions
from apscheduler.job import Job
from apscheduler.schedulers.base import BaseScheduler
from apscheduler.triggers.date import DateTrigger from apscheduler.triggers.date import DateTrigger
from interactions import ActionRow, Button, CommandContext, ComponentContext, Embed, Message, Modal, TextInput from interactions import (
ActionRow,
Button,
ButtonStyle,
Channel,
CommandContext,
ComponentContext,
Embed,
Message,
Modal,
TextInput,
)
from interactions.ext.paginator import Page, Paginator, RowPosition from interactions.ext.paginator import Page, Paginator, RowPosition
from discord_reminder_bot.countdown import calculate from discord_reminder_bot.countdown import calculate
@ -11,132 +25,200 @@ from discord_reminder_bot.settings import scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
from datetime import datetime from datetime import datetime
from apscheduler.job import Job
max_message_length: Literal[1010] = 1010 max_message_length: Literal[1010] = 1010
max_title_length: Literal[90] = 90 max_title_length: Literal[90] = 90
async def create_pages(ctx: CommandContext) -> list[Page]: def _get_trigger_text(job: Job) -> str:
"""Create pages for the paginator. """Get trigger time from a reminder and calculate how many days, hours and minutes till trigger.
Args: Args:
ctx: The context of the command. job: The job. Can be cron, interval or normal.
Returns: Returns:
list[Page]: A list of pages. str: The trigger time and countdown till trigger. If the job is paused, it will return "_Paused_".
""" """
pages: list[Page] = [] # TODO: Add support for cron jobs and interval jobs
trigger_time: datetime | None = job.trigger.run_date if type(job.trigger) is DateTrigger else job.next_run_time
jobs: list[Job] = scheduler.get_jobs() return "_Paused_" if trigger_time is None else f'{trigger_time.strftime("%Y-%m-%d %H:%M")} (in {calculate(job)})'
for job in jobs:
channel_id: int = job.kwargs.get("channel_id")
guild_id: int = job.kwargs.get("guild_id")
if ctx.guild is None:
await ctx.send("I can't find the server you're in. Are you sure you're in a server?", ephemeral=True)
return pages
if ctx.guild.channels is None:
await ctx.send("I can't find the channel you're in.", ephemeral=True)
return pages
# Only add reminders from channels in the server we run "/reminder list" in
# Check if channel is in the Discord server, if not, skip it.
for channel in ctx.guild.channels:
if int(channel.id) == channel_id or ctx.guild_id == guild_id:
trigger_time: datetime | None = (
job.trigger.run_date if type(job.trigger) is DateTrigger else job.next_run_time
)
# Paused reminders returns None
if trigger_time is None:
trigger_value: str | None = None
trigger_text: str = "Paused"
else:
trigger_value = f'{trigger_time.strftime("%Y-%m-%d %H:%M")} (in {calculate(job)})'
trigger_text = trigger_value
message: str = job.kwargs.get("message")
message = f"{message[:1000]}..." if len(message) > max_message_length else message
edit_button: Button = interactions.Button(
label="Edit",
style=interactions.ButtonStyle.PRIMARY,
custom_id="edit",
)
pause_button: Button = interactions.Button(
label="Pause",
style=interactions.ButtonStyle.PRIMARY,
custom_id="pause",
)
unpause_button: Button = interactions.Button(
label="Unpause",
style=interactions.ButtonStyle.PRIMARY,
custom_id="unpause",
)
remove_button: Button = interactions.Button(
label="Remove",
style=interactions.ButtonStyle.DANGER,
custom_id="remove",
)
embed: Embed = interactions.Embed(
title=f"{job.id}",
fields=[
interactions.EmbedField(
name="**Channel:**",
value=f"#{channel.name}",
),
interactions.EmbedField(
name="**Message:**",
value=f"{message}",
),
],
)
if trigger_value is not None:
embed.add_field(
name="**Trigger:**",
value=f"{trigger_text}",
)
else:
embed.add_field(
name="**Trigger:**",
value="_Paused_",
)
components: list[Button] = [
edit_button,
remove_button,
]
if type(job.trigger) is not DateTrigger:
# Get trigger time for cron and interval jobs
trigger_time = job.next_run_time
pause_or_unpause_button: Button = unpause_button if trigger_time is None else pause_button
components.insert(1, pause_or_unpause_button)
# Add a page to pages list
title: str = f"{message[:87]}..." if len(message) > max_title_length else message
pages.append(
Page(
embeds=embed,
title=title,
components=ActionRow(components=components), # type: ignore # noqa: PGH003
callback=callback,
position=RowPosition.BOTTOM,
),
)
return pages
async def callback(self: Paginator, ctx: ComponentContext) -> Message | None: # noqa: PLR0911 def _make_button(label: str, style: ButtonStyle) -> Button:
"""Make a button.
Args:
label: The label of the button.
style: The style of the button.
Returns:
Button: The button.
"""
return interactions.Button(
label=label,
style=style,
custom_id=label.lower(),
)
def _get_pause_or_unpause_button(job: Job) -> Button | None:
"""Get pause or unpause button.
If the job is paused, it will return the unpause button.
If the job is not paused, it will return the pause button.
If the job is not a cron or interval job, it will return None.
Args:
job: The job. Can be cron, interval or normal.
Returns:
Button | None: The pause or unpause button. If the job is not a cron or interval job, it will return None.
"""
if type(job.trigger) is not DateTrigger:
pause_button: Button = _make_button("Pause", interactions.ButtonStyle.PRIMARY)
unpause_button: Button = _make_button("Unpause", interactions.ButtonStyle.PRIMARY)
if not hasattr(job, "next_run_time"):
return pause_button
return unpause_button if job.next_run_time is None else pause_button
return None
def _get_row_of_buttons(job: Job) -> ActionRow:
"""Get components(buttons) for a page in /reminder list.
These buttons are below the embed.
Args:
job: The job. Can be cron, interval or normal.
Returns:
ActionRow: A row of buttons.
"""
components: list[Button] = [
_make_button("Edit", interactions.ButtonStyle.PRIMARY),
_make_button("Remove", interactions.ButtonStyle.DANGER),
]
# Add pause/unpause button as the second button if it's a cron or interval job
pause_or_unpause_button: Button | None = _get_pause_or_unpause_button(job=job)
if pause_or_unpause_button is not None:
components.insert(1, pause_or_unpause_button)
# TODO: Should fix the type error
return ActionRow(components=components) # type: ignore # noqa: PGH003
def _get_pages(job: Job, channel: Channel, ctx: CommandContext) -> Generator[Page, None, None]:
"""Get pages for a reminder.
Args:
job: The job. Can be cron, interval or normal.
channel: Check if the job kwargs channel ID is the same as the channel ID we looped through.
ctx: The context. Used to get the guild ID.
Yields:
Generator[Page, None, None]: A page.
"""
# Get channel ID and guild ID from job kwargs
channel_id: int = job.kwargs.get("channel_id")
guild_id: int = job.kwargs.get("guild_id")
if int(channel.id) == channel_id or ctx.guild_id == guild_id:
message: str = job.kwargs.get("message")
# If message is longer than 1000 characters, truncate it
message = f"{message[:1000]}..." if len(message) > max_message_length else message
# Create embed for the singular page
embed: Embed = interactions.Embed(
title=f"{job.id}", # Example: 593dcc18aab748faa571017454669eae
fields=[
interactions.EmbedField(
name="**Channel:**",
value=f"#{channel.name}", # Example: #general
),
interactions.EmbedField(
name="**Message:**",
value=f"{message}", # Example: Don't forget to feed the cat!
),
interactions.EmbedField(
name="**Trigger:**",
value=_get_trigger_text(job=job), # Example: 2023-08-24 00:06 (in 157 days, 23 hours, 49 minutes)
),
],
)
# Truncate title if it's longer than 90 characters
# This is the text that shows up in the dropdown menu
# Example: 2: Don't forget to feed the cat!
dropdown_title: str = f"{message[:87]}..." if len(message) > max_title_length else message
# Create a page and return it
yield Page(
embeds=embed,
title=dropdown_title,
components=_get_row_of_buttons(job),
callback=_callback,
position=RowPosition.BOTTOM,
)
def _remove_job(job: Job) -> str:
"""Remove a job.
Args:
job: The job to remove.
Returns:
str: The message to send to Discord.
"""
# TODO: Check if job exists before removing it?
# TODO: Add button to undo the removal?
channel_id: int = job.kwargs.get("channel_id")
old_message: str = job.kwargs.get("message")
trigger_time: datetime | None = job.trigger.run_date
scheduler.remove_job(job.id)
return f"Job {job.id} removed.\n**Message:** {old_message}\n**Channel:** {channel_id}\n**Time:** {trigger_time}"
def _unpause_job(job: Job, custom_scheduler: BaseScheduler = scheduler) -> str:
"""Unpause a job.
Args:
job: The job to unpause.
custom_scheduler: The scheduler to use. Defaults to the global scheduler.
Returns:
str: The message to send to Discord.
"""
# TODO: Should we check if the job is paused before unpause it?
custom_scheduler.resume_job(job.id)
return f"Job {job.id} unpaused."
def _pause_job(job: Job, custom_scheduler: BaseScheduler = scheduler) -> str:
"""Pause a job.
Args:
job: The job to pause.
custom_scheduler: The scheduler to use. Defaults to the global scheduler.
Returns:
str: The message to send to Discord.
"""
# TODO: Should we check if the job is unpaused before unpause it?
custom_scheduler.pause_job(job.id)
return f"Job {job.id} paused."
async def _callback(self: Paginator, ctx: ComponentContext) -> Message | None:
"""Callback for the paginator.""" """Callback for the paginator."""
if self.component_ctx is None: # TODO: Create a test for this
return await ctx.send("Something went wrong.", ephemeral=True) if self.component_ctx is None or self.component_ctx.message is None:
if self.component_ctx.message is None:
return await ctx.send("Something went wrong.", ephemeral=True) return await ctx.send("Something went wrong.", ephemeral=True)
job_id: str | None = self.component_ctx.message.embeds[0].title job_id: str | None = self.component_ctx.message.embeds[0].title
@ -145,7 +227,7 @@ async def callback(self: Paginator, ctx: ComponentContext) -> Message | None: #
if job is None: if job is None:
return await ctx.send("Job not found.", ephemeral=True) return await ctx.send("Job not found.", ephemeral=True)
channel_id: int = job.kwargs.get("channel_id") job.kwargs.get("channel_id")
old_message: str = job.kwargs.get("message") old_message: str = job.kwargs.get("message")
components: list[TextInput] = [ components: list[TextInput] = [
@ -158,6 +240,7 @@ async def callback(self: Paginator, ctx: ComponentContext) -> Message | None: #
), ),
] ]
job_type = "cron/interval"
if type(job.trigger) is DateTrigger: if type(job.trigger) is DateTrigger:
# Get trigger time for normal reminders # Get trigger time for normal reminders
trigger_time: datetime | None = job.trigger.run_date trigger_time: datetime | None = job.trigger.run_date
@ -172,37 +255,54 @@ async def callback(self: Paginator, ctx: ComponentContext) -> Message | None: #
), ),
) )
else: # Check what button was clicked and call the correct function
# Get trigger time for cron and interval jobs msg = "Something went wrong. I don't know what you clicked."
trigger_time = job.next_run_time
job_type = "cron/interval"
if ctx.custom_id == "edit": if ctx.custom_id == "edit":
# TODO: Add buttons to increase/decrease hour
modal: Modal = interactions.Modal( modal: Modal = interactions.Modal(
title=f"Edit {job_type} reminder.", title=f"Edit {job_type} reminder.",
custom_id="edit_modal", custom_id="edit_modal",
components=components, # type: ignore # noqa: PGH003 components=components, # type: ignore # noqa: PGH003
) )
await ctx.popup(modal) await ctx.popup(modal)
return None msg = f"You modified {job_id}"
elif ctx.custom_id == "pause":
msg: str = _pause_job(job)
elif ctx.custom_id == "unpause":
msg: str = _unpause_job(job)
elif ctx.custom_id == "remove":
msg: str = _remove_job(job)
if ctx.custom_id == "pause": return await ctx.send(msg, ephemeral=True)
scheduler.pause_job(job_id)
await ctx.send(f"Job {job_id} paused.")
return None
if ctx.custom_id == "unpause":
scheduler.resume_job(job_id)
await ctx.send(f"Job {job_id} unpaused.")
return None
if ctx.custom_id == "remove": async def create_pages(ctx: CommandContext) -> list[Page]:
scheduler.remove_job(job_id) """Create pages for the paginator.
await ctx.send(
f"Job {job_id} removed.\n" Args:
f"**Message:** {old_message}\n" ctx: The context of the command.
f"**Channel:** {channel_id}\n"
f"**Time:** {trigger_time}", Returns:
) list[Page]: A list of pages.
return None """
return None # TODO: Add tests for this
pages: list[Page] = []
jobs: list[Job] = scheduler.get_jobs()
for job in jobs:
# Check if we're in a server
if ctx.guild is None:
await ctx.send("I can't find the server you're in. Are you sure you're in a server?", ephemeral=True)
return []
# Check if we're in a channel
if ctx.guild.channels is None:
await ctx.send("I can't find the channel you're in.", ephemeral=True)
return []
# Only add reminders from channels in the server we run "/reminder list" in
# Check if channel is in the Discord server, if not, skip it.
for channel in ctx.guild.channels:
# Add a page for each reminder
pages.extend(iter(_get_pages(job=job, channel=channel, ctx=ctx)))
return pages

192
tests/test_create_pages.py Normal file
View File

@ -0,0 +1,192 @@
import re
from datetime import datetime
from typing import TYPE_CHECKING
import dateparser
import interactions
import pytz
from apscheduler.job import Job
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from interactions.ext.paginator import Page
from discord_reminder_bot.create_pages import (
_get_pages,
_get_pause_or_unpause_button,
_get_row_of_buttons,
_get_trigger_text,
_make_button,
_pause_job,
_unpause_job,
)
from discord_reminder_bot.main import send_to_discord
if TYPE_CHECKING:
from collections.abc import Generator
def _test_pause_unpause_button(job: Job, button_label: str) -> None:
button2: interactions.Button | None = _get_pause_or_unpause_button(job)
assert button2
assert button2.label == button_label
assert button2.style == interactions.ButtonStyle.PRIMARY
assert button2.type == interactions.ComponentType.BUTTON
assert button2.emoji is None
assert button2.custom_id == button_label.lower()
assert button2.url is None
assert button2.disabled is None
class TestCountdown:
jobstores: dict[str, SQLAlchemyJobStore] = {"default": SQLAlchemyJobStore(url="sqlite:///:memory")}
job_defaults: dict[str, bool] = {"coalesce": True}
scheduler = AsyncIOScheduler(
jobstores=jobstores,
timezone=pytz.timezone("Europe/Stockholm"),
job_defaults=job_defaults,
)
parsed_date: datetime | None = dateparser.parse(
"18 January 2040",
settings={
"PREFER_DATES_FROM": "future",
"TO_TIMEZONE": "Europe/Stockholm",
},
)
assert parsed_date
run_date: str = parsed_date.strftime("%Y-%m-%d %H:%M:%S")
normal_job: Job = scheduler.add_job(
send_to_discord,
run_date=run_date,
kwargs={
"channel_id": 865712621109772329,
"message": "Running PyTest",
"author_id": 126462229892694018,
},
)
cron_job: Job = scheduler.add_job(
send_to_discord,
"cron",
minute="0",
kwargs={
"channel_id": 865712621109772329,
"message": "Running PyTest",
"author_id": 126462229892694018,
},
)
interval_job: Job = scheduler.add_job(
send_to_discord,
"interval",
minutes=1,
kwargs={
"channel_id": 865712621109772329,
"message": "Running PyTest",
"author_id": 126462229892694018,
},
)
def test_get_trigger_text(self) -> None: # noqa: ANN101
# FIXME: This try except train should be replaced with a better solution lol
trigger_text: str = _get_trigger_text(self.normal_job)
try:
regex: str = r"2040-01-18 00:00 \(in \d+ days, \d+ hours, \d+ minutes\)"
assert re.match(regex, trigger_text)
except AssertionError:
try:
regex2: str = r"2040-01-18 00:00 \(in \d+ days, \d+ minutes\)"
assert re.match(regex2, trigger_text)
except AssertionError:
regex3: str = r"2040-01-18 00:00 \(in \d+ days\, \d+ minutes\)"
assert re.match(regex3, trigger_text)
def test_make_button(self) -> None: # noqa: ANN101
button_name: str = "Test"
button: interactions.Button = _make_button(label=button_name, style=interactions.ButtonStyle.PRIMARY)
assert button.label == button_name
assert button.style == interactions.ButtonStyle.PRIMARY
assert button.custom_id == button_name.lower()
assert button.disabled is None
assert button.emoji is None
def test_get_pause_or_unpause_button(self) -> None: # noqa: ANN101
button: interactions.Button | None = _get_pause_or_unpause_button(self.normal_job)
assert button is None
_test_pause_unpause_button(self.cron_job, "Pause")
self.cron_job.pause()
_test_pause_unpause_button(self.cron_job, "Unpause")
self.cron_job.resume()
_test_pause_unpause_button(self.interval_job, "Pause")
self.interval_job.pause()
_test_pause_unpause_button(self.interval_job, "Unpause")
self.interval_job.resume()
def test_get_row_of_buttons(self) -> None: # noqa: ANN101
row: interactions.ActionRow = _get_row_of_buttons(self.normal_job)
assert row
assert row.components
# A normal job should have 2 buttons, edit and delete
assert len(row.components) == 2 # noqa: PLR2004
row2: interactions.ActionRow = _get_row_of_buttons(self.cron_job)
assert row2
assert row2.components
# A cron job should have 3 buttons, edit, delete and pause/unpause
assert len(row2.components) == 3 # noqa: PLR2004
# A cron job should have 3 buttons, edit, delete and pause/unpause
assert len(row2.components) == 3 # noqa: PLR2004
def test_get_pages(self) -> None: # noqa: ANN101
ctx = None # TODO: We should check ctx as well and not only channel id
channel: interactions.Channel = interactions.Channel(id=interactions.Snowflake(865712621109772329))
pages: Generator[Page, None, None] = _get_pages(job=self.normal_job, channel=channel, ctx=ctx) # type: ignore # noqa: PGH003, E501
assert pages
for page in pages:
assert page
assert page.title == "Running PyTest"
assert page.components
assert page.embeds
assert page.embeds.fields is not None # type: ignore # noqa: PGH003
assert page.embeds.fields[0].name == "**Channel:**" # type: ignore # noqa: PGH003
assert page.embeds.fields[0].value == "#" # type: ignore # noqa: PGH003
assert page.embeds.fields[1].name == "**Message:**" # type: ignore # noqa: PGH003
assert page.embeds.fields[1].value == "Running PyTest" # type: ignore # noqa: PGH003
assert page.embeds.fields[2].name == "**Trigger:**" # type: ignore # noqa: PGH003
trigger_text: str = page.embeds.fields[2].value # type: ignore # noqa: PGH003
# FIXME: This try except train should be replaced with a better solution lol
try:
regex: str = r"2040-01-18 00:00 \(in \d+ days, \d+ hours, \d+ minutes\)"
assert re.match(regex, trigger_text)
except AssertionError:
try:
regex2: str = r"2040-01-18 00:00 \(in \d+ days, \d+ minutes\)"
assert re.match(regex2, trigger_text)
except AssertionError:
regex3: str = r"2040-01-18 00:00 \(in \d+ days\, \d+ minutes\)"
assert re.match(regex3, trigger_text)
# Check if type is Page
assert isinstance(page, Page)
def test_pause_job(self) -> None: # noqa: ANN101
assert _pause_job(self.interval_job, self.scheduler) == f"Job {self.interval_job.id} paused."
assert _pause_job(self.cron_job, self.scheduler) == f"Job {self.cron_job.id} paused."
assert _pause_job(self.normal_job, self.scheduler) == f"Job {self.normal_job.id} paused."
def test_unpause_job(self) -> None: # noqa: ANN101
assert _unpause_job(self.interval_job, self.scheduler) == f"Job {self.interval_job.id} unpaused."
assert _unpause_job(self.cron_job, self.scheduler) == f"Job {self.cron_job.id} unpaused."
assert _unpause_job(self.normal_job, self.scheduler) == f"Job {self.normal_job.id} unpaused."

68
tests/test_parse.py Normal file
View File

@ -0,0 +1,68 @@
from datetime import datetime
import tzlocal
from discord_reminder_bot.parse import ParsedTime, parse_time
def test_parse_time() -> None:
"""Test the parse_time function."""
parsed_time: ParsedTime = parse_time("18 January 2040")
assert parsed_time.err is False
assert not parsed_time.err_msg
assert parsed_time.date_to_parse == "18 January 2040"
assert parsed_time.parsed_time
assert parsed_time.parsed_time.strftime("%Y-%m-%d %H:%M:%S") == "2040-01-18 00:00:00"
parsed_time: ParsedTime = parse_time("18 January 2040 12:00")
assert parsed_time.err is False
assert not parsed_time.err_msg
assert parsed_time.date_to_parse == "18 January 2040 12:00"
assert parsed_time.parsed_time
assert parsed_time.parsed_time.strftime("%Y-%m-%d %H:%M:%S") == "2040-01-18 12:00:00"
parsed_time: ParsedTime = parse_time("18 January 2040 12:00:00")
assert parsed_time.err is False
assert not parsed_time.err_msg
assert parsed_time.date_to_parse == "18 January 2040 12:00:00"
assert parsed_time.parsed_time
assert parsed_time.parsed_time.strftime("%Y-%m-%d %H:%M:%S") == "2040-01-18 12:00:00"
parsed_time: ParsedTime = parse_time("18 January 2040 12:00:00 UTC")
assert parsed_time.err is False
assert not parsed_time.err_msg
assert parsed_time.date_to_parse == "18 January 2040 12:00:00 UTC"
assert parsed_time.parsed_time
assert parsed_time.parsed_time.strftime("%Y-%m-%d %H:%M:%S") == "2040-01-18 13:00:00"
parsed_time: ParsedTime = parse_time("18 January 2040 12:00:00 Europe/Stockholm")
assert parsed_time.err is True
assert parsed_time.err_msg == "Could not parse the date."
assert parsed_time.date_to_parse == "18 January 2040 12:00:00 Europe/Stockholm"
assert parsed_time.parsed_time is None
def test_ParsedTime() -> None: # noqa: N802
"""Test the ParsedTime class."""
parsed_time: ParsedTime = ParsedTime(
err=False,
err_msg="",
date_to_parse="18 January 2040",
parsed_time=datetime(2040, 1, 18, 0, 0, 0, tzinfo=tzlocal.get_localzone()),
)
assert parsed_time.err is False
assert not parsed_time.err_msg
assert parsed_time.date_to_parse == "18 January 2040"
assert parsed_time.parsed_time
assert parsed_time.parsed_time.strftime("%Y-%m-%d %H:%M:%S") == "2040-01-18 00:00:00"
parsed_time: ParsedTime = ParsedTime(
err=True,
err_msg="Could not parse the date.",
date_to_parse="18 January 2040 12:00:00 Europe/Stockholm",
parsed_time=None,
)
assert parsed_time.err is True
assert parsed_time.err_msg == "Could not parse the date."
assert parsed_time.date_to_parse == "18 January 2040 12:00:00 Europe/Stockholm"
assert parsed_time.parsed_time is None