Skip to content

Commit 00980ab

Browse files
mingshlrithin-pullela-aws
authored andcommitted
add query planning tool
Signed-off-by: Mingshi Liu <mingshl@amazon.com>
1 parent 7eb0d80 commit 00980ab

File tree

3 files changed

+296
-1
lines changed

3 files changed

+296
-1
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.tools;
7+
8+
import java.util.List;
9+
import java.util.Map;
10+
11+
import org.opensearch.core.action.ActionListener;
12+
import org.opensearch.ml.common.output.model.ModelTensors;
13+
import org.opensearch.ml.common.spi.tools.Parser;
14+
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
15+
import org.opensearch.ml.common.spi.tools.WithModelTool;
16+
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
17+
import org.opensearch.transport.client.Client;
18+
19+
import lombok.Getter;
20+
import lombok.Setter;
21+
import lombok.extern.log4j.Log4j2;
22+
23+
/**
24+
* This tool supports different types of query planning,
25+
* llmGenerated, systemSearchTemplates or userSearchTemplates.
26+
* //TODO only support llmGenerated for now.
27+
* //TODO to add in systemSearchTemplates or userSearchTemplates when searchTemplatesTool is implemented.
28+
*/
29+
@Log4j2
30+
@ToolAnnotation(QueryPlanningTool.TYPE)
31+
public class QueryPlanningTool implements WithModelTool {
32+
public static final String TYPE = "QueryPlanningTool";
33+
public static final String MODEL_ID_FIELD = "model_id";
34+
private final MLModelTool queryGenerationTool;
35+
public static final String PROMPT_FIELD = "prompt";
36+
private static final String GENERATION_TYPE_FIELD = "generation_type";
37+
private static final String LLM_GENERATED_TYPE_FIELD = "llmGenerated";
38+
private final String generationType;
39+
@Setter
40+
@Getter
41+
private String name = TYPE;
42+
@Getter
43+
@Setter
44+
private Map<String, Object> attributes;
45+
@VisibleForTesting
46+
static String DEFAULT_DESCRIPTION = "Use this tool to generate query plans for a given query text.";
47+
@Getter
48+
@Setter
49+
private String description = DEFAULT_DESCRIPTION;
50+
private String defaultPrompt =
51+
"You are an OpenSearch Query DSL generation assistant; using the provided index mapping ${parameters.ListIndexTool.output:-}, specified fields ${parameters.fields:-}, and the given sample queries as examples, generate an OpenSearch Query DSL to retrieve the most relevant documents for the user provided natural language question: ${parameters.query_text}\n";
52+
@Getter
53+
private Client client;
54+
@Getter
55+
private String modelId;
56+
@Setter
57+
@Getter
58+
@VisibleForTesting
59+
private Parser outputParser;
60+
@Setter
61+
@Getter
62+
private String responseField;
63+
64+
public QueryPlanningTool(Client client, String modelId, String generationType, MLModelTool queryGenerationTool) {
65+
this.client = client;
66+
this.modelId = modelId;
67+
this.generationType = generationType;
68+
this.queryGenerationTool = queryGenerationTool;
69+
}
70+
71+
@Override
72+
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
73+
if (!parameters.containsKey(PROMPT_FIELD)) {
74+
parameters.put(PROMPT_FIELD, defaultPrompt);
75+
}
76+
ActionListener<List<ModelTensors>> modelListener = ActionListener.wrap(r -> {
77+
try {
78+
@SuppressWarnings("unchecked")
79+
T result = (T) outputParser.parse(r);
80+
listener.onResponse(result);
81+
} catch (Exception e) {
82+
listener.onFailure(e);
83+
}
84+
}, listener::onFailure);
85+
queryGenerationTool.run(parameters, modelListener);
86+
}
87+
88+
@Override
89+
public String getType() {
90+
return TYPE;
91+
}
92+
93+
@Override
94+
public String getVersion() {
95+
return null;
96+
}
97+
98+
@Override
99+
public String getName() {
100+
return this.name;
101+
}
102+
103+
@Override
104+
public void setName(String s) {
105+
this.name = s;
106+
}
107+
108+
@Override
109+
public boolean validate(Map<String, String> parameters) {
110+
if (parameters == null || parameters.size() == 0) {
111+
return false;
112+
}
113+
return true;
114+
}
115+
116+
public static class Factory implements WithModelTool.Factory<QueryPlanningTool> {
117+
private Client client;
118+
119+
private static Factory INSTANCE;
120+
121+
public static Factory getInstance() {
122+
if (INSTANCE != null) {
123+
return INSTANCE;
124+
}
125+
synchronized (QueryPlanningTool.class) {
126+
if (INSTANCE != null) {
127+
return INSTANCE;
128+
}
129+
INSTANCE = new Factory();
130+
return INSTANCE;
131+
}
132+
}
133+
134+
public void init(Client client) {
135+
this.client = client;
136+
}
137+
138+
@Override
139+
public QueryPlanningTool create(Map<String, Object> map) {
140+
141+
MLModelTool queryGenerationTool = MLModelTool.Factory.getInstance().create(map);
142+
143+
String type = (String) map.get(GENERATION_TYPE_FIELD);
144+
if (type == null || type.isEmpty()) {
145+
type = LLM_GENERATED_TYPE_FIELD;
146+
}
147+
148+
// TODO to add in , SYSTEM_SEARCH_TEMPLATES_TYPE_FIELD, USER_SEARCH_TEMPLATES_TYPE_FIELD when searchTemplatesTool is
149+
// implemented.
150+
if (!LLM_GENERATED_TYPE_FIELD.equals(type)) {
151+
throw new IllegalArgumentException("Invalid generation type: " + type + ". The current supported types are llmGenerated.");
152+
}
153+
return new QueryPlanningTool(client, (String) map.get(MODEL_ID_FIELD), type, queryGenerationTool);
154+
}
155+
156+
@Override
157+
public String getDefaultDescription() {
158+
return DEFAULT_DESCRIPTION;
159+
}
160+
161+
@Override
162+
public String getDefaultType() {
163+
return TYPE;
164+
}
165+
166+
@Override
167+
public String getDefaultVersion() {
168+
return null;
169+
}
170+
171+
@Override
172+
public List<String> getAllModelKeys() {
173+
return List.of(MODEL_ID_FIELD);
174+
}
175+
}
176+
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.tools;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertFalse;
10+
import static org.junit.Assert.assertNotNull;
11+
import static org.junit.Assert.assertNull;
12+
import static org.junit.Assert.assertTrue;
13+
import static org.mockito.ArgumentMatchers.any;
14+
import static org.mockito.Mockito.doAnswer;
15+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.DEFAULT_DESCRIPTION;
16+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.MODEL_ID_FIELD;
17+
18+
import java.util.Collections;
19+
import java.util.HashMap;
20+
import java.util.List;
21+
import java.util.Map;
22+
import java.util.concurrent.CompletableFuture;
23+
import java.util.concurrent.ExecutionException;
24+
25+
import org.junit.Before;
26+
import org.junit.Test;
27+
import org.mockito.Mock;
28+
import org.mockito.MockitoAnnotations;
29+
import org.opensearch.core.action.ActionListener;
30+
import org.opensearch.ml.common.output.model.ModelTensor;
31+
import org.opensearch.ml.common.output.model.ModelTensors;
32+
import org.opensearch.ml.common.spi.tools.Tool;
33+
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;
34+
import org.opensearch.transport.client.Client;
35+
36+
public class QueryPlanningToolTests {
37+
38+
@Mock
39+
private Client client;
40+
41+
@Mock
42+
private MLModelTool queryGenerationTool;
43+
44+
private Map<String, String> validParams;
45+
private Map<String, String> emptyParams;
46+
47+
@Before
48+
public void setup() {
49+
MockitoAnnotations.openMocks(this);
50+
MLModelTool.Factory.getInstance().init(client);
51+
QueryPlanningTool.Factory.getInstance().init(client);
52+
validParams = new HashMap<>();
53+
validParams.put("prompt", "test prompt");
54+
emptyParams = Collections.emptyMap();
55+
}
56+
57+
@Test
58+
public void testFactoryCreate() {
59+
Map<String, Object> map = Map.of(MODEL_ID_FIELD, "test_model_id");
60+
Tool tool = QueryPlanningTool.Factory.getInstance().create(map);
61+
assertNotNull(tool);
62+
assertEquals(QueryPlanningTool.TYPE, tool.getName());
63+
}
64+
65+
@Test
66+
public void testRun() throws ExecutionException, InterruptedException {
67+
String matchQueryString = "{\"query\":{\"match\":{\"title\":\"wind\"}}}";
68+
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", matchQueryString)).build();
69+
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(List.of(modelTensor)).build();
70+
List<ModelTensors> modelTensorsList = List.of(modelTensors);
71+
72+
doAnswer(invocation -> {
73+
ActionListener<List<ModelTensors>> listener = invocation.getArgument(1);
74+
listener.onResponse(modelTensorsList);
75+
return null;
76+
}).when(queryGenerationTool).run(any(), any());
77+
78+
QueryPlanningTool tool = new QueryPlanningTool(client, "test_model_id", "llmGenerated", queryGenerationTool);
79+
tool.setOutputParser(o -> {
80+
List<ModelTensors> outputs = (List<ModelTensors>) o;
81+
return outputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
82+
});
83+
84+
final CompletableFuture<String> future = new CompletableFuture<>();
85+
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
86+
// test try to update the prompt
87+
validParams
88+
.put("prompt", "You are a query generation agent. Generate a dsl query for the following question: ${parameters.query_text}");
89+
validParams.put("query_text", "help me find some books related to wind");
90+
tool.run(validParams, listener);
91+
92+
assertEquals(matchQueryString, future.get());
93+
}
94+
95+
@Test
96+
public void testValidate() {
97+
Tool tool = QueryPlanningTool.Factory.getInstance().create(Collections.emptyMap());
98+
assertTrue(tool.validate(validParams));
99+
assertFalse(tool.validate(emptyParams));
100+
assertFalse(tool.validate(null));
101+
}
102+
103+
@Test
104+
public void testToolGetters() {
105+
Tool tool = QueryPlanningTool.Factory.getInstance().create(Collections.emptyMap());
106+
assertEquals(QueryPlanningTool.TYPE, tool.getName());
107+
assertEquals(QueryPlanningTool.TYPE, tool.getType());
108+
assertEquals(DEFAULT_DESCRIPTION, tool.getDescription());
109+
assertNull(tool.getVersion());
110+
}
111+
112+
@Test
113+
public void testFactoryGetAllModelKeys() {
114+
List<String> allModelKeys = QueryPlanningTool.Factory.getInstance().getAllModelKeys();
115+
assertEquals(List.of(MODEL_ID_FIELD), allModelKeys);
116+
}
117+
}

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@
238238
import org.opensearch.ml.engine.tools.ListIndexTool;
239239
import org.opensearch.ml.engine.tools.MLModelTool;
240240
import org.opensearch.ml.engine.tools.McpSseTool;
241+
import org.opensearch.ml.engine.tools.QueryPlanningTool;
241242
import org.opensearch.ml.engine.tools.SearchIndexTool;
242243
import org.opensearch.ml.engine.tools.VisualizationsTool;
243244
import org.opensearch.ml.engine.utils.AgentModelsSearcher;
@@ -720,6 +721,7 @@ public Collection<Object> createComponents(
720721
SearchIndexTool.Factory.getInstance().init(client, xContentRegistry);
721722
VisualizationsTool.Factory.getInstance().init(client);
722723
ConnectorTool.Factory.getInstance().init(client);
724+
QueryPlanningTool.Factory.getInstance().init(client);
723725

724726
toolFactories.put(MLModelTool.TYPE, MLModelTool.Factory.getInstance());
725727
toolFactories.put(McpSseTool.TYPE, McpSseTool.Factory.getInstance());
@@ -729,7 +731,7 @@ public Collection<Object> createComponents(
729731
toolFactories.put(SearchIndexTool.TYPE, SearchIndexTool.Factory.getInstance());
730732
toolFactories.put(VisualizationsTool.TYPE, VisualizationsTool.Factory.getInstance());
731733
toolFactories.put(ConnectorTool.TYPE, ConnectorTool.Factory.getInstance());
732-
734+
toolFactories.put(QueryPlanningTool.TYPE, QueryPlanningTool.Factory.getInstance());
733735
if (externalToolFactories != null) {
734736
toolFactories.putAll(externalToolFactories);
735737
}

0 commit comments

Comments
 (0)