Skip to content

Commit 04c2973

Browse files
Add extact JSON processor in Query Planning Tool
Signed-off-by: rithin-pullela-aws <rithinp@amazon.com>
1 parent 2585b8e commit 04c2973

File tree

2 files changed

+160
-7
lines changed

2 files changed

+160
-7
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_TEMPLATE_SELECTION_USER_PROMPT;
1919

2020
import java.io.IOException;
21+
import java.util.ArrayList;
2122
import java.util.HashMap;
2223
import java.util.List;
2324
import java.util.Map;
@@ -35,9 +36,12 @@
3536
import org.opensearch.core.action.ActionListener;
3637
import org.opensearch.index.IndexNotFoundException;
3738
import org.opensearch.index.query.QueryBuilders;
39+
import org.opensearch.ml.common.spi.tools.Parser;
3840
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
3941
import org.opensearch.ml.common.spi.tools.WithModelTool;
4042
import org.opensearch.ml.common.utils.ToolUtils;
43+
import org.opensearch.ml.engine.processor.ProcessorChain;
44+
import org.opensearch.ml.engine.tools.parser.ToolParser;
4145
import org.opensearch.search.SearchHit;
4246
import org.opensearch.search.builder.SearchSourceBuilder;
4347
import org.opensearch.transport.client.Client;
@@ -123,6 +127,10 @@ public class QueryPlanningTool implements WithModelTool {
123127
private String description = DEFAULT_DESCRIPTION;
124128
private final Client client;
125129

130+
@Setter
131+
@Getter
132+
private Parser outputParser;
133+
126134
public QueryPlanningTool(String generationType, MLModelTool queryGenerationTool, Client client, String searchTemplates) {
127135
this.generationType = generationType;
128136
this.queryGenerationTool = queryGenerationTool;
@@ -249,7 +257,7 @@ private <T> void executeQueryPlanning(Map<String, String> parameters, ActionList
249257
String defaultQueryString = substitutor.replace(DEFAULT_QUERY);
250258
listener.onResponse((T) defaultQueryString);
251259
} else {
252-
listener.onResponse((T) queryString);
260+
listener.onResponse((T) (outputParser != null ? outputParser.parse(queryString) : queryString));
253261
}
254262
} catch (Exception e) {
255263
IllegalArgumentException parsingException = new IllegalArgumentException(
@@ -410,11 +418,11 @@ public void init(Client client) {
410418
}
411419

412420
@Override
413-
public QueryPlanningTool create(Map<String, Object> map) {
421+
public QueryPlanningTool create(Map<String, Object> params) {
414422

415-
MLModelTool queryGenerationTool = MLModelTool.Factory.getInstance().create(map);
423+
MLModelTool queryGenerationTool = MLModelTool.Factory.getInstance().create(params);
416424

417-
String type = (String) map.get(GENERATION_TYPE_FIELD);
425+
String type = (String) params.get(GENERATION_TYPE_FIELD);
418426

419427
// defaulted to llmGenerated
420428
if (type == null || type.isEmpty()) {
@@ -431,17 +439,48 @@ public QueryPlanningTool create(Map<String, Object> map) {
431439
// Parse search templates if generation type is user_templates
432440
String searchTemplates = null;
433441
if (USER_SEARCH_TEMPLATES_TYPE_FIELD.equals(type)) {
434-
if (!map.containsKey(SEARCH_TEMPLATES_FIELD)) {
442+
if (!params.containsKey(SEARCH_TEMPLATES_FIELD)) {
435443
throw new IllegalArgumentException("search_templates field is required when generation_type is 'user_templates'");
436444
} else {
437445
// array is parsed as a json string
438-
String searchTemplatesJson = (String) map.get(SEARCH_TEMPLATES_FIELD);
446+
String searchTemplatesJson = (String) params.get(SEARCH_TEMPLATES_FIELD);
439447
validateSearchTemplates(searchTemplatesJson);
440448
searchTemplates = gson.toJson(searchTemplatesJson);
441449
}
442450
}
443451

444-
return new QueryPlanningTool(type, queryGenerationTool, client, searchTemplates);
452+
QueryPlanningTool queryPlanningTool = new QueryPlanningTool(type, queryGenerationTool, client, searchTemplates);
453+
454+
// Create parser with default extract_json processor + any custom processors
455+
queryPlanningTool.setOutputParser(createParserWithDefaultExtractJson(params));
456+
457+
return queryPlanningTool;
458+
}
459+
460+
/**
461+
* Create a parser with a default extract_json processor prepended to any custom processors.
462+
* This ensures that JSON is extracted from the LLM response before applying any custom processing.
463+
*
464+
* @param params Tool parameters that may contain custom output_processors
465+
* @return Parser with extract_json as first processor, followed by any custom processors
466+
*/
467+
private Parser createParserWithDefaultExtractJson(Map<String, Object> params) {
468+
// Extract any existing custom processors from params
469+
List<Map<String, Object>> customProcessorConfigs = ProcessorChain.extractProcessorConfigs(params);
470+
471+
// Create the default extract_json processor config
472+
Map<String, Object> extractJsonConfig = new HashMap<>();
473+
extractJsonConfig.put("type", "extract_json");
474+
extractJsonConfig.put("extract_type", "object"); // Extract JSON objects only
475+
extractJsonConfig.put("default", DEFAULT_SEARCH_TEMPLATE); // Return default search template if no JSON found
476+
477+
// Combine: default extract_json first, then any custom processors
478+
List<Map<String, Object>> combinedProcessorConfigs = new ArrayList<>();
479+
combinedProcessorConfigs.add(extractJsonConfig);
480+
combinedProcessorConfigs.addAll(customProcessorConfigs);
481+
482+
// Create parser using the combined processor configs
483+
return ToolParser.createProcessingParser(null, combinedProcessorConfigs);
445484
}
446485

447486
private void validateSearchTemplates(Object searchTemplatesObj) {

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
import static org.mockito.Mockito.times;
2020
import static org.mockito.Mockito.verify;
2121
import static org.mockito.Mockito.when;
22+
import static org.opensearch.ml.common.utils.StringUtils.gson;
2223
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY_PLANNING_SYSTEM_PROMPT;
24+
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_SEARCH_TEMPLATE;
2325
import static org.opensearch.ml.engine.tools.QueryPlanningTool.DEFAULT_DESCRIPTION;
2426
import static org.opensearch.ml.engine.tools.QueryPlanningTool.INDEX_MAPPING_FIELD;
2527
import static org.opensearch.ml.engine.tools.QueryPlanningTool.INDEX_NAME_FIELD;
@@ -37,6 +39,7 @@
3739

3840
import java.io.IOException;
3941
import java.io.InputStream;
42+
import java.util.ArrayList;
4043
import java.util.Collections;
4144
import java.util.HashMap;
4245
import java.util.List;
@@ -58,7 +61,9 @@
5861
import org.opensearch.core.action.ActionListener;
5962
import org.opensearch.core.xcontent.DeprecationHandler;
6063
import org.opensearch.core.xcontent.NamedXContentRegistry;
64+
import org.opensearch.ml.common.spi.tools.Parser;
6165
import org.opensearch.ml.common.spi.tools.Tool;
66+
import org.opensearch.ml.engine.tools.parser.ToolParser;
6267
import org.opensearch.script.StoredScriptSource;
6368
import org.opensearch.transport.client.AdminClient;
6469
import org.opensearch.transport.client.Client;
@@ -1383,4 +1388,113 @@ public void testTemplateSelectionPromptsWithDefaults() throws ExecutionException
13831388
assertTrue(firstCallParams.get("user_prompt").contains("INPUTS"));
13841389
}
13851390

1391+
// Test 1: Create tool from factory, get parser, test parser behavior directly
1392+
@SneakyThrows
1393+
@Test
1394+
public void testFactoryCreatedTool_DefaultExtractJsonParser() {
1395+
// Create tool using factory and verify the output parser is correctly configured
1396+
Map<String, Object> params = Map.of(MODEL_ID_FIELD, "test_model_id");
1397+
QueryPlanningTool tool = QueryPlanningTool.Factory.getInstance().create(params);
1398+
1399+
// Verify the output parser was created
1400+
assertNotNull("Output parser should be created by factory", tool.getOutputParser());
1401+
1402+
// Test the parser directly with different inputs
1403+
Parser outputParser = tool.getOutputParser();
1404+
1405+
// Test case 1: Extract JSON object from text
1406+
Object parsedResult1 = outputParser.parse("Here is your query: {\"query\":{\"match\":{\"title\":\"test\"}}}");
1407+
String resultWithText = parsedResult1 instanceof String ? (String) parsedResult1 : gson.toJson(parsedResult1);
1408+
assertEquals("{\"query\":{\"match\":{\"title\":\"test\"}}}", resultWithText);
1409+
1410+
// Test case 2: Extract pure JSON
1411+
Object parsedResult2 = outputParser.parse("{\"query\":{\"match\":{\"title\":\"test\"}}}");
1412+
String resultPureJson = parsedResult2 instanceof String ? (String) parsedResult2 : gson.toJson(parsedResult2);
1413+
assertEquals("{\"query\":{\"match\":{\"title\":\"test\"}}}", resultPureJson);
1414+
1415+
// Test case 3: No valid JSON - should return default template
1416+
Object parsedResult3 = outputParser.parse("No JSON here at all");
1417+
String resultNoJson = parsedResult3 instanceof String ? (String) parsedResult3 : gson.toJson(parsedResult3);
1418+
assertEquals(DEFAULT_SEARCH_TEMPLATE, resultNoJson);
1419+
}
1420+
1421+
// Test 2: Create tool from factory with custom processors, verify both default and custom processors work
1422+
@SneakyThrows
1423+
@Test
1424+
public void testFactoryCreatedTool_WithCustomProcessors() {
1425+
// Create tool using factory with custom output_processors (set_field)
1426+
Map<String, Object> params = new HashMap<>();
1427+
params.put(MODEL_ID_FIELD, "test_model_id");
1428+
1429+
// Add custom processor configuration
1430+
List<Map<String, Object>> outputProcessors = new ArrayList<>();
1431+
Map<String, Object> setFieldConfig = new HashMap<>();
1432+
setFieldConfig.put("type", "set_field");
1433+
setFieldConfig.put("path", "$.metadata");
1434+
setFieldConfig.put("value", Map.of("source", "query_planner_tool"));
1435+
outputProcessors.add(setFieldConfig);
1436+
params.put("output_processors", outputProcessors);
1437+
1438+
QueryPlanningTool tool = QueryPlanningTool.Factory.getInstance().create(params);
1439+
1440+
// Verify the output parser was created
1441+
assertNotNull("Output parser should be created by factory", tool.getOutputParser());
1442+
1443+
// Test the parser - it should use BOTH default extract_json AND custom set_field processors
1444+
Parser outputParser = tool.getOutputParser();
1445+
1446+
// Test: Extract JSON from text (default extract_json) + add metadata field (custom set_field)
1447+
String inputWithText = "Here is your query: {\"query\":{\"match\":{\"title\":\"test\"}}}";
1448+
Object parsedResult = outputParser.parse(inputWithText);
1449+
String result = parsedResult instanceof String ? (String) parsedResult : gson.toJson(parsedResult);
1450+
1451+
// Verify both processors worked: extract_json extracted JSON, set_field added metadata
1452+
String expectedResult = "{\"query\":{\"match\":{\"title\":\"test\"}},\"metadata\":{\"source\":\"query_planner_tool\"}}";
1453+
assertEquals("Parser should extract JSON and add metadata field", expectedResult, result);
1454+
}
1455+
1456+
// Test 3: Create tool with mocked queryGenerationTool, manually set extract_json processor, run end-to-end
1457+
@SneakyThrows
1458+
@Test
1459+
public void testQueryPlanningTool_WithMockedMLModelTool_EndToEnd() {
1460+
mockSampleDoc();
1461+
mockGetIndexMapping();
1462+
1463+
// Mock the queryGenerationTool (MLModelTool) to return JSON embedded in text
1464+
doAnswer(invocation -> {
1465+
ActionListener<String> listener = invocation.getArgument(1);
1466+
listener.onResponse("Here is your query: {\"query\":{\"match\":{\"title\":\"test\"}}}");
1467+
return null;
1468+
}).when(queryGenerationTool).run(any(), any());
1469+
1470+
// Create tool using constructor with the mocked queryGenerationTool
1471+
QueryPlanningTool tool = new QueryPlanningTool(LLM_GENERATED_TYPE_FIELD, queryGenerationTool, client, null);
1472+
1473+
// Create extract_json processor config (same as in factory)
1474+
Map<String, Object> extractJsonConfig = new HashMap<>();
1475+
extractJsonConfig.put("type", "extract_json");
1476+
extractJsonConfig.put("extract_type", "object");
1477+
extractJsonConfig.put("default", DEFAULT_SEARCH_TEMPLATE);
1478+
1479+
// Set the parser on the tool
1480+
tool.setOutputParser(ToolParser.createProcessingParser(null, List.of(extractJsonConfig)));
1481+
1482+
// Run the tool end-to-end - the output parser will return a Map, not String
1483+
CompletableFuture<Object> future = new CompletableFuture<>();
1484+
ActionListener<Object> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
1485+
1486+
Map<String, String> runParams = new HashMap<>();
1487+
runParams.put(QUESTION_FIELD, "test query");
1488+
runParams.put(INDEX_NAME_FIELD, "testIndex");
1489+
tool.run(runParams, listener);
1490+
1491+
// Trigger the async index mapping response
1492+
actionListenerCaptor.getValue().onResponse(getIndexResponse);
1493+
1494+
// Verify the JSON was extracted correctly by the parser
1495+
Object resultObj = future.get();
1496+
String result = resultObj instanceof String ? (String) resultObj : gson.toJson(resultObj);
1497+
assertEquals("{\"query\":{\"match\":{\"title\":\"test\"}}}", result);
1498+
}
1499+
13861500
}

0 commit comments

Comments
 (0)