Skip to content

Commit 17a0f7a

Browse files
Make compatible with agentic search
Signed-off-by: rithin-pullela-aws <rithinp@amazon.com>
1 parent 2a00167 commit 17a0f7a

File tree

3 files changed

+63
-13
lines changed

3 files changed

+63
-13
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningPromptTemplate.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
public class QueryPlanningPromptTemplate {
44

5-
public static final String DEFAULT_QUERY =
6-
"{ \"query\": { \"multi_match\" : { \"query\": \"${parameters.query_text}\", \"fields\": ${parameters.query_fields:-[\"*\"]} } } }";
5+
public static final String DEFAULT_QUERY = "{\"size\":10,\"query\":{\"match_all\":{}}}";
76

87
public static final String QUERY_TYPE_RULES = "\nChoose query types based on user intent and fields: \n"
98
+ "match: single-token full‑text searches on analyzed text fields, \n"
@@ -31,26 +30,32 @@ public class QueryPlanningPromptTemplate {
3130
+ QUERY_TYPE_RULES
3231
+ AGGREGATION_RULES;
3332

33+
public static final String USE_QUERY_FIELDS_INSTRUCTION =
34+
"When Query Fields are provided, prioritize incorporating them into the generated query.";
35+
3436
public static final String OUTPUT_FORMAT_INSTRUCTIONS = "Output format: Output only a valid escaped JSON string or the literal \n"
3537
+ DEFAULT_QUERY
3638
+ " \nReturn exactly one JSON object. "
3739
+ "Output nothing before or after it — no code fences/backticks (`), angle brackets (< >), hash marks (#), asterisks (*), pipes (|), tildes (~), ellipses (… or ...), emojis, typographic quotes (\" \"), non-breaking spaces (U+00A0), zero-width characters (U+200B, U+FEFF), or any other markup/control characters. "
3840
+ "Use valid JSON only (standard double quotes \"; no comments; no trailing commas). "
3941
+ "This applies to formatting only, string values inside the JSON may contain any needed Unicode characters. \n"
4042
+ "Follow the examples below. \n"
43+
+ USE_QUERY_FIELDS_INSTRUCTION
4144
+ "Fallback: If the request cannot be fulfilled with the mapping (missing field, unsupported feature, etc.), \n"
4245
+ "output the literal string: "
4346
+ DEFAULT_QUERY;
4447

4548
// Individual example constants for better maintainability
4649
public static final String EXAMPLE_1 = "Example 1 — numeric range \n"
4750
+ "Input: Show all products that cost more than 50 dollars. \n"
48-
+ "Mapping: \"{ \"properties\": { \"price\": { \"type\": \"float\" } } }\" \n"
51+
+ "Mapping: { \"properties\": { \"price\": { \"type\": \"float\" }, \"cost\": { \"type\": \"float\" } } }\n"
52+
+ "query_fields: [price]"
4953
+ "Output: \"{ \"query\": { \"range\": { \"price\": { \"gt\": 50 } } } }\" \n";
5054

5155
public static final String EXAMPLE_2 = "Example 2 — text match + exact filter \n"
5256
+ "Input: Find employees in London who are active. \n"
5357
+ "Mapping: \"{ \"properties\": { \"city\": { \"type\": \"text\", \"fields\": { \"keyword\": { \"type\": \"keyword\" } } }, \"status\": { \"type\": \"keyword\" } } }\" \n"
58+
+ "query_fields: [city, status]"
5459
+ "Output: \"{ \"query\": { \"bool\": { \"must\": [ { \"match\": { \"city\": \"London\" } } ], \"filter\": [ { \"term\": { \"status\": \"active\" } } ] } } }\" \n";
5560

5661
public static final String EXAMPLE_3 =
@@ -117,7 +122,8 @@ public class QueryPlanningPromptTemplate {
117122

118123
public static final String PROMPT_SUFFIX = "GIVE THE OUTPUT PART ONLY IN YOUR RESPONSE \n"
119124
+ "Question: asked by user \n"
120-
+ "Mapping:${parameters.index_mapping:-} \n"
125+
+ "Mapping :${parameters.index_mapping:-} \n"
126+
+ "Query Fields: ${parameters.query_fields:-} "
121127
+ "Output:";
122128

123129
public static final String DEFAULT_SYSTEM_PROMPT = PROMPT_PREFIX

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.ml.engine.tools;
77

8+
import static org.opensearch.ml.common.utils.StringUtils.gson;
89
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY;
910
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_SYSTEM_PROMPT;
1011

@@ -65,6 +66,12 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
6566
if (!parameters.containsKey(SYSTEM_PROMPT_FIELD)) {
6667
parameters.put(SYSTEM_PROMPT_FIELD, DEFAULT_SYSTEM_PROMPT);
6768
}
69+
if (parameters.containsKey("index_mapping")) {
70+
parameters.put("index_mapping", gson.toJson(parameters.get("index_mapping")));
71+
}
72+
if (parameters.containsKey("query_fields")) {
73+
parameters.put("query_fields", gson.toJson(parameters.get("query_fields")));
74+
}
6875
ActionListener<T> modelListener = ActionListener.wrap(r -> {
6976
try {
7077
String queryString = (String) r;

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
import static org.mockito.ArgumentMatchers.any;
1515
import static org.mockito.Mockito.doAnswer;
1616
import static org.mockito.Mockito.mock;
17+
import static org.mockito.Mockito.verify;
1718
import static org.opensearch.ml.engine.tools.QueryPlanningTool.DEFAULT_DESCRIPTION;
1819
import static org.opensearch.ml.engine.tools.QueryPlanningTool.MODEL_ID_FIELD;
20+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.SYSTEM_PROMPT_FIELD;
1921

2022
import java.util.Collections;
2123
import java.util.HashMap;
@@ -124,9 +126,8 @@ public void testRun_PredictionReturnsNull_ReturnDefaultQuery() throws ExecutionE
124126
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
125127
validParams.put("query_text", "help me find some books related to wind");
126128
tool.run(validParams, listener);
127-
String multiMatchQueryString =
128-
"{ \"query\": { \"multi_match\" : { \"query\": \"help me find some books related to wind\", \"fields\": [\"*\"] } } }";
129-
assertEquals(multiMatchQueryString, future.get());
129+
String defaultQueryString = "{\"size\":10,\"query\":{\"match_all\":{}}}";
130+
assertEquals(defaultQueryString, future.get());
130131
}
131132

132133
@Test
@@ -142,9 +143,8 @@ public void testRun_PredictionReturnsEmpty_ReturnDefaultQuery() throws Execution
142143
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
143144
validParams.put("query_text", "help me find some books related to wind");
144145
tool.run(validParams, listener);
145-
String multiMatchQueryString =
146-
"{ \"query\": { \"multi_match\" : { \"query\": \"help me find some books related to wind\", \"fields\": [\"*\"] } } }";
147-
assertEquals(multiMatchQueryString, future.get());
146+
String defaultQueryString = "{\"size\":10,\"query\":{\"match_all\":{}}}";
147+
assertEquals(defaultQueryString, future.get());
148148
}
149149

150150
@Test
@@ -160,9 +160,8 @@ public void testRun_PredictionReturnsNullString_ReturnDefaultQuery() throws Exec
160160
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
161161
validParams.put("query_text", "help me find some books related to wind");
162162
tool.run(validParams, listener);
163-
String multiMatchQueryString =
164-
"{ \"query\": { \"multi_match\" : { \"query\": \"help me find some books related to wind\", \"fields\": [\"*\"] } } }";
165-
assertEquals(multiMatchQueryString, future.get());
163+
String defaultQueryString = "{\"size\":10,\"query\":{\"match_all\":{}}}";
164+
assertEquals(defaultQueryString, future.get());
166165
}
167166

168167
@Test
@@ -269,4 +268,42 @@ public void testFactoryCreateWithInvalidType() {
269268
Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(map));
270269
assertEquals("Invalid generation type: invalid. The current supported types are llmGenerated.", exception.getMessage());
271270
}
271+
272+
273+
274+
@Test
275+
public void testAllParameterProcessing() {
276+
QueryPlanningTool tool = new QueryPlanningTool("llmGenerated", queryGenerationTool);
277+
Map<String, String> parameters = new HashMap<>();
278+
parameters.put("query_text", "test query");
279+
parameters.put("index_mapping", "{\"properties\":{\"title\":{\"type\":\"text\"}}}");
280+
parameters.put("query_fields", "[\"title\", \"content\"]");
281+
// No system_prompt - should use default
282+
283+
@SuppressWarnings("unchecked")
284+
ActionListener<String> listener = mock(ActionListener.class);
285+
286+
doAnswer(invocation -> {
287+
ActionListener<String> modelListener = invocation.getArgument(1);
288+
modelListener.onResponse("{\"query\":{\"match\":{\"title\":\"test\"}}}");
289+
return null;
290+
}).when(queryGenerationTool).run(any(), any());
291+
292+
tool.run(parameters, listener);
293+
294+
ArgumentCaptor<Map<String, String>> captor = ArgumentCaptor.forClass(Map.class);
295+
verify(queryGenerationTool).run(captor.capture(), any());
296+
297+
Map<String, String> capturedParams = captor.getValue();
298+
299+
// All parameters should be processed
300+
assertTrue(capturedParams.containsKey("query_text"));
301+
assertTrue(capturedParams.containsKey("index_mapping"));
302+
assertTrue(capturedParams.containsKey("query_fields"));
303+
assertTrue(capturedParams.containsKey(SYSTEM_PROMPT_FIELD));
304+
305+
// Processed parameters should be JSON strings
306+
assertTrue(capturedParams.get("index_mapping").startsWith("\""));
307+
assertTrue(capturedParams.get("query_fields").startsWith("\""));
308+
}
272309
}

0 commit comments

Comments
 (0)