diff --git a/bot/bot.py b/bot/bot.py index 5375d8f..1b846ee 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -115,6 +115,7 @@ async def wrapper(message: Message, context: CallbackContext, question: str) -> if ( not filters.is_known_user(username) and user.message_counter.value >= config.conversation.message_limit.count > 0 + and not user.message_counter.is_expired() ): # this is a group user and they have exceeded the message limit wait_for = models.format_timedelta(user.message_counter.expires_after()) diff --git a/tests/test_commands.py b/tests/test_commands.py index 0ca6108..74c4169 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -350,7 +350,7 @@ async def test_no_mention(self): self.assertEqual(self.bot.text, "") -class UserMessageLimitTest(unittest.IsolatedAsyncioTestCase, Helper): +class MessageLimitTest(unittest.IsolatedAsyncioTestCase, Helper): def setUp(self): self.ai = FakeGPT() askers.TextAsker.model = self.ai @@ -388,6 +388,22 @@ async def test_unknown_user(self): await self.command(update, self.context) self.assertTrue(self.bot.text.startswith("Please wait")) + async def test_expired(self): + config.conversation.message_limit.count = 3 + + user = User(id=2, first_name="Bob", is_bot=False, username="bob") + # the counter has reached the limit, but the value has expired + user_data = { + "message_counter": {"value": 3, "timestamp": dt.datetime.now() - dt.timedelta(hours=1)} + } + self.application.user_data[user.id] = user_data + context = CallbackContext(self.application, chat_id=1, user_id=user.id) + + update = self._create_update(11, text="What is your name?", user=user) + await self.command(update, context) + self.assertEqual(self.bot.text, "What is your name?") + self.assertEqual(user_data["message_counter"]["value"], 1) + async def test_unlimited(self): config.conversation.message_limit.count = 0 other_user = User(id=2, first_name="Bob", is_bot=False, username="bob")