Skip to content

Commit 326e7ba

Browse files
committed
Adding query planning tool search template validation and integration tests
Signed-off-by: Joshua Palis <jpalis@amazon.com>
1 parent 5cabf63 commit 326e7ba

File tree

8 files changed

+214
-20
lines changed

8 files changed

+214
-20
lines changed

common/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ dependencies {
2323
testImplementation "org.opensearch.test:framework:${opensearch_version}"
2424

2525
compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
26-
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.11.0'
26+
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.13.2'
2727
compileOnly group: 'org.json', name: 'json', version: '20231013'
2828
testImplementation group: 'org.json', name: 'json', version: '20231013'
2929
implementation('com.google.guava:guava:32.1.3-jre') {

memory/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ dependencies {
3737
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.15.2'
3838
testImplementation "org.opensearch.test:framework:${opensearch_version}"
3939
testImplementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
40-
testImplementation group: 'com.google.code.gson', name: 'gson', version: '2.11.0'
40+
testImplementation group: 'com.google.code.gson', name: 'gson', version: '2.13.2'
4141
testImplementation group: 'org.json', name: 'json', version: '20231013'
4242
testImplementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}")
4343
testImplementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}")

ml-algorithms/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ dependencies {
4343
implementation (group: 'com.google.guava', name: 'guava', version: '32.1.3-jre') {
4444
exclude group: 'com.google.errorprone', module: 'error_prone_annotations'
4545
}
46-
implementation group: 'com.google.code.gson', name: 'gson', version: '2.11.0'
46+
implementation group: 'com.google.code.gson', name: 'gson', version: '2.13.2'
4747
implementation platform("ai.djl:bom:0.31.1")
4848
implementation group: 'ai.djl.pytorch', name: 'pytorch-model-zoo'
4949
implementation group: 'ai.djl', name: 'api'

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import java.util.Map;
1919

2020
import org.apache.commons.text.StringSubstitutor;
21+
import org.apache.logging.log4j.LogManager;
22+
import org.apache.logging.log4j.Logger;
2123
import org.opensearch.OpenSearchException;
2224
import org.opensearch.action.admin.cluster.storedscripts.GetStoredScriptRequest;
2325
import org.opensearch.core.action.ActionListener;
@@ -27,6 +29,8 @@
2729
import org.opensearch.ml.common.utils.ToolUtils;
2830
import org.opensearch.transport.client.Client;
2931

32+
import com.google.gson.reflect.TypeToken;
33+
3034
import lombok.Getter;
3135
import lombok.Setter;
3236

@@ -46,13 +50,18 @@ public class QueryPlanningTool implements WithModelTool {
4650
public static final String USER_PROMPT_FIELD = "user_prompt";
4751
public static final String INDEX_MAPPING_FIELD = "index_mapping";
4852
public static final String QUERY_FIELDS_FIELD = "query_fields";
49-
private static final String GENERATION_TYPE_FIELD = "generation_type";
53+
public static final String GENERATION_TYPE_FIELD = "generation_type";
5054
private static final String LLM_GENERATED_TYPE_FIELD = "llmGenerated";
51-
private static final String USER_SEARCH_TEMPLATES_TYPE_FIELD = "user_templates";
52-
private static final String SEARCH_TEMPLATES_FIELD = "search_templates";
55+
public static final String USER_SEARCH_TEMPLATES_TYPE_FIELD = "user_templates";
56+
public static final String SEARCH_TEMPLATES_FIELD = "search_templates";
5357
public static final String TEMPLATE_FIELD = "template";
58+
private static final String TEMPLATE_ID_FIELD = "template_id";
59+
private static final String TEMPLATE_DESCRIPTION_FIELD = "template_description";
5460
private static final String DEFAULT_SYSTEM_PROMPT =
5561
"You are an OpenSearch Query DSL generation assistant, translating natural language questions to OpenSeach DSL Queries";
62+
63+
private static final Logger logger = LogManager.getLogger(QueryPlanningTool.class);
64+
5665
@Getter
5766
private final String generationType;
5867
@Getter
@@ -112,7 +121,12 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
112121
// Retrieve search template by ID
113122
GetStoredScriptRequest getStoredScriptRequest = new GetStoredScriptRequest(templateId);
114123
client.admin().cluster().getStoredScript(getStoredScriptRequest, ActionListener.wrap(getStoredScriptResponse -> {
115-
parameters.put(TEMPLATE_FIELD, gson.toJson(getStoredScriptResponse.getSource().getSource()));
124+
if (getStoredScriptResponse.getSource() == null) {
125+
// Edge case where stored scripts arent synced, default search template should be used
126+
parameters.put(TEMPLATE_FIELD, DEFAULT_SEARCH_TEMPLATE);
127+
} else {
128+
parameters.put(TEMPLATE_FIELD, gson.toJson(getStoredScriptResponse.getSource().getSource()));
129+
}
116130
executeQueryPlanning(parameters, listener);
117131
}, e -> { listener.onFailure(e); }));
118132
}
@@ -233,14 +247,38 @@ public QueryPlanningTool create(Map<String, Object> map) {
233247
throw new IllegalArgumentException("search_templates field is required when generation_type is 'user_templates'");
234248
} else {
235249
// array is parsed as a json string
236-
searchTemplates = gson.toJson((String) map.get(SEARCH_TEMPLATES_FIELD));
237-
250+
String searchTemplatesJson = (String) map.get(SEARCH_TEMPLATES_FIELD);
251+
validateSearchTemplates(searchTemplatesJson);
252+
searchTemplates = gson.toJson(searchTemplatesJson);
238253
}
239254
}
240255

241256
return new QueryPlanningTool(type, queryGenerationTool, client, searchTemplates);
242257
}
243258

