Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,22 @@ public Object process(Object input) {

String text = (String) input;
if (text.trim().isEmpty()) {
return defaultValue != null ? defaultValue : input;
if (defaultValue != null) {
log.warn("Input text is empty, returning default value");
return defaultValue;
}
return input;
}

try {
int start = findJsonStart(text);
if (start < 0) {
if (defaultValue != null) {
log.warn("No JSON found in input text, returning default value");
return defaultValue;
}
log.debug("No JSON found in text");
return defaultValue != null ? defaultValue : input;
return input;
}

JsonNode jsonNode = mapper.readTree(text.substring(start));
Expand All @@ -133,16 +141,24 @@ public Object process(Object input) {
if (jsonNode.isObject()) {
return mapper.convertValue(jsonNode, Map.class);
}
if (defaultValue != null) {
log.warn("Expected JSON object but found {}, returning default value", jsonNode.getNodeType());
return defaultValue;
}
log.debug("Expected JSON object but found {}", jsonNode.getNodeType());
return defaultValue != null ? defaultValue : input;
return input;
}

if (EXTRACT_TYPE_ARRAY.equalsIgnoreCase(extractType)) {
if (jsonNode.isArray()) {
return mapper.convertValue(jsonNode, List.class);
}
if (defaultValue != null) {
log.warn("Expected JSON array but found {}, returning default value", jsonNode.getNodeType());
return defaultValue;
}
log.debug("Expected JSON array but found {}", jsonNode.getNodeType());
return defaultValue != null ? defaultValue : input;
return input;
}

// auto detect
Expand All @@ -153,12 +169,20 @@ public Object process(Object input) {
return mapper.convertValue(jsonNode, List.class);
}

if (defaultValue != null) {
log.warn("JSON node is neither object nor array: {}, returning default value", jsonNode.getNodeType());
return defaultValue;
}
log.debug("JSON node is neither object nor array: {}", jsonNode.getNodeType());
return defaultValue != null ? defaultValue : input;
return input;

} catch (Exception e) {
log.warn("Failed to extract JSON from text: {}", e.getMessage());
return defaultValue != null ? defaultValue : input;
if (defaultValue != null) {
log.warn("Failed to extract JSON from input text: {}, returning default value", e.getMessage());
return defaultValue;
}
log.warn("Failed to extract JSON from input text: {}", e.getMessage());
return input;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_TEMPLATE_SELECTION_USER_PROMPT;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -35,9 +36,12 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.ml.common.utils.ToolUtils;
import org.opensearch.ml.engine.processor.ProcessorChain;
import org.opensearch.ml.engine.tools.parser.ToolParser;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.transport.client.Client;
Expand Down Expand Up @@ -123,6 +127,10 @@ public class QueryPlanningTool implements WithModelTool {
private String description = DEFAULT_DESCRIPTION;
private final Client client;

@Setter
@Getter
private Parser outputParser;

public QueryPlanningTool(String generationType, MLModelTool queryGenerationTool, Client client, String searchTemplates) {
this.generationType = generationType;
this.queryGenerationTool = queryGenerationTool;
Expand Down Expand Up @@ -245,11 +253,12 @@ private <T> void executeQueryPlanning(Map<String, String> parameters, ActionList
try {
String queryString = (String) r;
if (queryString == null || queryString.isBlank() || queryString.equals("null")) {
log.debug("Model failed to generate the DSL query, returning the Default match all query");
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
String defaultQueryString = substitutor.replace(DEFAULT_QUERY);
listener.onResponse((T) defaultQueryString);
} else {
listener.onResponse((T) queryString);
listener.onResponse((T) (outputParser != null ? outputParser.parse(queryString) : queryString));
}
} catch (Exception e) {
IllegalArgumentException parsingException = new IllegalArgumentException(
Expand Down Expand Up @@ -410,11 +419,11 @@ public void init(Client client) {
}

@Override
public QueryPlanningTool create(Map<String, Object> map) {
public QueryPlanningTool create(Map<String, Object> params) {

MLModelTool queryGenerationTool = MLModelTool.Factory.getInstance().create(map);
MLModelTool queryGenerationTool = MLModelTool.Factory.getInstance().create(params);

String type = (String) map.get(GENERATION_TYPE_FIELD);
String type = (String) params.get(GENERATION_TYPE_FIELD);

// defaulted to llmGenerated
if (type == null || type.isEmpty()) {
Expand All @@ -431,17 +440,48 @@ public QueryPlanningTool create(Map<String, Object> map) {
// Parse search templates if generation type is user_templates
String searchTemplates = null;
if (USER_SEARCH_TEMPLATES_TYPE_FIELD.equals(type)) {
if (!map.containsKey(SEARCH_TEMPLATES_FIELD)) {
if (!params.containsKey(SEARCH_TEMPLATES_FIELD)) {
throw new IllegalArgumentException("search_templates field is required when generation_type is 'user_templates'");
} else {
// array is parsed as a json string
String searchTemplatesJson = (String) map.get(SEARCH_TEMPLATES_FIELD);
String searchTemplatesJson = (String) params.get(SEARCH_TEMPLATES_FIELD);
validateSearchTemplates(searchTemplatesJson);
searchTemplates = gson.toJson(searchTemplatesJson);
}
}

return new QueryPlanningTool(type, queryGenerationTool, client, searchTemplates);
QueryPlanningTool queryPlanningTool = new QueryPlanningTool(type, queryGenerationTool, client, searchTemplates);

// Create parser with default extract_json processor + any custom processors
queryPlanningTool.setOutputParser(createParserWithDefaultExtractJson(params));

return queryPlanningTool;
}

/**
* Create a parser with a default extract_json processor prepended to any custom processors.
* This ensures that JSON is extracted from the LLM response before applying any custom processing.
*
* @param params Tool parameters that may contain custom output_processors
* @return Parser with extract_json as first processor, followed by any custom processors
*/
private Parser createParserWithDefaultExtractJson(Map<String, Object> params) {
// Extract any existing custom processors from params
List<Map<String, Object>> customProcessorConfigs = ProcessorChain.extractProcessorConfigs(params);

// Create the default extract_json processor config
Map<String, Object> extractJsonConfig = new HashMap<>();
extractJsonConfig.put("type", "extract_json");
extractJsonConfig.put("extract_type", "object"); // Extract JSON objects only
extractJsonConfig.put("default", DEFAULT_QUERY); // Return default match all query if no JSON found

// Combine: default extract_json first, then any custom processors
List<Map<String, Object>> combinedProcessorConfigs = new ArrayList<>();
combinedProcessorConfigs.add(extractJsonConfig);
combinedProcessorConfigs.addAll(customProcessorConfigs);

// Create parser using the combined processor configs
return ToolParser.createProcessingParser(null, combinedProcessorConfigs);
}

private void validateSearchTemplates(Object searchTemplatesObj) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY;
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY_PLANNING_SYSTEM_PROMPT;
import static org.opensearch.ml.engine.tools.QueryPlanningTool.DEFAULT_DESCRIPTION;
import static org.opensearch.ml.engine.tools.QueryPlanningTool.INDEX_MAPPING_FIELD;
Expand All @@ -37,6 +39,7 @@

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand All @@ -58,7 +61,9 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.engine.tools.parser.ToolParser;
import org.opensearch.script.StoredScriptSource;
import org.opensearch.transport.client.AdminClient;
import org.opensearch.transport.client.Client;
Expand Down Expand Up @@ -1383,4 +1388,113 @@ public void testTemplateSelectionPromptsWithDefaults() throws ExecutionException
assertTrue(firstCallParams.get("user_prompt").contains("INPUTS"));
}

// Test 1: Create tool from factory, get parser, test parser behavior directly
@SneakyThrows
@Test
public void testFactoryCreatedTool_DefaultExtractJsonParser() {
// Create tool using factory and verify the output parser is correctly configured
Map<String, Object> params = Map.of(MODEL_ID_FIELD, "test_model_id");
QueryPlanningTool tool = QueryPlanningTool.Factory.getInstance().create(params);

// Verify the output parser was created
assertNotNull("Output parser should be created by factory", tool.getOutputParser());

// Test the parser directly with different inputs
Parser outputParser = tool.getOutputParser();

// Test case 1: Extract JSON object from text
Object parsedResult1 = outputParser.parse("Here is your query: {\"query\":{\"match\":{\"title\":\"test\"}}}");
String resultWithText = parsedResult1 instanceof String ? (String) parsedResult1 : gson.toJson(parsedResult1);
assertEquals("{\"query\":{\"match\":{\"title\":\"test\"}}}", resultWithText);

// Test case 2: Extract pure JSON
Object parsedResult2 = outputParser.parse("{\"query\":{\"match\":{\"title\":\"test\"}}}");
String resultPureJson = parsedResult2 instanceof String ? (String) parsedResult2 : gson.toJson(parsedResult2);
assertEquals("{\"query\":{\"match\":{\"title\":\"test\"}}}", resultPureJson);

// Test case 3: No valid JSON - should return default template
Object parsedResult3 = outputParser.parse("No JSON here at all");
String resultNoJson = parsedResult3 instanceof String ? (String) parsedResult3 : gson.toJson(parsedResult3);
assertEquals(DEFAULT_QUERY, resultNoJson);
}

// Test 2: Create tool from factory with custom processors, verify both default and custom processors work
@SneakyThrows
@Test
public void testFactoryCreatedTool_WithCustomProcessors() {
// Create tool using factory with custom output_processors (set_field)
Map<String, Object> params = new HashMap<>();
params.put(MODEL_ID_FIELD, "test_model_id");

// Add custom processor configuration
List<Map<String, Object>> outputProcessors = new ArrayList<>();
Map<String, Object> setFieldConfig = new HashMap<>();
setFieldConfig.put("type", "set_field");
setFieldConfig.put("path", "$.metadata");
setFieldConfig.put("value", Map.of("source", "query_planner_tool"));
outputProcessors.add(setFieldConfig);
params.put("output_processors", outputProcessors);

QueryPlanningTool tool = QueryPlanningTool.Factory.getInstance().create(params);

// Verify the output parser was created
assertNotNull("Output parser should be created by factory", tool.getOutputParser());

// Test the parser - it should use BOTH default extract_json AND custom set_field processors
Parser outputParser = tool.getOutputParser();

// Test: Extract JSON from text (default extract_json) + add metadata field (custom set_field)
String inputWithText = "Here is your query: {\"query\":{\"match\":{\"title\":\"test\"}}}";
Object parsedResult = outputParser.parse(inputWithText);
String result = parsedResult instanceof String ? (String) parsedResult : gson.toJson(parsedResult);

// Verify both processors worked: extract_json extracted JSON, set_field added metadata
String expectedResult = "{\"query\":{\"match\":{\"title\":\"test\"}},\"metadata\":{\"source\":\"query_planner_tool\"}}";
assertEquals("Parser should extract JSON and add metadata field", expectedResult, result);
}

// Test 3: Create tool with mocked queryGenerationTool, manually set extract_json processor, run end-to-end
@SneakyThrows
@Test
public void testQueryPlanningTool_WithMockedMLModelTool_EndToEnd() {
mockSampleDoc();
mockGetIndexMapping();

// Mock the queryGenerationTool (MLModelTool) to return JSON embedded in text
doAnswer(invocation -> {
ActionListener<String> listener = invocation.getArgument(1);
listener.onResponse("Here is your query: {\"query\":{\"match\":{\"title\":\"test\"}}}");
return null;
}).when(queryGenerationTool).run(any(), any());

// Create tool using constructor with the mocked queryGenerationTool
QueryPlanningTool tool = new QueryPlanningTool(LLM_GENERATED_TYPE_FIELD, queryGenerationTool, client, null);

// Create extract_json processor config (same as in factory)
Map<String, Object> extractJsonConfig = new HashMap<>();
extractJsonConfig.put("type", "extract_json");
extractJsonConfig.put("extract_type", "object");
extractJsonConfig.put("default", DEFAULT_QUERY);

// Set the parser on the tool
tool.setOutputParser(ToolParser.createProcessingParser(null, List.of(extractJsonConfig)));

// Run the tool end-to-end - the output parser will return a Map, not String
CompletableFuture<Object> future = new CompletableFuture<>();
ActionListener<Object> listener = ActionListener.wrap(future::complete, future::completeExceptionally);

Map<String, String> runParams = new HashMap<>();
runParams.put(QUESTION_FIELD, "test query");
runParams.put(INDEX_NAME_FIELD, "testIndex");
tool.run(runParams, listener);

// Trigger the async index mapping response
actionListenerCaptor.getValue().onResponse(getIndexResponse);

// Verify the JSON was extracted correctly by the parser
Object resultObj = future.get();
String result = resultObj instanceof String ? (String) resultObj : gson.toJson(resultObj);
assertEquals("{\"query\":{\"match\":{\"title\":\"test\"}}}", result);
}

}
Loading