Skip to content

Commit 305ff65

Browse files
committed
enhance tool input parsing and add agentic rag tutorial (opensearch-project#4023)
* enhance tool input parsing and add agentic rag tutorial Signed-off-by: Yaliang Wu <ylwu@amazon.com> * remove tutorial; add unit test Signed-off-by: Yaliang Wu <ylwu@amazon.com> * fix unit test Signed-off-by: Yaliang Wu <ylwu@amazon.com> --------- Signed-off-by: Yaliang Wu <ylwu@amazon.com>
1 parent 077a177 commit 305ff65

File tree

2 files changed

+110
-1
lines changed

2 files changed

+110
-1
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,14 @@ public static Map<String, String> constructToolParams(
932932
toolParams.putAll(params);
933933
}
934934
} else {
935-
toolParams.put("input", actionInput);
935+
if (toolParams.containsKey("input")) {
936+
String input = toolParams.get("input");
937+
StringSubstitutor substitutor = new StringSubstitutor(toolParams, "${parameters.", "}");
938+
input = substitutor.replace(input);
939+
toolParams.put("input", input);
940+
} else {
941+
toolParams.put("input", actionInput);
942+
}
936943
}
937944
return toolParams;
938945
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
import org.opensearch.ml.common.output.model.ModelTensorOutput;
7777
import org.opensearch.ml.common.output.model.ModelTensors;
7878
import org.opensearch.ml.common.spi.tools.Tool;
79+
import org.opensearch.ml.common.utils.StringUtils;
7980
import org.opensearch.ml.engine.MLEngineClassLoader;
8081
import org.opensearch.ml.engine.MLStaticMockBase;
8182
import org.opensearch.ml.engine.algorithms.remote.McpConnectorExecutor;
@@ -1665,4 +1666,105 @@ public void testCreateTool_ToolNotFound() {
16651666

16661667
assertThrows(IllegalArgumentException.class, () -> AgentUtils.createTool(toolFactories, new HashMap<>(), toolSpec, "test_tenant"));
16671668
}
1669+
1670+
@Test
1671+
public void testConstructToolParams_ToolInputSubstitution() {
1672+
String question = "What is the population?";
1673+
String actionInput = "{\"question\": \"Seattle 2025 population\"}";
1674+
1675+
Map<String, Object> queryObj = new HashMap<>();
1676+
queryObj
1677+
.put(
1678+
"query",
1679+
Map
1680+
.of(
1681+
"neural",
1682+
Map
1683+
.of(
1684+
"population_description_embedding",
1685+
Map.of("query_text", "${parameters.question}", "model_id", "embedding_model_id")
1686+
)
1687+
)
1688+
);
1689+
queryObj.put("size", 2);
1690+
queryObj.put("_source", "population_description");
1691+
1692+
Map<String, Tool> tools = Map.of("SearchIndexTool", tool1);
1693+
Map<String, MLToolSpec> toolSpecMap = Map
1694+
.of(
1695+
"SearchIndexTool",
1696+
MLToolSpec
1697+
.builder()
1698+
.type("SearchIndexTool")
1699+
.parameters(
1700+
Map
1701+
.of(
1702+
"index",
1703+
"test_population_data",
1704+
"input",
1705+
"{\"index\": \"${parameters.index}\", \"query\": " + StringUtils.toJson(queryObj) + "}"
1706+
)
1707+
)
1708+
.build()
1709+
);
1710+
1711+
AtomicReference<String> lastActionInput = new AtomicReference<>();
1712+
String action = "SearchIndexTool";
1713+
1714+
// Execute
1715+
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);
1716+
1717+
// Verify
1718+
Assert.assertTrue(toolParams.get("input").contains("\"query_text\":\"Seattle 2025 population\""));
1719+
assertEquals("Seattle 2025 population", toolParams.get("question"));
1720+
}
1721+
1722+
@Test
1723+
public void testConstructToolParams_NoInputKey() {
1724+
// Setup
1725+
String question = "What is the population?";
1726+
String actionInput = "{\"question\": \"Seattle 2025 population\"}";
1727+
1728+
Map<String, Tool> tools = Map.of("SearchTool", tool1);
1729+
Map<String, MLToolSpec> toolSpecMap = Map
1730+
.of("SearchTool", MLToolSpec.builder().type("SearchTool").parameters(Map.of("key1", "value1")).build());
1731+
1732+
AtomicReference<String> lastActionInput = new AtomicReference<>();
1733+
String action = "SearchTool";
1734+
1735+
// Execute
1736+
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);
1737+
1738+
// Verify - should fall back to actionInput
1739+
assertEquals(actionInput, toolParams.get("input"));
1740+
assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
1741+
}
1742+
1743+
@Test
1744+
public void testConstructToolParams_InputWithMultipleSubstitutions() {
1745+
// Setup
1746+
String question = "What is the population?";
1747+
String actionInput = "{\"city\": \"Seattle\", \"year\": \"2025\"}";
1748+
1749+
Map<String, Tool> tools = Map.of("SearchTool", tool1);
1750+
Map<String, MLToolSpec> toolSpecMap = Map
1751+
.of(
1752+
"SearchTool",
1753+
MLToolSpec
1754+
.builder()
1755+
.type("SearchTool")
1756+
.parameters(Map.of("input", "Find population of ${parameters.city} in ${parameters.year}"))
1757+
.build()
1758+
);
1759+
1760+
AtomicReference<String> lastActionInput = new AtomicReference<>();
1761+
String action = "SearchTool";
1762+
1763+
// Execute
1764+
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);
1765+
1766+
// Verify
1767+
assertEquals("Find population of Seattle in 2025", toolParams.get("input"));
1768+
assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
1769+
}
16681770
}

0 commit comments

Comments
 (0)