Skip to content

Commit

Permalink
fix: fixes to the langgraph example agent (#4771)
Browse files Browse the repository at this point in the history
* Fixes to the langgraph agent

* Style fixes
  • Loading branch information
Jgilhuly authored Sep 27, 2024
1 parent e616fcc commit d457f70
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 88 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ pyrightconfig.json

# Demo data
tutorials/internal/demo_llama_index/*.json
examples/agent_framework_comparison/utils/saved_traces/*.parquet
.env
.conda
3 changes: 2 additions & 1 deletion examples/agent_framework_comparison/code_based_agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

sys.path.insert(1, os.path.join(sys.path[0], ".."))
from router import router

from utils.instrument import Framework, instrument


Expand All @@ -32,5 +33,5 @@ def launch_app():


if __name__ == "__main__":
instrument(project_name="code-based-agent", framework=Framework.CODE_BASED)
instrument(project_name="agent-demo", framework=Framework.CODE_BASED)
launch_app()
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def router(messages, parent_context):
span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, "CHAIN")
span.set_attribute(SpanAttributes.INPUT_VALUE, str(messages))
span.set_attribute(SpanAttributes.INPUT_MIME_TYPE, "application/json")
span.set_attribute(SpanAttributes.LLM_TOOLS, str(skill_map.get_function_list()))

if not any(
isinstance(message, dict) and message.get("role") == "system" for message in messages
Expand Down
2 changes: 1 addition & 1 deletion examples/agent_framework_comparison/db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,4 @@ def get_table():


if __name__ == "__main__":
print(run_query("SELECT name FROM pragma_table_info('traces')"))
print(run_query("SELECT attributes.retrieval.documents FROM traces"))
64 changes: 17 additions & 47 deletions examples/agent_framework_comparison/langgraph/analyze_data.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,39 @@
import json
import os
import sys

sys.path.insert(1, os.path.join(sys.path[0], ".."))

from dotenv import load_dotenv
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from openai import OpenAI
from langchain_openai import ChatOpenAI
from openinference.instrumentation import using_prompt_template
from openinference.semconv.trace import SpanAttributes
from opentelemetry import trace
from prompt_templates.data_analysis_template import PROMPT_TEMPLATE, SYSTEM_PROMPT

load_dotenv()


@tool
def data_analyzer(args):
def data_analyzer(original_prompt: str, data: str):
"""Provides insights, trends, or analysis based on the data and prompt.
Args:
args (dict): A dictionary containing the data to analyze and the original user prompt
that the data is based on.
original_prompt (str): The original user prompt that the data is based on.
data (str): The data to analyze.
Returns:
str: The analysis result.
"""

if isinstance(args, dict) and "prompt" in args and "data" in args:
prompt = args["prompt"]
data = args["data"]
elif isinstance(args, str):
try:
args = json.loads(args)
prompt = args["prompt"].strip()
data = args["data"].strip()
except ValueError:
return "Invalid input: expected a dictionary with 'prompt' and 'data' keys or a string."
else:
return "Invalid input: expected a dictionary with 'prompt' and 'data' keys or a string."

client = OpenAI()

tracer = trace.get_tracer(__name__)
with tracer.start_as_current_span("data_analysis_tool") as span:
span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, "CHAIN")
span.set_attribute(
SpanAttributes.INPUT_VALUE,
PROMPT_TEMPLATE.format(PROMPT=prompt, DATA=data),
)
with using_prompt_template(
template=PROMPT_TEMPLATE,
variables={"PROMPT": prompt, "DATA": data},
version="v0.1",
):
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": PROMPT_TEMPLATE.format(PROMPT=prompt, DATA=data),
},
],
)
analysis_result = response.choices[0].message.content
span.set_attribute(SpanAttributes.OUTPUT_VALUE, analysis_result)
return analysis_result
with using_prompt_template(
template=PROMPT_TEMPLATE,
variables={"PROMPT": original_prompt, "DATA": data},
version="v0.1",
):
model = ChatOpenAI(model="gpt-4o")
messages = [
SystemMessage(content=SYSTEM_PROMPT),
HumanMessage(content=PROMPT_TEMPLATE.format(PROMPT=original_prompt, DATA=data)),
]
response = model.invoke(messages)
return response.content
55 changes: 23 additions & 32 deletions examples/agent_framework_comparison/langgraph/generate_sql_query.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import os
import sys

from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from openai import OpenAI
from langchain_openai import ChatOpenAI
from openinference.instrumentation import using_prompt_template
from openinference.semconv.trace import SpanAttributes
from opentelemetry import trace

sys.path.insert(1, os.path.join(sys.path[0], ".."))

Expand All @@ -14,16 +13,19 @@


@tool
def generate_and_run_sql_query(query: str):
def generate_and_run_sql_query(original_prompt: str):
"""Generates and runs an SQL query based on the prompt.
Args:
query (str): A string containing the original user prompt.
original_prompt (str): A string containing the original user prompt.
Returns:
str: The result of the SQL query.
"""
return _generate_and_run_sql_query(original_prompt, retry=True)


def _generate_and_run_sql_query(original_prompt: str, retry: bool = False):
def _sanitize_query(query):
# Remove triple backticks from the query if present
query = query.strip()
Expand All @@ -35,14 +37,6 @@ def _sanitize_query(query):
query = query[:-3].strip()
return query

if isinstance(query, dict) and "prompt" in query:
prompt = query["prompt"]
elif isinstance(query, str):
prompt = query
else:
return "Invalid input: expected a dictionary with 'prompt' key or a string."

client = OpenAI()
table = get_table()
schema = get_schema()

Expand All @@ -51,24 +45,21 @@ def _sanitize_query(query):
variables={"SCHEMA": schema, "TABLE": table},
version="v0.1",
):
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": SYSTEM_PROMPT.format(SCHEMA=schema, TABLE=table),
},
{"role": "user", "content": prompt},
],
)
model = ChatOpenAI(model="gpt-4o")
messages = [
SystemMessage(content=SYSTEM_PROMPT.format(SCHEMA=schema, TABLE=table)),
HumanMessage(content=original_prompt),
]
response = model.invoke(messages)

sql_query = response.choices[0].message.content
sql_query = response.content

tracer = trace.get_tracer(__name__)
with tracer.start_as_current_span("run_sql_query") as span:
span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, "CHAIN")
span.set_attribute(SpanAttributes.INPUT_VALUE, sql_query)
sanitized_query = _sanitize_query(sql_query)
results = str(run_query(sanitized_query))
span.set_attribute(SpanAttributes.OUTPUT_VALUE, results)
return results
sanitized_query = _sanitize_query(sql_query)
results = str(run_query(sanitized_query))
if "An error occurred" in results and retry:
retry_str = (
f"\n I've already tried this query: {sql_query} \n"
f"and got this error: {results} \n Please try again."
)
return _generate_and_run_sql_query(original_prompt + retry_str, retry=False)
return results
3 changes: 2 additions & 1 deletion examples/agent_framework_comparison/langgraph/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
sys.path.insert(1, os.path.join(sys.path[0], ".."))
import gradio as gr
from langgraph.router import run_agent

from utils.instrument import Framework, instrument


Expand All @@ -17,5 +18,5 @@ def launch_app():


if __name__ == "__main__":
instrument(project_name="langgraph-agent", framework=Framework.LANGGRAPH)
instrument(project_name="langgraph-agent-demo", framework=Framework.LANGGRAPH)
launch_app()
2 changes: 1 addition & 1 deletion examples/agent_framework_comparison/li_workflow/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def gradio_interface(message, history):


def launch_app():
iface = gr.ChatInterface(fn=gradio_interface, title="Data Analyst Agent")
iface = gr.ChatInterface(fn=gradio_interface, title="LlamaIndex Workflow Agent")
iface.launch()


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
SYSTEM_PROMPT = """
You are a helpful assistant that choses a tool to call based on the user's request.
All of your responses should be a tool call or text.
All of your responses should be a tool call or text. Only generate tool calls or text.
If you generate a tool call, be sure you include the original prompt as is in the parameters.
Once you receive the results from all of your skills, generate a response to the
user that incorporates all of the results.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
# This is used primarily to create example sets of traces for each agent.
# Likely not needed for most users.
def save_agent_traces(project_name: str):
directory = "utils/saved_traces"
directory = "examples/agent_framework_comparison/utils/saved_traces"
os.makedirs(directory, exist_ok=True)

# Save the Trace Dataset
px.Client().get_trace_dataset(project_name=project_name).save(directory=directory)
traces = px.Client().get_trace_dataset(project_name=project_name)
traces.save(directory=directory)


if __name__ == "__main__":
save_agent_traces(project_name="function-calling-agent-demo")
save_agent_traces(project_name="agent-demo")

0 comments on commit d457f70

Please sign in to comment.