diff --git a/bot/telegram_bot.py b/bot/telegram_bot.py index a8907345..0421ca2e 100644 --- a/bot/telegram_bot.py +++ b/bot/telegram_bot.py @@ -1,48 +1,30 @@ from __future__ import annotations + +import asyncio import logging import os -import itertools -import asyncio -import telegram from uuid import uuid4 -from telegram import constants, BotCommandScopeAllGroupChats +from telegram import BotCommandScopeAllGroupChats, Update, constants from telegram import InlineKeyboardMarkup, InlineKeyboardButton, InlineQueryResultArticle -from telegram import Message, MessageEntity, Update, InputTextMessageContent, BotCommand, ChatMember +from telegram import InputTextMessageContent, BotCommand from telegram.error import RetryAfter, TimedOut -from telegram.ext import ApplicationBuilder, ContextTypes, CommandHandler, MessageHandler, \ - filters, InlineQueryHandler, CallbackQueryHandler, Application, CallbackContext +from telegram.ext import ApplicationBuilder, CommandHandler, MessageHandler, \ + filters, InlineQueryHandler, CallbackQueryHandler, Application, ContextTypes, CallbackContext from pydub import AudioSegment + +from utils import is_group_chat, get_thread_id, message_text, wrap_with_indicator, split_into_chunks, \ + edit_message_with_retry, get_stream_cutoff_values, is_allowed, get_remaining_budget, is_admin, is_within_budget, \ + get_reply_to_message_id, add_chat_request_to_usage_tracker, error_handler from openai_helper import OpenAIHelper, localized_text from usage_tracker import UsageTracker -def message_text(message: Message) -> str: - """ - Returns the text of a message, excluding any bot commands. - """ - message_txt = message.text - if message_txt is None: - return '' - - for _, text in sorted(message.parse_entities([MessageEntity.BOT_COMMAND]).items(), - key=(lambda item: item[0].offset)): - message_txt = message_txt.replace(text, '').strip() - - return message_txt if len(message_txt) > 0 else '' - - class ChatGPTTelegramBot: """ Class representing a ChatGPT Telegram Bot. """ - # Mapping of budget period to cost period - budget_cost_map = { - "monthly": "cost_month", - "daily": "cost_today", - "all-time": "cost_all_time" - } def __init__(self, config: dict, openai: OpenAIHelper): """ @@ -60,10 +42,9 @@ def __init__(self, config: dict, openai: OpenAIHelper): BotCommand(command='stats', description=localized_text('stats_description', bot_language)), BotCommand(command='resend', description=localized_text('resend_description', bot_language)) ] - self.group_commands = [ - BotCommand(command='chat', - description=localized_text('chat_description', bot_language)) - ] + self.commands + self.group_commands = [BotCommand( + command='chat', description=localized_text('chat_description', bot_language) + )] + self.commands self.disallowed_message = localized_text('disallowed', bot_language) self.budget_limit_message = localized_text('budget_limit', bot_language) self.usage = {} @@ -74,7 +55,7 @@ async def help(self, update: Update, _: ContextTypes.DEFAULT_TYPE) -> None: """ Shows the help menu. """ - commands = self.group_commands if self.is_group_chat(update) else self.commands + commands = self.group_commands if is_group_chat(update) else self.commands commands_description = [f'/{command.command} - {command.description}' for command in commands] bot_language = self.config['bot_language'] help_text = ( @@ -92,7 +73,7 @@ async def stats(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ Returns token usage statistics for current day and month. """ - if not await self.is_allowed(update, context): + if not await is_allowed(self.config, update, context): logging.warning(f'User {update.message.from_user.name} (id: {update.message.from_user.id}) ' f'is not allowed to request their usage statistics') await self.send_disallowed_message(update, context) @@ -113,7 +94,7 @@ async def stats(self, update: Update, context: ContextTypes.DEFAULT_TYPE): chat_id = update.effective_chat.id chat_messages, chat_token_length = self.openai.get_conversation_stats(chat_id) - remaining_budget = self.get_remaining_budget(update) + remaining_budget = get_remaining_budget(self.config, self.usage, update) bot_language = self.config['bot_language'] text_current_conversation = ( f"*{localized_text('stats_conversation', bot_language)[0]}*:\n" @@ -148,7 +129,7 @@ async def stats(self, update: Update, context: ContextTypes.DEFAULT_TYPE): f"${remaining_budget:.2f}.\n" ) # add OpenAI account information for admin request - if self.is_admin(user_id): + if is_admin(self.config, user_id): text_budget += ( f"{localized_text('stats_openai', bot_language)}" f"{self.openai.get_billing_current_month():.2f}" @@ -161,7 +142,7 @@ async def resend(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ Resend the last request """ - if not await self.is_allowed(update, context): + if not await is_allowed(self.config, update, context): logging.warning(f'User {update.message.from_user.name} (id: {update.message.from_user.id})' f' is not allowed to resend the message') await self.send_disallowed_message(update, context) @@ -172,7 +153,7 @@ async def resend(self, update: Update, context: ContextTypes.DEFAULT_TYPE): logging.warning(f'User {update.message.from_user.name} (id: {update.message.from_user.id})' f' does not have anything to resend') await update.effective_message.reply_text( - message_thread_id=self.get_thread_id(update), + message_thread_id=get_thread_id(update), text=localized_text('resend_failed', self.config['bot_language']) ) return @@ -189,7 +170,7 @@ async def reset(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ Resets the conversation. """ - if not await self.is_allowed(update, context): + if not await is_allowed(self.config, update, context): logging.warning(f'User {update.message.from_user.name} (id: {update.message.from_user.id}) ' f'is not allowed to reset the conversation') await self.send_disallowed_message(update, context) @@ -202,24 +183,23 @@ async def reset(self, update: Update, context: ContextTypes.DEFAULT_TYPE): reset_content = message_text(update.message) self.openai.reset_chat_history(chat_id=chat_id, content=reset_content) await update.effective_message.reply_text( - message_thread_id=self.get_thread_id(update), - text=localized_text('reset_done', self.config['bot_language']) + message_thread_id=get_thread_id(update), + text=localized_text('reset_done', self.config['bot_language']) ) async def image(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ Generates an image for the given prompt using DALLĀ·E APIs """ - if not self.config['enable_image_generation'] or not await self.check_allowed_and_within_budget(update, - context): + if not self.config['enable_image_generation'] \ + or not await self.check_allowed_and_within_budget(update, context): return - chat_id = update.effective_chat.id image_query = message_text(update.message) if image_query == '': await update.effective_message.reply_text( - message_thread_id=self.get_thread_id(update), - text=localized_text('image_no_prompt', self.config['bot_language']) + message_thread_id=get_thread_id(update), + text=localized_text('image_no_prompt', self.config['bot_language']) ) return @@ -230,7 +210,7 @@ async def _generate(): try: image_url, image_size = await self.openai.generate_image(prompt=image_query) await update.effective_message.reply_photo( - reply_to_message_id=self.get_reply_to_message_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), photo=image_url ) # add image request to users usage tracker @@ -243,13 +223,13 @@ async def _generate(): except Exception as e: logging.exception(e) await update.effective_message.reply_text( - message_thread_id=self.get_thread_id(update), - reply_to_message_id=self.get_reply_to_message_id(update), + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), text=f"{localized_text('image_fail', self.config['bot_language'])}: {str(e)}", parse_mode=constants.ParseMode.MARKDOWN ) - await self.wrap_with_indicator(update, context, _generate, constants.ChatAction.UPLOAD_PHOTO) + await wrap_with_indicator(update, context, _generate, constants.ChatAction.UPLOAD_PHOTO) async def transcribe(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ @@ -258,7 +238,7 @@ async def transcribe(self, update: Update, context: ContextTypes.DEFAULT_TYPE): if not self.config['enable_transcription'] or not await self.check_allowed_and_within_budget(update, context): return - if self.is_group_chat(update) and self.config['ignore_group_transcriptions']: + if is_group_chat(update) and self.config['ignore_group_transcriptions']: logging.info(f'Transcription coming from group chat, ignoring...') return @@ -274,8 +254,8 @@ async def _execute(): except Exception as e: logging.exception(e) await update.effective_message.reply_text( - message_thread_id=self.get_thread_id(update), - reply_to_message_id=self.get_reply_to_message_id(update), + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), text=( f"{localized_text('media_download_fail', bot_language)[0]}: " f"{str(e)}. {localized_text('media_download_fail', bot_language)[1]}" @@ -284,7 +264,6 @@ async def _execute(): ) return - # detect and extract audio from the attachment with pydub try: audio_track = AudioSegment.from_file(filename) audio_track.export(filename_mp3, format="mp3") @@ -294,8 +273,8 @@ async def _execute(): except Exception as e: logging.exception(e) await update.effective_message.reply_text( - message_thread_id=self.get_thread_id(update), - reply_to_message_id=self.get_reply_to_message_id(update), + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), text=localized_text('media_type_fail', bot_language) ) if os.path.exists(filename): @@ -306,17 +285,12 @@ async def _execute(): if user_id not in self.usage: self.usage[user_id] = UsageTracker(user_id, update.message.from_user.name) - # send decoded audio to openai try: - - # Transcribe the audio file transcript = await self.openai.transcribe(filename_mp3) - # add transcription seconds to usage tracker transcription_price = self.config['transcription_price'] self.usage[user_id].add_transcription_seconds(audio_track.duration_seconds, transcription_price) - # add guest chat request to guest usage tracker allowed_user_ids = self.config['allowed_user_ids'].split(',') if str(user_id) not in allowed_user_ids and 'guests' in self.usage: self.usage["guests"].add_transcription_seconds(audio_track.duration_seconds, transcription_price) @@ -329,12 +303,12 @@ async def _execute(): # Split into chunks of 4096 characters (Telegram's message limit) transcript_output = f"_{localized_text('transcript', bot_language)}:_\n\"{transcript}\"" - chunks = self.split_into_chunks(transcript_output) + chunks = split_into_chunks(transcript_output) for index, transcript_chunk in enumerate(chunks): await update.effective_message.reply_text( - message_thread_id=self.get_thread_id(update), - reply_to_message_id=self.get_reply_to_message_id(update) if index == 0 else None, + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update) if index == 0 else None, text=transcript_chunk, parse_mode=constants.ParseMode.MARKDOWN ) @@ -342,9 +316,7 @@ async def _execute(): # Get the response of the transcript response, total_tokens = await self.openai.get_chat_response(chat_id=chat_id, query=transcript) - # add chat request to users usage tracker self.usage[user_id].add_chat_tokens(total_tokens, self.config['token_price']) - # add guest chat request to guest usage tracker if str(user_id) not in allowed_user_ids and 'guests' in self.usage: self.usage["guests"].add_chat_tokens(total_tokens, self.config['token_price']) @@ -353,12 +325,12 @@ async def _execute(): f"_{localized_text('transcript', bot_language)}:_\n\"{transcript}\"\n\n" f"_{localized_text('answer', bot_language)}:_\n{response}" ) - chunks = self.split_into_chunks(transcript_output) + chunks = split_into_chunks(transcript_output) for index, transcript_chunk in enumerate(chunks): await update.effective_message.reply_text( - message_thread_id=self.get_thread_id(update), - reply_to_message_id=self.get_reply_to_message_id(update) if index == 0 else None, + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update) if index == 0 else None, text=transcript_chunk, parse_mode=constants.ParseMode.MARKDOWN ) @@ -366,19 +338,18 @@ async def _execute(): except Exception as e: logging.exception(e) await update.effective_message.reply_text( - message_thread_id=self.get_thread_id(update), - reply_to_message_id=self.get_reply_to_message_id(update), + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), text=f"{localized_text('transcribe_fail', bot_language)}: {str(e)}", parse_mode=constants.ParseMode.MARKDOWN ) finally: - # Cleanup files if os.path.exists(filename_mp3): os.remove(filename_mp3) if os.path.exists(filename): os.remove(filename) - await self.wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING) + await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING) async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ @@ -397,7 +368,7 @@ async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): prompt = message_text(update.message) self.last_message[chat_id] = prompt - if self.is_group_chat(update): + if is_group_chat(update): trigger_keyword = self.config['group_trigger_keyword'] if prompt.lower().startswith(trigger_keyword.lower()): prompt = prompt[len(trigger_keyword):].strip() @@ -405,10 +376,7 @@ async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): if update.message.reply_to_message and \ update.message.reply_to_message.text and \ update.message.reply_to_message.from_user.id != context.bot.id: - prompt = '"{reply}" {prompt}'.format( - reply=update.message.reply_to_message.text, - prompt=prompt - ) + prompt = f'"{update.message.reply_to_message.text}" {prompt}' else: if update.message.reply_to_message and update.message.reply_to_message.from_user.id == context.bot.id: logging.info('Message is a reply to the bot, allowing...') @@ -422,7 +390,7 @@ async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): if self.config['stream']: await update.effective_message.reply_chat_action( action=constants.ChatAction.TYPING, - message_thread_id=self.get_thread_id(update) + message_thread_id=get_thread_id(update) ) stream_response = self.openai.get_chat_response_stream(chat_id=chat_id, query=prompt) @@ -436,26 +404,26 @@ async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): if len(content.strip()) == 0: continue - stream_chunks = self.split_into_chunks(content) + stream_chunks = split_into_chunks(content) if len(stream_chunks) > 1: content = stream_chunks[-1] if stream_chunk != len(stream_chunks) - 1: stream_chunk += 1 try: - await self.edit_message_with_retry(context, chat_id, str(sent_message.message_id), - stream_chunks[-2]) + await edit_message_with_retry(context, chat_id, str(sent_message.message_id), + stream_chunks[-2]) except: pass try: sent_message = await update.effective_message.reply_text( - message_thread_id=self.get_thread_id(update), + message_thread_id=get_thread_id(update), text=content if len(content) > 0 else "..." ) except: pass continue - cutoff = self.get_stream_cutoff_values(update, content) + cutoff = get_stream_cutoff_values(update, content) cutoff += backoff if i == 0: @@ -464,8 +432,8 @@ async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): await context.bot.delete_message(chat_id=sent_message.chat_id, message_id=sent_message.message_id) sent_message = await update.effective_message.reply_text( - message_thread_id=self.get_thread_id(update), - reply_to_message_id=self.get_reply_to_message_id(update), + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), text=content ) except: @@ -476,8 +444,8 @@ async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): try: use_markdown = tokens != 'not_finished' - await self.edit_message_with_retry(context, chat_id, str(sent_message.message_id), - text=content, markdown=use_markdown) + await edit_message_with_retry(context, chat_id, str(sent_message.message_id), + text=content, markdown=use_markdown) except RetryAfter as e: backoff += 5 @@ -505,35 +473,37 @@ async def _reply(): response, total_tokens = await self.openai.get_chat_response(chat_id=chat_id, query=prompt) # Split into chunks of 4096 characters (Telegram's message limit) - chunks = self.split_into_chunks(response) + chunks = split_into_chunks(response) for index, chunk in enumerate(chunks): try: await update.effective_message.reply_text( - message_thread_id=self.get_thread_id(update), - reply_to_message_id=self.get_reply_to_message_id(update) if index == 0 else None, + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, + update) if index == 0 else None, text=chunk, parse_mode=constants.ParseMode.MARKDOWN ) except Exception: try: await update.effective_message.reply_text( - message_thread_id=self.get_thread_id(update), - reply_to_message_id=self.get_reply_to_message_id(update) if index == 0 else None, + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, + update) if index == 0 else None, text=chunk ) except Exception as exception: raise exception - await self.wrap_with_indicator(update, context, _reply, constants.ChatAction.TYPING) + await wrap_with_indicator(update, context, _reply, constants.ChatAction.TYPING) - self.add_chat_request_to_usage_tracker(user_id, total_tokens) + add_chat_request_to_usage_tracker(self.usage, self.config, user_id, total_tokens) except Exception as e: logging.exception(e) await update.effective_message.reply_text( - message_thread_id=self.get_thread_id(update), - reply_to_message_id=self.get_reply_to_message_id(update), + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), text=f"{localized_text('chat_fail', self.config['bot_language'])} {str(e)}", parse_mode=constants.ParseMode.MARKDOWN ) @@ -556,6 +526,9 @@ async def inline_query(self, update: Update, context: ContextTypes.DEFAULT_TYPE) await self.send_inline_query_result(update, result_id, message_content=query, callback_data=callback_data) async def send_inline_query_result(self, update: Update, result_id, message_content, callback_data=""): + """ + Send inline query result + """ try: reply_markup = None bot_language = self.config['bot_language'] @@ -580,6 +553,9 @@ async def send_inline_query_result(self, update: Update, result_id, message_cont logging.error(f'An error occurred while generating the result card for inline query {e}') async def handle_callback_inline_query(self, update: Update, context: CallbackContext): + """ + Handle the callback query from the inline query result + """ callback_data = update.callback_query.data user_id = update.callback_query.from_user.id inline_message_id = update.callback_query.inline_message_id @@ -604,9 +580,9 @@ async def handle_callback_inline_query(self, update: Update, context: CallbackCo f'{localized_text("error", bot_language)}. ' f'{localized_text("try_again", bot_language)}' ) - await self.edit_message_with_retry(context, chat_id=None, message_id=inline_message_id, - text=f'{query}\n\n_{answer_tr}:_\n{error_message}', - is_inline=True) + await edit_message_with_retry(context, chat_id=None, message_id=inline_message_id, + text=f'{query}\n\n_{answer_tr}:_\n{error_message}', + is_inline=True) return if self.config['stream']: @@ -619,16 +595,16 @@ async def handle_callback_inline_query(self, update: Update, context: CallbackCo if len(content.strip()) == 0: continue - cutoff = self.get_stream_cutoff_values(update, content) + cutoff = get_stream_cutoff_values(update, content) cutoff += backoff if i == 0: try: if sent_message is not None: - await self.edit_message_with_retry(context, chat_id=None, - message_id=inline_message_id, - text=f'{query}\n\n{answer_tr}:\n{content}', - is_inline=True) + await edit_message_with_retry(context, chat_id=None, + message_id=inline_message_id, + text=f'{query}\n\n{answer_tr}:\n{content}', + is_inline=True) except: continue @@ -642,8 +618,8 @@ async def handle_callback_inline_query(self, update: Update, context: CallbackCo # We only want to send the first 4096 characters. No chunking allowed in inline mode. text = text[:4096] - await self.edit_message_with_retry(context, chat_id=None, message_id=inline_message_id, - text=text, markdown=use_markdown, is_inline=True) + await edit_message_with_retry(context, chat_id=None, message_id=inline_message_id, + text=text, markdown=use_markdown, is_inline=True) except RetryAfter as e: backoff += 5 @@ -680,83 +656,52 @@ async def _send_inline_query_response(): text_content = text_content[:4096] # Edit the original message with the generated content - await self.edit_message_with_retry(context, chat_id=None, message_id=inline_message_id, - text=text_content, is_inline=True) + await edit_message_with_retry(context, chat_id=None, message_id=inline_message_id, + text=text_content, is_inline=True) - await self.wrap_with_indicator(update, context, _send_inline_query_response, - constants.ChatAction.TYPING, is_inline=True) + await wrap_with_indicator(update, context, _send_inline_query_response, + constants.ChatAction.TYPING, is_inline=True) - self.add_chat_request_to_usage_tracker(user_id, total_tokens) + add_chat_request_to_usage_tracker(self.usage, self.config, user_id, total_tokens) except Exception as e: logging.error(f'Failed to respond to an inline query via button callback: {e}') logging.exception(e) localized_answer = localized_text('chat_fail', self.config['bot_language']) - await self.edit_message_with_retry(context, chat_id=None, message_id=inline_message_id, - text=f"{query}\n\n_{answer_tr}:_\n{localized_answer} {str(e)}", - is_inline=True) + await edit_message_with_retry(context, chat_id=None, message_id=inline_message_id, + text=f"{query}\n\n_{answer_tr}:_\n{localized_answer} {str(e)}", + is_inline=True) - async def edit_message_with_retry(self, context: ContextTypes.DEFAULT_TYPE, chat_id: int | None, - message_id: str, text: str, markdown: bool = True, is_inline: bool = False): + async def check_allowed_and_within_budget(self, update: Update, context: ContextTypes.DEFAULT_TYPE, + is_inline=False) -> bool: """ - Edit a message with retry logic in case of failure (e.g. broken markdown) - :param context: The context to use - :param chat_id: The chat id to edit the message in - :param message_id: The message id to edit - :param text: The text to edit the message with - :param markdown: Whether to use markdown parse mode - :param is_inline: Whether the message to edit is an inline message - :return: None + Checks if the user is allowed to use the bot and if they are within their budget + :param update: Telegram update object + :param context: Telegram context object + :param is_inline: Boolean flag for inline queries + :return: Boolean indicating if the user is allowed to use the bot """ - try: - await context.bot.edit_message_text( - chat_id=chat_id, - message_id=int(message_id) if not is_inline else None, - inline_message_id=message_id if is_inline else None, - text=text, - parse_mode=constants.ParseMode.MARKDOWN if markdown else None - ) - except telegram.error.BadRequest as e: - if str(e).startswith("Message is not modified"): - return - try: - await context.bot.edit_message_text( - chat_id=chat_id, - message_id=int(message_id) if not is_inline else None, - inline_message_id=message_id if is_inline else None, - text=text - ) - except Exception as e: - logging.warning(f'Failed to edit message: {str(e)}') - raise e + name = update.inline_query.from_user.name if is_inline else update.message.from_user.name + user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id - except Exception as e: - logging.warning(str(e)) - raise e + if not await is_allowed(self.config, update, context, is_inline=is_inline): + logging.warning(f'User {name} (id: {user_id}) is not allowed to use the bot') + await self.send_disallowed_message(update, context, is_inline) + return False + if not is_within_budget(self.config, self.usage, update, is_inline=is_inline): + logging.warning(f'User {name} (id: {user_id}) reached their usage limit') + await self.send_budget_reached_message(update, context, is_inline) + return False - async def wrap_with_indicator(self, update: Update, context: CallbackContext, coroutine, - chat_action: constants.ChatAction = "", is_inline=False): - """ - Wraps a coroutine while repeatedly sending a chat action to the user. - """ - task = context.application.create_task(coroutine(), update=update) - while not task.done(): - if not is_inline: - context.application.create_task( - update.effective_chat.send_action(chat_action, message_thread_id=self.get_thread_id(update)) - ) - try: - await asyncio.wait_for(asyncio.shield(task), 4.5) - except asyncio.TimeoutError: - pass + return True - async def send_disallowed_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE, is_inline=False): + async def send_disallowed_message(self, update: Update, _: ContextTypes.DEFAULT_TYPE, is_inline=False): """ Sends the disallowed message to the user. """ if not is_inline: await update.effective_message.reply_text( - message_thread_id=self.get_thread_id(update), + message_thread_id=get_thread_id(update), text=self.disallowed_message, disable_web_page_preview=True ) @@ -764,240 +709,19 @@ async def send_disallowed_message(self, update: Update, context: ContextTypes.DE result_id = str(uuid4()) await self.send_inline_query_result(update, result_id, message_content=self.disallowed_message) - async def send_budget_reached_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE, is_inline=False): + async def send_budget_reached_message(self, update: Update, _: ContextTypes.DEFAULT_TYPE, is_inline=False): """ Sends the budget reached message to the user. """ if not is_inline: await update.effective_message.reply_text( - message_thread_id=self.get_thread_id(update), + message_thread_id=get_thread_id(update), text=self.budget_limit_message ) else: result_id = str(uuid4()) await self.send_inline_query_result(update, result_id, message_content=self.budget_limit_message) - async def error_handler(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None: - """ - Handles errors in the telegram-python-bot library. - """ - logging.error(f'Exception while handling an update: {context.error}') - - def get_thread_id(self, update: Update) -> int | None: - """ - Gets the message thread id for the update, if any - """ - if update.effective_message and update.effective_message.is_topic_message: - return update.effective_message.message_thread_id - return None - - def get_stream_cutoff_values(self, update: Update, content: str) -> int: - """ - Gets the stream cutoff values for the message length - """ - if self.is_group_chat(update): - # group chats have stricter flood limits - return 180 if len(content) > 1000 else 120 if len(content) > 200 else 90 if len( - content) > 50 else 50 - else: - return 90 if len(content) > 1000 else 45 if len(content) > 200 else 25 if len( - content) > 50 else 15 - - def is_group_chat(self, update: Update) -> bool: - """ - Checks if the message was sent from a group chat - """ - if not update.effective_chat: - return False - return update.effective_chat.type in [ - constants.ChatType.GROUP, - constants.ChatType.SUPERGROUP - ] - - async def is_user_in_group(self, update: Update, context: CallbackContext, user_id: int) -> bool: - """ - Checks if user_id is a member of the group - """ - try: - chat_member = await context.bot.get_chat_member(update.message.chat_id, user_id) - return chat_member.status in [ChatMember.OWNER, ChatMember.ADMINISTRATOR, ChatMember.MEMBER] - except telegram.error.BadRequest as e: - if str(e) == "User not found": - return False - else: - raise e - except Exception as e: - raise e - - async def is_allowed(self, update: Update, context: CallbackContext, is_inline=False) -> bool: - """ - Checks if the user is allowed to use the bot. - """ - if self.config['allowed_user_ids'] == '*': - return True - - user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id - if self.is_admin(user_id): - return True - name = update.inline_query.from_user.name if is_inline else update.message.from_user.name - allowed_user_ids = self.config['allowed_user_ids'].split(',') - # Check if user is allowed - if str(user_id) in allowed_user_ids: - return True - # Check if it's a group a chat with at least one authorized member - if not is_inline and self.is_group_chat(update): - admin_user_ids = self.config['admin_user_ids'].split(',') - for user in itertools.chain(allowed_user_ids, admin_user_ids): - if not user.strip(): - continue - if await self.is_user_in_group(update, context, user): - logging.info(f'{user} is a member. Allowing group chat message...') - return True - logging.info(f'Group chat messages from user {name} ' - f'(id: {user_id}) are not allowed') - return False - - def is_admin(self, user_id: int, log_no_admin=False) -> bool: - """ - Checks if the user is the admin of the bot. - The first user in the user list is the admin. - """ - if self.config['admin_user_ids'] == '-': - if log_no_admin: - logging.info('No admin user defined.') - return False - - admin_user_ids = self.config['admin_user_ids'].split(',') - - # Check if user is in the admin user list - if str(user_id) in admin_user_ids: - return True - - return False - - def get_user_budget(self, user_id) -> float | None: - """ - Get the user's budget based on their user ID and the bot configuration. - :param user_id: User id - :return: The user's budget as a float, or None if the user is not found in the allowed user list - """ - - # no budget restrictions for admins and '*'-budget lists - if self.is_admin(user_id) or self.config['user_budgets'] == '*': - return float('inf') - - user_budgets = self.config['user_budgets'].split(',') - if self.config['allowed_user_ids'] == '*': - # same budget for all users, use value in first position of budget list - if len(user_budgets) > 1: - logging.warning('multiple values for budgets set with unrestricted user list ' - 'only the first value is used as budget for everyone.') - return float(user_budgets[0]) - - allowed_user_ids = self.config['allowed_user_ids'].split(',') - if str(user_id) in allowed_user_ids: - user_index = allowed_user_ids.index(str(user_id)) - if len(user_budgets) <= user_index: - logging.warning(f'No budget set for user id: {user_id}. Budget list shorter than user list.') - return 0.0 - return float(user_budgets[user_index]) - return None - - def get_remaining_budget(self, update: Update, is_inline=False) -> float: - """ - Calculate the remaining budget for a user based on their current usage. - :param update: Telegram update object - :param is_inline: Boolean flag for inline queries - :return: The remaining budget for the user as a float - """ - user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id - name = update.inline_query.from_user.name if is_inline else update.message.from_user.name - if user_id not in self.usage: - self.usage[user_id] = UsageTracker(user_id, name) - - # Get budget for users - user_budget = self.get_user_budget(user_id) - budget_period = self.config['budget_period'] - if user_budget is not None: - cost = self.usage[user_id].get_current_cost()[self.budget_cost_map[budget_period]] - return user_budget - cost - - # Get budget for guests - if 'guests' not in self.usage: - self.usage['guests'] = UsageTracker('guests', 'all guest users in group chats') - cost = self.usage['guests'].get_current_cost()[self.budget_cost_map[budget_period]] - return self.config['guest_budget'] - cost - - def is_within_budget(self, update: Update, is_inline=False) -> bool: - """ - Checks if the user reached their usage limit. - Initializes UsageTracker for user and guest when needed. - :param update: Telegram update object - :param is_inline: Boolean flag for inline queries - :return: Boolean indicating if the user has a positive budget - """ - user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id - name = update.inline_query.from_user.name if is_inline else update.message.from_user.name - if user_id not in self.usage: - self.usage[user_id] = UsageTracker(user_id, name) - - remaining_budget = self.get_remaining_budget(update, is_inline=is_inline) - - return remaining_budget > 0 - - async def check_allowed_and_within_budget(self, update: Update, context: ContextTypes.DEFAULT_TYPE, - is_inline=False) -> bool: - """ - Checks if the user is allowed to use the bot and if they are within their budget - :param update: Telegram update object - :param context: Telegram context object - :param is_inline: Boolean flag for inline queries - :return: Boolean indicating if the user is allowed to use the bot - """ - name = update.inline_query.from_user.name if is_inline else update.message.from_user.name - user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id - - if not await self.is_allowed(update, context, is_inline=is_inline): - logging.warning(f'User {name} (id: {user_id}) ' - f'is not allowed to use the bot') - await self.send_disallowed_message(update, context, is_inline) - return False - if not self.is_within_budget(update, is_inline=is_inline): - logging.warning(f'User {name} (id: {user_id}) ' - f'reached their usage limit') - await self.send_budget_reached_message(update, context, is_inline) - return False - - return True - - def add_chat_request_to_usage_tracker(self, user_id, used_tokens): - try: - # add chat request to users usage tracker - self.usage[user_id].add_chat_tokens(used_tokens, self.config['token_price']) - # add guest chat request to guest usage tracker - allowed_user_ids = self.config['allowed_user_ids'].split(',') - if str(user_id) not in allowed_user_ids and 'guests' in self.usage: - self.usage["guests"].add_chat_tokens(used_tokens, self.config['token_price']) - except Exception as e: - logging.warning(f'Failed to add tokens to usage_logs: {str(e)}') - pass - - def get_reply_to_message_id(self, update: Update): - """ - Returns the message id of the message to reply to - :param update: Telegram update object - :return: Message id of the message to reply to, or None if quoting is disabled - """ - if self.config['enable_quoting'] or self.is_group_chat(update): - return update.message.message_id - return None - - def split_into_chunks(self, text: str, chunk_size: int = 4096) -> list[str]: - """ - Splits a string into chunks of a given size. - """ - return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] - async def post_init(self, application: Application) -> None: """ Post initialization hook for the bot. @@ -1036,6 +760,6 @@ def run(self): ])) application.add_handler(CallbackQueryHandler(self.handle_callback_inline_query)) - application.add_error_handler(self.error_handler) + application.add_error_handler(error_handler) application.run_polling() diff --git a/bot/usage_tracker.py b/bot/usage_tracker.py index 0131a5ac..c733f9eb 100644 --- a/bot/usage_tracker.py +++ b/bot/usage_tracker.py @@ -175,6 +175,9 @@ def add_transcription_seconds(self, seconds, minute_price=0.006): json.dump(self.usage, outfile) def add_current_costs(self, request_cost): + """ + Add current cost to all_time, day and month cost and update last_update date. + """ today = date.today() last_update = date.fromisoformat(self.usage["current_cost"]["last_update"]) diff --git a/bot/utils.py b/bot/utils.py new file mode 100644 index 00000000..04c46a3e --- /dev/null +++ b/bot/utils.py @@ -0,0 +1,307 @@ +from __future__ import annotations + +import asyncio +import itertools +import logging + +import telegram +from telegram import Message, MessageEntity, Update, ChatMember, constants +from telegram.ext import CallbackContext, ContextTypes + +from usage_tracker import UsageTracker + + +def message_text(message: Message) -> str: + """ + Returns the text of a message, excluding any bot commands. + """ + message_txt = message.text + if message_txt is None: + return '' + + for _, text in sorted(message.parse_entities([MessageEntity.BOT_COMMAND]).items(), + key=(lambda item: item[0].offset)): + message_txt = message_txt.replace(text, '').strip() + + return message_txt if len(message_txt) > 0 else '' + + +async def is_user_in_group(update: Update, context: CallbackContext, user_id: int) -> bool: + """ + Checks if user_id is a member of the group + """ + try: + chat_member = await context.bot.get_chat_member(update.message.chat_id, user_id) + return chat_member.status in [ChatMember.OWNER, ChatMember.ADMINISTRATOR, ChatMember.MEMBER] + except telegram.error.BadRequest as e: + if str(e) == "User not found": + return False + else: + raise e + except Exception as e: + raise e + + +def get_thread_id(update: Update) -> int | None: + """ + Gets the message thread id for the update, if any + """ + if update.effective_message and update.effective_message.is_topic_message: + return update.effective_message.message_thread_id + return None + + +def get_stream_cutoff_values(update: Update, content: str) -> int: + """ + Gets the stream cutoff values for the message length + """ + if is_group_chat(update): + # group chats have stricter flood limits + return 180 if len(content) > 1000 else 120 if len(content) > 200 \ + else 90 if len(content) > 50 else 50 + return 90 if len(content) > 1000 else 45 if len(content) > 200 \ + else 25 if len(content) > 50 else 15 + + +def is_group_chat(update: Update) -> bool: + """ + Checks if the message was sent from a group chat + """ + if not update.effective_chat: + return False + return update.effective_chat.type in [ + constants.ChatType.GROUP, + constants.ChatType.SUPERGROUP + ] + + +def split_into_chunks(text: str, chunk_size: int = 4096) -> list[str]: + """ + Splits a string into chunks of a given size. + """ + return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] + + +async def wrap_with_indicator(update: Update, context: CallbackContext, coroutine, + chat_action: constants.ChatAction = "", is_inline=False): + """ + Wraps a coroutine while repeatedly sending a chat action to the user. + """ + task = context.application.create_task(coroutine(), update=update) + while not task.done(): + if not is_inline: + context.application.create_task( + update.effective_chat.send_action(chat_action, message_thread_id=get_thread_id(update)) + ) + try: + await asyncio.wait_for(asyncio.shield(task), 4.5) + except asyncio.TimeoutError: + pass + + +async def edit_message_with_retry(context: ContextTypes.DEFAULT_TYPE, chat_id: int | None, + message_id: str, text: str, markdown: bool = True, is_inline: bool = False): + """ + Edit a message with retry logic in case of failure (e.g. broken markdown) + :param context: The context to use + :param chat_id: The chat id to edit the message in + :param message_id: The message id to edit + :param text: The text to edit the message with + :param markdown: Whether to use markdown parse mode + :param is_inline: Whether the message to edit is an inline message + :return: None + """ + try: + await context.bot.edit_message_text( + chat_id=chat_id, + message_id=int(message_id) if not is_inline else None, + inline_message_id=message_id if is_inline else None, + text=text, + parse_mode=constants.ParseMode.MARKDOWN if markdown else None + ) + except telegram.error.BadRequest as e: + if str(e).startswith("Message is not modified"): + return + try: + await context.bot.edit_message_text( + chat_id=chat_id, + message_id=int(message_id) if not is_inline else None, + inline_message_id=message_id if is_inline else None, + text=text + ) + except Exception as e: + logging.warning(f'Failed to edit message: {str(e)}') + raise e + + except Exception as e: + logging.warning(str(e)) + raise e + + +async def error_handler(_: object, context: ContextTypes.DEFAULT_TYPE) -> None: + """ + Handles errors in the telegram-python-bot library. + """ + logging.error(f'Exception while handling an update: {context.error}') + + +async def is_allowed(config, update: Update, context: CallbackContext, is_inline=False) -> bool: + """ + Checks if the user is allowed to use the bot. + """ + if config['allowed_user_ids'] == '*': + return True + + user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id + if is_admin(config, user_id): + return True + name = update.inline_query.from_user.name if is_inline else update.message.from_user.name + allowed_user_ids = config['allowed_user_ids'].split(',') + # Check if user is allowed + if str(user_id) in allowed_user_ids: + return True + # Check if it's a group a chat with at least one authorized member + if not is_inline and is_group_chat(update): + admin_user_ids = config['admin_user_ids'].split(',') + for user in itertools.chain(allowed_user_ids, admin_user_ids): + if not user.strip(): + continue + if await is_user_in_group(update, context, user): + logging.info(f'{user} is a member. Allowing group chat message...') + return True + logging.info(f'Group chat messages from user {name} ' + f'(id: {user_id}) are not allowed') + return False + +def is_admin(config, user_id: int, log_no_admin=False) -> bool: + """ + Checks if the user is the admin of the bot. + The first user in the user list is the admin. + """ + if config['admin_user_ids'] == '-': + if log_no_admin: + logging.info('No admin user defined.') + return False + + admin_user_ids = config['admin_user_ids'].split(',') + + # Check if user is in the admin user list + if str(user_id) in admin_user_ids: + return True + + return False + + +def get_user_budget(config, user_id) -> float | None: + """ + Get the user's budget based on their user ID and the bot configuration. + :param config: The bot configuration object + :param user_id: User id + :return: The user's budget as a float, or None if the user is not found in the allowed user list + """ + + # no budget restrictions for admins and '*'-budget lists + if is_admin(config, user_id) or config['user_budgets'] == '*': + return float('inf') + + user_budgets = config['user_budgets'].split(',') + if config['allowed_user_ids'] == '*': + # same budget for all users, use value in first position of budget list + if len(user_budgets) > 1: + logging.warning('multiple values for budgets set with unrestricted user list ' + 'only the first value is used as budget for everyone.') + return float(user_budgets[0]) + + allowed_user_ids = config['allowed_user_ids'].split(',') + if str(user_id) in allowed_user_ids: + user_index = allowed_user_ids.index(str(user_id)) + if len(user_budgets) <= user_index: + logging.warning(f'No budget set for user id: {user_id}. Budget list shorter than user list.') + return 0.0 + return float(user_budgets[user_index]) + return None + + +def get_remaining_budget(config, usage, update: Update, is_inline=False) -> float: + """ + Calculate the remaining budget for a user based on their current usage. + :param config: The bot configuration object + :param usage: The usage tracker object + :param update: Telegram update object + :param is_inline: Boolean flag for inline queries + :return: The remaining budget for the user as a float + """ + # Mapping of budget period to cost period + budget_cost_map = { + "monthly": "cost_month", + "daily": "cost_today", + "all-time": "cost_all_time" + } + + user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id + name = update.inline_query.from_user.name if is_inline else update.message.from_user.name + if user_id not in usage: + usage[user_id] = UsageTracker(user_id, name) + + # Get budget for users + user_budget = get_user_budget(config, user_id) + budget_period = config['budget_period'] + if user_budget is not None: + cost = usage[user_id].get_current_cost()[budget_cost_map[budget_period]] + return user_budget - cost + + # Get budget for guests + if 'guests' not in usage: + usage['guests'] = UsageTracker('guests', 'all guest users in group chats') + cost = usage['guests'].get_current_cost()[budget_cost_map[budget_period]] + return config['guest_budget'] - cost + + +def is_within_budget(config, usage, update: Update, is_inline=False) -> bool: + """ + Checks if the user reached their usage limit. + Initializes UsageTracker for user and guest when needed. + :param config: The bot configuration object + :param usage: The usage tracker object + :param update: Telegram update object + :param is_inline: Boolean flag for inline queries + :return: Boolean indicating if the user has a positive budget + """ + user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id + name = update.inline_query.from_user.name if is_inline else update.message.from_user.name + if user_id not in usage: + usage[user_id] = UsageTracker(user_id, name) + remaining_budget = get_remaining_budget(config, usage, update, is_inline=is_inline) + return remaining_budget > 0 + + +def add_chat_request_to_usage_tracker(usage, config, user_id, used_tokens): + """ + Add chat request to usage tracker + :param usage: The usage tracker object + :param config: The bot configuration object + :param user_id: The user id + :param used_tokens: The number of tokens used + """ + try: + # add chat request to users usage tracker + usage[user_id].add_chat_tokens(used_tokens, config['token_price']) + # add guest chat request to guest usage tracker + allowed_user_ids = config['allowed_user_ids'].split(',') + if str(user_id) not in allowed_user_ids and 'guests' in usage: + usage["guests"].add_chat_tokens(used_tokens, config['token_price']) + except Exception as e: + logging.warning(f'Failed to add tokens to usage_logs: {str(e)}') + pass + + +def get_reply_to_message_id(config, update: Update): + """ + Returns the message id of the message to reply to + :param config: Bot configuration object + :param update: Telegram update object + :return: Message id of the message to reply to, or None if quoting is disabled + """ + if config['enable_quoting'] or is_group_chat(update): + return update.message.message_id + return None