Skip to content

Commit addf057

Browse files
authored
Update parameter handling of tools (#618)
* Add parameter extraction utilities for tool inputs - Add utilities for extracting required parameters and JSON input parameters - Apply parameter extraction in AbstractRetrieverTool and RAGTool - Define TOOL_REQUIRED_PARAMS constant for consistent parameter handling Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Standardize parameter handling in all Tool implementations - Update all Tool interface implementations to use extractInputParameters utility Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Update release note Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Declare origin of helper method extractInputParameters Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Remove displaced comment in javadoc Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Fix failed test in AbstractRetrieverToolTests Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Replace copied tool utils to library ones Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> --------- Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent bc4d9ca commit addf057

15 files changed

+45
-21
lines changed
Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
1-
## Version 3.2.0 Release Notes
1+
## Version 3.2.0.0 Release Notes
22

3-
Compatible with OpenSearch and OpenSearch Dashboards version 3.2.0
3+
Compatible with OpenSearch and OpenSearch Dashboards version 3.2.0.0
44

55
### Features
66
* Support dynamic tool in agent framework ([#606](https://github.com/opensearch-project/skills/pull/606))
77

88
### Enhancements
99
* Merge index schema meta ([#596](https://github.com/opensearch-project/skills/pull/596))
10+
* Mask error message in PPLTool ([#609](https://github.com/opensearch-project/skills/pull/609))
1011

1112
### Bug Fixes
1213
* Fix attributes handling in dynamic tool ([#607](https://github.com/opensearch-project/skills/pull/607))
13-
* Mask error message in PPLTool ([#609](https://github.com/opensearch-project/skills/pull/609))
14-
15-
### Infrastructure
16-
* Update the maven snapshot publish endpoint and credential ([#601](https://github.com/opensearch-project/skills/pull/601))
17-
* Gradle and Lombok bump, changing CI java to 24 and adjusting AD getConfigRequest ([#615](https://github.com/opensearch-project/skills/pull/615))
1814

1915
### Maintenance
20-
* [AUTO] Increment version to 3.2.0-SNAPSHOT ([#605](https://github.com/opensearch-project/skills/pull/605))
16+
* Update the maven snapshot publish endpoint and credential ([#601](https://github.com/opensearch-project/skills/pull/601))
17+
* Bump gradle, java, lombok and fix ad configrequest change ([#615](https://github.com/opensearch-project/skills/pull/615))
18+
* Bump version to 3.2.0.0 ([#605](https://github.com/opensearch-project/skills/pull/605))

src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.opensearch.core.xcontent.NamedXContentRegistry;
2323
import org.opensearch.core.xcontent.XContentParser;
2424
import org.opensearch.ml.common.spi.tools.Tool;
25+
import org.opensearch.ml.common.utils.ToolUtils;
2526
import org.opensearch.search.SearchHit;
2627
import org.opensearch.search.builder.SearchSourceBuilder;
2728
import org.opensearch.transport.client.Client;
@@ -94,7 +95,8 @@ protected <T> SearchRequest buildSearchRequest(Map<String, String> parameters) t
9495
}
9596

9697
@Override
97-
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
98+
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
99+
Map<String, String> parameters = ToolUtils.extractInputParameters(originalParameters, attributes);
98100
SearchRequest searchRequest;
99101
try {
100102
searchRequest = buildSearchRequest(parameters);

src/main/java/org/opensearch/agent/tools/CreateAlertTool.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
4040
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
4141
import org.opensearch.ml.common.utils.StringUtils;
42+
import org.opensearch.ml.common.utils.ToolUtils;
4243
import org.opensearch.transport.client.Client;
4344

4445
import com.google.gson.reflect.TypeToken;
@@ -133,7 +134,8 @@ public boolean validate(Map<String, String> parameters) {
133134
}
134135

135136
@Override
136-
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
137+
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
138+
Map<String, String> parameters = ToolUtils.extractInputParameters(originalParameters, attributes);
137139
Map<String, String> tmpParams = new HashMap<>(parameters);
138140
if (!tmpParams.containsKey("indices") || Strings.isEmpty(tmpParams.get("indices"))) {
139141
throw new IllegalArgumentException(

src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.opensearch.ml.common.spi.tools.WithModelTool;
4545
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
4646
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
47+
import org.opensearch.ml.common.utils.ToolUtils;
4748
import org.opensearch.transport.client.Client;
4849

4950
import com.google.common.collect.ImmutableMap;
@@ -169,6 +170,7 @@ public CreateAnomalyDetectorTool(Client client, String modelId, String modelType
169170
*/
170171
@Override
171172
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
173+
parameters = ToolUtils.extractInputParameters(parameters, attributes);
172174
final String tenantId = parameters.get(TENANT_ID_FIELD);
173175
Map<String, String> enrichedParameters = enrichParameters(parameters);
174176
String indexName = enrichedParameters.get("index");

src/main/java/org/opensearch/agent/tools/DynamicTool.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.opensearch.core.xcontent.XContentParser;
2626
import org.opensearch.ml.common.spi.tools.Tool;
2727
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
28+
import org.opensearch.ml.common.utils.ToolUtils;
2829
import org.opensearch.rest.DynamicRestRequestCreator;
2930
import org.opensearch.rest.DynamicToolExecutor;
3031
import org.opensearch.rest.RestRequest;
@@ -114,7 +115,8 @@ public boolean validate(Map<String, String> map) {
114115
}
115116

116117
@Override
117-
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
118+
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
119+
Map<String, String> parameters = ToolUtils.extractInputParameters(originalParameters, attributes);
118120
RestRequest.Method method = RestRequest.Method.valueOf(parameters.get(METHOD_KEY));
119121
String uri = parameters.get(URI_KEY);
120122
String requestBody = parameters.get(REQUEST_BODY_KEY);

src/main/java/org/opensearch/agent/tools/LogPatternTool.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.opensearch.core.common.util.CollectionUtils;
3232
import org.opensearch.core.xcontent.NamedXContentRegistry;
3333
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
34+
import org.opensearch.ml.common.utils.ToolUtils;
3435
import org.opensearch.search.SearchHit;
3536
import org.opensearch.sql.plugin.transport.PPLQueryAction;
3637
import org.opensearch.sql.plugin.transport.TransportPPLQueryRequest;
@@ -107,7 +108,8 @@ protected String getQueryBody(String queryText) {
107108
}
108109

109110
@Override
110-
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
111+
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
112+
Map<String, String> parameters = ToolUtils.extractInputParameters(originalParameters, attributes);
111113
String dsl = parameters.get(INPUT_FIELD);
112114
String ppl = parameters.get(PPL_FIELD);
113115
if (!StringUtils.isBlank(dsl)) {

src/main/java/org/opensearch/agent/tools/PPLTool.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import org.opensearch.ml.common.spi.tools.WithModelTool;
5353
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
5454
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
55+
import org.opensearch.ml.common.utils.ToolUtils;
5556
import org.opensearch.search.SearchHit;
5657
import org.opensearch.search.builder.SearchSourceBuilder;
5758
import org.opensearch.sql.plugin.transport.PPLQueryAction;
@@ -196,7 +197,8 @@ public PPLTool(
196197

197198
@SuppressWarnings("unchecked")
198199
@Override
199-
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
200+
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
201+
Map<String, String> parameters = ToolUtils.extractInputParameters(originalParameters, attributes);
200202
final String tenantId = parameters.get(TENANT_ID_FIELD);
201203
extractFromChatParameters(parameters);
202204
String indexName = getIndexNameFromParameters(parameters);

src/main/java/org/opensearch/agent/tools/RAGTool.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.opensearch.ml.common.spi.tools.WithModelTool;
2929
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
3030
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
31+
import org.opensearch.ml.common.utils.ToolUtils;
3132
import org.opensearch.transport.client.Client;
3233

3334
import com.google.gson.Gson;
@@ -95,7 +96,9 @@ public Object parse(Object o) {
9596
};
9697
}
9798

98-
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
99+
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
100+
Map<String, String> parameters = ToolUtils.extractInputParameters(originalParameters, attributes);
101+
99102
final String tenantId = parameters.get(TENANT_ID_FIELD);
100103

101104
String input = null;

src/main/java/org/opensearch/agent/tools/SearchAlertsTool.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.opensearch.ml.common.spi.tools.Parser;
2222
import org.opensearch.ml.common.spi.tools.Tool;
2323
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
24+
import org.opensearch.ml.common.utils.ToolUtils;
2425
import org.opensearch.transport.client.Client;
2526
import org.opensearch.transport.client.node.NodeClient;
2627

@@ -70,7 +71,8 @@ public Object parse(Object o) {
7071
}
7172

7273
@Override
73-
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
74+
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
75+
Map<String, String> parameters = ToolUtils.extractInputParameters(originalParameters, attributes);
7476
final String tableSortOrder = parameters.getOrDefault("sortOrder", "asc");
7577
final String tableSortString = parameters.getOrDefault("sortString", "monitor_name.keyword");
7678
final int tableSize = parameters.containsKey("size") && StringUtils.isNumeric(parameters.get("size"))

src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.opensearch.ml.common.spi.tools.Parser;
3737
import org.opensearch.ml.common.spi.tools.Tool;
3838
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
39+
import org.opensearch.ml.common.utils.ToolUtils;
3940
import org.opensearch.search.SearchHit;
4041
import org.opensearch.search.builder.SearchSourceBuilder;
4142
import org.opensearch.search.sort.SortOrder;
@@ -94,7 +95,8 @@ public Object parse(Object o) {
9495
// number of total detectors. The output will likely need to be updated, standardized, and include more fields in the
9596
// future to cover a sufficient amount of potential questions the agent will need to handle.
9697
@Override
97-
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
98+
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
99+
Map<String, String> parameters = ToolUtils.extractInputParameters(originalParameters, attributes);
98100
final String detectorName = parameters.getOrDefault("detectorName", null);
99101
final String detectorNamePattern = parameters.getOrDefault("detectorNamePattern", null);
100102
final String indices = parameters.getOrDefault("indices", null);

0 commit comments

Comments
 (0)