From d457f708fb9c1edc177727aed52e82fc52badedd Mon Sep 17 00:00:00 2001 From: JohnGilhuly Date: Thu, 26 Sep 2024 17:05:19 -0700 Subject: [PATCH] fix: fixes to the langgraph example agent (#4771) * Fixes to the langgraph agent * Style fixes --- .gitignore | 1 + .../code_based_agent/main.py | 3 +- .../code_based_agent/router.py | 1 - .../agent_framework_comparison/db/database.py | 2 +- .../langgraph/analyze_data.py | 64 +++++-------------- .../langgraph/generate_sql_query.py | 55 +++++++--------- .../langgraph/main.py | 3 +- .../li_workflow/main.py | 2 +- .../prompt_templates/router_template.py | 3 +- .../utils/save_agent_traces.py | 7 +- 10 files changed, 53 insertions(+), 88 deletions(-) diff --git a/.gitignore b/.gitignore index d2f6f76df3..1556d78aae 100644 --- a/.gitignore +++ b/.gitignore @@ -28,5 +28,6 @@ pyrightconfig.json # Demo data tutorials/internal/demo_llama_index/*.json +examples/agent_framework_comparison/utils/saved_traces/*.parquet .env .conda diff --git a/examples/agent_framework_comparison/code_based_agent/main.py b/examples/agent_framework_comparison/code_based_agent/main.py index c13a7371b3..d5a87dac83 100644 --- a/examples/agent_framework_comparison/code_based_agent/main.py +++ b/examples/agent_framework_comparison/code_based_agent/main.py @@ -8,6 +8,7 @@ sys.path.insert(1, os.path.join(sys.path[0], "..")) from router import router + from utils.instrument import Framework, instrument @@ -32,5 +33,5 @@ def launch_app(): if __name__ == "__main__": - instrument(project_name="code-based-agent", framework=Framework.CODE_BASED) + instrument(project_name="agent-demo", framework=Framework.CODE_BASED) launch_app() diff --git a/examples/agent_framework_comparison/code_based_agent/router.py b/examples/agent_framework_comparison/code_based_agent/router.py index e9abce29a4..bfec17d6d2 100644 --- a/examples/agent_framework_comparison/code_based_agent/router.py +++ b/examples/agent_framework_comparison/code_based_agent/router.py @@ -29,7 +29,6 @@ def router(messages, parent_context): span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, "CHAIN") span.set_attribute(SpanAttributes.INPUT_VALUE, str(messages)) span.set_attribute(SpanAttributes.INPUT_MIME_TYPE, "application/json") - span.set_attribute(SpanAttributes.LLM_TOOLS, str(skill_map.get_function_list())) if not any( isinstance(message, dict) and message.get("role") == "system" for message in messages diff --git a/examples/agent_framework_comparison/db/database.py b/examples/agent_framework_comparison/db/database.py index 93905e1fa8..a0621125d2 100644 --- a/examples/agent_framework_comparison/db/database.py +++ b/examples/agent_framework_comparison/db/database.py @@ -79,4 +79,4 @@ def get_table(): if __name__ == "__main__": - print(run_query("SELECT name FROM pragma_table_info('traces')")) + print(run_query("SELECT attributes.retrieval.documents FROM traces")) diff --git a/examples/agent_framework_comparison/langgraph/analyze_data.py b/examples/agent_framework_comparison/langgraph/analyze_data.py index 61c2648557..e88d472113 100644 --- a/examples/agent_framework_comparison/langgraph/analyze_data.py +++ b/examples/agent_framework_comparison/langgraph/analyze_data.py @@ -1,69 +1,39 @@ -import json import os import sys sys.path.insert(1, os.path.join(sys.path[0], "..")) from dotenv import load_dotenv +from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.tools import tool -from openai import OpenAI +from langchain_openai import ChatOpenAI from openinference.instrumentation import using_prompt_template -from openinference.semconv.trace import SpanAttributes -from opentelemetry import trace from prompt_templates.data_analysis_template import PROMPT_TEMPLATE, SYSTEM_PROMPT load_dotenv() @tool -def data_analyzer(args): +def data_analyzer(original_prompt: str, data: str): """Provides insights, trends, or analysis based on the data and prompt. Args: - args (dict): A dictionary containing the data to analyze and the original user prompt - that the data is based on. + original_prompt (str): The original user prompt that the data is based on. + data (str): The data to analyze. Returns: str: The analysis result. """ - if isinstance(args, dict) and "prompt" in args and "data" in args: - prompt = args["prompt"] - data = args["data"] - elif isinstance(args, str): - try: - args = json.loads(args) - prompt = args["prompt"].strip() - data = args["data"].strip() - except ValueError: - return "Invalid input: expected a dictionary with 'prompt' and 'data' keys or a string." - else: - return "Invalid input: expected a dictionary with 'prompt' and 'data' keys or a string." - - client = OpenAI() - - tracer = trace.get_tracer(__name__) - with tracer.start_as_current_span("data_analysis_tool") as span: - span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, "CHAIN") - span.set_attribute( - SpanAttributes.INPUT_VALUE, - PROMPT_TEMPLATE.format(PROMPT=prompt, DATA=data), - ) - with using_prompt_template( - template=PROMPT_TEMPLATE, - variables={"PROMPT": prompt, "DATA": data}, - version="v0.1", - ): - response = client.chat.completions.create( - model="gpt-4o", - messages=[ - {"role": "system", "content": SYSTEM_PROMPT}, - { - "role": "user", - "content": PROMPT_TEMPLATE.format(PROMPT=prompt, DATA=data), - }, - ], - ) - analysis_result = response.choices[0].message.content - span.set_attribute(SpanAttributes.OUTPUT_VALUE, analysis_result) - return analysis_result + with using_prompt_template( + template=PROMPT_TEMPLATE, + variables={"PROMPT": original_prompt, "DATA": data}, + version="v0.1", + ): + model = ChatOpenAI(model="gpt-4o") + messages = [ + SystemMessage(content=SYSTEM_PROMPT), + HumanMessage(content=PROMPT_TEMPLATE.format(PROMPT=original_prompt, DATA=data)), + ] + response = model.invoke(messages) + return response.content diff --git a/examples/agent_framework_comparison/langgraph/generate_sql_query.py b/examples/agent_framework_comparison/langgraph/generate_sql_query.py index f53403bca1..6dbae9f7e7 100644 --- a/examples/agent_framework_comparison/langgraph/generate_sql_query.py +++ b/examples/agent_framework_comparison/langgraph/generate_sql_query.py @@ -1,11 +1,10 @@ import os import sys +from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.tools import tool -from openai import OpenAI +from langchain_openai import ChatOpenAI from openinference.instrumentation import using_prompt_template -from openinference.semconv.trace import SpanAttributes -from opentelemetry import trace sys.path.insert(1, os.path.join(sys.path[0], "..")) @@ -14,16 +13,19 @@ @tool -def generate_and_run_sql_query(query: str): +def generate_and_run_sql_query(original_prompt: str): """Generates and runs an SQL query based on the prompt. Args: - query (str): A string containing the original user prompt. + original_prompt (str): A string containing the original user prompt. Returns: str: The result of the SQL query. """ + return _generate_and_run_sql_query(original_prompt, retry=True) + +def _generate_and_run_sql_query(original_prompt: str, retry: bool = False): def _sanitize_query(query): # Remove triple backticks from the query if present query = query.strip() @@ -35,14 +37,6 @@ def _sanitize_query(query): query = query[:-3].strip() return query - if isinstance(query, dict) and "prompt" in query: - prompt = query["prompt"] - elif isinstance(query, str): - prompt = query - else: - return "Invalid input: expected a dictionary with 'prompt' key or a string." - - client = OpenAI() table = get_table() schema = get_schema() @@ -51,24 +45,21 @@ def _sanitize_query(query): variables={"SCHEMA": schema, "TABLE": table}, version="v0.1", ): - response = client.chat.completions.create( - model="gpt-4o", - messages=[ - { - "role": "system", - "content": SYSTEM_PROMPT.format(SCHEMA=schema, TABLE=table), - }, - {"role": "user", "content": prompt}, - ], - ) + model = ChatOpenAI(model="gpt-4o") + messages = [ + SystemMessage(content=SYSTEM_PROMPT.format(SCHEMA=schema, TABLE=table)), + HumanMessage(content=original_prompt), + ] + response = model.invoke(messages) - sql_query = response.choices[0].message.content + sql_query = response.content - tracer = trace.get_tracer(__name__) - with tracer.start_as_current_span("run_sql_query") as span: - span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, "CHAIN") - span.set_attribute(SpanAttributes.INPUT_VALUE, sql_query) - sanitized_query = _sanitize_query(sql_query) - results = str(run_query(sanitized_query)) - span.set_attribute(SpanAttributes.OUTPUT_VALUE, results) - return results + sanitized_query = _sanitize_query(sql_query) + results = str(run_query(sanitized_query)) + if "An error occurred" in results and retry: + retry_str = ( + f"\n I've already tried this query: {sql_query} \n" + f"and got this error: {results} \n Please try again." + ) + return _generate_and_run_sql_query(original_prompt + retry_str, retry=False) + return results diff --git a/examples/agent_framework_comparison/langgraph/main.py b/examples/agent_framework_comparison/langgraph/main.py index 286571adf8..ee9d7ec3a0 100644 --- a/examples/agent_framework_comparison/langgraph/main.py +++ b/examples/agent_framework_comparison/langgraph/main.py @@ -4,6 +4,7 @@ sys.path.insert(1, os.path.join(sys.path[0], "..")) import gradio as gr from langgraph.router import run_agent + from utils.instrument import Framework, instrument @@ -17,5 +18,5 @@ def launch_app(): if __name__ == "__main__": - instrument(project_name="langgraph-agent", framework=Framework.LANGGRAPH) + instrument(project_name="langgraph-agent-demo", framework=Framework.LANGGRAPH) launch_app() diff --git a/examples/agent_framework_comparison/li_workflow/main.py b/examples/agent_framework_comparison/li_workflow/main.py index 8dd705f169..1d670c094a 100644 --- a/examples/agent_framework_comparison/li_workflow/main.py +++ b/examples/agent_framework_comparison/li_workflow/main.py @@ -17,7 +17,7 @@ async def gradio_interface(message, history): def launch_app(): - iface = gr.ChatInterface(fn=gradio_interface, title="Data Analyst Agent") + iface = gr.ChatInterface(fn=gradio_interface, title="LlamaIndex Workflow Agent") iface.launch() diff --git a/examples/agent_framework_comparison/prompt_templates/router_template.py b/examples/agent_framework_comparison/prompt_templates/router_template.py index f2d2a7e078..04dacec91b 100644 --- a/examples/agent_framework_comparison/prompt_templates/router_template.py +++ b/examples/agent_framework_comparison/prompt_templates/router_template.py @@ -1,7 +1,8 @@ SYSTEM_PROMPT = """ You are a helpful assistant that choses a tool to call based on the user's request. -All of your responses should be a tool call or text. +All of your responses should be a tool call or text. Only generate tool calls or text. +If you generate a tool call, be sure you include the original prompt as is in the parameters. Once you receive the results from all of your skills, generate a response to the user that incorporates all of the results. diff --git a/examples/agent_framework_comparison/utils/save_agent_traces.py b/examples/agent_framework_comparison/utils/save_agent_traces.py index 1c6ca95380..81d5d8a712 100644 --- a/examples/agent_framework_comparison/utils/save_agent_traces.py +++ b/examples/agent_framework_comparison/utils/save_agent_traces.py @@ -7,12 +7,13 @@ # This is used primarily to create example sets of traces for each agent. # Likely not needed for most users. def save_agent_traces(project_name: str): - directory = "utils/saved_traces" + directory = "examples/agent_framework_comparison/utils/saved_traces" os.makedirs(directory, exist_ok=True) # Save the Trace Dataset - px.Client().get_trace_dataset(project_name=project_name).save(directory=directory) + traces = px.Client().get_trace_dataset(project_name=project_name) + traces.save(directory=directory) if __name__ == "__main__": - save_agent_traces(project_name="function-calling-agent-demo") + save_agent_traces(project_name="agent-demo")