259+
private void validateSearchTemplates(Object searchTemplatesObj) {
260+
List<Map<String, String>> templates = gson.fromJson(searchTemplatesObj.toString(), new TypeToken<List<Map<String, String>>>() {
261+
}.getType());
262+
263+
for (Map<String, String> template : templates) {
264+
validateTemplateFields(template);
265+
}
266+
}
267+
268+
private void validateTemplateFields(Map<String, String> template) {
269+
// Validate templateId
270+
String templateId = template.get(TEMPLATE_ID_FIELD);
271+
if (templateId == null || templateId.trim().isEmpty()) {
272+
throw new IllegalArgumentException("search_templates field entries must have a template_id");
273+
}
274+
275+
// Validate templateDescription
276+
String templateDescription = template.get(TEMPLATE_DESCRIPTION_FIELD);
277+
if (templateDescription == null || templateDescription.trim().isEmpty()) {
278+
throw new IllegalArgumentException("search_templates field entries must have a template_description");
279+
}
280+
}
281+
244282
@Override
245283
public String getDefaultDescription() {
246284
return DEFAULT_DESCRIPTION;

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,38 @@ public void testFactoryCreate() {
9595
assertEquals(QueryPlanningTool.TYPE, tool.getName());
9696
}
9797

98+
@Test
99+
public void testCreateWithInvalidSearchTemplatesDescription() throws IllegalArgumentException {
100+
Map<String, Object> params = new HashMap<>();
101+
params.put("generation_type", "user_templates");
102+
params.put(MODEL_ID_FIELD, "test_model_id");
103+
params
104+
.put(
105+
SYSTEM_PROMPT_FIELD,
106+
"You are a query generation agent. Generate a dsl query for the following question: ${parameters.query_text}"
107+
);
108+
params.put("query_text", "help me find some books related to wind");
109+
params.put("search_templates", "[{'template_id': 'template_id', 'template_des': 'test_description'}]");
110+
Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params));
111+
assertEquals("search_templates field entries must have a template_description", exception.getMessage());
112+
}
113+
114+
@Test
115+
public void testCreateWithInvalidSearchTemplatesID() throws IllegalArgumentException {
116+
Map<String, Object> params = new HashMap<>();
117+
params.put("generation_type", "user_templates");
118+
params.put(MODEL_ID_FIELD, "test_model_id");
119+
params
120+
.put(
121+
SYSTEM_PROMPT_FIELD,
122+
"You are a query generation agent. Generate a dsl query for the following question: ${parameters.query_text}"
123+
);
124+
params.put("query_text", "help me find some books related to wind");
125+
params.put("search_templates", "[{'templateid': 'template_id', 'template_description': 'test_description'}]");
126+
Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params));
127+
assertEquals("search_templates field entries must have a template_id", exception.getMessage());
128+
}
129+
98130
@Test
99131
public void testRun() throws ExecutionException, InterruptedException {
100132
String matchQueryString = "{\"query\":{\"match\":{\"title\":\"wind\"}}}";

plugin/build.gradle

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ dependencies {
6868

6969
implementation group: 'software.amazon.awssdk', name: 'protocol-core', version: "2.30.18"
7070

71-
zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}"
71+
zipArchive("org.opensearch.plugin:opensearch-job-scheduler:${opensearch_build}")
72+
zipArchive("org.opensearch.plugin:opensearch-knn:${opensearch_build}")
73+
zipArchive("org.opensearch.plugin:neural-search:${opensearch_build}")
7274
compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}"
7375
implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
7476
implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
@@ -82,7 +84,8 @@ dependencies {
8284
implementation (group: 'com.google.guava', name: 'guava', version: '32.1.3-jre') {
8385
exclude group: 'com.google.errorprone', module: 'error_prone_annotations'
8486
}
85-
implementation group: 'com.google.code.gson', name: 'gson', version: '2.11.0'
87+
implementation group: 'com.google.code.gson', name: 'gson', version: '2.13.2'
88+
api "com.google.errorprone:error_prone_annotations:${versions.error_prone_annotations}"
8689
implementation group: 'org.apache.commons', name: 'commons-lang3', version: "${versions.commonslang}"
8790
implementation group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1'
8891
implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
@@ -248,16 +251,38 @@ testClusters.integTest {
248251
}
249252
plugin(project.tasks.bundlePlugin.archiveFile)
250253
plugin(provider(new Callable<RegularFile>(){
251-
@Override
252-
RegularFile call() throws Exception {
253-
return new RegularFile() {
254-
@Override
255-
File getAsFile() {
256-
return configurations.zipArchive.asFileTree.getSingleFile()
254+
@Override
255+
RegularFile call() throws Exception {
256+
return new RegularFile() {
257+
@Override
258+
File getAsFile() {
259+
return configurations.zipArchive.asFileTree.matching{include "**/opensearch-job-scheduler-${opensearch_build}.zip"}.getSingleFile()
260+
}
257261
}
258262
}
259-
}
260-
}))
263+
}))
264+
plugin(provider(new Callable<RegularFile>(){
265+
@Override
266+
RegularFile call() throws Exception {
267+
return new RegularFile() {
268+
@Override
269+
File getAsFile() {
270+
return configurations.zipArchive.asFileTree.matching{include "**/opensearch-knn-${opensearch_build}.zip"}.getSingleFile()
271+
}
272+
}
273+
}
274+
}))
275+
plugin(provider(new Callable<RegularFile>(){
276+
@Override
277+
RegularFile call() throws Exception {
278+
return new RegularFile() {
279+
@Override
280+
File getAsFile() {
281+
return configurations.zipArchive.asFileTree.matching{include "**/neural-search-${opensearch_build}.zip"}.getSingleFile()
282+
}
283+
}
284+
}
285+
}))
261286

