Skip to content

Commit

Permalink
Add auto_dalle plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
Jipok committed Dec 17, 2023
1 parent 335344c commit b9c9dda
Show file tree
Hide file tree
Showing 21 changed files with 192 additions and 125 deletions.
22 changes: 14 additions & 8 deletions bot/openai_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_conversation_stats(self, chat_id: int) -> tuple[int, int]:
self.reset_chat_history(chat_id)
return len(self.conversations[chat_id]), self.__count_tokens(self.conversations[chat_id])

async def get_chat_response(self, chat_id: int, query: str) -> tuple[str, str]:
async def get_chat_response(self, bot: ChatGPTTelegramBot, tg_upd: telegram.Update, chat_id: int, query: str) -> tuple[str, str]:
"""
Gets a full response from the GPT model.
:param chat_id: The chat ID
Expand All @@ -132,7 +132,7 @@ async def get_chat_response(self, chat_id: int, query: str) -> tuple[str, str]:
plugins_used = ()
response = await self.__common_get_chat_response(chat_id, query)
if self.config['enable_functions'] and not self.conversations_vision[chat_id]:
response, plugins_used = await self.__handle_function_call(chat_id, response)
response, plugins_used = await self.__handle_function_call(bot, tg_upd, chat_id, response)
if is_direct_result(response):
return response, '0'

Expand Down Expand Up @@ -165,17 +165,19 @@ async def get_chat_response(self, chat_id: int, query: str) -> tuple[str, str]:

return answer, response.usage.total_tokens

async def get_chat_response_stream(self, chat_id: int, query: str):
async def get_chat_response_stream(self, bot: ChatGPTTelegramBot, tg_upd: telegram.Update, chat_id: int, query: str):
"""
Stream response from the GPT model.
:param chat_id: The chat ID
:param query: The query to send to the model
:return: The answer from the model and the number of tokens used, or 'not_finished'
"""
import telegram_bot
plugins_used = ()
response = await self.__common_get_chat_response(chat_id, query, stream=True)
if self.config['enable_functions'] and not self.conversations_vision[chat_id]:
response, plugins_used = await self.__handle_function_call(chat_id, response, stream=True)

if self.config['enable_functions']:
response, plugins_used = await self.__handle_function_call(bot, tg_upd, chat_id, response, stream=True)
if is_direct_result(response):
yield response, '0'
return
Expand Down Expand Up @@ -269,7 +271,7 @@ async def __common_get_chat_response(self, chat_id: int, query: str, stream=Fals
except Exception as e:
raise Exception(f"⚠️ _{localized_text('error', bot_language)}._ ⚠️\n{str(e)}") from e

async def __handle_function_call(self, chat_id, response, stream=False, times=0, plugins_used=()):
async def __handle_function_call(self, bot: ChatGPTTelegramBot, tg_upd: telegram.Update, chat_id, response, stream=False, times=0, plugins_used=()):
function_name = ''
arguments = ''
if stream:
Expand Down Expand Up @@ -301,11 +303,15 @@ async def __handle_function_call(self, chat_id, response, stream=False, times=0,
return response, plugins_used

logging.info(f'Calling function {function_name} with arguments {arguments}')
function_response = await self.plugin_manager.call_function(function_name, self, arguments)
function_response, function_response_dict = await self.plugin_manager.call_function(bot, tg_upd, chat_id, function_name, arguments)

if function_name not in plugins_used:
plugins_used += (function_name,)

# if "result" in function_response_dict and function_response_dict["result"] == "Success":
# self.__add_function_call_to_history(chat_id=chat_id, function_name=function_name, content=function_response)
# return response, plugins_used

if is_direct_result(function_response):
self.__add_function_call_to_history(chat_id=chat_id, function_name=function_name,
content=json.dumps({'result': 'Done, the content has been sent'
Expand All @@ -320,7 +326,7 @@ async def __handle_function_call(self, chat_id, response, stream=False, times=0,
function_call='auto' if times < self.config['functions_max_consecutive_calls'] else 'none',
stream=stream
)
return await self.__handle_function_call(chat_id, response, stream, times + 1, plugins_used)
return await self.__handle_function_call(bot, tg_upd, chat_id, response, stream, times + 1, plugins_used)

async def generate_image(self, prompt: str) -> tuple[str, str]:
"""
Expand Down
7 changes: 5 additions & 2 deletions bot/plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from plugins.gtts_text_to_speech import GTTSTextToSpeech
from plugins.auto_tts import AutoTextToSpeech
from plugins.auto_dalle import AutoDalle
from plugins.dice import DicePlugin
from plugins.youtube_audio_extractor import YouTubeAudioExtractorPlugin
from plugins.ddg_image_search import DDGImageSearchPlugin
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(self, config):
'deepl_translate': DeeplTranslatePlugin,
'gtts_text_to_speech': GTTSTextToSpeech,
'auto_tts': AutoTextToSpeech,
'auto_dalle': AutoDalle,
'whois': WhoisPlugin,
'webshot': WebshotPlugin,
}
Expand All @@ -49,14 +51,15 @@ def get_functions_specs(self):
"""
return [spec for specs in map(lambda plugin: plugin.get_spec(), self.plugins) for spec in specs]

async def call_function(self, function_name, helper, arguments):
async def call_function(self, bot, tg_upd, chat_id, function_name, arguments):
"""
Call a function based on the name and parameters provided
"""
plugin = self.__get_plugin_by_function_name(function_name)
if not plugin:
return json.dumps({'error': f'Function {function_name} not found'})
return json.dumps(await plugin.execute(function_name, helper, **json.loads(arguments)), default=str)
result = await plugin.execute(function_name, bot, tg_upd, chat_id, **json.loads(arguments))
return json.dumps(result, default=str), result

def get_plugin_source_name(self, function_name) -> str:
"""
Expand Down
40 changes: 40 additions & 0 deletions bot/plugins/auto_dalle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import asyncio
import datetime
import tempfile
import traceback
from typing import Dict
import telegram

