Skip to content

Commit 5aea15f

Browse files
mingshlrithin-pullela-aws
authored andcommitted
add java-time and default query
Signed-off-by: Mingshi Liu <mingshl@amazon.com>
1 parent 00980ab commit 5aea15f

File tree

4 files changed

+296
-22
lines changed

4 files changed

+296
-22
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import java.util.List;
99
import java.util.Map;
1010

11+
import org.apache.commons.text.StringSubstitutor;
1112
import org.opensearch.core.action.ActionListener;
12-
import org.opensearch.ml.common.output.model.ModelTensors;
1313
import org.opensearch.ml.common.spi.tools.Parser;
1414
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
1515
import org.opensearch.ml.common.spi.tools.WithModelTool;
@@ -47,8 +47,10 @@ public class QueryPlanningTool implements WithModelTool {
4747
@Getter
4848
@Setter
4949
private String description = DEFAULT_DESCRIPTION;
50+
private String defaultQuery =
51+
"{ \"query\": { \"multi_match\" : { \"query\": \"${parameters.query_text}\", \"fields\": ${parameters.query_fields:-[\"*\"]} } } }";
5052
private String defaultPrompt =
51-
"You are an OpenSearch Query DSL generation assistant; using the provided index mapping ${parameters.ListIndexTool.output:-}, specified fields ${parameters.fields:-}, and the given sample queries as examples, generate an OpenSearch Query DSL to retrieve the most relevant documents for the user provided natural language question: ${parameters.query_text}\n";
53+
"You are an OpenSearch Query DSL generation assistant; try using the optional provided index mapping ${parameters.index_mapping:-}, specified fields ${parameters.query_fields:-}, and the given sample queries as examples, generate an OpenSearch Query DSL to retrieve the most relevant documents for the user provided natural language question: ${parameters.query_text}, please return the query dsl only in a string format, no other texts.\n";
5254
@Getter
5355
private Client client;
5456
@Getter
@@ -73,13 +75,26 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
7375
if (!parameters.containsKey(PROMPT_FIELD)) {
7476
parameters.put(PROMPT_FIELD, defaultPrompt);
7577
}
76-
ActionListener<List<ModelTensors>> modelListener = ActionListener.wrap(r -> {
78+
if (!validate(parameters)) {
79+
listener.onFailure(new IllegalArgumentException("Empty parameters for QueryPlanningTool: " + parameters));
80+
return;
81+
}
82+
ActionListener<T> modelListener = ActionListener.wrap(r -> {
7783
try {
78-
@SuppressWarnings("unchecked")
79-
T result = (T) outputParser.parse(r);
80-
listener.onResponse(result);
84+
String queryString = (String) r;
85+
if (queryString == null || queryString.isBlank() || queryString.isEmpty() || queryString.equals("null")) {
86+
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
87+
String defaultQueryString = substitutor.replace(this.defaultQuery);
88+
listener.onResponse((T) defaultQueryString);
89+
} else {
90+
listener.onResponse((T) queryString);
91+
}
8192
} catch (Exception e) {
82-
listener.onFailure(e);
93+
IllegalArgumentException parsingException = new IllegalArgumentException(
94+
"Error processing query string: " + r + ". Try using response_filter in agent registration if needed.",
95+
e
96+
);
97+
listener.onFailure(parsingException);
8398
}
8499
}, listener::onFailure);
85100
queryGenerationTool.run(parameters, modelListener);
@@ -145,7 +160,7 @@ public QueryPlanningTool create(Map<String, Object> map) {
145160
type = LLM_GENERATED_TYPE_FIELD;
146161
}
147162

148-
// TODO to add in , SYSTEM_SEARCH_TEMPLATES_TYPE_FIELD, USER_SEARCH_TEMPLATES_TYPE_FIELD when searchTemplatesTool is
163+
// TODO to add in SYSTEM_SEARCH_TEMPLATES_TYPE_FIELD, USER_SEARCH_TEMPLATES_TYPE_FIELD when searchTemplatesTool is
149164
// implemented.
150165
if (!LLM_GENERATED_TYPE_FIELD.equals(type)) {
151166
throw new IllegalArgumentException("Invalid generation type: " + type + ". The current supported types are llmGenerated.");

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,18 @@
2323
import java.util.concurrent.ExecutionException;
2424

2525
import org.junit.Before;
26+
import org.junit.Rule;
2627
import org.junit.Test;
28+
import org.junit.rules.ExpectedException;
2729
import org.mockito.Mock;
2830
import org.mockito.MockitoAnnotations;
2931
import org.opensearch.core.action.ActionListener;
30-
import org.opensearch.ml.common.output.model.ModelTensor;
31-
import org.opensearch.ml.common.output.model.ModelTensors;
3232
import org.opensearch.ml.common.spi.tools.Tool;
33-
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;
3433
import org.opensearch.transport.client.Client;
3534

35+
/**
36+
* Units test for QueryPlanningTools
37+
*/
3638
public class QueryPlanningToolTests {
3739

3840
@Mock
@@ -65,22 +67,13 @@ public void testFactoryCreate() {
6567
@Test
6668
public void testRun() throws ExecutionException, InterruptedException {
6769
String matchQueryString = "{\"query\":{\"match\":{\"title\":\"wind\"}}}";
68-
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", matchQueryString)).build();
69-
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(List.of(modelTensor)).build();
70-
List<ModelTensors> modelTensorsList = List.of(modelTensors);
71-
7270
doAnswer(invocation -> {
73-
ActionListener<List<ModelTensors>> listener = invocation.getArgument(1);
74-
listener.onResponse(modelTensorsList);
71+
ActionListener<String> listener = invocation.getArgument(1);
72+
listener.onResponse(matchQueryString);
7573
return null;
7674
}).when(queryGenerationTool).run(any(), any());
7775

7876
QueryPlanningTool tool = new QueryPlanningTool(client, "test_model_id", "llmGenerated", queryGenerationTool);
79-
tool.setOutputParser(o -> {
80-
List<ModelTensors> outputs = (List<ModelTensors>) o;
81-
return outputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
82-
});
83-
8477
final CompletableFuture<String> future = new CompletableFuture<>();
8578
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
8679
// test try to update the prompt
@@ -92,6 +85,63 @@ public void testRun() throws ExecutionException, InterruptedException {
9285
assertEquals(matchQueryString, future.get());
9386
}
9487

88+
@Test
89+
public void testRun_PredictionReturnsList_ThrowsIllegalArgumentException() throws ExecutionException, InterruptedException {
90+
thrown.expect(ExecutionException.class);
91+
thrown.expectCause(org.hamcrest.Matchers.isA(IllegalArgumentException.class));
92+
thrown.expectMessage("Error processing query string: [invalid_query]. Try using response_filter in agent registration if needed.");
93+
94+
doAnswer(invocation -> {
95+
ActionListener<List<String>> listener = invocation.getArgument(1);
96+
listener.onResponse(List.of("invalid_query"));
97+
return null;
98+
}).when(queryGenerationTool).run(any(), any());
99+
100+
QueryPlanningTool tool = new QueryPlanningTool(client, "test_model_id", "llmGenerated", queryGenerationTool);
101+
final CompletableFuture<String> future = new CompletableFuture<>();
102+
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
103+
validParams.put("query_text", "help me find some books related to wind");
104+
tool.run(validParams, listener);
105+
106+
future.get();
107+
}
108+
109+
@Test
110+
public void testRun_PredictionReturnsNull_ReturnDefaultQuery() throws ExecutionException, InterruptedException {
111+
doAnswer(invocation -> {
112+
ActionListener<String> listener = invocation.getArgument(1);
113+
listener.onResponse(null);
114+
return null;
115+
}).when(queryGenerationTool).run(any(), any());
116+
117+
QueryPlanningTool tool = new QueryPlanningTool(client, "test_model_id", "llmGenerated", queryGenerationTool);
118+
final CompletableFuture<String> future = new CompletableFuture<>();
119+
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
120+
validParams.put("query_text", "help me find some books related to wind");
121+
tool.run(validParams, listener);
122+
String multiMatchQueryString =
123+
"{ \"query\": { \"multi_match\" : { \"query\": \"help me find some books related to wind\", \"fields\": [\"*\"] } } }";
124+
assertEquals(multiMatchQueryString, future.get());
125+
}
126+
127+
@Test
128+
public void testRun_PredictionReturnsEmpty_ReturnDefaultQuery() throws ExecutionException, InterruptedException {
129+
doAnswer(invocation -> {
130+
ActionListener<String> listener = invocation.getArgument(1);
131+
listener.onResponse("");
132+
return null;
133+
}).when(queryGenerationTool).run(any(), any());
134+
135+
QueryPlanningTool tool = new QueryPlanningTool(client, "test_model_id", "llmGenerated", queryGenerationTool);
136+
final CompletableFuture<String> future = new CompletableFuture<>();
137+
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
138+
validParams.put("query_text", "help me find some books related to wind");
139+
tool.run(validParams, listener);
140+
String multiMatchQueryString =
141+
"{ \"query\": { \"multi_match\" : { \"query\": \"help me find some books related to wind\", \"fields\": [\"*\"] } } }";
142+
assertEquals(multiMatchQueryString, future.get());
143+
}
144+
95145
@Test
96146
public void testValidate() {
97147
Tool tool = QueryPlanningTool.Factory.getInstance().create(Collections.emptyMap());
@@ -114,4 +164,8 @@ public void testFactoryGetAllModelKeys() {
114164
List<String> allModelKeys = QueryPlanningTool.Factory.getInstance().getAllModelKeys();
115165
assertEquals(List.of(MODEL_ID_FIELD), allModelKeys);
116166
}
167+
168+
@Rule
169+
public ExpectedException thrown = ExpectedException.none();
170+
117171
}

plugin/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ integTest {
226226
if (System.getProperty("test.debug") != null) {
227227
jvmArgs '-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=*:5005'
228228
}
229+
jvmArgs '--add-opens', 'java.base/java.time=ALL-UNNAMED'
229230

230231
// Set this to true this if you want to see the logs in the terminal test output.
231232
// note: if left false the log output will still show in your IDE
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.rest;
7+
8+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.MODEL_ID_FIELD;
9+
10+
import java.io.IOException;
11+
import java.util.List;
12+
import java.util.Map;
13+
14+
import org.apache.hc.core5.http.HttpHeaders;
15+
import org.apache.hc.core5.http.io.entity.StringEntity;
16+
import org.apache.hc.core5.http.message.BasicHeader;
17+
import org.junit.After;
18+
import org.junit.Before;
19+
import org.junit.Test;
20+
import org.opensearch.client.Response;
21+
import org.opensearch.ml.common.agent.MLAgent;
22+
import org.opensearch.ml.common.agent.MLToolSpec;
23+
import org.opensearch.ml.utils.TestHelper;
24+
25+
public class RestQueryPlanningToolIT extends MLCommonsRestTestCase {
26+
27+
private static final String IRIS_INDEX = "iris_data";
28+
private String queryPlanningModelId;
29+
private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID");
30+
private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY");
31+
private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN");
32+
private static final String GITHUB_CI_AWS_REGION = "us-west-2";
33+
34+
private final String bedrockClaudeModelConnectorEntity = "{\n"
35+
+ " \"name\": \"Amazon Bedrock Claude 3.7-sonnet connector\",\n"
36+
+ " \"description\": \"connector for base agent with tools\",\n"
37+
+ " \"version\": 1,\n"
38+
+ " \"protocol\": \"aws_sigv4\",\n"
39+
+ " \"parameters\": {\n"
40+
+ " \"region\": \""
41+
+ GITHUB_CI_AWS_REGION
42+
+ "\",\n"
43+
+ " \"service_name\": \"bedrock\",\n"
44+
+ " \"model\": \"us.anthropic.claude-3-7-sonnet-20250219-v1:0\",\n"
45+
+ " \"system_prompt\":\"please help answer the user question. \"\n"
46+
+ " },\n"
47+
+ " \"credential\": {\n"
48+
+ " \"access_key\":\" "
49+
+ AWS_ACCESS_KEY_ID
50+
+ "\",\n"
51+
+ " \"secret_key\": \""
52+
+ AWS_SECRET_ACCESS_KEY
53+
+ "\",\n"
54+
+ " \"session_token\": \""
55+
+ AWS_SESSION_TOKEN
56+
+ "\"\n"
57+
+ " },\n"
58+
+ " \"actions\": [\n"
59+
+ " {\n"
60+
+ " \"action_type\": \"predict\",\n"
61+
+ " \"method\": \"POST\",\n"
62+
+ " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/converse\",\n"
63+
+ " \"headers\": {\n"
64+
+ " \"content-type\": \"application/json\"\n"
65+
+ " },\n"
66+
+ " \"request_body\": \"{ \\\"system\\\": [{\\\"text\\\": \\\"${parameters.system_prompt}\\\"}], \\\"messages\\\": [{\\\"role\\\":\\\"user\\\",\\\"content\\\":[{\\\"text\\\":\\\"${parameters.prompt}\\\"}]}]}\"\n"
67+
+ " }\n"
68+
+ " ]\n"
69+
+ "}";
70+
71+
@Before
72+
public void setup() throws IOException, InterruptedException {
73+
ingestIrisIndexData();
74+
queryPlanningModelId = registerQueryPlanningModel();
75+
}
76+
77+
@After
78+
public void deleteIndices() throws IOException {
79+
deleteIndex(IRIS_INDEX);
80+
}
81+
82+
@Test
83+
public void testAgentWithQueryPlanningTool_DefaultPrompt() throws IOException {
84+
String agentName = "Test_QueryPlanningAgent_DefaultPrompt";
85+
String agentId = registerAgentWithQueryPlanningTool(agentName, queryPlanningModelId);
86+
assertNotNull(agentId);
87+
88+
String query = "{\"parameters\": {\"query_text\": \"How many iris flowers of type setosa are there?\"}}";
89+
Response response = executeAgent(agentId, query);
90+
String responseBody = TestHelper.httpEntityToString(response.getEntity());
91+
92+
Map<String, Object> responseMap = gson.fromJson(responseBody, Map.class);
93+
94+
List<Map<String, Object>> inferenceResults = (List<Map<String, Object>>) responseMap.get("inference_results");
95+
Map<String, Object> firstResult = inferenceResults.get(0);
96+
List<Map<String, Object>> outputArray = (List<Map<String, Object>>) firstResult.get("output");
97+
Map<String, Object> output = (Map<String, Object>) outputArray.get(0);
98+
String result = output.get("result").toString();
99+
100+
assertTrue(result.contains("query"));
101+
deleteAgent(agentId);
102+
}
103+
104+
private String registerAgentWithQueryPlanningTool(String agentName, String modelId) throws IOException {
105+
MLToolSpec listIndexTool = MLToolSpec
106+
.builder()
107+
.type("ListIndexTool")
108+
.name("MyListIndexTool")
109+
.description("A tool for list indices")
110+
.parameters(Map.of("index", IRIS_INDEX, "question", "what fields are in the index?"))
111+
.includeOutputInAgentResponse(true)
112+
.build();
113+
114+
MLToolSpec queryPlanningTool = MLToolSpec
115+
.builder()
116+
.type("QueryPlanningTool")
117+
.name("MyQueryPlanningTool")
118+
.description("A tool for planning queries")
119+
.parameters(Map.of(MODEL_ID_FIELD, modelId))
120+
.includeOutputInAgentResponse(true)
121+
.build();
122+
123+
MLAgent agent = MLAgent
124+
.builder()
125+
.name(agentName)
126+
.type("flow")
127+
.description("Test agent with QueryPlanningTool")
128+
.tools(List.of(listIndexTool, queryPlanningTool))
129+
.build();
130+
131+
return registerAgent(agentName, agent);
132+
}
133+
134+
private String registerQueryPlanningModel() throws IOException, InterruptedException {
135+
String bedrockClaudeModelName = "bedrock claude model " + randomAlphaOfLength(5);
136+
return registerRemoteModel(bedrockClaudeModelConnectorEntity, bedrockClaudeModelName, true);
137+
}
138+
139+
private void ingestIrisIndexData() throws IOException {
140+
String bulkRequestBody = "{\"index\":{\"_index\":\""
141+
+ IRIS_INDEX
142+
+ "\",\"_id\":\"1\"}}\n"
143+
+ "{\"petal_length_in_cm\":1.4,\"petal_width_in_cm\":0.2,\"sepal_length_in_cm\":5.1,\"sepal_width_in_cm\":3.5,\"species\":\"setosa\"}\n"
144+
+ "{\"index\":{\"_index\":\""
145+
+ IRIS_INDEX
146+
+ "\",\"_id\":\"2\"}}\n"
147+
+ "{\"petal_length_in_cm\":4.5,\"petal_width_in_cm\":1.5,\"sepal_length_in_cm\":6.4,\"sepal_width_in_cm\":2.9,\"species\":\"versicolor\"}\n";
148+
TestHelper
149+
.makeRequest(
150+
client(),
151+
"POST",
152+
"/_bulk",
153+
null,
154+
new StringEntity(bulkRequestBody),
155+
List.of(new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json"))
156+
);
157+
TestHelper.makeRequest(client(), "POST", "/" + IRIS_INDEX + "/_refresh", null, "", List.of());
158+
}
159+
160+
private String registerAgent(String agentName, MLAgent agent) throws IOException {
161+
Response response = TestHelper
162+
.makeRequest(
163+
client(),
164+
"POST",
165+
"/_plugins/_ml/agents/_register",
166+
null,
167+
new StringEntity(gson.toJson(agent)),
168+
List.of(new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json"))
169+
);
170+
Map<String, String> responseMap = gson.fromJson(TestHelper.httpEntityToString(response.getEntity()), Map.class);
171+
return responseMap.get("agent_id");
172+
}
173+
174+
private Response executeAgent(String agentId, String query) throws IOException {
175+
return TestHelper
176+
.makeRequest(
177+
client(),
178+
"POST",
179+
"/_plugins/_ml/agents/" + agentId + "/_execute",
180+
null,
181+
new StringEntity(query),
182+
List.of(new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json"))
183+
);
184+
}
185+
186+
private void deleteAgent(String agentId) throws IOException {
187+
TestHelper.makeRequest(client(), "DELETE", "/_plugins/_ml/agents/" + agentId, null, "", List.of());
188+
}
189+
190+
public String registerModel(String modelContent) throws IOException {
191+
Response response = TestHelper
192+
.makeRequest(
193+
client(),
194+
"POST",
195+
"/_plugins/_ml/models/_register",
196+
null,
197+
new StringEntity(modelContent),
198+
List.of(new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json"))
199+
);
200+
Map<String, String> responseMap = gson.fromJson(TestHelper.httpEntityToString(response.getEntity()), Map.class);
201+
return responseMap.get("task_id");
202+
}
203+
204+
}

0 commit comments

Comments
 (0)