262287
nodes.each { node ->
263288
def plugins = node.plugins
@@ -430,6 +455,8 @@ configurations.all {
430455
resolutionStrategy.force "org.bouncycastle:bcprov-jdk18on:1.78.1"
431456
resolutionStrategy.force 'io.projectreactor:reactor-core:3.7.0'
432457
resolutionStrategy.force 'commons-beanutils:commons-beanutils:1.11.0'
458+
resolutionStrategy.force 'com.google.code.gson:gson:2.13.2'
459+
resolutionStrategy.force "com.google.errorprone:error_prone_annotations:${versions.error_prone_annotations}"
433460
}
434461

435462
apply plugin: 'com.netflix.nebula.ospackage'

plugin/src/test/java/org/opensearch/ml/rest/RestQueryPlanningToolIT.java

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
package org.opensearch.ml.rest;
77

88
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED;
9+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.GENERATION_TYPE_FIELD;
910
import static org.opensearch.ml.engine.tools.QueryPlanningTool.MODEL_ID_FIELD;
11+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.SEARCH_TEMPLATES_FIELD;
12+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.USER_SEARCH_TEMPLATES_TYPE_FIELD;
1013

1114
import java.io.IOException;
1215
import java.util.List;
@@ -95,6 +98,50 @@ public void testAgentWithQueryPlanningTool_DefaultPrompt() throws IOException {
9598
deleteAgent(agentId);
9699
}
97100

101+
@Test
102+
public void testAgentWithQueryPlanningTool_SearchTemplates() throws IOException {
103+
if (OPENAI_KEY == null) {
104+
return;
105+
}
106+
107+
// Create Search Templates
108+
String templateBody = "{\"script\":{\"lang\":\"mustache\",\"source\":{\"query\":{\"match\":{\"type\":\"{{type}}\"}}}}}";
109+
Response response = createSearchTemplate("type_search_template", templateBody);
110+
templateBody = "{\"script\":{\"lang\":\"mustache\",\"source\":{\"query\":{\"term\":{\"type\":\"{{type}}\"}}}}}";
111+
response = createSearchTemplate("type_search_template_2", templateBody);
112+
113+
// Register agent with search template IDs
114+
String agentName = "Test_AgentWithQueryPlanningTool_SearchTemplates";
115+
String searchTemplates = "[{"
116+
+ "\"template_id\":\"type_search_template\","
117+
+ "\"template_description\":\"this templates searches for flowers that match the given type this uses a match query\""
118+
+ "},{"
119+
+ "\"template_id\":\"type_search_template_2\","
120+
+ "\"template_description\":\"this templates searches for flowers that match the given type this uses a term query\""
121+
+ "},{"
122+
+ "\"template_id\":\"brand_search_template\","
123+
+ "\"template_description\":\"this templates searches for products that match the given brand\""
124+
+ "}]";
125+
String agentId = registerQueryPlanningAgentWithSearchTemplates(agentName, queryPlanningModelId, searchTemplates);
126+
assertNotNull(agentId);
127+
128+
String query = "{\"parameters\": {\"query_text\": \"List 5 iris flowers of type setosa\"}}";
129+
Response agentResponse = executeAgent(agentId, query);
130+
String responseBody = TestHelper.httpEntityToString(agentResponse.getEntity());
131+
132+
Map<String, Object> responseMap = gson.fromJson(responseBody, Map.class);
133+
134+
List<Map<String, Object>> inferenceResults = (List<Map<String, Object>>) responseMap.get("inference_results");
135+
Map<String, Object> firstResult = inferenceResults.get(0);
136+
List<Map<String, Object>> outputArray = (List<Map<String, Object>>) firstResult.get("output");
137+
Map<String, Object> output = (Map<String, Object>) outputArray.get(0);
138+
String result = output.get("result").toString();
139+
140+
assertTrue(result.contains("query"));
141+
assertTrue(result.contains("term"));
142+
deleteAgent(agentId);
143+
}
144+
98145
private String registerAgentWithQueryPlanningTool(String agentName, String modelId) throws IOException {
99146
MLToolSpec listIndexTool = MLToolSpec
100147
.builder()
@@ -125,6 +172,44 @@ private String registerAgentWithQueryPlanningTool(String agentName, String model
125172
return registerAgent(agentName, agent);
126173
}
127174

175+
private String registerQueryPlanningAgentWithSearchTemplates(String agentName, String modelId, String searchTemplates)
176+
throws IOException {
177+
MLToolSpec listIndexTool = MLToolSpec
178+
.builder()
179+
.type("ListIndexTool")
180+
.name("MyListIndexTool")
181+
.description("A tool for list indices")
182+
.parameters(Map.of("index", IRIS_INDEX, "question", "what fields are in the index?"))
183+
.includeOutputInAgentResponse(true)
184+
.build();
185+
186+
MLToolSpec queryPlanningTool = MLToolSpec
187+
.builder()
188+
.type("QueryPlanningTool")
189+
.name("MyQueryPlanningTool")
190+
.description("A tool for planning queries")
191+
.parameters(
192+
Map
193+
.ofEntries(
194+
Map.entry(MODEL_ID_FIELD, modelId),
195+
Map.entry(GENERATION_TYPE_FIELD, USER_SEARCH_TEMPLATES_TYPE_FIELD),
196+
Map.entry(SEARCH_TEMPLATES_FIELD, searchTemplates)
197+
)
198+
)
199+
.includeOutputInAgentResponse(true)
200+
.build();
201+
202+
MLAgent agent = MLAgent
203+
.builder()
204+
.name(agentName)
205+
.type("flow")
206+
.description("Test agent with QueryPlanningTool")
207+
.tools(List.of(listIndexTool, queryPlanningTool))
208+
.build();
209+
210+
return registerAgent(agentName, agent);
211+
}
212+
128213
private String registerQueryPlanningModel() throws IOException, InterruptedException {
129214
String openaiModelName = "openai gpt-4o model " + randomAlphaOfLength(5);
130215
return registerRemoteModel(openaiConnectorEntity, openaiModelName, true);
@@ -177,6 +262,18 @@ private Response executeAgent(String agentId, String query) throws IOException {
177262
);
178263
}
179264

265+
private Response createSearchTemplate(String templateName, String templateBody) throws IOException {
266+
return TestHelper
267+
.makeRequest(
268+
client(),
269+
"PUT",
270+
"/_scripts/" + templateName,
271+
null,
272+
new StringEntity(templateBody),
273+
List.of(new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json"))
274+
);
275+
}
276+
180277
private void deleteAgent(String agentId) throws IOException {
181278
TestHelper.makeRequest(client(), "DELETE", "/_plugins/_ml/agents/" + agentId, null, "", List.of());
182279
}

search-processors/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ repositories {
2929
dependencies {
3030
implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
3131
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
32-
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.11.0'
32+
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.13.2'
3333
implementation "org.apache.commons:commons-lang3:${versions.commonslang}"
3434
implementation project(':opensearch-ml-memory')
3535
implementation group: 'org.opensearch', name: 'common-utils', version: "${common_utils_version}"

0 commit comments

Comments
 (0)