forked from steamship-core/langchain-production-starter
-
Notifications
You must be signed in to change notification settings - Fork 448
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enias Cailliau
committed
May 12, 2023
1 parent
ebfa343
commit 3a93ae3
Showing
7 changed files
with
262 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
"""Tool for generating images.""" | ||
import json | ||
import logging | ||
|
||
from langchain.agents import Tool | ||
from steamship import Steamship | ||
from steamship.base.error import SteamshipError | ||
from steamship.data.plugin.plugin_instance import PluginInstance | ||
|
||
NAME = "GenerateImage" | ||
|
||
DESCRIPTION = """ | ||
Useful for when you need to generate an image. | ||
Input: A detailed dall-e prompt describing an image | ||
Output: the UUID of a generated image | ||
""" | ||
|
||
PLUGIN_HANDLE = "stable-diffusion" | ||
|
||
|
||
class GenerateImageTool(Tool): | ||
"""Tool used to generate images from a text-prompt.""" | ||
|
||
client: Steamship | ||
|
||
def __init__(self, client: Steamship): | ||
super().__init__( | ||
name=NAME, func=self.run, description=DESCRIPTION, client=client | ||
) | ||
|
||
@property | ||
def is_single_input(self) -> bool: | ||
"""Whether the tool only accepts a single input.""" | ||
return True | ||
|
||
def run(self, prompt: str, **kwargs) -> str: | ||
"""Respond to LLM prompt.""" | ||
|
||
# Use the Steamship DALL-E plugin. | ||
image_generator = self.client.use_plugin( | ||
plugin_handle=PLUGIN_HANDLE, config={"n": 1, "size": "768x768"} | ||
) | ||
|
||
logging.info(f"[{self.name}] {prompt}") | ||
if not isinstance(prompt, str): | ||
prompt = json.dumps(prompt) | ||
|
||
task = image_generator.generate(text=prompt, append_output_to_file=True) | ||
task.wait() | ||
blocks = task.output.blocks | ||
logging.info(f"[{self.name}] got back {len(blocks)} blocks") | ||
if len(blocks) > 0: | ||
logging.info(f"[{self.name}] image size: {len(blocks[0].raw())}") | ||
return blocks[0].id | ||
raise SteamshipError(f"[{self.name}] Tool unable to generate image!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
"""Use this file to create your own tool.""" | ||
import logging | ||
|
||
from langchain import LLMChain, PromptTemplate | ||
from langchain.agents import Tool | ||
from steamship import Steamship | ||
from steamship_langchain.llms.openai import OpenAI | ||
|
||
NAME = "MyTool" | ||
|
||
DESCRIPTION = """ | ||
Useful for when you need to come up with todo lists. | ||
Input: an objective to create a todo list for. | ||
Output: a todo list for that objective. Please be very clear what the objective is! | ||
""" | ||
|
||
PROMPT = """ | ||
You are a planner who is an expert at coming up with a todo list for a given objective. | ||
Come up with a todo list for this objective: {objective}" | ||
""" | ||
|
||
|
||
class MyTool(Tool): | ||
"""Tool used to manage to-do lists.""" | ||
|
||
client: Steamship | ||
|
||
def __init__(self, client: Steamship): | ||
super().__init__( | ||
name=NAME, func=self.run, description=DESCRIPTION, client=client | ||
) | ||
|
||
def _get_chain(self, client): | ||
todo_prompt = PromptTemplate.from_template(PROMPT) | ||
return LLMChain(llm=OpenAI(client=client, temperature=0), prompt=todo_prompt) | ||
|
||
@property | ||
def is_single_input(self) -> bool: | ||
"""Whether the tool only accepts a single input.""" | ||
return True | ||
|
||
def run(self, prompt: str, **kwargs) -> str: | ||
"""Respond to LLM prompts.""" | ||
chain = self._get_chain(self.client) | ||
return chain.predict(objective=prompt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
"""Tool for searching the web.""" | ||
|
||
from langchain.agents import Tool | ||
from steamship import Steamship | ||
from steamship_langchain.tools import SteamshipSERP | ||
|
||
NAME = "Search" | ||
|
||
DESCRIPTION = """ | ||
Useful for when you need to answer questions about current events | ||
""" | ||
|
||
|
||
class SearchTool(Tool): | ||
"""Tool used to search for information using SERP API.""" | ||
|
||
client: Steamship | ||
|
||
def __init__(self, client: Steamship): | ||
super().__init__( | ||
name=NAME, func=self.run, description=DESCRIPTION, client=client | ||
) | ||
|
||
@property | ||
def is_single_input(self) -> bool: | ||
"""Whether the tool only accepts a single input.""" | ||
return True | ||
|
||
def run(self, prompt: str, **kwargs) -> str: | ||
"""Respond to LLM prompts.""" | ||
search = SteamshipSERP(client=self.client) | ||
return search.search(prompt) | ||
|
||
|
||
if __name__ == "__main__": | ||
with Steamship.temporary_workspace() as client: | ||
my_tool = SearchTool(client) | ||
result = my_tool.run("What's the weather today?") | ||
print(result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import logging | ||
import re | ||
import uuid | ||
|
||
from steamship.data.workspace import SignedUrl | ||
from steamship.utils.signed_urls import upload_to_signed_url | ||
|
||
UUID_PATTERN = re.compile( | ||
r"([0-9A-Za-z]{8}-[0-9A-Za-z]{4}-[0-9A-Za-z]{4}-[0-9A-Za-z]{4}-[0-9A-Za-z]{12})" | ||
) | ||
|
||
|
||
def is_valid_uuid(uuid_to_test: str, version=4) -> bool: | ||
"""Check a string to see if it is actually a UUID.""" | ||
lowered = uuid_to_test.lower() | ||
try: | ||
uuid_obj = uuid.UUID(lowered, version=version) | ||
except ValueError: | ||
return False | ||
return str(uuid_obj) == lowered | ||
|
||
|
||
def make_image_public(client, block): | ||
filepath = str(uuid.uuid4()) | ||
signed_url = ( | ||
client.get_workspace() | ||
.create_signed_url( | ||
SignedUrl.Request( | ||
bucket=SignedUrl.Bucket.PLUGIN_DATA, | ||
filepath=filepath, | ||
operation=SignedUrl.Operation.WRITE, | ||
) | ||
) | ||
.signed_url | ||
) | ||
logging.info(f"Got signed url for uploading block content: {signed_url}") | ||
read_signed_url = ( | ||
client.get_workspace() | ||
.create_signed_url( | ||
SignedUrl.Request( | ||
bucket=SignedUrl.Bucket.PLUGIN_DATA, | ||
filepath=filepath, | ||
operation=SignedUrl.Operation.READ, | ||
) | ||
) | ||
.signed_url | ||
) | ||
upload_to_signed_url(signed_url, block.raw()) | ||
return read_signed_url |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters