From 40569cb91c0a73228b237e41400a58fff13aa97f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Hells=C3=A9n?= Date: Mon, 20 Mar 2023 02:18:04 +0100 Subject: [PATCH] Refactor create_pages.py and write tests --- discord_reminder_bot/create_pages.py | 378 +++++++++++++++++---------- tests/test_create_pages.py | 192 ++++++++++++++ tests/test_parse.py | 68 +++++ 3 files changed, 499 insertions(+), 139 deletions(-) create mode 100644 tests/test_create_pages.py create mode 100644 tests/test_parse.py diff --git a/discord_reminder_bot/create_pages.py b/discord_reminder_bot/create_pages.py index 8a0d1d6..447e9fb 100644 --- a/discord_reminder_bot/create_pages.py +++ b/discord_reminder_bot/create_pages.py @@ -1,8 +1,22 @@ +from collections.abc import Generator from typing import TYPE_CHECKING, Literal import interactions +from apscheduler.job import Job +from apscheduler.schedulers.base import BaseScheduler from apscheduler.triggers.date import DateTrigger -from interactions import ActionRow, Button, CommandContext, ComponentContext, Embed, Message, Modal, TextInput +from interactions import ( + ActionRow, + Button, + ButtonStyle, + Channel, + CommandContext, + ComponentContext, + Embed, + Message, + Modal, + TextInput, +) from interactions.ext.paginator import Page, Paginator, RowPosition from discord_reminder_bot.countdown import calculate @@ -11,132 +25,200 @@ from discord_reminder_bot.settings import scheduler if TYPE_CHECKING: from datetime import datetime - from apscheduler.job import Job max_message_length: Literal[1010] = 1010 max_title_length: Literal[90] = 90 -async def create_pages(ctx: CommandContext) -> list[Page]: - """Create pages for the paginator. +def _get_trigger_text(job: Job) -> str: + """Get trigger time from a reminder and calculate how many days, hours and minutes till trigger. Args: - ctx: The context of the command. + job: The job. Can be cron, interval or normal. Returns: - list[Page]: A list of pages. + str: The trigger time and countdown till trigger. If the job is paused, it will return "_Paused_". """ - pages: list[Page] = [] - - jobs: list[Job] = scheduler.get_jobs() - for job in jobs: - channel_id: int = job.kwargs.get("channel_id") - guild_id: int = job.kwargs.get("guild_id") - - if ctx.guild is None: - await ctx.send("I can't find the server you're in. Are you sure you're in a server?", ephemeral=True) - return pages - if ctx.guild.channels is None: - await ctx.send("I can't find the channel you're in.", ephemeral=True) - return pages - - # Only add reminders from channels in the server we run "/reminder list" in - # Check if channel is in the Discord server, if not, skip it. - for channel in ctx.guild.channels: - if int(channel.id) == channel_id or ctx.guild_id == guild_id: - trigger_time: datetime | None = ( - job.trigger.run_date if type(job.trigger) is DateTrigger else job.next_run_time - ) - - # Paused reminders returns None - if trigger_time is None: - trigger_value: str | None = None - trigger_text: str = "Paused" - else: - trigger_value = f'{trigger_time.strftime("%Y-%m-%d %H:%M")} (in {calculate(job)})' - trigger_text = trigger_value - - message: str = job.kwargs.get("message") - message = f"{message[:1000]}..." if len(message) > max_message_length else message - - edit_button: Button = interactions.Button( - label="Edit", - style=interactions.ButtonStyle.PRIMARY, - custom_id="edit", - ) - pause_button: Button = interactions.Button( - label="Pause", - style=interactions.ButtonStyle.PRIMARY, - custom_id="pause", - ) - unpause_button: Button = interactions.Button( - label="Unpause", - style=interactions.ButtonStyle.PRIMARY, - custom_id="unpause", - ) - remove_button: Button = interactions.Button( - label="Remove", - style=interactions.ButtonStyle.DANGER, - custom_id="remove", - ) - - embed: Embed = interactions.Embed( - title=f"{job.id}", - fields=[ - interactions.EmbedField( - name="**Channel:**", - value=f"#{channel.name}", - ), - interactions.EmbedField( - name="**Message:**", - value=f"{message}", - ), - ], - ) - - if trigger_value is not None: - embed.add_field( - name="**Trigger:**", - value=f"{trigger_text}", - ) - else: - embed.add_field( - name="**Trigger:**", - value="_Paused_", - ) - - components: list[Button] = [ - edit_button, - remove_button, - ] - - if type(job.trigger) is not DateTrigger: - # Get trigger time for cron and interval jobs - trigger_time = job.next_run_time - pause_or_unpause_button: Button = unpause_button if trigger_time is None else pause_button - components.insert(1, pause_or_unpause_button) - - # Add a page to pages list - title: str = f"{message[:87]}..." if len(message) > max_title_length else message - pages.append( - Page( - embeds=embed, - title=title, - components=ActionRow(components=components), # type: ignore # noqa: PGH003 - callback=callback, - position=RowPosition.BOTTOM, - ), - ) - - return pages + # TODO: Add support for cron jobs and interval jobs + trigger_time: datetime | None = job.trigger.run_date if type(job.trigger) is DateTrigger else job.next_run_time + return "_Paused_" if trigger_time is None else f'{trigger_time.strftime("%Y-%m-%d %H:%M")} (in {calculate(job)})' -async def callback(self: Paginator, ctx: ComponentContext) -> Message | None: # noqa: PLR0911 +def _make_button(label: str, style: ButtonStyle) -> Button: + """Make a button. + + Args: + label: The label of the button. + style: The style of the button. + + Returns: + Button: The button. + """ + return interactions.Button( + label=label, + style=style, + custom_id=label.lower(), + ) + + +def _get_pause_or_unpause_button(job: Job) -> Button | None: + """Get pause or unpause button. + + If the job is paused, it will return the unpause button. + If the job is not paused, it will return the pause button. + If the job is not a cron or interval job, it will return None. + + Args: + job: The job. Can be cron, interval or normal. + + Returns: + Button | None: The pause or unpause button. If the job is not a cron or interval job, it will return None. + """ + if type(job.trigger) is not DateTrigger: + pause_button: Button = _make_button("Pause", interactions.ButtonStyle.PRIMARY) + unpause_button: Button = _make_button("Unpause", interactions.ButtonStyle.PRIMARY) + + if not hasattr(job, "next_run_time"): + return pause_button + + return unpause_button if job.next_run_time is None else pause_button + + return None + + +def _get_row_of_buttons(job: Job) -> ActionRow: + """Get components(buttons) for a page in /reminder list. + + These buttons are below the embed. + + Args: + job: The job. Can be cron, interval or normal. + + Returns: + ActionRow: A row of buttons. + """ + components: list[Button] = [ + _make_button("Edit", interactions.ButtonStyle.PRIMARY), + _make_button("Remove", interactions.ButtonStyle.DANGER), + ] + + # Add pause/unpause button as the second button if it's a cron or interval job + pause_or_unpause_button: Button | None = _get_pause_or_unpause_button(job=job) + if pause_or_unpause_button is not None: + components.insert(1, pause_or_unpause_button) + + # TODO: Should fix the type error + return ActionRow(components=components) # type: ignore # noqa: PGH003 + + +def _get_pages(job: Job, channel: Channel, ctx: CommandContext) -> Generator[Page, None, None]: + """Get pages for a reminder. + + Args: + job: The job. Can be cron, interval or normal. + channel: Check if the job kwargs channel ID is the same as the channel ID we looped through. + ctx: The context. Used to get the guild ID. + + Yields: + Generator[Page, None, None]: A page. + """ + # Get channel ID and guild ID from job kwargs + channel_id: int = job.kwargs.get("channel_id") + guild_id: int = job.kwargs.get("guild_id") + + if int(channel.id) == channel_id or ctx.guild_id == guild_id: + message: str = job.kwargs.get("message") + + # If message is longer than 1000 characters, truncate it + message = f"{message[:1000]}..." if len(message) > max_message_length else message + + # Create embed for the singular page + embed: Embed = interactions.Embed( + title=f"{job.id}", # Example: 593dcc18aab748faa571017454669eae + fields=[ + interactions.EmbedField( + name="**Channel:**", + value=f"#{channel.name}", # Example: #general + ), + interactions.EmbedField( + name="**Message:**", + value=f"{message}", # Example: Don't forget to feed the cat! + ), + interactions.EmbedField( + name="**Trigger:**", + value=_get_trigger_text(job=job), # Example: 2023-08-24 00:06 (in 157 days, 23 hours, 49 minutes) + ), + ], + ) + + # Truncate title if it's longer than 90 characters + # This is the text that shows up in the dropdown menu + # Example: 2: Don't forget to feed the cat! + dropdown_title: str = f"{message[:87]}..." if len(message) > max_title_length else message + + # Create a page and return it + yield Page( + embeds=embed, + title=dropdown_title, + components=_get_row_of_buttons(job), + callback=_callback, + position=RowPosition.BOTTOM, + ) + + +def _remove_job(job: Job) -> str: + """Remove a job. + + Args: + job: The job to remove. + + Returns: + str: The message to send to Discord. + """ + # TODO: Check if job exists before removing it? + # TODO: Add button to undo the removal? + channel_id: int = job.kwargs.get("channel_id") + old_message: str = job.kwargs.get("message") + trigger_time: datetime | None = job.trigger.run_date + scheduler.remove_job(job.id) + + return f"Job {job.id} removed.\n**Message:** {old_message}\n**Channel:** {channel_id}\n**Time:** {trigger_time}" + + +def _unpause_job(job: Job, custom_scheduler: BaseScheduler = scheduler) -> str: + """Unpause a job. + + Args: + job: The job to unpause. + custom_scheduler: The scheduler to use. Defaults to the global scheduler. + + Returns: + str: The message to send to Discord. + """ + # TODO: Should we check if the job is paused before unpause it? + custom_scheduler.resume_job(job.id) + return f"Job {job.id} unpaused." + + +def _pause_job(job: Job, custom_scheduler: BaseScheduler = scheduler) -> str: + """Pause a job. + + Args: + job: The job to pause. + custom_scheduler: The scheduler to use. Defaults to the global scheduler. + + Returns: + str: The message to send to Discord. + """ + # TODO: Should we check if the job is unpaused before unpause it? + custom_scheduler.pause_job(job.id) + return f"Job {job.id} paused." + + +async def _callback(self: Paginator, ctx: ComponentContext) -> Message | None: """Callback for the paginator.""" - if self.component_ctx is None: - return await ctx.send("Something went wrong.", ephemeral=True) - - if self.component_ctx.message is None: + # TODO: Create a test for this + if self.component_ctx is None or self.component_ctx.message is None: return await ctx.send("Something went wrong.", ephemeral=True) job_id: str | None = self.component_ctx.message.embeds[0].title @@ -145,7 +227,7 @@ async def callback(self: Paginator, ctx: ComponentContext) -> Message | None: # if job is None: return await ctx.send("Job not found.", ephemeral=True) - channel_id: int = job.kwargs.get("channel_id") + job.kwargs.get("channel_id") old_message: str = job.kwargs.get("message") components: list[TextInput] = [ @@ -158,6 +240,7 @@ async def callback(self: Paginator, ctx: ComponentContext) -> Message | None: # ), ] + job_type = "cron/interval" if type(job.trigger) is DateTrigger: # Get trigger time for normal reminders trigger_time: datetime | None = job.trigger.run_date @@ -172,37 +255,54 @@ async def callback(self: Paginator, ctx: ComponentContext) -> Message | None: # ), ) - else: - # Get trigger time for cron and interval jobs - trigger_time = job.next_run_time - job_type = "cron/interval" - + # Check what button was clicked and call the correct function + msg = "Something went wrong. I don't know what you clicked." if ctx.custom_id == "edit": + # TODO: Add buttons to increase/decrease hour modal: Modal = interactions.Modal( title=f"Edit {job_type} reminder.", custom_id="edit_modal", components=components, # type: ignore # noqa: PGH003 ) await ctx.popup(modal) - return None + msg = f"You modified {job_id}" + elif ctx.custom_id == "pause": + msg: str = _pause_job(job) + elif ctx.custom_id == "unpause": + msg: str = _unpause_job(job) + elif ctx.custom_id == "remove": + msg: str = _remove_job(job) - if ctx.custom_id == "pause": - scheduler.pause_job(job_id) - await ctx.send(f"Job {job_id} paused.") - return None + return await ctx.send(msg, ephemeral=True) - if ctx.custom_id == "unpause": - scheduler.resume_job(job_id) - await ctx.send(f"Job {job_id} unpaused.") - return None - if ctx.custom_id == "remove": - scheduler.remove_job(job_id) - await ctx.send( - f"Job {job_id} removed.\n" - f"**Message:** {old_message}\n" - f"**Channel:** {channel_id}\n" - f"**Time:** {trigger_time}", - ) - return None - return None +async def create_pages(ctx: CommandContext) -> list[Page]: + """Create pages for the paginator. + + Args: + ctx: The context of the command. + + Returns: + list[Page]: A list of pages. + """ + # TODO: Add tests for this + pages: list[Page] = [] + + jobs: list[Job] = scheduler.get_jobs() + for job in jobs: + # Check if we're in a server + if ctx.guild is None: + await ctx.send("I can't find the server you're in. Are you sure you're in a server?", ephemeral=True) + return [] + + # Check if we're in a channel + if ctx.guild.channels is None: + await ctx.send("I can't find the channel you're in.", ephemeral=True) + return [] + + # Only add reminders from channels in the server we run "/reminder list" in + # Check if channel is in the Discord server, if not, skip it. + for channel in ctx.guild.channels: + # Add a page for each reminder + pages.extend(iter(_get_pages(job=job, channel=channel, ctx=ctx))) + return pages diff --git a/tests/test_create_pages.py b/tests/test_create_pages.py new file mode 100644 index 0000000..2374637 --- /dev/null +++ b/tests/test_create_pages.py @@ -0,0 +1,192 @@ +import re +from datetime import datetime +from typing import TYPE_CHECKING + +import dateparser +import interactions +import pytz +from apscheduler.job import Job +from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from interactions.ext.paginator import Page + +from discord_reminder_bot.create_pages import ( + _get_pages, + _get_pause_or_unpause_button, + _get_row_of_buttons, + _get_trigger_text, + _make_button, + _pause_job, + _unpause_job, +) +from discord_reminder_bot.main import send_to_discord + +if TYPE_CHECKING: + from collections.abc import Generator + + +def _test_pause_unpause_button(job: Job, button_label: str) -> None: + button2: interactions.Button | None = _get_pause_or_unpause_button(job) + assert button2 + assert button2.label == button_label + assert button2.style == interactions.ButtonStyle.PRIMARY + assert button2.type == interactions.ComponentType.BUTTON + assert button2.emoji is None + assert button2.custom_id == button_label.lower() + assert button2.url is None + assert button2.disabled is None + + +class TestCountdown: + jobstores: dict[str, SQLAlchemyJobStore] = {"default": SQLAlchemyJobStore(url="sqlite:///:memory")} + job_defaults: dict[str, bool] = {"coalesce": True} + scheduler = AsyncIOScheduler( + jobstores=jobstores, + timezone=pytz.timezone("Europe/Stockholm"), + job_defaults=job_defaults, + ) + + parsed_date: datetime | None = dateparser.parse( + "18 January 2040", + settings={ + "PREFER_DATES_FROM": "future", + "TO_TIMEZONE": "Europe/Stockholm", + }, + ) + assert parsed_date + + run_date: str = parsed_date.strftime("%Y-%m-%d %H:%M:%S") + normal_job: Job = scheduler.add_job( + send_to_discord, + run_date=run_date, + kwargs={ + "channel_id": 865712621109772329, + "message": "Running PyTest", + "author_id": 126462229892694018, + }, + ) + + cron_job: Job = scheduler.add_job( + send_to_discord, + "cron", + minute="0", + kwargs={ + "channel_id": 865712621109772329, + "message": "Running PyTest", + "author_id": 126462229892694018, + }, + ) + + interval_job: Job = scheduler.add_job( + send_to_discord, + "interval", + minutes=1, + kwargs={ + "channel_id": 865712621109772329, + "message": "Running PyTest", + "author_id": 126462229892694018, + }, + ) + + def test_get_trigger_text(self) -> None: # noqa: ANN101 + # FIXME: This try except train should be replaced with a better solution lol + trigger_text: str = _get_trigger_text(self.normal_job) + try: + regex: str = r"2040-01-18 00:00 \(in \d+ days, \d+ hours, \d+ minutes\)" + assert re.match(regex, trigger_text) + except AssertionError: + try: + regex2: str = r"2040-01-18 00:00 \(in \d+ days, \d+ minutes\)" + assert re.match(regex2, trigger_text) + except AssertionError: + regex3: str = r"2040-01-18 00:00 \(in \d+ days\, \d+ minutes\)" + assert re.match(regex3, trigger_text) + + def test_make_button(self) -> None: # noqa: ANN101 + button_name: str = "Test" + + button: interactions.Button = _make_button(label=button_name, style=interactions.ButtonStyle.PRIMARY) + assert button.label == button_name + assert button.style == interactions.ButtonStyle.PRIMARY + assert button.custom_id == button_name.lower() + assert button.disabled is None + assert button.emoji is None + + def test_get_pause_or_unpause_button(self) -> None: # noqa: ANN101 + button: interactions.Button | None = _get_pause_or_unpause_button(self.normal_job) + assert button is None + + _test_pause_unpause_button(self.cron_job, "Pause") + self.cron_job.pause() + + _test_pause_unpause_button(self.cron_job, "Unpause") + self.cron_job.resume() + + _test_pause_unpause_button(self.interval_job, "Pause") + self.interval_job.pause() + + _test_pause_unpause_button(self.interval_job, "Unpause") + self.interval_job.resume() + + def test_get_row_of_buttons(self) -> None: # noqa: ANN101 + row: interactions.ActionRow = _get_row_of_buttons(self.normal_job) + assert row + assert row.components + + # A normal job should have 2 buttons, edit and delete + assert len(row.components) == 2 # noqa: PLR2004 + + row2: interactions.ActionRow = _get_row_of_buttons(self.cron_job) + assert row2 + assert row2.components + + # A cron job should have 3 buttons, edit, delete and pause/unpause + assert len(row2.components) == 3 # noqa: PLR2004 + + # A cron job should have 3 buttons, edit, delete and pause/unpause + assert len(row2.components) == 3 # noqa: PLR2004 + + def test_get_pages(self) -> None: # noqa: ANN101 + ctx = None # TODO: We should check ctx as well and not only channel id + channel: interactions.Channel = interactions.Channel(id=interactions.Snowflake(865712621109772329)) + + pages: Generator[Page, None, None] = _get_pages(job=self.normal_job, channel=channel, ctx=ctx) # type: ignore # noqa: PGH003, E501 + assert pages + + for page in pages: + assert page + assert page.title == "Running PyTest" + assert page.components + assert page.embeds + assert page.embeds.fields is not None # type: ignore # noqa: PGH003 + assert page.embeds.fields[0].name == "**Channel:**" # type: ignore # noqa: PGH003 + assert page.embeds.fields[0].value == "#" # type: ignore # noqa: PGH003 + assert page.embeds.fields[1].name == "**Message:**" # type: ignore # noqa: PGH003 + assert page.embeds.fields[1].value == "Running PyTest" # type: ignore # noqa: PGH003 + assert page.embeds.fields[2].name == "**Trigger:**" # type: ignore # noqa: PGH003 + trigger_text: str = page.embeds.fields[2].value # type: ignore # noqa: PGH003 + + # FIXME: This try except train should be replaced with a better solution lol + try: + regex: str = r"2040-01-18 00:00 \(in \d+ days, \d+ hours, \d+ minutes\)" + assert re.match(regex, trigger_text) + except AssertionError: + try: + regex2: str = r"2040-01-18 00:00 \(in \d+ days, \d+ minutes\)" + assert re.match(regex2, trigger_text) + except AssertionError: + regex3: str = r"2040-01-18 00:00 \(in \d+ days\, \d+ minutes\)" + assert re.match(regex3, trigger_text) + + # Check if type is Page + assert isinstance(page, Page) + + def test_pause_job(self) -> None: # noqa: ANN101 + assert _pause_job(self.interval_job, self.scheduler) == f"Job {self.interval_job.id} paused." + assert _pause_job(self.cron_job, self.scheduler) == f"Job {self.cron_job.id} paused." + assert _pause_job(self.normal_job, self.scheduler) == f"Job {self.normal_job.id} paused." + + def test_unpause_job(self) -> None: # noqa: ANN101 + assert _unpause_job(self.interval_job, self.scheduler) == f"Job {self.interval_job.id} unpaused." + assert _unpause_job(self.cron_job, self.scheduler) == f"Job {self.cron_job.id} unpaused." + assert _unpause_job(self.normal_job, self.scheduler) == f"Job {self.normal_job.id} unpaused." diff --git a/tests/test_parse.py b/tests/test_parse.py new file mode 100644 index 0000000..91d7bd4 --- /dev/null +++ b/tests/test_parse.py @@ -0,0 +1,68 @@ +from datetime import datetime + +import tzlocal + +from discord_reminder_bot.parse import ParsedTime, parse_time + + +def test_parse_time() -> None: + """Test the parse_time function.""" + parsed_time: ParsedTime = parse_time("18 January 2040") + assert parsed_time.err is False + assert not parsed_time.err_msg + assert parsed_time.date_to_parse == "18 January 2040" + assert parsed_time.parsed_time + assert parsed_time.parsed_time.strftime("%Y-%m-%d %H:%M:%S") == "2040-01-18 00:00:00" + + parsed_time: ParsedTime = parse_time("18 January 2040 12:00") + assert parsed_time.err is False + assert not parsed_time.err_msg + assert parsed_time.date_to_parse == "18 January 2040 12:00" + assert parsed_time.parsed_time + assert parsed_time.parsed_time.strftime("%Y-%m-%d %H:%M:%S") == "2040-01-18 12:00:00" + + parsed_time: ParsedTime = parse_time("18 January 2040 12:00:00") + assert parsed_time.err is False + assert not parsed_time.err_msg + assert parsed_time.date_to_parse == "18 January 2040 12:00:00" + assert parsed_time.parsed_time + assert parsed_time.parsed_time.strftime("%Y-%m-%d %H:%M:%S") == "2040-01-18 12:00:00" + + parsed_time: ParsedTime = parse_time("18 January 2040 12:00:00 UTC") + assert parsed_time.err is False + assert not parsed_time.err_msg + assert parsed_time.date_to_parse == "18 January 2040 12:00:00 UTC" + assert parsed_time.parsed_time + assert parsed_time.parsed_time.strftime("%Y-%m-%d %H:%M:%S") == "2040-01-18 13:00:00" + + parsed_time: ParsedTime = parse_time("18 January 2040 12:00:00 Europe/Stockholm") + assert parsed_time.err is True + assert parsed_time.err_msg == "Could not parse the date." + assert parsed_time.date_to_parse == "18 January 2040 12:00:00 Europe/Stockholm" + assert parsed_time.parsed_time is None + + +def test_ParsedTime() -> None: # noqa: N802 + """Test the ParsedTime class.""" + parsed_time: ParsedTime = ParsedTime( + err=False, + err_msg="", + date_to_parse="18 January 2040", + parsed_time=datetime(2040, 1, 18, 0, 0, 0, tzinfo=tzlocal.get_localzone()), + ) + assert parsed_time.err is False + assert not parsed_time.err_msg + assert parsed_time.date_to_parse == "18 January 2040" + assert parsed_time.parsed_time + assert parsed_time.parsed_time.strftime("%Y-%m-%d %H:%M:%S") == "2040-01-18 00:00:00" + + parsed_time: ParsedTime = ParsedTime( + err=True, + err_msg="Could not parse the date.", + date_to_parse="18 January 2040 12:00:00 Europe/Stockholm", + parsed_time=None, + ) + assert parsed_time.err is True + assert parsed_time.err_msg == "Could not parse the date." + assert parsed_time.date_to_parse == "18 January 2040 12:00:00 Europe/Stockholm" + assert parsed_time.parsed_time is None