|
77 | 77 | import org.opensearch.ml.common.output.model.ModelTensorOutput; |
78 | 78 | import org.opensearch.ml.common.output.model.ModelTensors; |
79 | 79 | import org.opensearch.ml.common.spi.tools.Tool; |
| 80 | +import org.opensearch.ml.common.utils.StringUtils; |
80 | 81 | import org.opensearch.ml.engine.MLEngineClassLoader; |
81 | 82 | import org.opensearch.ml.engine.MLStaticMockBase; |
82 | 83 | import org.opensearch.ml.engine.algorithms.remote.McpConnectorExecutor; |
@@ -1692,4 +1693,105 @@ public void testGetCurrentDateTime_WithValidFormat() { |
1692 | 1693 | Assert.assertTrue(result.startsWith(DEFAULT_DATETIME_PREFIX)); |
1693 | 1694 | Assert.assertTrue(result.contains("UTC")); |
1694 | 1695 | } |
| 1696 | + |
| 1697 | + @Test |
| 1698 | + public void testConstructToolParams_ToolInputSubstitution() { |
| 1699 | + String question = "What is the population?"; |
| 1700 | + String actionInput = "{\"question\": \"Seattle 2025 population\"}"; |
| 1701 | + |
| 1702 | + Map<String, Object> queryObj = new HashMap<>(); |
| 1703 | + queryObj |
| 1704 | + .put( |
| 1705 | + "query", |
| 1706 | + Map |
| 1707 | + .of( |
| 1708 | + "neural", |
| 1709 | + Map |
| 1710 | + .of( |
| 1711 | + "population_description_embedding", |
| 1712 | + Map.of("query_text", "${parameters.question}", "model_id", "embedding_model_id") |
| 1713 | + ) |
| 1714 | + ) |
| 1715 | + ); |
| 1716 | + queryObj.put("size", 2); |
| 1717 | + queryObj.put("_source", "population_description"); |
| 1718 | + |
| 1719 | + Map<String, Tool> tools = Map.of("SearchIndexTool", tool1); |
| 1720 | + Map<String, MLToolSpec> toolSpecMap = Map |
| 1721 | + .of( |
| 1722 | + "SearchIndexTool", |
| 1723 | + MLToolSpec |
| 1724 | + .builder() |
| 1725 | + .type("SearchIndexTool") |
| 1726 | + .parameters( |
| 1727 | + Map |
| 1728 | + .of( |
| 1729 | + "index", |
| 1730 | + "test_population_data", |
| 1731 | + "input", |
| 1732 | + "{\"index\": \"${parameters.index}\", \"query\": " + StringUtils.toJson(queryObj) + "}" |
| 1733 | + ) |
| 1734 | + ) |
| 1735 | + .build() |
| 1736 | + ); |
| 1737 | + |
| 1738 | + AtomicReference<String> lastActionInput = new AtomicReference<>(); |
| 1739 | + String action = "SearchIndexTool"; |
| 1740 | + |
| 1741 | + // Execute |
| 1742 | + Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput); |
| 1743 | + |
| 1744 | + // Verify |
| 1745 | + Assert.assertTrue(toolParams.get("input").contains("\"query_text\":\"Seattle 2025 population\"")); |
| 1746 | + assertEquals("Seattle 2025 population", toolParams.get("question")); |
| 1747 | + } |
| 1748 | + |
| 1749 | + @Test |
| 1750 | + public void testConstructToolParams_NoInputKey() { |
| 1751 | + // Setup |
| 1752 | + String question = "What is the population?"; |
| 1753 | + String actionInput = "{\"question\": \"Seattle 2025 population\"}"; |
| 1754 | + |
| 1755 | + Map<String, Tool> tools = Map.of("SearchTool", tool1); |
| 1756 | + Map<String, MLToolSpec> toolSpecMap = Map |
| 1757 | + .of("SearchTool", MLToolSpec.builder().type("SearchTool").parameters(Map.of("key1", "value1")).build()); |
| 1758 | + |
| 1759 | + AtomicReference<String> lastActionInput = new AtomicReference<>(); |
| 1760 | + String action = "SearchTool"; |
| 1761 | + |
| 1762 | + // Execute |
| 1763 | + Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput); |
| 1764 | + |
| 1765 | + // Verify - should fall back to actionInput |
| 1766 | + assertEquals(actionInput, toolParams.get("input")); |
| 1767 | + assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT)); |
| 1768 | + } |
| 1769 | + |
| 1770 | + @Test |
| 1771 | + public void testConstructToolParams_InputWithMultipleSubstitutions() { |
| 1772 | + // Setup |
| 1773 | + String question = "What is the population?"; |
| 1774 | + String actionInput = "{\"city\": \"Seattle\", \"year\": \"2025\"}"; |
| 1775 | + |
| 1776 | + Map<String, Tool> tools = Map.of("SearchTool", tool1); |
| 1777 | + Map<String, MLToolSpec> toolSpecMap = Map |
| 1778 | + .of( |
| 1779 | + "SearchTool", |
| 1780 | + MLToolSpec |
| 1781 | + .builder() |
| 1782 | + .type("SearchTool") |
| 1783 | + .parameters(Map.of("input", "Find population of ${parameters.city} in ${parameters.year}")) |
| 1784 | + .build() |
| 1785 | + ); |
| 1786 | + |
| 1787 | + AtomicReference<String> lastActionInput = new AtomicReference<>(); |
| 1788 | + String action = "SearchTool"; |
| 1789 | + |
| 1790 | + // Execute |
| 1791 | + Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput); |
| 1792 | + |
| 1793 | + // Verify |
| 1794 | + assertEquals("Find population of Seattle in 2025", toolParams.get("input")); |
| 1795 | + assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT)); |
| 1796 | + } |
1695 | 1797 | } |
0 commit comments