|
6 | 6 | package org.opensearch.ml.rest;
|
7 | 7 |
|
8 | 8 | import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED;
|
| 9 | +import static org.opensearch.ml.engine.tools.QueryPlanningTool.GENERATION_TYPE_FIELD; |
9 | 10 | import static org.opensearch.ml.engine.tools.QueryPlanningTool.MODEL_ID_FIELD;
|
| 11 | +import static org.opensearch.ml.engine.tools.QueryPlanningTool.SEARCH_TEMPLATES_FIELD; |
| 12 | +import static org.opensearch.ml.engine.tools.QueryPlanningTool.USER_SEARCH_TEMPLATES_TYPE_FIELD; |
10 | 13 |
|
11 | 14 | import java.io.IOException;
|
12 | 15 | import java.util.List;
|
@@ -95,6 +98,50 @@ public void testAgentWithQueryPlanningTool_DefaultPrompt() throws IOException {
|
95 | 98 | deleteAgent(agentId);
|
96 | 99 | }
|
97 | 100 |
|
| 101 | + @Test |
| 102 | + public void testAgentWithQueryPlanningTool_SearchTemplates() throws IOException { |
| 103 | + if (OPENAI_KEY == null) { |
| 104 | + return; |
| 105 | + } |
| 106 | + |
| 107 | + // Create Search Templates |
| 108 | + String templateBody = "{\"script\":{\"lang\":\"mustache\",\"source\":{\"query\":{\"match\":{\"type\":\"{{type}}\"}}}}}"; |
| 109 | + Response response = createSearchTemplate("type_search_template", templateBody); |
| 110 | + templateBody = "{\"script\":{\"lang\":\"mustache\",\"source\":{\"query\":{\"term\":{\"type\":\"{{type}}\"}}}}}"; |
| 111 | + response = createSearchTemplate("type_search_template_2", templateBody); |
| 112 | + |
| 113 | + // Register agent with search template IDs |
| 114 | + String agentName = "Test_AgentWithQueryPlanningTool_SearchTemplates"; |
| 115 | + String searchTemplates = "[{" |
| 116 | + + "\"template_id\":\"type_search_template\"," |
| 117 | + + "\"template_description\":\"this templates searches for flowers that match the given type this uses a match query\"" |
| 118 | + + "},{" |
| 119 | + + "\"template_id\":\"type_search_template_2\"," |
| 120 | + + "\"template_description\":\"this templates searches for flowers that match the given type this uses a term query\"" |
| 121 | + + "},{" |
| 122 | + + "\"template_id\":\"brand_search_template\"," |
| 123 | + + "\"template_description\":\"this templates searches for products that match the given brand\"" |
| 124 | + + "}]"; |
| 125 | + String agentId = registerQueryPlanningAgentWithSearchTemplates(agentName, queryPlanningModelId, searchTemplates); |
| 126 | + assertNotNull(agentId); |
| 127 | + |
| 128 | + String query = "{\"parameters\": {\"query_text\": \"List 5 iris flowers of type setosa\"}}"; |
| 129 | + Response agentResponse = executeAgent(agentId, query); |
| 130 | + String responseBody = TestHelper.httpEntityToString(agentResponse.getEntity()); |
| 131 | + |
| 132 | + Map<String, Object> responseMap = gson.fromJson(responseBody, Map.class); |
| 133 | + |
| 134 | + List<Map<String, Object>> inferenceResults = (List<Map<String, Object>>) responseMap.get("inference_results"); |
| 135 | + Map<String, Object> firstResult = inferenceResults.get(0); |
| 136 | + List<Map<String, Object>> outputArray = (List<Map<String, Object>>) firstResult.get("output"); |
| 137 | + Map<String, Object> output = (Map<String, Object>) outputArray.get(0); |
| 138 | + String result = output.get("result").toString(); |
| 139 | + |
| 140 | + assertTrue(result.contains("query")); |
| 141 | + assertTrue(result.contains("term")); |
| 142 | + deleteAgent(agentId); |
| 143 | + } |
| 144 | + |
98 | 145 | private String registerAgentWithQueryPlanningTool(String agentName, String modelId) throws IOException {
|
99 | 146 | MLToolSpec listIndexTool = MLToolSpec
|
100 | 147 | .builder()
|
@@ -125,6 +172,44 @@ private String registerAgentWithQueryPlanningTool(String agentName, String model
|
125 | 172 | return registerAgent(agentName, agent);
|
126 | 173 | }
|
127 | 174 |
|
| 175 | + private String registerQueryPlanningAgentWithSearchTemplates(String agentName, String modelId, String searchTemplates) |
| 176 | + throws IOException { |
| 177 | + MLToolSpec listIndexTool = MLToolSpec |
| 178 | + .builder() |
| 179 | + .type("ListIndexTool") |
| 180 | + .name("MyListIndexTool") |
| 181 | + .description("A tool for list indices") |
| 182 | + .parameters(Map.of("index", IRIS_INDEX, "question", "what fields are in the index?")) |
| 183 | + .includeOutputInAgentResponse(true) |
| 184 | + .build(); |
| 185 | + |
| 186 | + MLToolSpec queryPlanningTool = MLToolSpec |
| 187 | + .builder() |
| 188 | + .type("QueryPlanningTool") |
| 189 | + .name("MyQueryPlanningTool") |
| 190 | + .description("A tool for planning queries") |
| 191 | + .parameters( |
| 192 | + Map |
| 193 | + .ofEntries( |
| 194 | + Map.entry(MODEL_ID_FIELD, modelId), |
| 195 | + Map.entry(GENERATION_TYPE_FIELD, USER_SEARCH_TEMPLATES_TYPE_FIELD), |
| 196 | + Map.entry(SEARCH_TEMPLATES_FIELD, searchTemplates) |
| 197 | + ) |
| 198 | + ) |
| 199 | + .includeOutputInAgentResponse(true) |
| 200 | + .build(); |
| 201 | + |
| 202 | + MLAgent agent = MLAgent |
| 203 | + .builder() |
| 204 | + .name(agentName) |
| 205 | + .type("flow") |
| 206 | + .description("Test agent with QueryPlanningTool") |
| 207 | + .tools(List.of(listIndexTool, queryPlanningTool)) |
| 208 | + .build(); |
| 209 | + |
| 210 | + return registerAgent(agentName, agent); |
| 211 | + } |
| 212 | + |
128 | 213 | private String registerQueryPlanningModel() throws IOException, InterruptedException {
|
129 | 214 | String openaiModelName = "openai gpt-4o model " + randomAlphaOfLength(5);
|
130 | 215 | return registerRemoteModel(openaiConnectorEntity, openaiModelName, true);
|
@@ -177,6 +262,18 @@ private Response executeAgent(String agentId, String query) throws IOException {
|
177 | 262 | );
|
178 | 263 | }
|
179 | 264 |
|
| 265 | + private Response createSearchTemplate(String templateName, String templateBody) throws IOException { |
| 266 | + return TestHelper |
| 267 | + .makeRequest( |
| 268 | + client(), |
| 269 | + "PUT", |
| 270 | + "/_scripts/" + templateName, |
| 271 | + null, |
| 272 | + new StringEntity(templateBody), |
| 273 | + List.of(new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json")) |
| 274 | + ); |
| 275 | + } |
| 276 | + |
180 | 277 | private void deleteAgent(String agentId) throws IOException {
|
181 | 278 | TestHelper.makeRequest(client(), "DELETE", "/_plugins/_ml/agents/" + agentId, null, "", List.of());
|
182 | 279 | }
|
|
0 commit comments