Skip to content

Commit 266bcfe

Browse files
authored
Adding query planning tool search template validation and integration tests (#4177)
Signed-off-by: Joshua Palis <jpalis@amazon.com>
1 parent f83026c commit 266bcfe

File tree

4 files changed

+268
-12
lines changed

4 files changed

+268
-12
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningPromptTemplate.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,14 @@ public class QueryPlanningPromptTemplate {
201201
+ "- If no perfect match exists, pick the closest by the criteria above. Never output “none” or invent an id.";
202202

203203
public static final String TEMPLATE_SELECTION_INPUTS = "question: ${parameters.query_text}\n"
204-
+ "templates: ${parameters.search_templates}";
204+
+ "search_templates: ${parameters.search_templates}";
205205

206206
public static final String TEMPLATE_SELECTION_EXAMPLES = "Example A: \n"
207207
+ "question: 'what shoes are highly rated'\n"
208-
+ "templates:\n"
208+
+ "search_templates :\n"
209209
+ "[\n"
210-
+ "{'id':'product-search-template','description':'Searches products in an e-commerce store.'},\n"
211-
+ "{'id':'sales-value-analysis-template','description':'Aggregates sales value for top-selling products.'}\n"
210+
+ "{'template_id':'product-search-template','template_description':'Searches products in an e-commerce store.'},\n"
211+
+ "{'template_id':'sales-value-analysis-template','template_description':'Aggregates sales value for top-selling products.'}\n"
212212
+ "]\n"
213213
+ "Example output : 'product-search-template'";
214214

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import org.opensearch.ml.common.utils.ToolUtils;
2828
import org.opensearch.transport.client.Client;
2929

30+
import com.google.gson.reflect.TypeToken;
31+
3032
import lombok.Getter;
3133
import lombok.Setter;
3234

@@ -46,13 +48,16 @@ public class QueryPlanningTool implements WithModelTool {
4648
public static final String USER_PROMPT_FIELD = "user_prompt";
4749
public static final String INDEX_MAPPING_FIELD = "index_mapping";
4850
public static final String QUERY_FIELDS_FIELD = "query_fields";
49-
private static final String GENERATION_TYPE_FIELD = "generation_type";
51+
public static final String GENERATION_TYPE_FIELD = "generation_type";
5052
private static final String LLM_GENERATED_TYPE_FIELD = "llmGenerated";
51-
private static final String USER_SEARCH_TEMPLATES_TYPE_FIELD = "user_templates";
52-
private static final String SEARCH_TEMPLATES_FIELD = "search_templates";
53+
public static final String USER_SEARCH_TEMPLATES_TYPE_FIELD = "user_templates";
54+
public static final String SEARCH_TEMPLATES_FIELD = "search_templates";
5355
public static final String TEMPLATE_FIELD = "template";
56+
private static final String TEMPLATE_ID_FIELD = "template_id";
57+
private static final String TEMPLATE_DESCRIPTION_FIELD = "template_description";
5458
private static final String DEFAULT_SYSTEM_PROMPT =
5559
"You are an OpenSearch Query DSL generation assistant, translating natural language questions to OpenSeach DSL Queries";
60+
5661
@Getter
5762
private final String generationType;
5863
@Getter
@@ -102,17 +107,19 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
102107
templateSelectionParameters.put(SEARCH_TEMPLATES_FIELD, searchTemplates);
103108

104109
ActionListener<T> templateSelectionListener = ActionListener.wrap(r -> {
110+
// Default search template if LLM does not choose or if returned search template is null
111+
parameters.put(TEMPLATE_FIELD, DEFAULT_SEARCH_TEMPLATE);
105112
try {
106113
String templateId = (String) r;
107114
if (templateId == null || templateId.isBlank() || templateId.equals("null")) {
108-
// Default search template if LLM does not choose
109-
parameters.put(TEMPLATE_FIELD, DEFAULT_SEARCH_TEMPLATE);
110115
executeQueryPlanning(parameters, listener);
111116
} else {
112117
// Retrieve search template by ID
113118
GetStoredScriptRequest getStoredScriptRequest = new GetStoredScriptRequest(templateId);
114119
client.admin().cluster().getStoredScript(getStoredScriptRequest, ActionListener.wrap(getStoredScriptResponse -> {
115-
parameters.put(TEMPLATE_FIELD, gson.toJson(getStoredScriptResponse.getSource().getSource()));
120+
if (getStoredScriptResponse.getSource() != null) {
121+
parameters.put(TEMPLATE_FIELD, gson.toJson(getStoredScriptResponse.getSource().getSource()));
122+
}
116123
executeQueryPlanning(parameters, listener);
117124
}, e -> { listener.onFailure(e); }));
118125
}
@@ -233,14 +240,38 @@ public QueryPlanningTool create(Map<String, Object> map) {
233240
throw new IllegalArgumentException("search_templates field is required when generation_type is 'user_templates'");
234241
} else {
235242
// array is parsed as a json string
236-
searchTemplates = gson.toJson((String) map.get(SEARCH_TEMPLATES_FIELD));
237-
243+
String searchTemplatesJson = (String) map.get(SEARCH_TEMPLATES_FIELD);
244+
validateSearchTemplates(searchTemplatesJson);
245+
searchTemplates = gson.toJson(searchTemplatesJson);
238246
}
239247
}
240248

241249
return new QueryPlanningTool(type, queryGenerationTool, client, searchTemplates);
242250
}
243251

252+
private void validateSearchTemplates(Object searchTemplatesObj) {
253+
List<Map<String, String>> templates = gson.fromJson(searchTemplatesObj.toString(), new TypeToken<List<Map<String, String>>>() {
254+
}.getType());
255+
256+
for (Map<String, String> template : templates) {
257+
validateTemplateFields(template);
258+
}
259+
}
260+
261+
private void validateTemplateFields(Map<String, String> template) {
262+
// Validate templateId
263+
String templateId = template.get(TEMPLATE_ID_FIELD);
264+
if (templateId == null || templateId.isBlank()) {
265+
throw new IllegalArgumentException("search_templates field entries must have a template_id");
266+
}
267+
268+
// Validate templateDescription
269+
String templateDescription = template.get(TEMPLATE_DESCRIPTION_FIELD);
270+
if (templateDescription == null || templateDescription.isBlank()) {
271+
throw new IllegalArgumentException("search_templates field entries must have a template_description");
272+
}
273+
}
274+
244275
@Override
245276
public String getDefaultDescription() {
246277
return DEFAULT_DESCRIPTION;

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,38 @@ public void testFactoryCreate() {
9595
assertEquals(QueryPlanningTool.TYPE, tool.getName());
9696
}
9797

98+
@Test
99+
public void testCreateWithInvalidSearchTemplatesDescription() throws IllegalArgumentException {
100+
Map<String, Object> params = new HashMap<>();
101+
params.put("generation_type", "user_templates");
102+
params.put(MODEL_ID_FIELD, "test_model_id");
103+
params
104+
.put(
105+
SYSTEM_PROMPT_FIELD,
106+
"You are a query generation agent. Generate a dsl query for the following question: ${parameters.query_text}"
107+
);
108+
params.put("query_text", "help me find some books related to wind");
109+
params.put("search_templates", "[{'template_id': 'template_id', 'template_des': 'test_description'}]");
110+
Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params));
111+
assertEquals("search_templates field entries must have a template_description", exception.getMessage());
112+
}
113+
114+
@Test
115+
public void testCreateWithInvalidSearchTemplatesID() throws IllegalArgumentException {
116+
Map<String, Object> params = new HashMap<>();
117+
params.put("generation_type", "user_templates");
118+
params.put(MODEL_ID_FIELD, "test_model_id");
119+
params
120+
.put(
121+
SYSTEM_PROMPT_FIELD,
122+
"You are a query generation agent. Generate a dsl query for the following question: ${parameters.query_text}"
123+
);
124+
params.put("query_text", "help me find some books related to wind");
125+
params.put("search_templates", "[{'templateid': 'template_id', 'template_description': 'test_description'}]");
126+
Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params));
127+
assertEquals("search_templates field entries must have a template_id", exception.getMessage());
128+
}
129+
98130
@Test
99131
public void testRun() throws ExecutionException, InterruptedException {
100132
String matchQueryString = "{\"query\":{\"match\":{\"title\":\"wind\"}}}";
@@ -552,4 +584,101 @@ public void testFactoryCreateWhenAgenticSearchDisabled() {
552584
Exception exception = assertThrows(OpenSearchException.class, () -> factory.create(map));
553585
assertEquals(ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE, exception.getMessage());
554586
}
587+
588+
@Test
589+
public void testCreateWithValidSearchTemplates() {
590+
Map<String, Object> params = new HashMap<>();
591+
params.put("generation_type", "user_templates");
592+
params.put(MODEL_ID_FIELD, "test_model_id");
593+
params
594+
.put(
595+
"search_templates",
596+
"[{'template_id': 'template1', 'template_description': 'description1'}, {'template_id': 'template2', 'template_description': 'description2'}]"
597+
);
598+
599+
QueryPlanningTool tool = factory.create(params);
600+
assertNotNull(tool);
601+
assertEquals("user_templates", tool.getGenerationType());
602+
}
603+
604+
@Test
605+
public void testCreateWithEmptySearchTemplatesList() {
606+
Map<String, Object> params = new HashMap<>();
607+
params.put("generation_type", "user_templates");
608+
params.put(MODEL_ID_FIELD, "test_model_id");
609+
params.put("search_templates", "[]");
610+
611+
QueryPlanningTool tool = factory.create(params);
612+
assertNotNull(tool);
613+
assertEquals("user_templates", tool.getGenerationType());
614+
}
615+
616+
@Test
617+
public void testCreateWithMissingSearchTemplatesField() {
618+
Map<String, Object> params = new HashMap<>();
619+
params.put("generation_type", "user_templates");
620+
params.put(MODEL_ID_FIELD, "test_model_id");
621+
622+
Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params));
623+
assertEquals("search_templates field is required when generation_type is 'user_templates'", exception.getMessage());
624+
}
625+
626+
@Test
627+
public void testCreateWithInvalidSearchTemplatesJson() {
628+
Map<String, Object> params = new HashMap<>();
629+
params.put("generation_type", "user_templates");
630+
params.put(MODEL_ID_FIELD, "test_model_id");
631+
params.put("search_templates", "invalid_json");
632+
633+
assertThrows(com.google.gson.JsonSyntaxException.class, () -> factory.create(params));
634+
}
635+
636+
@Test
637+
public void testCreateWithNullTemplateId() {
638+
Map<String, Object> params = new HashMap<>();
639+
params.put("generation_type", "user_templates");
640+
params.put(MODEL_ID_FIELD, "test_model_id");
641+
params.put("search_templates", "[{'template_id': null, 'template_description': 'description'}]");
642+
643+
Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params));
644+
assertEquals("search_templates field entries must have a template_id", exception.getMessage());
645+
}
646+
647+
@Test
648+
public void testCreateWithBlankTemplateDescription() {
649+
Map<String, Object> params = new HashMap<>();
650+
params.put("generation_type", "user_templates");
651+
params.put(MODEL_ID_FIELD, "test_model_id");
652+
params.put("search_templates", "[{'template_id': 'template1', 'template_description': ' '}]");
653+
654+
Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params));
655+
assertEquals("search_templates field entries must have a template_description", exception.getMessage());
656+
}
657+
658+
@Test
659+
public void testCreateWithMixedValidAndInvalidTemplates() {
660+
Map<String, Object> params = new HashMap<>();
661+
params.put("generation_type", "user_templates");
662+
params.put(MODEL_ID_FIELD, "test_model_id");
663+
params
664+
.put(
665+
"search_templates",
666+
"[{'template_id': 'template1', 'template_description': 'description1'}, {'template_description': 'description2'}]"
667+
);
668+
669+
Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params));
670+
assertEquals("search_templates field entries must have a template_id", exception.getMessage());
671+
}
672+
673+
@Test
674+
public void testCreateWithExtraFieldsInSearchTemplates() {
675+
Map<String, Object> params = new HashMap<>();
676+
params.put("generation_type", "user_templates");
677+
params.put(MODEL_ID_FIELD, "test_model_id");
678+
params.put("search_templates", "[{'template_id': 'template1', 'template_description': 'description1', 'extra_field': 'value'}]");
679+
680+
QueryPlanningTool tool = factory.create(params);
681+
assertNotNull(tool);
682+
assertEquals("user_templates", tool.getGenerationType());
683+
}
555684
}

