Skip to content

Commit

Permalink
Add new tools!
Browse files Browse the repository at this point in the history
Enias Cailliau committed May 12, 2023
1 parent ebfa343 commit 3a93ae3
Showing 7 changed files with 262 additions and 14 deletions.
10 changes: 6 additions & 4 deletions src/agent/get_agent.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,8 @@
from steamship_langchain.tools import SteamshipSERP

from agent.parser import get_format_instructions, CustomParser
from agent.tools.image import GenerateImageTool
from agent.tools.my_tool import MyTool
from agent.tools.reminder import RemindMe

MODEL_NAME = "gpt-3.5-turbo" # or "gpt-4.0"
@@ -34,7 +36,9 @@ def get_tools(client: Steamship, invoke_later: Callable, chat_id: str) -> List[T
name="Search",
func=search.search,
description="useful for when you need to answer questions about current events",
)
),
MyTool(client),
GenerateImageTool(client),
]


@@ -53,6 +57,4 @@ def get_agent(client: Steamship, chat_id: str, invoke_later: Callable) -> AgentE
format_instructions=get_format_instructions(bool(tools)),
output_parser=CustomParser(),
)
return AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, verbose=VERBOSE
)
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=VERBOSE)
16 changes: 16 additions & 0 deletions src/agent/parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import re
from typing import Union

from langchain.agents.mrkl.output_parser import MRKLOutputParser
from langchain.schema import AgentAction, AgentFinish

from agent.utils import UUID_PATTERN

FINAL_ANSWER_ACTION = "Final Answer:"

@@ -31,3 +37,13 @@ def get_format_instructions(has_tools=True) -> str:
class CustomParser(MRKLOutputParser):
def get_format_instructions(self) -> str:
return get_format_instructions(True)

def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
if FINAL_ANSWER_ACTION in text:
output = text.split(FINAL_ANSWER_ACTION)[-1].strip()
output = UUID_PATTERN.split(output)
output = [re.sub(r"^\W+", "", el) for el in output]

return AgentFinish({"output": output}, text)
cleaned_output = super().parse(text)
return cleaned_output
55 changes: 55 additions & 0 deletions src/agent/tools/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Tool for generating images."""
import json
import logging

from langchain.agents import Tool
from steamship import Steamship
from steamship.base.error import SteamshipError
from steamship.data.plugin.plugin_instance import PluginInstance

NAME = "GenerateImage"

DESCRIPTION = """
Useful for when you need to generate an image.
Input: A detailed dall-e prompt describing an image
Output: the UUID of a generated image
"""

PLUGIN_HANDLE = "stable-diffusion"


class GenerateImageTool(Tool):
"""Tool used to generate images from a text-prompt."""

client: Steamship

def __init__(self, client: Steamship):
super().__init__(
name=NAME, func=self.run, description=DESCRIPTION, client=client
)

@property
def is_single_input(self) -> bool:
"""Whether the tool only accepts a single input."""
return True

def run(self, prompt: str, **kwargs) -> str:
"""Respond to LLM prompt."""

# Use the Steamship DALL-E plugin.
image_generator = self.client.use_plugin(
plugin_handle=PLUGIN_HANDLE, config={"n": 1, "size": "768x768"}
)

logging.info(f"[{self.name}] {prompt}")
if not isinstance(prompt, str):
prompt = json.dumps(prompt)

task = image_generator.generate(text=prompt, append_output_to_file=True)
task.wait()
blocks = task.output.blocks
logging.info(f"[{self.name}] got back {len(blocks)} blocks")
if len(blocks) > 0:
logging.info(f"[{self.name}] image size: {len(blocks[0].raw())}")
return blocks[0].id
raise SteamshipError(f"[{self.name}] Tool unable to generate image!")
45 changes: 45 additions & 0 deletions src/agent/tools/my_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Use this file to create your own tool."""
import logging

from langchain import LLMChain, PromptTemplate
from langchain.agents import Tool
from steamship import Steamship
from steamship_langchain.llms.openai import OpenAI

NAME = "MyTool"

DESCRIPTION = """
Useful for when you need to come up with todo lists.
Input: an objective to create a todo list for.
Output: a todo list for that objective. Please be very clear what the objective is!
"""

PROMPT = """
You are a planner who is an expert at coming up with a todo list for a given objective.
Come up with a todo list for this objective: {objective}"
"""


class MyTool(Tool):
"""Tool used to manage to-do lists."""

client: Steamship

def __init__(self, client: Steamship):
super().__init__(
name=NAME, func=self.run, description=DESCRIPTION, client=client
)

def _get_chain(self, client):
todo_prompt = PromptTemplate.from_template(PROMPT)
return LLMChain(llm=OpenAI(client=client, temperature=0), prompt=todo_prompt)

@property
def is_single_input(self) -> bool:
"""Whether the tool only accepts a single input."""
return True

