Skip to content

Commit fb71867

Browse files
authored
Initiate query planning tool (opensearch-project#4006)
* add query planning tool Signed-off-by: Mingshi Liu <mingshl@amazon.com> * add java-time and default query Signed-off-by: Mingshi Liu <mingshl@amazon.com> * add code coverage and address comments Signed-off-by: Mingshi Liu <mingshl@amazon.com> * update description Signed-off-by: Mingshi Liu <mingshl@amazon.com> --------- Signed-off-by: Mingshi Liu <mingshl@amazon.com>
1 parent 606639a commit fb71867

File tree

5 files changed

+637
-1
lines changed

5 files changed

+637
-1
lines changed
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.tools;
7+
8+
import java.util.List;
9+
import java.util.Map;
10+
11+
import org.apache.commons.text.StringSubstitutor;
12+
import org.opensearch.core.action.ActionListener;
13+
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
14+
import org.opensearch.ml.common.spi.tools.WithModelTool;
15+
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
16+
import org.opensearch.transport.client.Client;
17+
18+
import lombok.Getter;
19+
import lombok.Setter;
20+
21+
/**
22+
* This tool supports different types of query planning,
23+
* llmGenerated, systemSearchTemplates or userSearchTemplates.
24+
* //TODO only support llmGenerated for now.
25+
* //TODO to add in systemSearchTemplates or userSearchTemplates when searchTemplatesTool is implemented.
26+
*/
27+
28+
@ToolAnnotation(QueryPlanningTool.TYPE)
29+
public class QueryPlanningTool implements WithModelTool {
30+
public static final String TYPE = "QueryPlanningTool";
31+
public static final String MODEL_ID_FIELD = "model_id";
32+
private final MLModelTool queryGenerationTool;
33+
public static final String PROMPT_FIELD = "prompt";
34+
private static final String GENERATION_TYPE_FIELD = "generation_type";
35+
private static final String LLM_GENERATED_TYPE_FIELD = "llmGenerated";
36+
@Getter
37+
private final String generationType;
38+
@Setter
39+
@Getter
40+
private String name = TYPE;
41+
@Getter
42+
@Setter
43+
private Map<String, Object> attributes;
44+
@VisibleForTesting
45+
static String DEFAULT_DESCRIPTION = "Use this tool to generate opensearch query dsl for a given natural language question.";
46+
@Getter
47+
@Setter
48+
private String description = DEFAULT_DESCRIPTION;
49+
private String defaultQuery =
50+
"{ \"query\": { \"multi_match\" : { \"query\": \"${parameters.query_text}\", \"fields\": ${parameters.query_fields:-[\"*\"]} } } }";
51+
private String defaultPrompt =
52+
"You are an OpenSearch Query DSL generation assistant; try using the optional provided index mapping ${parameters.index_mapping:-}, specified fields ${parameters.query_fields:-}, and the given sample queries as examples, generate an OpenSearch Query DSL to retrieve the most relevant documents for the user provided natural language question: ${parameters.query_text}, please return the query dsl only in a string format, no other texts.\n";
53+
54+
public QueryPlanningTool(String generationType, MLModelTool queryGenerationTool) {
55+
this.generationType = generationType;
56+
this.queryGenerationTool = queryGenerationTool;
57+
}
58+
59+
@Override
60+
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
61+
62+
if (!validate(parameters)) {
63+
listener.onFailure(new IllegalArgumentException("Empty parameters for QueryPlanningTool: " + parameters));
64+
return;
65+
}
66+
if (!parameters.containsKey(PROMPT_FIELD)) {
67+
parameters.put(PROMPT_FIELD, defaultPrompt);
68+
}
69+
ActionListener<T> modelListener = ActionListener.wrap(r -> {
70+
try {
71+
String queryString = (String) r;
72+
if (queryString == null || queryString.isBlank() || queryString.equals("null")) {
73+
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
74+
String defaultQueryString = substitutor.replace(this.defaultQuery);
75+
listener.onResponse((T) defaultQueryString);
76+
} else {
77+
listener.onResponse((T) queryString);
78+
}
79+
} catch (Exception e) {
80+
IllegalArgumentException parsingException = new IllegalArgumentException(
81+
"Error processing query string: " + r + ". Try using response_filter in agent registration if needed.",
82+
e
83+
);
84+
listener.onFailure(parsingException);
85+
}
86+
}, listener::onFailure);
87+
queryGenerationTool.run(parameters, modelListener);
88+
}
89+
90+
@Override
91+
public String getType() {
92+
return TYPE;
93+
}
94+
95+
@Override
96+
public String getVersion() {
97+
return null;
98+
}
99+
100+
@Override
101+
public boolean validate(Map<String, String> parameters) {
102+
if (parameters == null || parameters.size() == 0) {
103+
return false;
104+
}
105+
return true;
106+
}
107+
108+
public static class Factory implements WithModelTool.Factory<QueryPlanningTool> {
109+
private Client client;
110+
private static volatile Factory INSTANCE;
111+
112+
public static Factory getInstance() {
113+
if (INSTANCE != null) {
114+
return INSTANCE;
115+
}
116+
synchronized (QueryPlanningTool.class) {
117+
if (INSTANCE != null) {
118+
return INSTANCE;
119+
}
120+
INSTANCE = new Factory();
121+
return INSTANCE;
122+
}
123+
}
124+
125+
public void init(Client client) {
126+
this.client = client;
127+
}
128+
129+
@Override
130+
public QueryPlanningTool create(Map<String, Object> map) {
131+
132+
MLModelTool queryGenerationTool = MLModelTool.Factory.getInstance().create(map);
133+
134+
String type = (String) map.get(GENERATION_TYPE_FIELD);
135+
if (type == null || type.isEmpty()) {
136+
type = LLM_GENERATED_TYPE_FIELD;
137+
}
138+
139+
// TODO to add in SYSTEM_SEARCH_TEMPLATES_TYPE_FIELD, USER_SEARCH_TEMPLATES_TYPE_FIELD when searchTemplatesTool is
140+
// implemented.
141+
if (!LLM_GENERATED_TYPE_FIELD.equals(type)) {
142+
throw new IllegalArgumentException("Invalid generation type: " + type + ". The current supported types are llmGenerated.");
143+
}
144+
return new QueryPlanningTool(type, queryGenerationTool);
145+
}
146+
147+
@Override
148+
public String getDefaultDescription() {
149+
return DEFAULT_DESCRIPTION;
150+
}
151+
152+
@Override
153+
public String getDefaultType() {
154+
return TYPE;
155+
}
156+
157+
@Override
158+
public String getDefaultVersion() {
159+
return null;
160+
}
161+
162+
@Override
163+
public List<String> getAllModelKeys() {
164+
return List.of(MODEL_ID_FIELD);
165+
}
166+
}
167+
}

0 commit comments

Comments
 (0)