Skip to content

Commit 6cd0beb

Browse files
Add Default System Prompt for the query Planner tool (opensearch-project#4046)
* Add Default system prompt Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Update ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningPromptTemplate.java Co-authored-by: Owais Kazi <owaiskazi19@gmail.com> Signed-off-by: Rithin Pullela <rithinp@amazon.com> * Address comments, fix IT Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Make compatible with agentic search Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * spotless Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * add test coverage Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * spotless Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> --------- Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> Signed-off-by: Rithin Pullela <rithinp@amazon.com> Co-authored-by: Owais Kazi <owaiskazi19@gmail.com>
1 parent 7f4252b commit 6cd0beb

File tree

4 files changed

+206
-23
lines changed

4 files changed

+206
-23
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
package org.opensearch.ml.engine.tools;
2+
3+
public class QueryPlanningPromptTemplate {
4+
5+
public static final String DEFAULT_QUERY = "{\"size\":10,\"query\":{\"match_all\":{}}}";
6+
7+
public static final String QUERY_TYPE_RULES = "\nChoose query types based on user intent and fields: \n"
8+
+ "match: single-token full‑text searches on analyzed text fields, \n"
9+
+ "match_phrase: multi-token phrases on analyzed text fields (search string contains a space, hyphen, comma, etc.), \n"
10+
+ "term / terms:exact match on keyword, numeric, boolean, \n"
11+
+ "range:numeric/date comparisons (gt, lt, gte, lte), \n"
12+
+ "bool with must, should, must_not, filter: AND/OR/NOT logic, \n"
13+
+ "wildcard / prefix on keyword:\"starts with\", \"contains\", \n"
14+
+ "exists:field presence/absence, \n"
15+
+ "nested query / nested agg:Never wrap a field in nested unless the mapping for that exact path (or one of its parents) explicitly says \"type\": \"nested\". \n"
16+
+ "Otherwise use a normal query on the flattened field. \n";
17+
18+
public static final String AGGREGATION_RULES = "Aggregations (when asked for counts, averages, \"top N\", distributions): \n"
19+
+ "terms on field.keyword or numeric for grouping / top N, \n"
20+
+ "Metric aggs (avg, min, max, sum, stats, cardinality) on numeric fields, \n"
21+
+ "date_histogram, histogram, range for distributions, \n"
22+
+ "Always set \"size\": 0 when only aggregations are needed, \n"
23+
+ "Use sub‑aggregations + order for \"top N by metric\", \n"
24+
+ "If grouping by a text field, use its .keyword sub‑field.\n";
25+
26+
public static final String PROMPT_PREFIX =
27+
"You are an OpenSearch DSL expert. Your job is to convert natural‑language questions into strict JSON OpenSearch search query bodies. \n"
28+
+ "Follow every rule: Use only the provided index mapping to decide which fields exist and their types, pay close attention to index mapping. \n"
29+
+ "Do not use fields that not present in mapping. \n"
30+
+ QUERY_TYPE_RULES
31+
+ AGGREGATION_RULES;
32+
33+
public static final String USE_QUERY_FIELDS_INSTRUCTION =
34+
"When Query Fields are provided, prioritize incorporating them into the generated query.";
35+
36+
public static final String OUTPUT_FORMAT_INSTRUCTIONS = "Output format: Output only a valid escaped JSON string or the literal \n"
37+
+ DEFAULT_QUERY
38+
+ " \nReturn exactly one JSON object. "
39+
+ "Output nothing before or after it — no code fences/backticks (`), angle brackets (< >), hash marks (#), asterisks (*), pipes (|), tildes (~), ellipses (… or ...), emojis, typographic quotes (\" \"), non-breaking spaces (U+00A0), zero-width characters (U+200B, U+FEFF), or any other markup/control characters. "
40+
+ "Use valid JSON only (standard double quotes \"; no comments; no trailing commas). "
41+
+ "This applies to formatting only, string values inside the JSON may contain any needed Unicode characters. \n"
42+
+ "Follow the examples below. \n"
43+
+ USE_QUERY_FIELDS_INSTRUCTION
44+
+ "Fallback: If the request cannot be fulfilled with the mapping (missing field, unsupported feature, etc.), \n"
45+
+ "output the literal string: "
46+
+ DEFAULT_QUERY;
47+
48+
// Individual example constants for better maintainability
49+
public static final String EXAMPLE_1 = "Example 1 — numeric range \n"
50+
+ "Input: Show all products that cost more than 50 dollars. \n"
51+
+ "Mapping: { \"properties\": { \"price\": { \"type\": \"float\" }, \"cost\": { \"type\": \"float\" } } }\n"
52+
+ "query_fields: [price]"
53+
+ "Output: \"{ \"query\": { \"range\": { \"price\": { \"gt\": 50 } } } }\" \n";
54+
55+
public static final String EXAMPLE_2 = "Example 2 — text match + exact filter \n"
56+
+ "Input: Find employees in London who are active. \n"
57+
+ "Mapping: \"{ \"properties\": { \"city\": { \"type\": \"text\", \"fields\": { \"keyword\": { \"type\": \"keyword\" } } }, \"status\": { \"type\": \"keyword\" } } }\" \n"
58+
+ "query_fields: [city, status]"
59+
+ "Output: \"{ \"query\": { \"bool\": { \"must\": [ { \"match\": { \"city\": \"London\" } } ], \"filter\": [ { \"term\": { \"status\": \"active\" } } ] } } }\" \n";
60+
61+
public static final String EXAMPLE_3 =
62+
"Example 3 — match_phrase (use when search string contains a space, hyphen, comma, etc. here \"new york city\" has space) \n"
63+
+ "Input: Find employees who are active and located in New York City \n"
64+
+ "Mapping: \"{ \"properties\": { \"city\": { \"type\": \"text\", \"fields\": { \"keyword\": { \"type\": \"keyword\" } } }, \"status\": { \"type\": \"keyword\" } } }\" \n"
65+
+ "Output: \"{ \"query\": { \"bool\": { \"must\": [ { \"match_phrase\": { \"city\": \"New York City\" } } ], \"filter\": [ { \"term\": { \"status\": \"active\" } } ] } } }\" \n";
66+
67+
public static final String EXAMPLE_4 = "Example 4 — bool with SHOULD \n"
68+
+ "Input: Search articles about \"machine learning\" that are research papers or blogs. \n"
69+
+ "Mapping: \"{ \"properties\": { \"content\": { \"type\": \"text\" }, \"type\": { \"type\": \"keyword\" } } }\" \n"
70+
+ "Output: \"{ \"query\": { \"bool\": { \"must\": [ { \"match\": { \"content\": \"machine learning\" } } ], \"should\": [ { \"term\": { \"type\": \"research paper\" } }, { \"term\": { \"type\": \"blog\" } } ], \"minimum_should_match\": 1 } } }\" \n";
71+
72+
public static final String EXAMPLE_5 = "Example 5 — MUST NOT \n"
73+
+ "Input: List customers who have not made a purchase in 2023. \n"
74+
+ "Mapping: \"{ \"properties\": { \"last_purchase_date\": { \"type\": \"date\" } } }\" \n"
75+
+ "Output: \"{ \"query\": { \"bool\": { \"must_not\": [ { \"range\": { \"last_purchase_date\": { \"gte\": \"2023-01-01\", \"lte\": \"2023-12-31\" } } } ] } } }\" \n";
76+
77+
public static final String EXAMPLE_6 = "Example 6 — wildcard \n"
78+
+ "Input: Find files with names starting with \"report_\". \n"
79+
+ "Mapping: \"{ \"properties\": { \"filename\": { \"type\": \"keyword\" } } }\" \n"
80+
+ "Output: \"{ \"query\": { \"wildcard\": { \"filename\": \"report_*\" } } }\" \n";
81+
82+
public static final String EXAMPLE_7 =
83+
"Example 7 — nested query (note the index mapping says \"type\": \"nested\", do not use it for other types) \n"
84+
+ "Input: Find books where an authors first_name is John AND last_name is Doe. \n"
85+
+ "Mapping: \"{ \"properties\": { \"author\": { \"type\": \"nested\", \"properties\": { \"first_name\": { \"type\": \"text\", \"fields\": { \"keyword\": { \"type\": \"keyword\" } } }, \"last_name\": { \"type\": \"text\", \"fields\": { \"keyword\": { \"type\": \"keyword\" } } } } } } }\" \n"
86+
+ "Output: \"{ \"query\": { \"nested\": { \"path\": \"author\", \"query\": { \"bool\": { \"must\": [ { \"term\": { \"author.first_name.keyword\": \"John\" } }, { \"term\": { \"author.last_name.keyword\": \"Doe\" } } ] } } } } }\" \n";
87+
88+
public static final String EXAMPLE_8 = "Example 8 — terms aggregation \n"
89+
+ "Input: Show the number of orders per status. \n"
90+
+ "Mapping: \"{ \"properties\": { \"status\": { \"type\": \"keyword\" } } }\" \n"
91+
+ "Output: \"{ \"size\": 0, \"aggs\": { \"orders_by_status\": { \"terms\": { \"field\": \"status\" } } } }\" \n";
92+
93+
public static final String EXAMPLE_9 = "Example 9 — metric aggregation with filter \n"
94+
+ "Input: What is the average price of electronics products? \n"
95+
+ "Mapping: \"{ \"properties\": { \"category\": { \"type\": \"keyword\" }, \"price\": { \"type\": \"float\" } } }\" \n"
96+
+ "Output: \"{ \"size\": 0, \"query\": { \"term\": { \"category\": \"electronics\" } }, \"aggs\": { \"avg_price\": { \"avg\": { \"field\": \"price\" } } } }\" \n";
97+
98+
public static final String EXAMPLE_10 = "Example 10 — top N by metric \n"
99+
+ "Input: List the top 3 categories by total sales volume. \n"
100+
+ "Mapping: \"{ \"properties\": { \"category\": { \"type\": \"text\", \"fields\": { \"keyword\": { \"type\": \"keyword\" } } }, \"sales\": { \"type\": \"float\" } } }\" \n"
101+
+ "Output: \"{ \"size\": 0, \"aggs\": { \"top_categories\": { \"terms\": { \"field\": \"category.keyword\", \"size\": 3, \"order\": { \"total_sales\": \"desc\" } }, \"aggs\": { \"total_sales\": { \"sum\": { \"field\": \"sales\" } } } } } }\" \n";
102+
103+
public static final String EXAMPLE_11 = "Example 11 — fallback \n"
104+
+ "Input: Find employees who speak Klingon fluently. \n"
105+
+ "Mapping: \"{ \"properties\": { \"name\": { \"type\": \"text\" }, \"role\": { \"type\": \"keyword\" } } }\" \n"
106+
+ "Output: "
107+
+ DEFAULT_QUERY
108+
+ "\n";
109+
110+
public static final String EXAMPLES = "\nEXAMPLES: "
111+
+ EXAMPLE_1
112+
+ EXAMPLE_2
113+
+ EXAMPLE_3
114+
+ EXAMPLE_4
115+
+ EXAMPLE_5
116+
+ EXAMPLE_6
117+
+ EXAMPLE_7
118+
+ EXAMPLE_8
119+
+ EXAMPLE_9
120+
+ EXAMPLE_10
121+
+ EXAMPLE_11;
122+
123+
public static final String PROMPT_SUFFIX = "GIVE THE OUTPUT PART ONLY IN YOUR RESPONSE \n"
124+
+ "Question: asked by user \n"
125+
+ "Mapping :${parameters.index_mapping:-} \n"
126+
+ "Query Fields: ${parameters.query_fields:-} "
127+
+ "Output:";
128+
129+
public static final String DEFAULT_SYSTEM_PROMPT = PROMPT_PREFIX
130+
+ " \n "
131+
+ OUTPUT_FORMAT_INSTRUCTIONS
132+
+ " \n "
133+
+ EXAMPLES
134+
+ " \n "
135+
+ PROMPT_SUFFIX;
136+
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55

66
package org.opensearch.ml.engine.tools;
77

8+
import static org.opensearch.ml.common.utils.StringUtils.gson;
9+
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY;
10+
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_SYSTEM_PROMPT;
11+
812
import java.util.List;
913
import java.util.Map;
1014

@@ -30,7 +34,9 @@ public class QueryPlanningTool implements WithModelTool {
3034
public static final String TYPE = "QueryPlanningTool";
3135
public static final String MODEL_ID_FIELD = "model_id";
3236
private final MLModelTool queryGenerationTool;
33-
public static final String PROMPT_FIELD = "prompt";
37+
public static final String SYSTEM_PROMPT_FIELD = "system_prompt";
38+
public static final String INDEX_MAPPING_FIELD = "index_mapping";
39+
public static final String QUERY_FIELDS_FIELD = "query_fields";
3440
private static final String GENERATION_TYPE_FIELD = "generation_type";
3541
private static final String LLM_GENERATED_TYPE_FIELD = "llmGenerated";
3642
@Getter
@@ -46,10 +52,6 @@ public class QueryPlanningTool implements WithModelTool {
4652
@Getter
4753
@Setter
4854
private String description = DEFAULT_DESCRIPTION;
49-
private String defaultQuery =
50-
"{ \"query\": { \"multi_match\" : { \"query\": \"${parameters.query_text}\", \"fields\": ${parameters.query_fields:-[\"*\"]} } } }";
51-
private String defaultPrompt =
52-
"You are an OpenSearch Query DSL generation assistant; try using the optional provided index mapping ${parameters.index_mapping:-}, specified fields ${parameters.query_fields:-}, and the given sample queries as examples, generate an OpenSearch Query DSL to retrieve the most relevant documents for the user provided natural language question: ${parameters.query_text}, please return the query dsl only in a string format, no other texts.\n";
5355

5456
public QueryPlanningTool(String generationType, MLModelTool queryGenerationTool) {
5557
this.generationType = generationType;
@@ -63,15 +65,21 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
6365
listener.onFailure(new IllegalArgumentException("Empty parameters for QueryPlanningTool: " + parameters));
6466
return;
6567
}
66-
if (!parameters.containsKey(PROMPT_FIELD)) {
67-
parameters.put(PROMPT_FIELD, defaultPrompt);
68+
if (!parameters.containsKey(SYSTEM_PROMPT_FIELD)) {
69+
parameters.put(SYSTEM_PROMPT_FIELD, DEFAULT_SYSTEM_PROMPT);
70+
}
71+
if (parameters.containsKey(INDEX_MAPPING_FIELD)) {
72+
parameters.put(INDEX_MAPPING_FIELD, gson.toJson(parameters.get(INDEX_MAPPING_FIELD)));
73+
}
74+
if (parameters.containsKey(QUERY_FIELDS_FIELD)) {
75+
parameters.put(QUERY_FIELDS_FIELD, gson.toJson(parameters.get(QUERY_FIELDS_FIELD)));
6876
}
6977
ActionListener<T> modelListener = ActionListener.wrap(r -> {
7078
try {
7179
String queryString = (String) r;
7280
if (queryString == null || queryString.isBlank() || queryString.equals("null")) {
7381
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
74-
String defaultQueryString = substitutor.replace(this.defaultQuery);
82+
String defaultQueryString = substitutor.replace(DEFAULT_QUERY);
7583
listener.onResponse((T) defaultQueryString);
7684
} else {
7785
listener.onResponse((T) queryString);

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414
import static org.mockito.ArgumentMatchers.any;
1515
import static org.mockito.Mockito.doAnswer;
1616
import static org.mockito.Mockito.mock;
17+
import static org.mockito.Mockito.verify;
1718
import static org.opensearch.ml.engine.tools.QueryPlanningTool.DEFAULT_DESCRIPTION;
19+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.INDEX_MAPPING_FIELD;
1820
import static org.opensearch.ml.engine.tools.QueryPlanningTool.MODEL_ID_FIELD;
21+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.QUERY_FIELDS_FIELD;
22+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.SYSTEM_PROMPT_FIELD;
1923

2024
import java.util.Collections;
2125
import java.util.HashMap;
@@ -57,7 +61,7 @@ public void setup() {
5761
MLModelTool.Factory.getInstance().init(client);
5862
factory = new QueryPlanningTool.Factory();
5963
validParams = new HashMap<>();
60-
validParams.put("prompt", "test prompt");
64+
validParams.put(SYSTEM_PROMPT_FIELD, "test prompt");
6165
emptyParams = Collections.emptyMap();
6266
}
6367

@@ -83,7 +87,10 @@ public void testRun() throws ExecutionException, InterruptedException {
8387
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
8488
// test try to update the prompt
8589
validParams
86-
.put("prompt", "You are a query generation agent. Generate a dsl query for the following question: ${parameters.query_text}");
90+
.put(
91+
SYSTEM_PROMPT_FIELD,
92+
"You are a query generation agent. Generate a dsl query for the following question: ${parameters.query_text}"
93+
);
8794
validParams.put("query_text", "help me find some books related to wind");
8895
tool.run(validParams, listener);
8996

@@ -124,9 +131,8 @@ public void testRun_PredictionReturnsNull_ReturnDefaultQuery() throws ExecutionE
124131
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
125132
validParams.put("query_text", "help me find some books related to wind");
126133
tool.run(validParams, listener);
127-
String multiMatchQueryString =
128-
"{ \"query\": { \"multi_match\" : { \"query\": \"help me find some books related to wind\", \"fields\": [\"*\"] } } }";
129-
assertEquals(multiMatchQueryString, future.get());
134+
String defaultQueryString = "{\"size\":10,\"query\":{\"match_all\":{}}}";
135+
assertEquals(defaultQueryString, future.get());
130136
}
131137

132138
@Test
@@ -142,9 +148,8 @@ public void testRun_PredictionReturnsEmpty_ReturnDefaultQuery() throws Execution
142148
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
143149
validParams.put("query_text", "help me find some books related to wind");
144150
tool.run(validParams, listener);
145-
String multiMatchQueryString =
146-
"{ \"query\": { \"multi_match\" : { \"query\": \"help me find some books related to wind\", \"fields\": [\"*\"] } } }";
147-
assertEquals(multiMatchQueryString, future.get());
151+
String defaultQueryString = "{\"size\":10,\"query\":{\"match_all\":{}}}";
152+
assertEquals(defaultQueryString, future.get());
148153
}
149154

150155
@Test
@@ -160,9 +165,8 @@ public void testRun_PredictionReturnsNullString_ReturnDefaultQuery() throws Exec
160165
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
161166
validParams.put("query_text", "help me find some books related to wind");
162167
tool.run(validParams, listener);
163-
String multiMatchQueryString =
164-
"{ \"query\": { \"multi_match\" : { \"query\": \"help me find some books related to wind\", \"fields\": [\"*\"] } } }";
165-
assertEquals(multiMatchQueryString, future.get());
168+
String defaultQueryString = "{\"size\":10,\"query\":{\"match_all\":{}}}";
169+
assertEquals(defaultQueryString, future.get());
166170
}
167171

168172
@Test
@@ -204,7 +208,7 @@ public void testRunWithNoPrompt() {
204208
ArgumentCaptor<Map<String, String>> captor = ArgumentCaptor.forClass(Map.class);
205209
doAnswer(invocation -> {
206210
Map<String, String> params = invocation.getArgument(0);
207-
assertNotNull(params.get("prompt"));
211+
assertNotNull(params.get(SYSTEM_PROMPT_FIELD));
208212
return null;
209213
}).when(queryGenerationTool).run(captor.capture(), any());
210214
}
@@ -269,4 +273,40 @@ public void testFactoryCreateWithInvalidType() {
269273
Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(map));
270274
assertEquals("Invalid generation type: invalid. The current supported types are llmGenerated.", exception.getMessage());
271275
}
276+
277+
@Test
278+
public void testAllParameterProcessing() {
279+
QueryPlanningTool tool = new QueryPlanningTool("llmGenerated", queryGenerationTool);
280+
Map<String, String> parameters = new HashMap<>();
281+
parameters.put("query_text", "test query");
282+
parameters.put(INDEX_MAPPING_FIELD, "{\"properties\":{\"title\":{\"type\":\"text\"}}}");
283+
parameters.put(QUERY_FIELDS_FIELD, "[\"title\", \"content\"]");
284+
// No system_prompt - should use default
285+
286+
@SuppressWarnings("unchecked")
287+
ActionListener<String> listener = mock(ActionListener.class);
288+
289+
doAnswer(invocation -> {
290+
ActionListener<String> modelListener = invocation.getArgument(1);
291+
modelListener.onResponse("{\"query\":{\"match\":{\"title\":\"test\"}}}");
292+
return null;
293+
}).when(queryGenerationTool).run(any(), any());
294+
295+
tool.run(parameters, listener);
296+
297+
ArgumentCaptor<Map<String, String>> captor = ArgumentCaptor.forClass(Map.class);
298+
verify(queryGenerationTool).run(captor.capture(), any());
299+
300+
Map<String, String> capturedParams = captor.getValue();
301+
302+
// All parameters should be processed
303+
assertTrue(capturedParams.containsKey("query_text"));
304+
assertTrue(capturedParams.containsKey(INDEX_MAPPING_FIELD));
305+
assertTrue(capturedParams.containsKey(QUERY_FIELDS_FIELD));
306+
assertTrue(capturedParams.containsKey(SYSTEM_PROMPT_FIELD));
307+
308+
// Processed parameters should be JSON strings
309+
assertTrue(capturedParams.get(INDEX_MAPPING_FIELD).startsWith("\""));
310+
assertTrue(capturedParams.get(QUERY_FIELDS_FIELD).startsWith("\""));
311+
}
272312
}

0 commit comments

Comments
 (0)