Skip to content

Commit d206ffd

Browse files
[Agentic Search]Add extract JSON processor in Query Planning Tool (opensearch-project#4356)
1 parent 2585b8e commit d206ffd

File tree

3 files changed

+192
-14
lines changed

3 files changed

+192
-14
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/processor/MLExtractJsonProcessor.java

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,22 @@ public Object process(Object input) {
117117

118118
String text = (String) input;
119119
if (text.trim().isEmpty()) {
120-
return defaultValue != null ? defaultValue : input;
120+
if (defaultValue != null) {
121+
log.warn("Input text is empty, returning default value");
122+
return defaultValue;
123+
}
124+
return input;
121125
}
122126

123127
try {
124128
int start = findJsonStart(text);
125129
if (start < 0) {
130+
if (defaultValue != null) {
131+
log.warn("No JSON found in input text, returning default value");
132+
return defaultValue;
133+
}
126134
log.debug("No JSON found in text");
127-
return defaultValue != null ? defaultValue : input;
135+
return input;
128136
}
129137

130138
JsonNode jsonNode = mapper.readTree(text.substring(start));
@@ -133,16 +141,24 @@ public Object process(Object input) {
133141
if (jsonNode.isObject()) {
134142
return mapper.convertValue(jsonNode, Map.class);
135143
}
144+
if (defaultValue != null) {
145+
log.warn("Expected JSON object but found {}, returning default value", jsonNode.getNodeType());
146+
return defaultValue;
147+
}
136148
log.debug("Expected JSON object but found {}", jsonNode.getNodeType());
137-
return defaultValue != null ? defaultValue : input;
149+
return input;
138150
}
139151

140152
if (EXTRACT_TYPE_ARRAY.equalsIgnoreCase(extractType)) {
141153
if (jsonNode.isArray()) {
142154
return mapper.convertValue(jsonNode, List.class);
143155
}
156+
if (defaultValue != null) {
157+
log.warn("Expected JSON array but found {}, returning default value", jsonNode.getNodeType());
158+
return defaultValue;
159+
}
144160
log.debug("Expected JSON array but found {}", jsonNode.getNodeType());
145-
return defaultValue != null ? defaultValue : input;
161+
return input;
146162
}
147163

148164
// auto detect
@@ -153,12 +169,20 @@ public Object process(Object input) {
153169
return mapper.convertValue(jsonNode, List.class);
154170
}
155171

172+
if (defaultValue != null) {
173+
log.warn("JSON node is neither object nor array: {}, returning default value", jsonNode.getNodeType());
174+
return defaultValue;
175+
}
156176
log.debug("JSON node is neither object nor array: {}", jsonNode.getNodeType());
157-
return defaultValue != null ? defaultValue : input;
177+
return input;
158178

159179
} catch (Exception e) {
160-
log.warn("Failed to extract JSON from text: {}", e.getMessage());
161-
return defaultValue != null ? defaultValue : input;
180+
if (defaultValue != null) {
181+
log.warn("Failed to extract JSON from input text: {}, returning default value", e.getMessage());
182+
return defaultValue;
183+
}
184+
log.warn("Failed to extract JSON from input text: {}", e.getMessage());
185+
return input;
162186
}
163187
}
164188

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_TEMPLATE_SELECTION_USER_PROMPT;
1919

2020
import java.io.IOException;
21+
import java.util.ArrayList;
2122
import java.util.HashMap;
2223
import java.util.List;
2324
import java.util.Map;
@@ -35,9 +36,12 @@
3536
import org.opensearch.core.action.ActionListener;
3637
import org.opensearch.index.IndexNotFoundException;
3738
import org.opensearch.index.query.QueryBuilders;
39+
import org.opensearch.ml.common.spi.tools.Parser;
3840
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
3941
import org.opensearch.ml.common.spi.tools.WithModelTool;
4042
import org.opensearch.ml.common.utils.ToolUtils;
43+
import org.opensearch.ml.engine.processor.ProcessorChain;
44+
import org.opensearch.ml.engine.tools.parser.ToolParser;
4145
import org.opensearch.search.SearchHit;
4246
import org.opensearch.search.builder.SearchSourceBuilder;
4347
import org.opensearch.transport.client.Client;
@@ -123,6 +127,10 @@ public class QueryPlanningTool implements WithModelTool {
123127
private String description = DEFAULT_DESCRIPTION;
124128
private final Client client;
125129

130+
@Setter
131+
@Getter
132+
private Parser outputParser;
133+
126134
public QueryPlanningTool(String generationType, MLModelTool queryGenerationTool, Client client, String searchTemplates) {
127135
this.generationType = generationType;
128136
this.queryGenerationTool = queryGenerationTool;
@@ -245,11 +253,12 @@ private <T> void executeQueryPlanning(Map<String, String> parameters, ActionList
245253
try {
246254
String queryString = (String) r;
247255
if (queryString == null || queryString.isBlank() || queryString.equals("null")) {
256+
log.debug("Model failed to generate the DSL query, returning the Default match all query");
248257
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
249258
String defaultQueryString = substitutor.replace(DEFAULT_QUERY);
250259
listener.onResponse((T) defaultQueryString);
251260
} else {
252-
listener.onResponse((T) queryString);
261+
listener.onResponse((T) (outputParser != null ? outputParser.parse(queryString) : queryString));
253262
}
254263
} catch (Exception e) {
255264
IllegalArgumentException parsingException = new IllegalArgumentException(
@@ -410,11 +419,11 @@ public void init(Client client) {
410419
}
411420

412421
@Override
413-
public QueryPlanningTool create(Map<String, Object> map) {
422+
public QueryPlanningTool create(Map<String, Object> params) {
414423

415-
MLModelTool queryGenerationTool = MLModelTool.Factory.getInstance().create(map);
424+
MLModelTool queryGenerationTool = MLModelTool.Factory.getInstance().create(params);
416425

417-
String type = (String) map.get(GENERATION_TYPE_FIELD);
426+
String type = (String) params.get(GENERATION_TYPE_FIELD);
418427

419428
// defaulted to llmGenerated
420429
if (type == null || type.isEmpty()) {
@@ -431,17 +440,48 @@ public QueryPlanningTool create(Map<String, Object> map) {
431440
// Parse search templates if generation type is user_templates
432441
String searchTemplates = null;
433442
if (USER_SEARCH_TEMPLATES_TYPE_FIELD.equals(type)) {
434-
if (!map.containsKey(SEARCH_TEMPLATES_FIELD)) {
443+
if (!params.containsKey(SEARCH_TEMPLATES_FIELD)) {
435444
throw new IllegalArgumentException("search_templates field is required when generation_type is 'user_templates'");
436445
} else {
437446
// array is parsed as a json string
438-
String searchTemplatesJson = (String) map.get(SEARCH_TEMPLATES_FIELD);
447+
String searchTemplatesJson = (String) params.get(SEARCH_TEMPLATES_FIELD);
439448
validateSearchTemplates(searchTemplatesJson);
440449
searchTemplates = gson.toJson(searchTemplatesJson);
441450
}
442451
}
443452

444-
return new QueryPlanningTool(type, queryGenerationTool, client, searchTemplates);
453+
QueryPlanningTool queryPlanningTool = new QueryPlanningTool(type, queryGenerationTool, client, searchTemplates);
454+
455+
// Create parser with default extract_json processor + any custom processors
456+
queryPlanningTool.setOutputParser(createParserWithDefaultExtractJson(params));
457+
458+
return queryPlanningTool;
459+
}
460+
461+
/**
462+
* Create a parser with a default extract_json processor prepended to any custom processors.
463+
* This ensures that JSON is extracted from the LLM response before applying any custom processing.
464+
*
465+
* @param params Tool parameters that may contain custom output_processors
466+
* @return Parser with extract_json as first processor, followed by any custom processors
467+
*/
468+
private Parser createParserWithDefaultExtractJson(Map<String, Object> params) {
469+
// Extract any existing custom processors from params
470+
List<Map<String, Object>> customProcessorConfigs = ProcessorChain.extractProcessorConfigs(params);
471+
472+
// Create the default extract_json processor config
473+
Map<String, Object> extractJsonConfig = new HashMap<>();
474+
extractJsonConfig.put("type", "extract_json");
475+
extractJsonConfig.put("extract_type", "object"); // Extract JSON objects only
476+
extractJsonConfig.put("default", DEFAULT_QUERY); // Return default match all query if no JSON found
477+
478+
// Combine: default extract_json first, then any custom processors
479+
List<Map<String, Object>> combinedProcessorConfigs = new ArrayList<>();
480+
combinedProcessorConfigs.add(extractJsonConfig);
481+
combinedProcessorConfigs.addAll(customProcessorConfigs);
482+
483+
// Create parser using the combined processor configs
484+
return ToolParser.createProcessingParser(null, combinedProcessorConfigs);
445485
}
446486

447487
private void validateSearchTemplates(Object searchTemplatesObj) {

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import static org.mockito.Mockito.times;
2020
import static org.mockito.Mockito.verify;
2121
import static org.mockito.Mockito.when;
22+
import static org.opensearch.ml.common.utils.StringUtils.gson;
23+
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY;
2224
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY_PLANNING_SYSTEM_PROMPT;
2325
import static org.opensearch.ml.engine.tools.QueryPlanningTool.DEFAULT_DESCRIPTION;
2426
import static org.opensearch.ml.engine.tools.QueryPlanningTool.INDEX_MAPPING_FIELD;
@@ -37,6 +39,7 @@
3739

3840
import java.io.IOException;
3941
import java.io.InputStream;
42+
import java.util.ArrayList;
4043
import java.util.Collections;
4144
import java.util.HashMap;
4245
import java.util.List;
@@ -58,7 +61,9 @@
5861
import org.opensearch.core.action.ActionListener;
5962
import org.opensearch.core.xcontent.DeprecationHandler;
6063
import org.opensearch.core.xcontent.NamedXContentRegistry;
64+
import org.opensearch.ml.common.spi.tools.Parser;
6165
import org.opensearch.ml.common.spi.tools.Tool;
66+
import org.opensearch.ml.engine.tools.parser.ToolParser;
6267
import org.opensearch.script.StoredScriptSource;
6368
import org.opensearch.transport.client.AdminClient;
6469
import org.opensearch.transport.client.Client;
@@ -1383,4 +1388,113 @@ public void testTemplateSelectionPromptsWithDefaults() throws ExecutionException
13831388
assertTrue(firstCallParams.get("user_prompt").contains("INPUTS"));
13841389
}
13851390

1391+
// Test 1: Create tool from factory, get parser, test parser behavior directly
1392+
@SneakyThrows
1393+
@Test
1394+
public void testFactoryCreatedTool_DefaultExtractJsonParser() {
1395+
// Create tool using factory and verify the output parser is correctly configured
1396+
Map<String, Object> params = Map.of(MODEL_ID_FIELD, "test_model_id");
1397+
QueryPlanningTool tool = QueryPlanningTool.Factory.getInstance().create(params);
1398+
1399+
// Verify the output parser was created
1400+
assertNotNull("Output parser should be created by factory", tool.getOutputParser());
1401+
1402+
// Test the parser directly with different inputs
1403+
Parser outputParser = tool.getOutputParser();
1404+
1405+
// Test case 1: Extract JSON object from text
1406+
Object parsedResult1 = outputParser.parse("Here is your query: {\"query\":{\"match\":{\"title\":\"test\"}}}");
1407+
String resultWithText = parsedResult1 instanceof String ? (String) parsedResult1 : gson.toJson(parsedResult1);
1408+
assertEquals("{\"query\":{\"match\":{\"title\":\"test\"}}}", resultWithText);
1409+
1410+
// Test case 2: Extract pure JSON
1411+
Object parsedResult2 = outputParser.parse("{\"query\":{\"match\":{\"title\":\"test\"}}}");
1412+
String resultPureJson = parsedResult2 instanceof String ? (String) parsedResult2 : gson.toJson(parsedResult2);
1413+
assertEquals("{\"query\":{\"match\":{\"title\":\"test\"}}}", resultPureJson);
1414+
1415+
// Test case 3: No valid JSON - should return default template
1416+
Object parsedResult3 = outputParser.parse("No JSON here at all");
1417+
String resultNoJson = parsedResult3 instanceof String ? (String) parsedResult3 : gson.toJson(parsedResult3);
1418+
assertEquals(DEFAULT_QUERY, resultNoJson);
1419+
}
1420+
1421+
// Test 2: Create tool from factory with custom processors, verify both default and custom processors work
1422+
@SneakyThrows
1423+
@Test
1424+
public void testFactoryCreatedTool_WithCustomProcessors() {
1425+
// Create tool using factory with custom output_processors (set_field)
1426+
Map<String, Object> params = new HashMap<>();
1427+
params.put(MODEL_ID_FIELD, "test_model_id");
1428+
1429+
// Add custom processor configuration
1430+
List<Map<String, Object>> outputProcessors = new ArrayList<>();
1431+
Map<String, Object> setFieldConfig = new HashMap<>();
1432+
setFieldConfig.put("type", "set_field");
1433+
setFieldConfig.put("path", "$.metadata");
1434+
setFieldConfig.put("value", Map.of("source", "query_planner_tool"));
1435+
outputProcessors.add(setFieldConfig);
1436+
params.put("output_processors", outputProcessors);
1437+
1438+
QueryPlanningTool tool = QueryPlanningTool.Factory.getInstance().create(params);
1439+
1440+
// Verify the output parser was created
1441+
assertNotNull("Output parser should be created by factory", tool.getOutputParser());
1442+
1443+
// Test the parser - it should use BOTH default extract_json AND custom set_field processors
1444+
Parser outputParser = tool.getOutputParser();
1445+
1446+
// Test: Extract JSON from text (default extract_json) + add metadata field (custom set_field)
1447+
String inputWithText = "Here is your query: {\"query\":{\"match\":{\"title\":\"test\"}}}";
1448+
Object parsedResult = outputParser.parse(inputWithText);
1449+
String result = parsedResult instanceof String ? (String) parsedResult : gson.toJson(parsedResult);
1450+
1451+
// Verify both processors worked: extract_json extracted JSON, set_field added metadata
1452+
String expectedResult = "{\"query\":{\"match\":{\"title\":\"test\"}},\"metadata\":{\"source\":\"query_planner_tool\"}}";
1453+
assertEquals("Parser should extract JSON and add metadata field", expectedResult, result);
1454+
}
1455+
1456+
// Test 3: Create tool with mocked queryGenerationTool, manually set extract_json processor, run end-to-end
1457+
@SneakyThrows
1458+
@Test
1459+
public void testQueryPlanningTool_WithMockedMLModelTool_EndToEnd() {
1460+
mockSampleDoc();
1461+
mockGetIndexMapping();
1462+
1463+
// Mock the queryGenerationTool (MLModelTool) to return JSON embedded in text
1464+
doAnswer(invocation -> {
1465+
ActionListener<String> listener = invocation.getArgument(1);
1466+
listener.onResponse("Here is your query: {\"query\":{\"match\":{\"title\":\"test\"}}}");
1467+
return null;
1468+
}).when(queryGenerationTool).run(any(), any());
1469+
1470+
// Create tool using constructor with the mocked queryGenerationTool
1471+
QueryPlanningTool tool = new QueryPlanningTool(LLM_GENERATED_TYPE_FIELD, queryGenerationTool, client, null);
1472+
1473+
// Create extract_json processor config (same as in factory)
1474+
Map<String, Object> extractJsonConfig = new HashMap<>();
1475+
extractJsonConfig.put("type", "extract_json");
1476+
extractJsonConfig.put("extract_type", "object");
1477+
extractJsonConfig.put("default", DEFAULT_QUERY);
1478+
1479+
// Set the parser on the tool
1480+
tool.setOutputParser(ToolParser.createProcessingParser(null, List.of(extractJsonConfig)));
1481+
1482+
// Run the tool end-to-end - the output parser will return a Map, not String
1483+
CompletableFuture<Object> future = new CompletableFuture<>();
1484+
ActionListener<Object> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
1485+
1486+
Map<String, String> runParams = new HashMap<>();
1487+
runParams.put(QUESTION_FIELD, "test query");
1488+
runParams.put(INDEX_NAME_FIELD, "testIndex");
1489+
tool.run(runParams, listener);
1490+
1491+
// Trigger the async index mapping response
1492+
actionListenerCaptor.getValue().onResponse(getIndexResponse);
1493+
1494+
// Verify the JSON was extracted correctly by the parser
1495+
Object resultObj = future.get();
1496+
String result = resultObj instanceof String ? (String) resultObj : gson.toJson(resultObj);
1497+
assertEquals("{\"query\":{\"match\":{\"title\":\"test\"}}}", result);
1498+
}
1499+
13861500
}

0 commit comments

Comments
 (0)