from .plugin import Plugin


class AutoDalle(Plugin):
"""
A plugin to generate image using Openai image generation API
"""

def get_source_name(self) -> str:
return "DALLE"

def get_spec(self) -> [Dict]:
return [{
"name": "dalle_image",
"description": "Create image from scratch based on a text prompt (DALL·E 3 and DALL·E 2). Send to user.",
"parameters": {
"type": "object",
"properties": {
"prompt": {"type": "string", "prompt": "Image description. Use English language for better results."},
},
"required": ["prompt"],
},
}]

async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
await bot.wrap_with_indicator(tg_upd, bot.image_gen(tg_upd, kwargs['prompt']), "upload_photo")
return {
'direct_result': {
'kind': 'none',
'format': '',
'value': 'none',
}
}
24 changes: 9 additions & 15 deletions bot/plugins/auto_tts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import tempfile
from typing import Dict
import telegram

from .plugin import Plugin

Expand All @@ -15,8 +16,8 @@ def get_source_name(self) -> str:

def get_spec(self) -> [Dict]:
return [{
"name": "translate_text_to_speech",
"description": "Translate text to speech using OpenAI API",
"name": "translate_text_to_speech_and_send",
"description": "Translate text to speech using OpenAI API and send result to user.",
"parameters": {
"type": "object",
"properties": {
Expand All @@ -26,19 +27,12 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
try:
bytes, text_length = await helper.generate_speech(text=kwargs['text'])
with tempfile.NamedTemporaryFile(delete=False, suffix='.opus') as temp_file:
temp_file.write(bytes.getvalue())
temp_file_path = temp_file.name
except Exception as e:
logging.exception(e)
return {"Result": "Exception: " + str(e)}
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
await bot.wrap_with_indicator(tg_upd, bot.tts_gen(tg_upd, kwargs['text']), "record_voice")
return {
'direct_result': {
'kind': 'file',
'format': 'path',
'value': temp_file_path
'kind': 'none',
'format': '',
'value': 'none',
}
}
}
3 changes: 2 additions & 1 deletion bot/plugins/crypto.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import telegram
from typing import Dict

import requests
Expand Down Expand Up @@ -26,5 +27,5 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
return requests.get(f"https://api.coincap.io/v2/rates/{kwargs['asset']}").json()
3 changes: 2 additions & 1 deletion bot/plugins/ddg_image_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import random
import telegram
from itertools import islice
from typing import Dict

Expand Down Expand Up @@ -49,7 +50,7 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
with DDGS() as ddgs:
image_type = kwargs.get('type', 'photo')
ddgs_images_gen = ddgs.images(
Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/ddg_translate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import telegram
from typing import Dict

from duckduckgo_search import DDGS
Expand Down Expand Up @@ -26,6 +27,6 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
with DDGS() as ddgs:
return ddgs.translate(kwargs['text'], to=kwargs['to_language'])
3 changes: 2 additions & 1 deletion bot/plugins/ddg_web_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from itertools import islice
import telegram
from typing import Dict

from duckduckgo_search import DDGS
Expand Down Expand Up @@ -46,7 +47,7 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
with DDGS() as ddgs:
ddgs_gen = ddgs.text(
kwargs['query'],
Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/deepl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Dict

import telegram
import requests

from .plugin import Plugin
Expand Down Expand Up @@ -33,7 +34,7 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
if self.api_key.endswith(':fx'):
url = "https://api-free.deepl.com/v2/translate"
else:
Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/dice.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import telegram
from typing import Dict

from .plugin import Plugin
Expand Down Expand Up @@ -28,7 +29,7 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
return {
'direct_result': {
'kind': 'dice',
Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/gtts_text_to_speech.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import telegram
from typing import Dict

from gtts import gTTS
Expand Down Expand Up @@ -31,7 +32,7 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
tts = gTTS(kwargs['text'], lang=kwargs.get('lang', 'en'))
output = f'gtts_{datetime.datetime.now().timestamp()}.mp3'
tts.save(output)
Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/plugin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import telegram
from abc import abstractmethod, ABC
from typing import Dict

Expand All @@ -23,7 +24,7 @@ def get_spec(self) -> [Dict]:
pass

@abstractmethod
async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
"""
Execute the plugin and return a JSON serializable response
"""
Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/spotify.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import telegram
from typing import Dict

import spotipy
Expand Down Expand Up @@ -111,7 +112,7 @@ def get_spec(self) -> [Dict]:
}
]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
time_range = kwargs.get('time_range', 'short_term')
limit = kwargs.get('limit', 5)

Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/weather.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import telegram
from datetime import datetime
from typing import Dict

Expand Down Expand Up @@ -57,7 +58,7 @@ def get_spec(self) -> [Dict]:
}
]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
url = f'https://api.open-meteo.com/v1/forecast' \
f'?latitude={kwargs["latitude"]}' \
f'&longitude={kwargs["longitude"]}' \
Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/webshot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os, requests, random, string
import telegram
from typing import Dict
from .plugin import Plugin

Expand Down Expand Up @@ -26,7 +27,7 @@ def generate_random_string(self, length):
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
try:
image_url = f'https://image.thum.io/get/maxAge/12/width/720/{kwargs["url"]}'

Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/whois_.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import telegram
from typing import Dict
from .plugin import Plugin

Expand All @@ -24,7 +25,7 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
try:
whois_result = whois.query(kwargs['domain'])
if whois_result is None:
Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/wolfram_alpha.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import telegram
from typing import Dict

import wolframalpha
Expand Down Expand Up @@ -32,7 +33,7 @@ def get_spec(self) -> [Dict]:
}
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
client = wolframalpha.Client(self.app_id)
res = client.query(kwargs['query'])
try:
Expand Down
Loading

0 comments on commit b9c9dda

Please sign in to comment.