Description
What happened?
I followed the directions at https://dspy.ai/tutorials/observability/ and noticed the model repeatedly failes to call Retrieve and the returned error doesn't help:
{
"thought_0": "I need to find out which baseball team Shohei Ohtani plays for. I can use the Retrieve tool to get this information.",
"tool_name_0": "Retrieve",
"tool_args_0": {
"args": {
"query": "Which baseball team does Shohei Ohtani play for?"
}
},
"observation_0": "Execution error in Retrieve: \nTraceback (most recent call last):\n File \"/Users/ribrdb/Developer/dspy-test/.conda/lib/python3.11/site-packages/dspy/predict/react.py\", line 89, in forward\n trajectory[f\"observation_{idx}\"] = self.tools[pred.next_tool_name](**pred.next_tool_args)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/Users/ribrdb/Developer/dspy-test/.conda/lib/python3.11/site-packages/dspy/utils/callback.py\", line 343, in sync_wrapper\n raise exception\n File \"/Users/ribrdb/Developer/dspy-test/.conda/lib/python3.11/site-packages/dspy/utils/callback.py\", line 339, in sync_wrapper\n results = fn(instance, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/Users/ribrdb/Developer/dspy-test/.conda/lib/python3.11/site-packages/dspy/adapters/types/tool.py\", line 166, in __call__\n result = self.func(**parsed_kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/Users/ribrdb/Developer/dspy-test/.conda/lib/python3.11/site-packages/dspy/utils/callback.py\", line 343, in sync_wrapper\n raise exception\nTypeError: Retrieve.forward() missing 1 required positional argument: 'query'",
"thought_1": "The Retrieve tool failed. I need to provide the query argument when calling the tool.",
"tool_name_1": "Retrieve",
"tool_args_1": {
"args": {
"query": "Which baseball team does Shohei Ohtani play for?"
}
},
"observation_1": "Execution error in Retrieve: \nTraceback (most recent call last):\n File \"/Users/ribrdb/Developer/dspy-test/.conda/lib/python3.11/site-packages/dspy/predict/react.py\", line 89, in forward\n trajectory[f\"observation_{idx}\"] = self.tools[pred.next_tool_name](**pred.next_tool_args)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/Users/ribrdb/Developer/dspy-test/.conda/lib/python3.11/site-packages/dspy/utils/callback.py\", line 343, in sync_wrapper\n raise exception\n File \"/Users/ribrdb/Developer/dspy-test/.conda/lib/python3.11/site-packages/dspy/utils/callback.py\", line 339, in sync_wrapper\n results = fn(instance, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/Users/ribrdb/Developer/dspy-test/.conda/lib/python3.11/site-packages/dspy/adapters/types/tool.py\", line 166, in __call__\n result = self.func(**parsed_kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/Users/ribrdb/Developer/dspy-test/.conda/lib/python3.11/site-packages/dspy/utils/callback.py\", line 343, in sync_wrapper\n raise exception\nTypeError: Retrieve.forward() missing 1 required positional argument: 'query'",
"thought_2": "The Retrieve tool is failing repeatedly. I will try using the finish tool to signal that I have finished the task.",
"tool_name_2": "finish",
"tool_args_2": {},
"observation_2": "Completed."
}
The model thinks it's including a query, so telling it query
is missing doesn't help.
It seems like Tool inspects the __call__
method, but Retrieve has no type annotations for __call__
:
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
This results in not-great instructions, which makes the model think args
should be a dict not a list:
When selecting the next_tool_name and its next_tool_args, the tool must be one of:
(1) Retrieve. It takes arguments {'args': {}, 'kwargs': {}}.
(2) finish, whose description is <desc>Marks the task as complete. That is, signals that all information for producing the outputs, i.e. `answer`, are now available to be extracted.</desc>. It takes arguments {}.
When providing `next_tool_args`, the value inside the field must be in JSON format
Maybe Retrieve
should have the actual implementation in __call__
and change forward
to delegate to __call__
? That way tool could find updated types for Retrieve subclassess too? Or should Tool
directly support looking at the forward
method?
Maybe Tool should have special handling for *args
and **kwargs
, at least showing *args
as a list.
Should this tutorial even be using Retrieve? It looks like maybe it could use colbert
as a tool directly, which would get a good tool signature. And there's no documentation anywhere for Retrieve
even in the API reference, so I'm not sure the purpose of introducing it in a debugging/observability tutorial.
I also noticed some other strange things in the prompt about next_tool_args
:
In the system prompt, next_tool_args
is described as being a python dict the first time, then there's a (questionably useful) json schema, then it says the value inside next_tool_args
must be json. Then the per-turn instructions refer to it as a python dict again. It seems like switching back and forth between JSON and python descriptions could confuse some LLMs.
system prompt:
5. next_tool_args (dict[str, Any]): All interactions will be structured in the following way, with the appropriate values filled in
...
[[ ## next_tool_args ## ]] {next_tool_args} # note: the value you produce must adhere to the JSON schema: {"type": "object", "additionalProperties": true}
...
When providing `next_tool_args`, the value inside the field must be in JSON format
user prompt trailer:
Respond with the corresponding output fields, starting with the field [[ ## next_thought ## ]], then [[ ## next_tool_name ## ]] (must be formatted as a valid Python Literal['Retrieve', 'finish']), then [[ ## next_tool_args ## ]] (must be formatted as a valid Python dict[str, Any]), and then ending with the marker for [[ ## completed ## ]].
Steps to reproduce
https://dspy.ai/tutorials/observability/
import dspy
from dspy.datasets import HotPotQA
lm = dspy.LM('openai/gpt-4o-mini')
colbert = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')
dspy.configure(lm=lm, rm=colbert)
agent = dspy.ReAct("question -> answer", tools=[dspy.Retrieve(k=1)])
prediction = agent(question="Which baseball team does Shohei Ohtani play for?")
print(prediction.answer)
DSPy version
2.6.27