Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ public static Map<String, String> extractInputParameters(Map<String, String> par
StringSubstitutor stringSubstitutor = new StringSubstitutor(parameters, "${parameters.", "}");
String input = stringSubstitutor.replace(parameters.get("input"));
extractedParameters.put("input", input);
Map<String, String> inputParameters = gson
.fromJson(input, TypeToken.getParameterized(Map.class, String.class, String.class).getType());
extractedParameters.putAll(inputParameters);
Map<String, Object> parsedInputParameters = gson
.fromJson(input, TypeToken.getParameterized(Map.class, String.class, Object.class).getType());
extractedParameters.putAll(StringUtils.getParameterMap(parsedInputParameters));
} catch (Exception exception) {
log.info("fail extract parameters from key 'input' due to" + exception.getMessage());
log.info("Failed to extract parameters from key 'input'. Falling back to raw input string.", exception);
}
}
return extractedParameters;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,67 @@ public void testFilterToolOutput_ComplexNestedPath() {
// Should contain only the targeted deep value
assertEquals("targetValue", result);
}

@Test
public void testExtractInputParameters_WithJsonInput() {
Map<String, String> parameters = new HashMap<>();
parameters.put("param1", "value1");
parameters.put("input", "{\"key1\": \"jsonValue1\", \"key2\": \"jsonValue2\"}");

Map<String, Object> attributes = new HashMap<>();

Map<String, String> result = ToolUtils.extractInputParameters(parameters, attributes);

assertEquals(4, result.size());
assertEquals("value1", result.get("param1"));
assertEquals("{\"key1\": \"jsonValue1\", \"key2\": \"jsonValue2\"}", result.get("input"));
assertEquals("jsonValue1", result.get("key1"));
assertEquals("jsonValue2", result.get("key2"));
}

@Test
public void testExtractInputParameters_WithParameterSubstitution() {
Map<String, String> parameters = new HashMap<>();
parameters.put("param1", "substitutedValue");
parameters.put("input", "{\"message\": \"Hello ${parameters.param1}\"}");

Map<String, Object> attributes = new HashMap<>();

Map<String, String> result = ToolUtils.extractInputParameters(parameters, attributes);

assertEquals(3, result.size());
assertEquals("substitutedValue", result.get("param1"));
assertEquals("{\"message\": \"Hello substitutedValue\"}", result.get("input"));
assertEquals("Hello substitutedValue", result.get("message"));
}

@Test
public void testExtractInputParameters_WithInvalidJson() {
Map<String, String> parameters = new HashMap<>();
parameters.put("param1", "value1");
parameters.put("input", "invalid json string");

Map<String, Object> attributes = new HashMap<>();

Map<String, String> result = ToolUtils.extractInputParameters(parameters, attributes);

assertEquals(2, result.size());
assertEquals("value1", result.get("param1"));
assertEquals("invalid json string", result.get("input"));
}

