|
19 | 19 | import static org.mockito.Mockito.times; |
20 | 20 | import static org.mockito.Mockito.verify; |
21 | 21 | import static org.mockito.Mockito.when; |
| 22 | +import static org.opensearch.ml.common.utils.StringUtils.gson; |
| 23 | +import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY; |
22 | 24 | import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY_PLANNING_SYSTEM_PROMPT; |
23 | 25 | import static org.opensearch.ml.engine.tools.QueryPlanningTool.DEFAULT_DESCRIPTION; |
24 | 26 | import static org.opensearch.ml.engine.tools.QueryPlanningTool.INDEX_MAPPING_FIELD; |
|
37 | 39 |
|
38 | 40 | import java.io.IOException; |
39 | 41 | import java.io.InputStream; |
| 42 | +import java.util.ArrayList; |
40 | 43 | import java.util.Collections; |
41 | 44 | import java.util.HashMap; |
42 | 45 | import java.util.List; |
|
58 | 61 | import org.opensearch.core.action.ActionListener; |
59 | 62 | import org.opensearch.core.xcontent.DeprecationHandler; |
60 | 63 | import org.opensearch.core.xcontent.NamedXContentRegistry; |
| 64 | +import org.opensearch.ml.common.spi.tools.Parser; |
61 | 65 | import org.opensearch.ml.common.spi.tools.Tool; |
| 66 | +import org.opensearch.ml.engine.tools.parser.ToolParser; |
62 | 67 | import org.opensearch.script.StoredScriptSource; |
63 | 68 | import org.opensearch.transport.client.AdminClient; |
64 | 69 | import org.opensearch.transport.client.Client; |
@@ -1383,4 +1388,113 @@ public void testTemplateSelectionPromptsWithDefaults() throws ExecutionException |
1383 | 1388 | assertTrue(firstCallParams.get("user_prompt").contains("INPUTS")); |
1384 | 1389 | } |
1385 | 1390 |
|
| 1391 | + // Test 1: Create tool from factory, get parser, test parser behavior directly |
| 1392 | + @SneakyThrows |
| 1393 | + @Test |
| 1394 | + public void testFactoryCreatedTool_DefaultExtractJsonParser() { |
| 1395 | + // Create tool using factory and verify the output parser is correctly configured |
| 1396 | + Map<String, Object> params = Map.of(MODEL_ID_FIELD, "test_model_id"); |
| 1397 | + QueryPlanningTool tool = QueryPlanningTool.Factory.getInstance().create(params); |
| 1398 | + |
| 1399 | + // Verify the output parser was created |
| 1400 | + assertNotNull("Output parser should be created by factory", tool.getOutputParser()); |
| 1401 | + |
| 1402 | + // Test the parser directly with different inputs |
| 1403 | + Parser outputParser = tool.getOutputParser(); |
| 1404 | + |
| 1405 | + // Test case 1: Extract JSON object from text |
| 1406 | + Object parsedResult1 = outputParser.parse("Here is your query: {\"query\":{\"match\":{\"title\":\"test\"}}}"); |
| 1407 | + String resultWithText = parsedResult1 instanceof String ? (String) parsedResult1 : gson.toJson(parsedResult1); |
| 1408 | + assertEquals("{\"query\":{\"match\":{\"title\":\"test\"}}}", resultWithText); |
| 1409 | + |
| 1410 | + // Test case 2: Extract pure JSON |
| 1411 | + Object parsedResult2 = outputParser.parse("{\"query\":{\"match\":{\"title\":\"test\"}}}"); |
| 1412 | + String resultPureJson = parsedResult2 instanceof String ? (String) parsedResult2 : gson.toJson(parsedResult2); |
| 1413 | + assertEquals("{\"query\":{\"match\":{\"title\":\"test\"}}}", resultPureJson); |
| 1414 | + |
| 1415 | + // Test case 3: No valid JSON - should return default template |
| 1416 | + Object parsedResult3 = outputParser.parse("No JSON here at all"); |
| 1417 | + String resultNoJson = parsedResult3 instanceof String ? (String) parsedResult3 : gson.toJson(parsedResult3); |
| 1418 | + assertEquals(DEFAULT_QUERY, resultNoJson); |
| 1419 | + } |
| 1420 | + |
| 1421 | + // Test 2: Create tool from factory with custom processors, verify both default and custom processors work |
| 1422 | + @SneakyThrows |
| 1423 | + @Test |
| 1424 | + public void testFactoryCreatedTool_WithCustomProcessors() { |
| 1425 | + // Create tool using factory with custom output_processors (set_field) |
| 1426 | + Map<String, Object> params = new HashMap<>(); |
| 1427 | + params.put(MODEL_ID_FIELD, "test_model_id"); |
| 1428 | + |
| 1429 | + // Add custom processor configuration |
| 1430 | + List<Map<String, Object>> outputProcessors = new ArrayList<>(); |
| 1431 | + Map<String, Object> setFieldConfig = new HashMap<>(); |
| 1432 | + setFieldConfig.put("type", "set_field"); |
| 1433 | + setFieldConfig.put("path", "$.metadata"); |
| 1434 | + setFieldConfig.put("value", Map.of("source", "query_planner_tool")); |
| 1435 | + outputProcessors.add(setFieldConfig); |
| 1436 | + params.put("output_processors", outputProcessors); |
| 1437 | + |
| 1438 | + QueryPlanningTool tool = QueryPlanningTool.Factory.getInstance().create(params); |
| 1439 | + |
| 1440 | + // Verify the output parser was created |
| 1441 | + assertNotNull("Output parser should be created by factory", tool.getOutputParser()); |
| 1442 | + |
| 1443 | + // Test the parser - it should use BOTH default extract_json AND custom set_field processors |
| 1444 | + Parser outputParser = tool.getOutputParser(); |
| 1445 | + |
| 1446 | + // Test: Extract JSON from text (default extract_json) + add metadata field (custom set_field) |
| 1447 | + String inputWithText = "Here is your query: {\"query\":{\"match\":{\"title\":\"test\"}}}"; |
| 1448 | + Object parsedResult = outputParser.parse(inputWithText); |
| 1449 | + String result = parsedResult instanceof String ? (String) parsedResult : gson.toJson(parsedResult); |
| 1450 | + |
| 1451 | + // Verify both processors worked: extract_json extracted JSON, set_field added metadata |
| 1452 | + String expectedResult = "{\"query\":{\"match\":{\"title\":\"test\"}},\"metadata\":{\"source\":\"query_planner_tool\"}}"; |
| 1453 | + assertEquals("Parser should extract JSON and add metadata field", expectedResult, result); |
| 1454 | + } |
| 1455 | + |
| 1456 | + // Test 3: Create tool with mocked queryGenerationTool, manually set extract_json processor, run end-to-end |
| 1457 | + @SneakyThrows |
| 1458 | + @Test |
| 1459 | + public void testQueryPlanningTool_WithMockedMLModelTool_EndToEnd() { |
| 1460 | + mockSampleDoc(); |
| 1461 | + mockGetIndexMapping(); |
| 1462 | + |
| 1463 | + // Mock the queryGenerationTool (MLModelTool) to return JSON embedded in text |
| 1464 | + doAnswer(invocation -> { |
| 1465 | + ActionListener<String> listener = invocation.getArgument(1); |
| 1466 | + listener.onResponse("Here is your query: {\"query\":{\"match\":{\"title\":\"test\"}}}"); |
| 1467 | + return null; |
| 1468 | + }).when(queryGenerationTool).run(any(), any()); |
| 1469 | + |
| 1470 | + // Create tool using constructor with the mocked queryGenerationTool |
| 1471 | + QueryPlanningTool tool = new QueryPlanningTool(LLM_GENERATED_TYPE_FIELD, queryGenerationTool, client, null); |
| 1472 | + |
| 1473 | + // Create extract_json processor config (same as in factory) |
| 1474 | + Map<String, Object> extractJsonConfig = new HashMap<>(); |
| 1475 | + extractJsonConfig.put("type", "extract_json"); |
| 1476 | + extractJsonConfig.put("extract_type", "object"); |
| 1477 | + extractJsonConfig.put("default", DEFAULT_QUERY); |
| 1478 | + |
| 1479 | + // Set the parser on the tool |
| 1480 | + tool.setOutputParser(ToolParser.createProcessingParser(null, List.of(extractJsonConfig))); |
| 1481 | + |
| 1482 | + // Run the tool end-to-end - the output parser will return a Map, not String |
| 1483 | + CompletableFuture<Object> future = new CompletableFuture<>(); |
| 1484 | + ActionListener<Object> listener = ActionListener.wrap(future::complete, future::completeExceptionally); |
| 1485 | + |
| 1486 | + Map<String, String> runParams = new HashMap<>(); |
| 1487 | + runParams.put(QUESTION_FIELD, "test query"); |
| 1488 | + runParams.put(INDEX_NAME_FIELD, "testIndex"); |
| 1489 | + tool.run(runParams, listener); |
| 1490 | + |
| 1491 | + // Trigger the async index mapping response |
| 1492 | + actionListenerCaptor.getValue().onResponse(getIndexResponse); |
| 1493 | + |
| 1494 | + // Verify the JSON was extracted correctly by the parser |
| 1495 | + Object resultObj = future.get(); |
| 1496 | + String result = resultObj instanceof String ? (String) resultObj : gson.toJson(resultObj); |
| 1497 | + assertEquals("{\"query\":{\"match\":{\"title\":\"test\"}}}", result); |
| 1498 | + } |
| 1499 | + |
1386 | 1500 | } |
0 commit comments