def run(self, prompt: str, **kwargs) -> str:
"""Respond to LLM prompts."""
chain = self._get_chain(self.client)
return chain.predict(objective=prompt)
39 changes: 39 additions & 0 deletions src/agent/tools/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Tool for searching the web."""

from langchain.agents import Tool
from steamship import Steamship
from steamship_langchain.tools import SteamshipSERP

NAME = "Search"

DESCRIPTION = """
Useful for when you need to answer questions about current events
"""


class SearchTool(Tool):
"""Tool used to search for information using SERP API."""

client: Steamship

def __init__(self, client: Steamship):
super().__init__(
name=NAME, func=self.run, description=DESCRIPTION, client=client
)

@property
def is_single_input(self) -> bool:
"""Whether the tool only accepts a single input."""
return True

def run(self, prompt: str, **kwargs) -> str:
"""Respond to LLM prompts."""
search = SteamshipSERP(client=self.client)
return search.search(prompt)


if __name__ == "__main__":
with Steamship.temporary_workspace() as client:
my_tool = SearchTool(client)
result = my_tool.run("What's the weather today?")
print(result)
49 changes: 49 additions & 0 deletions src/agent/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import logging
import re
import uuid

from steamship.data.workspace import SignedUrl
from steamship.utils.signed_urls import upload_to_signed_url

UUID_PATTERN = re.compile(
r"([0-9A-Za-z]{8}-[0-9A-Za-z]{4}-[0-9A-Za-z]{4}-[0-9A-Za-z]{4}-[0-9A-Za-z]{12})"
)


def is_valid_uuid(uuid_to_test: str, version=4) -> bool:
"""Check a string to see if it is actually a UUID."""
lowered = uuid_to_test.lower()
try:
uuid_obj = uuid.UUID(lowered, version=version)
except ValueError:
return False
return str(uuid_obj) == lowered


def make_image_public(client, block):
filepath = str(uuid.uuid4())
signed_url = (
client.get_workspace()
.create_signed_url(
SignedUrl.Request(
bucket=SignedUrl.Bucket.PLUGIN_DATA,
filepath=filepath,
operation=SignedUrl.Operation.WRITE,
)
)
.signed_url
)
logging.info(f"Got signed url for uploading block content: {signed_url}")
read_signed_url = (
client.get_workspace()
.create_signed_url(
SignedUrl.Request(
bucket=SignedUrl.Bucket.PLUGIN_DATA,
filepath=filepath,
operation=SignedUrl.Operation.READ,
)
)
.signed_url
)
upload_to_signed_url(signed_url, block.raw())
return read_signed_url
62 changes: 52 additions & 10 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Scaffolding to host your LangChain Chatbot on Steamship and connect it to Telegram."""
from typing import List, Optional

from steamship import Steamship
from steamship import Steamship, Block
from steamship.experimental.package_starters.telegram_bot import TelegramBot
from steamship.experimental.transports.chat import ChatMessage
from steamship.invocable import post

from agent.get_agent import get_agent
from agent.utils import is_valid_uuid, make_image_public


class LangChainTelegramChatbot(TelegramBot):
@@ -28,23 +29,64 @@ def _invoke_later(self, delay_ms: int, message: str, chat_id: str):
)

def create_response(
self, incoming_message: ChatMessage
self, incoming_message: ChatMessage
) -> Optional[List[ChatMessage]]:
"""Use the LLM to prepare the next response by appending the user input to the file and then generating."""
if incoming_message.text == "/start":
return [ChatMessage(text="New conversation started.",
chat_id=incoming_message.get_chat_id())]
return [
ChatMessage(
text="New conversation started.",
chat_id=incoming_message.get_chat_id(),
)
]

conversation = get_agent(self.client,
chat_id=incoming_message.get_chat_id(),
invoke_later=self._invoke_later)
conversation = get_agent(
self.client,
chat_id=incoming_message.get_chat_id(),
invoke_later=self._invoke_later,
)
response = conversation.run(input=incoming_message.text)

return [ChatMessage(text=response, chat_id=incoming_message.get_chat_id())]
return self.agent_output_to_chat_messages(
chat_id=incoming_message.get_chat_id(), agent_output=response
)

def agent_output_to_chat_messages(
self, chat_id: str, agent_output: List[str]
) -> List[ChatMessage]:
"""Transform the output of the Multi-Modal Agent into a list of ChatMessage objects.
The response of a ulti-Modal Agent contains one or more:
- parseable UUIDs, representing a block containing binary data, or:
- Text
This method inspects each string and creates a ChatMessage of the appropriate type.
"""
ret = []
for part_response in agent_output:
if is_valid_uuid(part_response):
block = Block.get(self.client, _id=part_response)
message = ChatMessage.from_block(
block,
chat_id=chat_id,
)
message.url = make_image_public(self.client, block)

else:
message = ChatMessage(
client=self.client,
chat_id=chat_id,
text=part_response,
)

ret.append(message)
return ret


if __name__ == '__main__':
if __name__ == "__main__":
client = Steamship()
bot = LangChainTelegramChatbot(client=client, config={"bot_token": "test"})
answer = bot.create_response(ChatMessage(text="Hi bro", chat_id="2"))
answer = bot.create_response(
ChatMessage(text="Hi bro, generate me an image of a cat", chat_id="2")
)
print("answer", answer)

0 comments on commit 3a93ae3

Please sign in to comment.