@Test
public void testExtractInputParameters_NoInputParameter() {
Map<String, String> parameters = new HashMap<>();
parameters.put("param1", "value1");
parameters.put("param2", "value2");

Map<String, Object> attributes = new HashMap<>();

Map<String, String> result = ToolUtils.extractInputParameters(parameters, attributes);

assertEquals(2, result.size());
assertEquals("value1", result.get("param1"));
assertEquals("value2", result.get("param2"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,11 @@ public static Map<String, String> parseLLMOutput(
if (functionCalling != null) {
toolCalls = functionCalling.handle(tmpModelTensorOutput, parameters);
// TODO: support multiple tool calls here
toolName = toolCalls.getFirst().get("tool_name");
toolInput = toolCalls.getFirst().get("tool_input");
toolCallId = toolCalls.getFirst().get("tool_call_id");
if (!toolCalls.isEmpty()) {
toolName = toolCalls.getFirst().get("tool_name");
toolInput = toolCalls.getFirst().get("tool_input");
toolCallId = toolCalls.getFirst().get("tool_call_id");
}
} else {
String toolCallsPath = parameters.get(TOOL_CALLS_PATH);
if (toolCallsPath.startsWith("_llm_response.")) {
Expand All @@ -372,9 +374,11 @@ public static Map<String, String> parseLLMOutput(
} else {
toolCalls = JsonPath.read(dataAsMap, toolCallsPath);
}
toolName = JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALLS_TOOL_NAME));
toolInput = StringUtils.toJson(JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALLS_TOOL_INPUT)));
toolCallId = JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALL_ID_PATH));
if (!toolCalls.isEmpty()) {
toolName = JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALLS_TOOL_NAME));
toolInput = StringUtils.toJson(JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALLS_TOOL_INPUT)));
toolCallId = JsonPath.read(toolCalls.get(0), parameters.get(TOOL_CALL_ID_PATH));
}
}
String toolCallsMsgPath = parameters.get(INTERACTION_TEMPLATE_ASSISTANT_TOOL_CALLS_PATH);
String toolCallsMsgExcludePath = parameters.get(INTERACTION_TEMPLATE_ASSISTANT_TOOL_CALLS_EXCLUDE_PATH);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
Expand Down Expand Up @@ -401,10 +400,10 @@ private void executePlanningLoop(

planListener.whenComplete(llmOutput -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) llmOutput.getOutput();
Map<String, String> parseLLMOutput = parseLLMOutput(allParams, modelTensorOutput);
Map<String, Object> parseLLMOutput = parseLLMOutput(allParams, modelTensorOutput);

if (parseLLMOutput.get(RESULT_FIELD) != null) {
String finalResult = parseLLMOutput.get(RESULT_FIELD);
String finalResult = (String) parseLLMOutput.get(RESULT_FIELD);
saveAndReturnFinalResult(
(ConversationIndexMemory) memory,
parentInteractionId,
Expand All @@ -415,8 +414,7 @@ private void executePlanningLoop(
finalListener
);
} else {
// todo: optimize double conversion of steps (string to list to string)
List<String> steps = Arrays.stream(parseLLMOutput.get(STEPS_FIELD).split(", ")).toList();
List<String> steps = (List<String>) parseLLMOutput.get(STEPS_FIELD);
addSteps(steps, allParams, STEPS_FIELD);

String stepToExecute = steps.getFirst();
Expand Down Expand Up @@ -546,8 +544,8 @@ private void executePlanningLoop(
}

@VisibleForTesting
Map<String, String> parseLLMOutput(Map<String, String> allParams, ModelTensorOutput modelTensorOutput) {
Map<String, String> modelOutput = new HashMap<>();
Map<String, Object> parseLLMOutput(Map<String, String> allParams, ModelTensorOutput modelTensorOutput) {
Map<String, Object> modelOutput = new HashMap<>();
Map<String, ?> dataAsMap = modelTensorOutput.getMlModelOutputs().getFirst().getMlModelTensors().getFirst().getDataAsMap();
String llmResponse;
if (dataAsMap.size() == 1 && dataAsMap.containsKey(RESPONSE_FIELD)) {
Expand All @@ -571,7 +569,7 @@ Map<String, String> parseLLMOutput(Map<String, String> allParams, ModelTensorOut

if (parsedJson.containsKey(STEPS_FIELD)) {
List<String> steps = (List<String>) parsedJson.get(STEPS_FIELD);
modelOutput.put(STEPS_FIELD, String.join(", ", steps));
modelOutput.put(STEPS_FIELD, steps);
}

if (parsedJson.containsKey(RESULT_FIELD)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1861,4 +1861,41 @@ public void testCreateTool_WithNullRuntimeResources() {

verify(factory).create(argThat(toolParamsMap -> ((Map<String, Object>) toolParamsMap).get("param1").equals("value1")));
}

@Test
public void testParseLLMOutput_PathNotFoundExceptionWithEmptyToolCalls() {
Map<String, String> parameters = new HashMap<>();
parameters.put(LLM_RESPONSE_FILTER, "$.output.message.content[0].text");
parameters.put(TOOL_CALLS_PATH, "$.output.message.content[*].toolUse");
parameters.put(TOOL_CALLS_TOOL_NAME, "name");
parameters.put(TOOL_CALLS_TOOL_INPUT, "input");
parameters.put(TOOL_CALL_ID_PATH, "toolUseId");
parameters.put(LLM_FINISH_REASON_PATH, "$.stopReason");
parameters.put(LLM_FINISH_REASON_TOOL_USE, "tool_use");

Map<String, Object> dataAsMap = Map
.of("output", Map.of("message", Map.of("content", Collections.emptyList(), "role", "assistant")), "stopReason", "end_turn");

ModelTensorOutput modelTensorOutput = ModelTensorOutput
.builder()
.mlModelOutputs(
List
.of(
ModelTensors
.builder()
.mlModelTensors(List.of(ModelTensor.builder().name("response").dataAsMap(dataAsMap).build()))
.build()
)
)
.build();

Map<String, String> output = AgentUtils
.parseLLMOutput(parameters, modelTensorOutput, null, Set.of("test_tool"), new ArrayList<>(), null);

Assert.assertEquals("", output.get(THOUGHT));
Assert.assertEquals("", output.get(ACTION_INPUT));
Assert.assertEquals("", output.get(TOOL_CALL_ID));
Assert.assertTrue(output.containsKey(FINAL_ANSWER));
Assert.assertTrue(output.get(FINAL_ANSWER).contains("[]"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,11 @@ public void testParseLLMOutput() {
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput modelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();

Map<String, String> result = mlPlanExecuteAndReflectAgentRunner.parseLLMOutput(allParams, modelTensorOutput);
Map<String, Object> result = mlPlanExecuteAndReflectAgentRunner.parseLLMOutput(allParams, modelTensorOutput);

assertEquals("step1, step2", result.get(MLPlanExecuteAndReflectAgentRunner.STEPS_FIELD));
List<String> expectedSteps = Arrays.asList("step1", "step2");
List<String> actualSteps = (List<String>) result.get(MLPlanExecuteAndReflectAgentRunner.STEPS_FIELD);
assertEquals(expectedSteps, actualSteps);
assertEquals("final result", result.get(MLPlanExecuteAndReflectAgentRunner.RESULT_FIELD));

modelTensor = ModelTensor.builder().dataAsMap(Map.of(MLPlanExecuteAndReflectAgentRunner.RESPONSE_FIELD, "random response")).build();
Expand Down
Loading