Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,14 @@ public static Map<String, String> constructToolParams(
toolParams.putAll(params);
}
} else {
toolParams.put("input", actionInput);
if (toolParams.containsKey("input")) {
String input = toolParams.get("input");
StringSubstitutor substitutor = new StringSubstitutor(toolParams, "${parameters.", "}");
input = substitutor.replace(input);
toolParams.put("input", input);
} else {
toolParams.put("input", actionInput);
}
}
return toolParams;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.MLStaticMockBase;
import org.opensearch.ml.engine.algorithms.remote.McpConnectorExecutor;
Expand Down Expand Up @@ -1692,4 +1693,105 @@ public void testGetCurrentDateTime_WithValidFormat() {
Assert.assertTrue(result.startsWith(DEFAULT_DATETIME_PREFIX));
Assert.assertTrue(result.contains("UTC"));
}

@Test
public void testConstructToolParams_ToolInputSubstitution() {
String question = "What is the population?";
String actionInput = "{\"question\": \"Seattle 2025 population\"}";

Map<String, Object> queryObj = new HashMap<>();
queryObj
.put(
"query",
Map
.of(
"neural",
Map
.of(
"population_description_embedding",
Map.of("query_text", "${parameters.question}", "model_id", "embedding_model_id")
)
)
);
queryObj.put("size", 2);
queryObj.put("_source", "population_description");

Map<String, Tool> tools = Map.of("SearchIndexTool", tool1);
Map<String, MLToolSpec> toolSpecMap = Map
.of(
"SearchIndexTool",
MLToolSpec
.builder()
.type("SearchIndexTool")
.parameters(
Map
.of(
"index",
"test_population_data",
"input",
"{\"index\": \"${parameters.index}\", \"query\": " + StringUtils.toJson(queryObj) + "}"
)
)
.build()
);

AtomicReference<String> lastActionInput = new AtomicReference<>();
String action = "SearchIndexTool";

// Execute
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);

// Verify
Assert.assertTrue(toolParams.get("input").contains("\"query_text\":\"Seattle 2025 population\""));
assertEquals("Seattle 2025 population", toolParams.get("question"));
}

@Test
public void testConstructToolParams_NoInputKey() {
// Setup
String question = "What is the population?";
String actionInput = "{\"question\": \"Seattle 2025 population\"}";

Map<String, Tool> tools = Map.of("SearchTool", tool1);
Map<String, MLToolSpec> toolSpecMap = Map
.of("SearchTool", MLToolSpec.builder().type("SearchTool").parameters(Map.of("key1", "value1")).build());

AtomicReference<String> lastActionInput = new AtomicReference<>();
String action = "SearchTool";

// Execute
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);

// Verify - should fall back to actionInput
assertEquals(actionInput, toolParams.get("input"));
assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
}

@Test
public void testConstructToolParams_InputWithMultipleSubstitutions() {
// Setup
String question = "What is the population?";
String actionInput = "{\"city\": \"Seattle\", \"year\": \"2025\"}";

Map<String, Tool> tools = Map.of("SearchTool", tool1);
Map<String, MLToolSpec> toolSpecMap = Map
.of(
"SearchTool",
MLToolSpec
.builder()
.type("SearchTool")
.parameters(Map.of("input", "Find population of ${parameters.city} in ${parameters.year}"))
.build()
);

AtomicReference<String> lastActionInput = new AtomicReference<>();
String action = "SearchTool";

// Execute
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);

// Verify
assertEquals("Find population of Seattle in 2025", toolParams.get("input"));
assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
}
}
Loading