plugin/src/test/java/org/opensearch/ml/rest/RestQueryPlanningToolIT.java

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
package org.opensearch.ml.rest;
77

88
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED;
9+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.GENERATION_TYPE_FIELD;
910
import static org.opensearch.ml.engine.tools.QueryPlanningTool.MODEL_ID_FIELD;
11+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.SEARCH_TEMPLATES_FIELD;
12+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.USER_SEARCH_TEMPLATES_TYPE_FIELD;
1013

1114
import java.io.IOException;
1215
import java.util.List;
@@ -95,6 +98,49 @@ public void testAgentWithQueryPlanningTool_DefaultPrompt() throws IOException {
9598
deleteAgent(agentId);
9699
}
97100

101+
@Test
102+
public void testAgentWithQueryPlanningTool_SearchTemplates() throws IOException {
103+
if (OPENAI_KEY == null) {
104+
return;
105+
}
106+
107+
// Create Search Templates
108+
String templateBody = "{\"script\":{\"lang\":\"mustache\",\"source\":{\"query\":{\"match\":{\"type\":\"{{type}}\"}}}}}";
109+
Response response = createSearchTemplate("type_search_template", templateBody);
110+
templateBody = "{\"script\":{\"lang\":\"mustache\",\"source\":{\"query\":{\"term\":{\"type\":\"{{type}}\"}}}}}";
111+
response = createSearchTemplate("type_search_template_2", templateBody);
112+
113+
// Register agent with search template IDs
114+
String agentName = "Test_AgentWithQueryPlanningTool_SearchTemplates";
115+
String searchTemplates = "[{"
116+
+ "\"template_id\":\"type_search_template\","
117+
+ "\"template_description\":\"this templates searches for flowers that match the given type this uses a match query\""
118+
+ "},{"
119+
+ "\"template_id\":\"type_search_template_2\","
120+
+ "\"template_description\":\"this templates searches for flowers that match the given type this uses a term query\""
121+
+ "},{"
122+
+ "\"template_id\":\"brand_search_template\","
123+
+ "\"template_description\":\"this templates searches for products that match the given brand\""
124+
+ "}]";
125+
String agentId = registerQueryPlanningAgentWithSearchTemplates(agentName, queryPlanningModelId, searchTemplates);
126+
assertNotNull(agentId);
127+
128+
String query = "{\"parameters\": {\"query_text\": \"List 5 iris flowers of type setosa\"}}";
129+
Response agentResponse = executeAgent(agentId, query);
130+
String responseBody = TestHelper.httpEntityToString(agentResponse.getEntity());
131+
132+
Map<String, Object> responseMap = gson.fromJson(responseBody, Map.class);
133+
134+
List<Map<String, Object>> inferenceResults = (List<Map<String, Object>>) responseMap.get("inference_results");
135+
Map<String, Object> firstResult = inferenceResults.get(0);
136+
List<Map<String, Object>> outputArray = (List<Map<String, Object>>) firstResult.get("output");
137+
Map<String, Object> output = (Map<String, Object>) outputArray.get(0);
138+
String result = output.get("result").toString();
139+
140+
assertTrue(result.contains("query"));
141+
deleteAgent(agentId);
142+
}
143+
98144
private String registerAgentWithQueryPlanningTool(String agentName, String modelId) throws IOException {
99145
MLToolSpec listIndexTool = MLToolSpec
100146
.builder()
@@ -125,6 +171,44 @@ private String registerAgentWithQueryPlanningTool(String agentName, String model
125171
return registerAgent(agentName, agent);
126172
}
127173

174+
private String registerQueryPlanningAgentWithSearchTemplates(String agentName, String modelId, String searchTemplates)
175+
throws IOException {
176+
MLToolSpec listIndexTool = MLToolSpec
177+
.builder()
178+
.type("ListIndexTool")
179+
.name("MyListIndexTool")
180+
.description("A tool for list indices")
181+
.parameters(Map.of("index", IRIS_INDEX, "question", "what fields are in the index?"))
182+
.includeOutputInAgentResponse(true)
183+
.build();
184+
185+
MLToolSpec queryPlanningTool = MLToolSpec
186+
.builder()
187+
.type("QueryPlanningTool")
188+
.name("MyQueryPlanningTool")
189+
.description("A tool for planning queries")
190+
.parameters(
191+
Map
192+
.ofEntries(
193+
Map.entry(MODEL_ID_FIELD, modelId),
194+
Map.entry(GENERATION_TYPE_FIELD, USER_SEARCH_TEMPLATES_TYPE_FIELD),
195+
Map.entry(SEARCH_TEMPLATES_FIELD, searchTemplates)
196+
)
197+
)
198+
.includeOutputInAgentResponse(true)
199+
.build();
200+
201+
MLAgent agent = MLAgent
202+
.builder()
203+
.name(agentName)
204+
.type("flow")
205+
.description("Test agent with QueryPlanningTool")
206+
.tools(List.of(listIndexTool, queryPlanningTool))
207+
.build();
208+
209+
return registerAgent(agentName, agent);
210+
}
211+
128212
private String registerQueryPlanningModel() throws IOException, InterruptedException {
129213
String openaiModelName = "openai gpt-4o model " + randomAlphaOfLength(5);
130214
return registerRemoteModel(openaiConnectorEntity, openaiModelName, true);
@@ -177,6 +261,18 @@ private Response executeAgent(String agentId, String query) throws IOException {
177261
);
178262
}
179263

264+
private Response createSearchTemplate(String templateName, String templateBody) throws IOException {
265+
return TestHelper
266+
.makeRequest(
267+
client(),
268+
"PUT",
269+
"/_scripts/" + templateName,
270+
null,
271+
new StringEntity(templateBody),
272+
List.of(new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json"))
273+
);
274+
}
275+
180276
private void deleteAgent(String agentId) throws IOException {
181277
TestHelper.makeRequest(client(), "DELETE", "/_plugins/_ml/agents/" + agentId, null, "", List.of());
182278
}

0 commit comments

Comments
 (0)