Skip to content

Commit e5adbc7

Browse files
mingshlylwu-amzn
authored andcommitted
add code coverage and address comments
Signed-off-by: Mingshi Liu <mingshl@amazon.com>
1 parent 301ee14 commit e5adbc7

File tree

3 files changed

+122
-55
lines changed

3 files changed

+122
-55
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,21 @@
1010

1111
import org.apache.commons.text.StringSubstitutor;
1212
import org.opensearch.core.action.ActionListener;
13-
import org.opensearch.ml.common.spi.tools.Parser;
1413
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
1514
import org.opensearch.ml.common.spi.tools.WithModelTool;
1615
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
1716
import org.opensearch.transport.client.Client;
1817

1918
import lombok.Getter;
2019
import lombok.Setter;
21-
import lombok.extern.log4j.Log4j2;
2220

2321
/**
2422
* This tool supports different types of query planning,
2523
* llmGenerated, systemSearchTemplates or userSearchTemplates.
2624
* //TODO only support llmGenerated for now.
2725
* //TODO to add in systemSearchTemplates or userSearchTemplates when searchTemplatesTool is implemented.
2826
*/
29-
@Log4j2
27+
3028
@ToolAnnotation(QueryPlanningTool.TYPE)
3129
public class QueryPlanningTool implements WithModelTool {
3230
public static final String TYPE = "QueryPlanningTool";
@@ -35,6 +33,7 @@ public class QueryPlanningTool implements WithModelTool {
3533
public static final String PROMPT_FIELD = "prompt";
3634
private static final String GENERATION_TYPE_FIELD = "generation_type";
3735
private static final String LLM_GENERATED_TYPE_FIELD = "llmGenerated";
36+
@Getter
3837
private final String generationType;
3938
@Setter
4039
@Getter
@@ -51,38 +50,26 @@ public class QueryPlanningTool implements WithModelTool {
5150
"{ \"query\": { \"multi_match\" : { \"query\": \"${parameters.query_text}\", \"fields\": ${parameters.query_fields:-[\"*\"]} } } }";
5251
private String defaultPrompt =
5352
"You are an OpenSearch Query DSL generation assistant; try using the optional provided index mapping ${parameters.index_mapping:-}, specified fields ${parameters.query_fields:-}, and the given sample queries as examples, generate an OpenSearch Query DSL to retrieve the most relevant documents for the user provided natural language question: ${parameters.query_text}, please return the query dsl only in a string format, no other texts.\n";
54-
@Getter
55-
private Client client;
56-
@Getter
57-
private String modelId;
58-
@Setter
59-
@Getter
60-
@VisibleForTesting
61-
private Parser outputParser;
62-
@Setter
63-
@Getter
64-
private String responseField;
6553

66-
public QueryPlanningTool(Client client, String modelId, String generationType, MLModelTool queryGenerationTool) {
67-
this.client = client;
68-
this.modelId = modelId;
54+
public QueryPlanningTool(String generationType, MLModelTool queryGenerationTool) {
6955
this.generationType = generationType;
7056
this.queryGenerationTool = queryGenerationTool;
7157
}
7258

7359
@Override
7460
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
75-
if (!parameters.containsKey(PROMPT_FIELD)) {
76-
parameters.put(PROMPT_FIELD, defaultPrompt);
77-
}
61+
7862
if (!validate(parameters)) {
7963
listener.onFailure(new IllegalArgumentException("Empty parameters for QueryPlanningTool: " + parameters));
8064
return;
8165
}
66+
if (!parameters.containsKey(PROMPT_FIELD)) {
67+
parameters.put(PROMPT_FIELD, defaultPrompt);
68+
}
8269
ActionListener<T> modelListener = ActionListener.wrap(r -> {
8370
try {
8471
String queryString = (String) r;
85-
if (queryString == null || queryString.isBlank() || queryString.isEmpty() || queryString.equals("null")) {
72+
if (queryString == null || queryString.isBlank() || queryString.equals("null")) {
8673
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
8774
String defaultQueryString = substitutor.replace(this.defaultQuery);
8875
listener.onResponse((T) defaultQueryString);
@@ -110,16 +97,6 @@ public String getVersion() {
11097
return null;
11198
}
11299

113-
@Override
114-
public String getName() {
115-
return this.name;
116-
}
117-
118-
@Override
119-
public void setName(String s) {
120-
this.name = s;
121-
}
122-
123100
@Override
124101
public boolean validate(Map<String, String> parameters) {
125102
if (parameters == null || parameters.size() == 0) {
@@ -130,8 +107,7 @@ public boolean validate(Map<String, String> parameters) {
130107

131108
public static class Factory implements WithModelTool.Factory<QueryPlanningTool> {
132109
private Client client;
133-
134-
private static Factory INSTANCE;
110+
private static volatile Factory INSTANCE;
135111

136112
public static Factory getInstance() {
137113
if (INSTANCE != null) {
@@ -165,7 +141,7 @@ public QueryPlanningTool create(Map<String, Object> map) {
165141
if (!LLM_GENERATED_TYPE_FIELD.equals(type)) {
166142
throw new IllegalArgumentException("Invalid generation type: " + type + ". The current supported types are llmGenerated.");
167143
}
168-
return new QueryPlanningTool(client, (String) map.get(MODEL_ID_FIELD), type, queryGenerationTool);
144+
return new QueryPlanningTool(type, queryGenerationTool);
169145
}
170146

171147
@Override

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java

Lines changed: 106 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
import static org.junit.Assert.assertFalse;
1010
import static org.junit.Assert.assertNotNull;
1111
import static org.junit.Assert.assertNull;
12+
import static org.junit.Assert.assertThrows;
1213
import static org.junit.Assert.assertTrue;
1314
import static org.mockito.ArgumentMatchers.any;
1415
import static org.mockito.Mockito.doAnswer;
16+
import static org.mockito.Mockito.mock;
1517
import static org.opensearch.ml.engine.tools.QueryPlanningTool.DEFAULT_DESCRIPTION;
1618
import static org.opensearch.ml.engine.tools.QueryPlanningTool.MODEL_ID_FIELD;
1719

@@ -26,6 +28,7 @@
2628
import org.junit.Rule;
2729
import org.junit.Test;
2830
import org.junit.rules.ExpectedException;
31+
import org.mockito.ArgumentCaptor;
2932
import org.mockito.Mock;
3033
import org.mockito.MockitoAnnotations;
3134
import org.opensearch.core.action.ActionListener;
@@ -46,11 +49,13 @@ public class QueryPlanningToolTests {
4649
private Map<String, String> validParams;
4750
private Map<String, String> emptyParams;
4851

52+
private QueryPlanningTool.Factory factory;
53+
4954
@Before
5055
public void setup() {
5156
MockitoAnnotations.openMocks(this);
5257
MLModelTool.Factory.getInstance().init(client);
53-
QueryPlanningTool.Factory.getInstance().init(client);
58+
factory = new QueryPlanningTool.Factory();
5459
validParams = new HashMap<>();
5560
validParams.put("prompt", "test prompt");
5661
emptyParams = Collections.emptyMap();
@@ -73,7 +78,7 @@ public void testRun() throws ExecutionException, InterruptedException {
7378
return null;
7479
}).when(queryGenerationTool).run(any(), any());
7580

76-
QueryPlanningTool tool = new QueryPlanningTool(client, "test_model_id", "llmGenerated", queryGenerationTool);
81+
QueryPlanningTool tool = new QueryPlanningTool("llmGenerated", queryGenerationTool);
7782
final CompletableFuture<String> future = new CompletableFuture<>();
7883
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
7984
// test try to update the prompt
@@ -97,7 +102,7 @@ public void testRun_PredictionReturnsList_ThrowsIllegalArgumentException() throw
97102
return null;
98103
}).when(queryGenerationTool).run(any(), any());
99104

100-
QueryPlanningTool tool = new QueryPlanningTool(client, "test_model_id", "llmGenerated", queryGenerationTool);
105+
QueryPlanningTool tool = new QueryPlanningTool("llmGenerated", queryGenerationTool);
101106
final CompletableFuture<String> future = new CompletableFuture<>();
102107
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
103108
validParams.put("query_text", "help me find some books related to wind");
@@ -114,7 +119,7 @@ public void testRun_PredictionReturnsNull_ReturnDefaultQuery() throws ExecutionE
114119
return null;
115120
}).when(queryGenerationTool).run(any(), any());
116121

117-
QueryPlanningTool tool = new QueryPlanningTool(client, "test_model_id", "llmGenerated", queryGenerationTool);
122+
QueryPlanningTool tool = new QueryPlanningTool("llmGenerated", queryGenerationTool);
118123
final CompletableFuture<String> future = new CompletableFuture<>();
119124
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
120125
validParams.put("query_text", "help me find some books related to wind");
@@ -132,7 +137,25 @@ public void testRun_PredictionReturnsEmpty_ReturnDefaultQuery() throws Execution
132137
return null;
133138
}).when(queryGenerationTool).run(any(), any());
134139

135-
QueryPlanningTool tool = new QueryPlanningTool(client, "test_model_id", "llmGenerated", queryGenerationTool);
140+
QueryPlanningTool tool = new QueryPlanningTool("llmGenerated", queryGenerationTool);
141+
final CompletableFuture<String> future = new CompletableFuture<>();
142+
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
143+
validParams.put("query_text", "help me find some books related to wind");
144+
tool.run(validParams, listener);
145+
String multiMatchQueryString =
146+
"{ \"query\": { \"multi_match\" : { \"query\": \"help me find some books related to wind\", \"fields\": [\"*\"] } } }";
147+
assertEquals(multiMatchQueryString, future.get());
148+
}
149+
150+
@Test
151+
public void testRun_PredictionReturnsNullString_ReturnDefaultQuery() throws ExecutionException, InterruptedException {
152+
doAnswer(invocation -> {
153+
ActionListener<String> listener = invocation.getArgument(1);
154+
listener.onResponse("null");
155+
return null;
156+
}).when(queryGenerationTool).run(any(), any());
157+
158+
QueryPlanningTool tool = new QueryPlanningTool("llmGenerated", queryGenerationTool);
136159
final CompletableFuture<String> future = new CompletableFuture<>();
137160
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
138161
validParams.put("query_text", "help me find some books related to wind");
@@ -168,4 +191,82 @@ public void testFactoryGetAllModelKeys() {
168191
@Rule
169192
public ExpectedException thrown = ExpectedException.none();
170193

194+
@Test
195+
public void testRunWithNoPrompt() {
196+
QueryPlanningTool tool = new QueryPlanningTool("llmGenerated", queryGenerationTool);
197+
Map<String, String> parameters = new HashMap<>();
198+
parameters.put("query_text", "some query");
199+
@SuppressWarnings("unchecked")
200+
ActionListener<String> listener = mock(ActionListener.class);
201+
202+
tool.run(parameters, listener);
203+
204+
ArgumentCaptor<Map<String, String>> captor = ArgumentCaptor.forClass(Map.class);
205+
doAnswer(invocation -> {
206+
Map<String, String> params = invocation.getArgument(0);
207+
assertNotNull(params.get("prompt"));
208+
return null;
209+
}).when(queryGenerationTool).run(captor.capture(), any());
210+
}
211+
212+
@Test
213+
public void testRunWithInvalidParameters() {
214+
QueryPlanningTool tool = new QueryPlanningTool("llmGenerated", queryGenerationTool);
215+
@SuppressWarnings("unchecked")
216+
ActionListener<String> listener = mock(ActionListener.class);
217+
218+
tool.run(Collections.emptyMap(), listener);
219+
220+
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
221+
org.mockito.Mockito.verify(listener).onFailure(captor.capture());
222+
assertEquals("Empty parameters for QueryPlanningTool: {}", captor.getValue().getMessage());
223+
}
224+
225+
@Test
226+
public void testRunModelReturnsNull() {
227+
QueryPlanningTool tool = new QueryPlanningTool("llmGenerated", queryGenerationTool);
228+
Map<String, String> parameters = new HashMap<>();
229+
parameters.put("query_text", "some query");
230+
@SuppressWarnings("unchecked")
231+
ActionListener<String> listener = mock(ActionListener.class);
232+
233+
doAnswer(invocation -> {
234+
ActionListener<String> modelListener = invocation.getArgument(1);
235+
modelListener.onResponse(null);
236+
return null;
237+
}).when(queryGenerationTool).run(any(), any());
238+
239+
tool.run(parameters, listener);
240+
241+
ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);
242+
org.mockito.Mockito.verify(listener).onResponse(captor.capture());
243+
assertNotNull(captor.getValue());
244+
}
245+
246+
@Test
247+
public void testSetName() {
248+
QueryPlanningTool tool = new QueryPlanningTool("llmGenerated", queryGenerationTool);
249+
tool.setName("NewName");
250+
assertEquals("NewName", tool.getName());
251+
}
252+
253+
@Test
254+
public void testFactoryCreateWithEmptyType() {
255+
Map<String, Object> map = new HashMap<>();
256+
map.put(QueryPlanningTool.MODEL_ID_FIELD, "modelId");
257+
Tool tool = factory.create(map);
258+
assertEquals(QueryPlanningTool.TYPE, tool.getName());
259+
assertEquals("llmGenerated", ((QueryPlanningTool) tool).getGenerationType());
260+
assertNotNull(tool);
261+
}
262+
263+
@Test
264+
public void testFactoryCreateWithInvalidType() {
265+
Map<String, Object> map = new HashMap<>();
266+
map.put("generation_type", "invalid");
267+
map.put(QueryPlanningTool.MODEL_ID_FIELD, "modelId");
268+
269+
Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(map));
270+
assertEquals("Invalid generation type: invalid. The current supported types are llmGenerated.", exception.getMessage());
271+
}
171272
}

plugin/src/test/java/org/opensearch/ml/rest/RestQueryPlanningToolIT.java

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ public class RestQueryPlanningToolIT extends MLCommonsRestTestCase {
3030
private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY");
3131
private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN");
3232
private static final String GITHUB_CI_AWS_REGION = "us-west-2";
33-
3433
private final String bedrockClaudeModelConnectorEntity = "{\n"
3534
+ " \"name\": \"Amazon Bedrock Claude 3.7-sonnet connector\",\n"
3635
+ " \"description\": \"connector for base agent with tools\",\n"
@@ -71,6 +70,9 @@ public class RestQueryPlanningToolIT extends MLCommonsRestTestCase {
7170
@Before
7271
public void setup() throws IOException, InterruptedException {
7372
ingestIrisIndexData();
73+
if (AWS_ACCESS_KEY_ID == null) {
74+
return;
75+
}
7476
queryPlanningModelId = registerQueryPlanningModel();
7577
}
7678

@@ -81,6 +83,9 @@ public void deleteIndices() throws IOException {
8183

8284
@Test
8385
public void testAgentWithQueryPlanningTool_DefaultPrompt() throws IOException {
86+
if (AWS_ACCESS_KEY_ID == null) {
87+
return;
88+
}
8489
String agentName = "Test_QueryPlanningAgent_DefaultPrompt";
8590
String agentId = registerAgentWithQueryPlanningTool(agentName, queryPlanningModelId);
8691
assertNotNull(agentId);
@@ -186,19 +191,4 @@ private Response executeAgent(String agentId, String query) throws IOException {
186191
private void deleteAgent(String agentId) throws IOException {
187192
TestHelper.makeRequest(client(), "DELETE", "/_plugins/_ml/agents/" + agentId, null, "", List.of());
188193
}
189-
190-
public String registerModel(String modelContent) throws IOException {
191-
Response response = TestHelper
192-
.makeRequest(
193-
client(),
194-
"POST",
195-
"/_plugins/_ml/models/_register",
196-
null,
197-
new StringEntity(modelContent),
198-
List.of(new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json"))
199-
);
200-
Map<String, String> responseMap = gson.fromJson(TestHelper.httpEntityToString(response.getEntity()), Map.class);
201-
return responseMap.get("task_id");
202-
}
203-
204194
}

0 commit comments

Comments
 (0)