Skip to content

Commit ad3bdee

Browse files
committed
Revert "Revert "Standardize parameter handling in all Tool implementations""
This reverts commit c542568.
1 parent 789b6df commit ad3bdee

File tree

9 files changed

+23
-8
lines changed

9 files changed

+23
-8
lines changed

src/main/java/org/opensearch/agent/tools/CreateAlertTool.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ public boolean validate(Map<String, String> parameters) {
133133
}
134134

135135
@Override
136-
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
136+
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
137+
Map<String, String> parameters = ToolHelper.extractInputParameters(originalParameters, attributes);
137138
Map<String, String> tmpParams = new HashMap<>(parameters);
138139
if (!tmpParams.containsKey("indices") || Strings.isEmpty(tmpParams.get("indices"))) {
139140
throw new IllegalArgumentException(

src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ public CreateAnomalyDetectorTool(Client client, String modelId, String modelType
169169
*/
170170
@Override
171171
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
172+
parameters = ToolHelper.extractInputParameters(parameters, attributes);
172173
final String tenantId = parameters.get(TENANT_ID_FIELD);
173174
Map<String, String> enrichedParameters = enrichParameters(parameters);
174175
String indexName = enrichedParameters.get("index");

src/main/java/org/opensearch/agent/tools/DynamicTool.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.apache.commons.text.StringSubstitutor;
1616
import org.apache.logging.log4j.LogManager;
1717
import org.apache.logging.log4j.Logger;
18+
import org.opensearch.agent.tools.utils.ToolHelper;
1819
import org.opensearch.common.xcontent.XContentType;
1920
import org.opensearch.core.action.ActionListener;
2021
import org.opensearch.core.common.bytes.BytesReference;
@@ -114,7 +115,8 @@ public boolean validate(Map<String, String> map) {
114115
}
115116

116117
@Override
117-
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
118+
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
119+
Map<String, String> parameters = ToolHelper.extractInputParameters(originalParameters, attributes);
118120
RestRequest.Method method = RestRequest.Method.valueOf(parameters.get(METHOD_KEY));
119121
String uri = parameters.get(URI_KEY);
120122
String requestBody = parameters.get(REQUEST_BODY_KEY);

src/main/java/org/opensearch/agent/tools/PPLTool.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ public PPLTool(
196196

197197
@SuppressWarnings("unchecked")
198198
@Override
199-
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
199+
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
200+
Map<String, String> parameters = ToolHelper.extractInputParameters(originalParameters, attributes);
200201
final String tenantId = parameters.get(TENANT_ID_FIELD);
201202
extractFromChatParameters(parameters);
202203
String indexName = getIndexNameFromParameters(parameters);

src/main/java/org/opensearch/agent/tools/SearchAlertsTool.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import java.util.Map;
1212

1313
import org.apache.commons.lang3.StringUtils;
14+
import org.opensearch.agent.tools.utils.ToolHelper;
1415
import org.opensearch.commons.alerting.AlertingPluginInterface;
1516
import org.opensearch.commons.alerting.action.GetAlertsRequest;
1617
import org.opensearch.commons.alerting.action.GetAlertsResponse;
@@ -70,7 +71,8 @@ public Object parse(Object o) {
7071
}
7172

7273
@Override
73-
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
74+
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
75+
Map<String, String> parameters = ToolHelper.extractInputParameters(originalParameters, attributes);
7476
final String tableSortOrder = parameters.getOrDefault("sortOrder", "asc");
7577
final String tableSortString = parameters.getOrDefault("sortString", "monitor_name.keyword");
7678
final int tableSize = parameters.containsKey("size") && StringUtils.isNumeric(parameters.get("size"))

src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.opensearch.ad.transport.GetAnomalyDetectorResponse;
2424
import org.opensearch.agent.tools.utils.ToolConstants;
2525
import org.opensearch.agent.tools.utils.ToolConstants.DetectorStateString;
26+
import org.opensearch.agent.tools.utils.ToolHelper;
2627
import org.opensearch.common.lucene.uid.Versions;
2728
import org.opensearch.core.action.ActionListener;
2829
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
@@ -94,7 +95,8 @@ public Object parse(Object o) {
9495
// number of total detectors. The output will likely need to be updated, standardized, and include more fields in the
9596
// future to cover a sufficient amount of potential questions the agent will need to handle.
9697
@Override
97-
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
98+
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
99+
Map<String, String> parameters = ToolHelper.extractInputParameters(originalParameters, attributes);
98100
final String detectorName = parameters.getOrDefault("detectorName", null);
99101
final String detectorNamePattern = parameters.getOrDefault("detectorNamePattern", null);
100102
final String indices = parameters.getOrDefault("indices", null);

src/main/java/org/opensearch/agent/tools/SearchAnomalyResultsTool.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.opensearch.action.search.SearchResponse;
1515
import org.opensearch.ad.client.AnomalyDetectionNodeClient;
1616
import org.opensearch.agent.tools.utils.ToolConstants;
17+
import org.opensearch.agent.tools.utils.ToolHelper;
1718
import org.opensearch.core.action.ActionListener;
1819
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
1920
import org.opensearch.index.IndexNotFoundException;
@@ -84,7 +85,8 @@ public Object parse(Object o) {
8485
// and total # of results. The output will likely need to be updated, standardized, and include more fields in the
8586
// future to cover a sufficient amount of potential questions the agent will need to handle.
8687
@Override
87-
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
88+
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
89+
Map<String, String> parameters = ToolHelper.extractInputParameters(originalParameters, attributes);
8890
final String detectorId = parameters.getOrDefault("detectorId", null);
8991
final Boolean realTime = parameters.containsKey("realTime") ? Boolean.parseBoolean(parameters.get("realTime")) : null;
9092
final Double anomalyGradeThreshold = parameters.containsKey("anomalyGradeThreshold")

src/main/java/org/opensearch/agent/tools/SearchMonitorsTool.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.apache.lucene.search.join.ScoreMode;
1717
import org.opensearch.action.search.SearchRequest;
1818
import org.opensearch.action.search.SearchResponse;
19+
import org.opensearch.agent.tools.utils.ToolHelper;
1920
import org.opensearch.commons.alerting.AlertingPluginInterface;
2021
import org.opensearch.commons.alerting.action.SearchMonitorRequest;
2122
import org.opensearch.commons.alerting.model.ScheduledJob;
@@ -84,7 +85,8 @@ public Object parse(Object o) {
8485
// number of total monitors. The output will likely need to be updated, standardized, and include more fields in the
8586
// future to cover a sufficient amount of potential questions the agent will need to handle.
8687
@Override
87-
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
88+
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
89+
Map<String, String> parameters = ToolHelper.extractInputParameters(originalParameters, attributes);
8890
final String monitorId = parameters.getOrDefault("monitorId", null);
8991
final String monitorName = parameters.getOrDefault("monitorName", null);
9092
final String monitorNamePattern = parameters.getOrDefault("monitorNamePattern", null);

src/main/java/org/opensearch/agent/tools/WebSearchTool.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.jsoup.nodes.Element;
3030
import org.jsoup.select.Elements;
3131
import org.opensearch.agent.ToolPlugin;
32+
import org.opensearch.agent.tools.utils.ToolHelper;
3233
import org.opensearch.core.action.ActionListener;
3334
import org.opensearch.ml.common.spi.tools.Tool;
3435
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
@@ -99,7 +100,8 @@ public WebSearchTool(ThreadPool threadPool) {
99100
}
100101

101102
@Override
102-
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
103+
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
104+
Map<String, String> parameters = ToolHelper.extractInputParameters(originalParameters, attributes);
103105
try {
104106
// common search parameters
105107
String query = parameters.getOrDefault("query", parameters.get("question")).replaceAll(" ", "+");

0 commit comments

Comments
 (0)