|
76 | 76 | import org.opensearch.ml.common.output.model.ModelTensorOutput; |
77 | 77 | import org.opensearch.ml.common.output.model.ModelTensors; |
78 | 78 | import org.opensearch.ml.common.spi.tools.Tool; |
| 79 | +import org.opensearch.ml.common.utils.StringUtils; |
79 | 80 | import org.opensearch.ml.engine.MLEngineClassLoader; |
80 | 81 | import org.opensearch.ml.engine.MLStaticMockBase; |
81 | 82 | import org.opensearch.ml.engine.algorithms.remote.McpConnectorExecutor; |
@@ -1665,4 +1666,105 @@ public void testCreateTool_ToolNotFound() { |
1665 | 1666 |
|
1666 | 1667 | assertThrows(IllegalArgumentException.class, () -> AgentUtils.createTool(toolFactories, new HashMap<>(), toolSpec, "test_tenant")); |
1667 | 1668 | } |
| 1669 | + |
| 1670 | + @Test |
| 1671 | + public void testConstructToolParams_ToolInputSubstitution() { |
| 1672 | + String question = "What is the population?"; |
| 1673 | + String actionInput = "{\"question\": \"Seattle 2025 population\"}"; |
| 1674 | + |
| 1675 | + Map<String, Object> queryObj = new HashMap<>(); |
| 1676 | + queryObj |
| 1677 | + .put( |
| 1678 | + "query", |
| 1679 | + Map |
| 1680 | + .of( |
| 1681 | + "neural", |
| 1682 | + Map |
| 1683 | + .of( |
| 1684 | + "population_description_embedding", |
| 1685 | + Map.of("query_text", "${parameters.question}", "model_id", "embedding_model_id") |
| 1686 | + ) |
| 1687 | + ) |
| 1688 | + ); |
| 1689 | + queryObj.put("size", 2); |
| 1690 | + queryObj.put("_source", "population_description"); |
| 1691 | + |
| 1692 | + Map<String, Tool> tools = Map.of("SearchIndexTool", tool1); |
| 1693 | + Map<String, MLToolSpec> toolSpecMap = Map |
| 1694 | + .of( |
| 1695 | + "SearchIndexTool", |
| 1696 | + MLToolSpec |
| 1697 | + .builder() |
| 1698 | + .type("SearchIndexTool") |
| 1699 | + .parameters( |
| 1700 | + Map |
| 1701 | + .of( |
| 1702 | + "index", |
| 1703 | + "test_population_data", |
| 1704 | + "input", |
| 1705 | + "{\"index\": \"${parameters.index}\", \"query\": " + StringUtils.toJson(queryObj) + "}" |
| 1706 | + ) |
| 1707 | + ) |
| 1708 | + .build() |
| 1709 | + ); |
| 1710 | + |
| 1711 | + AtomicReference<String> lastActionInput = new AtomicReference<>(); |
| 1712 | + String action = "SearchIndexTool"; |
| 1713 | + |
| 1714 | + // Execute |
| 1715 | + Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput); |
| 1716 | + |
| 1717 | + // Verify |
| 1718 | + Assert.assertTrue(toolParams.get("input").contains("\"query_text\":\"Seattle 2025 population\"")); |
| 1719 | + assertEquals("Seattle 2025 population", toolParams.get("question")); |
| 1720 | + } |
| 1721 | + |
| 1722 | + @Test |
| 1723 | + public void testConstructToolParams_NoInputKey() { |
| 1724 | + // Setup |
| 1725 | + String question = "What is the population?"; |
| 1726 | + String actionInput = "{\"question\": \"Seattle 2025 population\"}"; |
| 1727 | + |
| 1728 | + Map<String, Tool> tools = Map.of("SearchTool", tool1); |
| 1729 | + Map<String, MLToolSpec> toolSpecMap = Map |
| 1730 | + .of("SearchTool", MLToolSpec.builder().type("SearchTool").parameters(Map.of("key1", "value1")).build()); |
| 1731 | + |
| 1732 | + AtomicReference<String> lastActionInput = new AtomicReference<>(); |
| 1733 | + String action = "SearchTool"; |
| 1734 | + |
| 1735 | + // Execute |
| 1736 | + Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput); |
| 1737 | + |
| 1738 | + // Verify - should fall back to actionInput |
| 1739 | + assertEquals(actionInput, toolParams.get("input")); |
| 1740 | + assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT)); |
| 1741 | + } |
| 1742 | + |
| 1743 | + @Test |
| 1744 | + public void testConstructToolParams_InputWithMultipleSubstitutions() { |
| 1745 | + // Setup |
| 1746 | + String question = "What is the population?"; |
| 1747 | + String actionInput = "{\"city\": \"Seattle\", \"year\": \"2025\"}"; |
| 1748 | + |
| 1749 | + Map<String, Tool> tools = Map.of("SearchTool", tool1); |
| 1750 | + Map<String, MLToolSpec> toolSpecMap = Map |
| 1751 | + .of( |
| 1752 | + "SearchTool", |
| 1753 | + MLToolSpec |
| 1754 | + .builder() |
| 1755 | + .type("SearchTool") |
| 1756 | + .parameters(Map.of("input", "Find population of ${parameters.city} in ${parameters.year}")) |
| 1757 | + .build() |
| 1758 | + ); |
| 1759 | + |
| 1760 | + AtomicReference<String> lastActionInput = new AtomicReference<>(); |
| 1761 | + String action = "SearchTool"; |
| 1762 | + |
| 1763 | + // Execute |
| 1764 | + Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput); |
| 1765 | + |
| 1766 | + // Verify |
| 1767 | + assertEquals("Find population of Seattle in 2025", toolParams.get("input")); |
| 1768 | + assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT)); |
| 1769 | + } |
1668 | 1770 | } |
0 commit comments