Skip to content

Commit 443a405

Browse files
authored
enhance tool input parsing and add agentic rag tutorial (opensearch-project#4023)
* enhance tool input parsing and add agentic rag tutorial Signed-off-by: Yaliang Wu <ylwu@amazon.com> * remove tutorial; add unit test Signed-off-by: Yaliang Wu <ylwu@amazon.com> * fix unit test Signed-off-by: Yaliang Wu <ylwu@amazon.com> --------- Signed-off-by: Yaliang Wu <ylwu@amazon.com>
1 parent 83c0dcd commit 443a405

File tree

2 files changed

+110
-1
lines changed

2 files changed

+110
-1
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,14 @@ public static Map<String, String> constructToolParams(
940940
toolParams.putAll(params);
941941
}
942942
} else {
943-
toolParams.put("input", actionInput);
943+
if (toolParams.containsKey("input")) {
944+
String input = toolParams.get("input");
945+
StringSubstitutor substitutor = new StringSubstitutor(toolParams, "${parameters.", "}");
946+
input = substitutor.replace(input);
947+
toolParams.put("input", input);
948+
} else {
949+
toolParams.put("input", actionInput);
950+
}
944951
}
945952
return toolParams;
946953
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
import org.opensearch.ml.common.output.model.ModelTensorOutput;
7878
import org.opensearch.ml.common.output.model.ModelTensors;
7979
import org.opensearch.ml.common.spi.tools.Tool;
80+
import org.opensearch.ml.common.utils.StringUtils;
8081
import org.opensearch.ml.engine.MLEngineClassLoader;
8182
import org.opensearch.ml.engine.MLStaticMockBase;
8283
import org.opensearch.ml.engine.algorithms.remote.McpConnectorExecutor;
@@ -1692,4 +1693,105 @@ public void testGetCurrentDateTime_WithValidFormat() {
16921693
Assert.assertTrue(result.startsWith(DEFAULT_DATETIME_PREFIX));
16931694
Assert.assertTrue(result.contains("UTC"));
16941695
}
1696+
1697+
@Test
1698+
public void testConstructToolParams_ToolInputSubstitution() {
1699+
String question = "What is the population?";
1700+
String actionInput = "{\"question\": \"Seattle 2025 population\"}";
1701+
1702+
Map<String, Object> queryObj = new HashMap<>();
1703+
queryObj
1704+
.put(
1705+
"query",
1706+
Map
1707+
.of(
1708+
"neural",
1709+
Map
1710+
.of(
1711+
"population_description_embedding",
1712+
Map.of("query_text", "${parameters.question}", "model_id", "embedding_model_id")
1713+
)
1714+
)
1715+
);
1716+
queryObj.put("size", 2);
1717+
queryObj.put("_source", "population_description");
1718+
1719+
Map<String, Tool> tools = Map.of("SearchIndexTool", tool1);
1720+
Map<String, MLToolSpec> toolSpecMap = Map
1721+
.of(
1722+
"SearchIndexTool",
1723+
MLToolSpec
1724+
.builder()
1725+
.type("SearchIndexTool")
1726+
.parameters(
1727+
Map
1728+
.of(
1729+
"index",
1730+
"test_population_data",
1731+
"input",
1732+
"{\"index\": \"${parameters.index}\", \"query\": " + StringUtils.toJson(queryObj) + "}"
1733+
)
1734+
)
1735+
.build()
1736+
);
1737+
1738+
AtomicReference<String> lastActionInput = new AtomicReference<>();
1739+
String action = "SearchIndexTool";
1740+
1741+
// Execute
1742+
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);
1743+
1744+
// Verify
1745+
Assert.assertTrue(toolParams.get("input").contains("\"query_text\":\"Seattle 2025 population\""));
1746+
assertEquals("Seattle 2025 population", toolParams.get("question"));
1747+
}
1748+
1749+
@Test
1750+
public void testConstructToolParams_NoInputKey() {
1751+
// Setup
1752+
String question = "What is the population?";
1753+
String actionInput = "{\"question\": \"Seattle 2025 population\"}";
1754+
1755+
Map<String, Tool> tools = Map.of("SearchTool", tool1);
1756+
Map<String, MLToolSpec> toolSpecMap = Map
1757+
.of("SearchTool", MLToolSpec.builder().type("SearchTool").parameters(Map.of("key1", "value1")).build());
1758+
1759+
AtomicReference<String> lastActionInput = new AtomicReference<>();
1760+
String action = "SearchTool";
1761+
1762+
// Execute
1763+
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);
1764+
1765+
// Verify - should fall back to actionInput
1766+
assertEquals(actionInput, toolParams.get("input"));
1767+
assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
1768+
}
1769+
1770+
@Test
1771+
public void testConstructToolParams_InputWithMultipleSubstitutions() {
1772+
// Setup
1773+
String question = "What is the population?";
1774+
String actionInput = "{\"city\": \"Seattle\", \"year\": \"2025\"}";
1775+
1776+
Map<String, Tool> tools = Map.of("SearchTool", tool1);
1777+
Map<String, MLToolSpec> toolSpecMap = Map
1778+
.of(
1779+
"SearchTool",
1780+
MLToolSpec
1781+
.builder()
1782+
.type("SearchTool")
1783+
.parameters(Map.of("input", "Find population of ${parameters.city} in ${parameters.year}"))
1784+
.build()
1785+
);
1786+
1787+
AtomicReference<String> lastActionInput = new AtomicReference<>();
1788+
String action = "SearchTool";
1789+
1790+
// Execute
1791+
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);
1792+
1793+
// Verify
1794+
assertEquals("Find population of Seattle in 2025", toolParams.get("input"));
1795+
assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
1796+
}
16951797
}

0 commit comments

Comments
 (0)