Skip to content

Commit e0d0c94

Browse files
committed
Support output filter, unify tool parameter handling and improve SearchIndexTool output parsing
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
1 parent a39dd33 commit e0d0c94

File tree

8 files changed

+198
-80
lines changed

8 files changed

+198
-80
lines changed

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.ml.common.utils;
77

8+
import static org.apache.commons.text.StringEscapeUtils.escapeJson;
89
import static org.opensearch.action.ValidateActions.addValidationError;
910

1011
import java.nio.ByteBuffer;
@@ -40,6 +41,7 @@
4041
import com.google.gson.JsonObject;
4142
import com.google.gson.JsonParser;
4243
import com.google.gson.JsonSyntaxException;
44+
import com.google.gson.reflect.TypeToken;
4345
import com.jayway.jsonpath.JsonPath;
4446
import com.jayway.jsonpath.PathNotFoundException;
4547
import com.networknt.schema.JsonSchema;
@@ -111,6 +113,13 @@ public static boolean isJson(String json) {
111113
}
112114
}
113115

116+
public static String escapeString(String input) {
117+
if (isJson(input)) {
118+
return input;
119+
}
120+
return escapeJson(input);
121+
}
122+
114123
public static String toUTF8(String rawString) {
115124
ByteBuffer buffer = StandardCharsets.UTF_8.encode(rawString);
116125

@@ -552,4 +561,7 @@ public static boolean matchesSafePattern(String value) {
552561
return SAFE_INPUT_PATTERN.matcher(value).matches();
553562
}
554563

564+
public static List<String> parseStringArrayToList(String jsonArrayString) {
565+
return gson.fromJson(jsonArrayString, TypeToken.getParameterized(List.class, String.class).getType());
566+
}
555567
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD;
1111
import static org.opensearch.ml.common.CommonValue.MCP_CONNECTOR_ID_FIELD;
1212
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
13-
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
1413
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
1514
import static org.opensearch.ml.common.utils.StringUtils.gson;
1615
import static org.opensearch.ml.common.utils.StringUtils.isJson;
@@ -81,6 +80,7 @@
8180
import org.opensearch.ml.engine.encryptor.Encryptor;
8281
import org.opensearch.ml.engine.function_calling.FunctionCalling;
8382
import org.opensearch.ml.engine.tools.McpSseTool;
83+
import org.opensearch.ml.engine.tools.ToolUtils;
8484
import org.opensearch.remote.metadata.client.GetDataObjectRequest;
8585
import org.opensearch.remote.metadata.client.SdkClient;
8686
import org.opensearch.remote.metadata.common.SdkClientUtils;
@@ -841,7 +841,8 @@ public static void createTools(
841841
return;
842842
}
843843
for (MLToolSpec toolSpec : toolSpecs) {
844-
Tool tool = createTool(toolFactories, params, toolSpec, mlAgent.getTenantId());
844+
Map<String, String> executeParams = ToolUtils.buildToolParameters(params, toolSpec, mlAgent.getTenantId());
845+
Tool tool = ToolUtils.createTool(toolFactories, executeParams, toolSpec);
845846
tools.put(tool.getName(), tool);
846847
if (toolSpec.getAttributes() != null) {
847848
if (tool.getAttributes() == null) {
@@ -856,46 +857,6 @@ public static void createTools(
856857
}
857858
}
858859

859-
public static Tool createTool(
860-
Map<String, Tool.Factory> toolFactories,
861-
Map<String, String> params,
862-
MLToolSpec toolSpec,
863-
String tenantId
864-
) {
865-
if (!toolFactories.containsKey(toolSpec.getType())) {
866-
throw new IllegalArgumentException("Tool not found: " + toolSpec.getType());
867-
}
868-
Map<String, String> executeParams = new HashMap<>();
869-
if (toolSpec.getParameters() != null) {
870-
executeParams.putAll(toolSpec.getParameters());
871-
}
872-
executeParams.put(TENANT_ID_FIELD, tenantId);
873-
for (String key : params.keySet()) {
874-
String toolNamePrefix = getToolName(toolSpec) + ".";
875-
if (key.startsWith(toolNamePrefix)) {
876-
executeParams.put(key.replace(toolNamePrefix, ""), params.get(key));
877-
}
878-
}
879-
Map<String, Object> toolParams = new HashMap<>();
880-
toolParams.putAll(executeParams);
881-
Map<String, Object> runtimeResources = toolSpec.getRuntimeResources();
882-
if (runtimeResources != null) {
883-
toolParams.putAll(runtimeResources);
884-
}
885-
Tool tool = toolFactories.get(toolSpec.getType()).create(toolParams);
886-
String toolName = getToolName(toolSpec);
887-
tool.setName(toolName);
888-
889-
if (toolSpec.getDescription() != null) {
890-
tool.setDescription(toolSpec.getDescription());
891-
}
892-
if (params.containsKey(toolName + ".description")) {
893-
tool.setDescription(params.get(toolName + ".description"));
894-
}
895-
896-
return tool;
897-
}
898-
899860
public static List<String> getToolNames(Map<String, Tool> tools) {
900861
final List<String> inputTools = new ArrayList<>();
901862
for (Map.Entry<String, Tool> entry : tools.entrySet()) {

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@
8484
import org.opensearch.transport.client.Client;
8585

8686
import com.google.common.annotations.VisibleForTesting;
87+
import com.jayway.jsonpath.JsonPath;
88+
import com.jayway.jsonpath.PathNotFoundException;
8789

8890
import lombok.Data;
8991
import lombok.NoArgsConstructor;
@@ -615,7 +617,17 @@ private static void runTool(
615617
String finalAction = action;
616618
ActionListener<Object> toolListener = ActionListener.wrap(r -> {
617619
if (functionCalling != null) {
618-
List<Map<String, Object>> toolResults = List.of(Map.of(TOOL_CALL_ID, toolCallId, TOOL_RESULT, Map.of("text", r)));
620+
String outputResponse = StringUtils.toJson(r);
621+
if (toolParams.containsKey("output_filter")) {
622+
try {
623+
Object filteredOutput = JsonPath.read(outputResponse, toolParams.get("output_filter"));
624+
outputResponse = StringUtils.toJson(filteredOutput);
625+
} catch (PathNotFoundException e) {
626+
log.error("Failed to read tool response from path [{}]", toolParams.get("output_filter"), e);
627+
}
628+
}
629+
List<Map<String, Object>> toolResults = List
630+
.of(Map.of(TOOL_CALL_ID, toolCallId, TOOL_RESULT, Map.of("text", outputResponse)));
619631
List<LLMMessage> llmMessages = functionCalling.supply(toolResults);
620632
// TODO: support multiple tool calls at the same time so that multiple LLMMessages can be generated here
621633
interactions.add(llmMessages.getFirst().getResponse());

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@
55

66
package org.opensearch.ml.engine.algorithms.agent;
77

8-
import static org.apache.commons.text.StringEscapeUtils.escapeJson;
98
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
109
import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD;
1110
import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD;
1211
import static org.opensearch.ml.common.conversation.ActionConstants.MEMORY_ID;
1312
import static org.opensearch.ml.common.conversation.ActionConstants.PARENT_INTERACTION_ID_FIELD;
1413
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE;
15-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTool;
1614
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit;
1715
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;
1816
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolName;
@@ -53,10 +51,14 @@
5351
import org.opensearch.ml.engine.encryptor.Encryptor;
5452
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
5553
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
54+
import org.opensearch.ml.engine.tools.ToolUtils;
5655
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
5756
import org.opensearch.remote.metadata.client.SdkClient;
5857
import org.opensearch.transport.client.Client;
5958

59+
import com.jayway.jsonpath.JsonPath;
60+
import com.jayway.jsonpath.PathNotFoundException;
61+
6062
import lombok.Data;
6163
import lombok.NoArgsConstructor;
6264
import lombok.extern.log4j.Log4j2;
@@ -183,7 +185,8 @@ private void runAgent(
183185
for (int i = 0; i <= toolSpecs.size(); i++) {
184186
if (i == 0) {
185187
MLToolSpec toolSpec = toolSpecs.get(i);
186-
Tool tool = createTool(toolFactories, params, toolSpec, mlAgent.getTenantId());
188+
Map<String, String> executeParams = ToolUtils.buildToolParameters(params, toolSpec, mlAgent.getTenantId());
189+
Tool tool = ToolUtils.createTool(toolFactories, executeParams, toolSpec);
187190
firstStepListener = new StepListener<>();
188191
previousStepListener = firstStepListener;
189192
firstTool = tool;
@@ -265,7 +268,17 @@ private void processOutput(
265268
String toolName = getToolName(previousToolSpec);
266269
String outputKey = toolName + ".output";
267270
String outputResponse = parseResponse(output);
268-
params.put(outputKey, escapeJson(outputResponse));
271+
if (previousToolSpec.getParameters() != null && previousToolSpec.getParameters().containsKey("output_filter")) {
272+
try {
273+
Object filteredOutput = JsonPath.read(outputResponse, previousToolSpec.getParameters().get("output_filter"));
274+
params.put(outputKey, StringUtils.toJson(filteredOutput));
275+
} catch (PathNotFoundException e) {
276+
log.error("Failed to read response from path [{}]", previousToolSpec.getParameters().get("output_filter"), e);
277+
params.put(outputKey, StringUtils.escapeString(outputResponse));
278+
}
279+
} else {
280+
params.put(outputKey, StringUtils.escapeString(outputResponse));
281+
}
269282
boolean traceDisabled = params.containsKey(DISABLE_TRACE) && Boolean.parseBoolean(params.get(DISABLE_TRACE));
270283

271284
if (previousToolSpec.isIncludeOutputInAgentResponse() || finalI == toolSpecs.size()) {
@@ -351,9 +364,10 @@ private void runNextStep(
351364
StepListener<Object> nextStepListener
352365
) {
353366
MLToolSpec toolSpec = toolSpecs.get(finalI);
354-
Tool tool = createTool(toolFactories, params, toolSpec, tenantId);
367+
Map<String, String> executeParams = ToolUtils.buildToolParameters(params, toolSpec, tenantId);
368+
Tool tool = ToolUtils.createTool(toolFactories, executeParams, toolSpec);
355369
if (finalI < toolSpecs.size()) {
356-
tool.run(getToolExecuteParams(toolSpec, params, tenantId), nextStepListener);
370+
tool.run(getToolExecuteParams(toolSpec, executeParams, tenantId), nextStepListener);
357371
}
358372
}
359373

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
package org.opensearch.ml.engine.algorithms.agent;
77

8-
import static org.apache.commons.text.StringEscapeUtils.escapeJson;
98
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
109
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;
1110
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolName;
@@ -39,11 +38,15 @@
3938
import org.opensearch.ml.common.utils.StringUtils;
4039
import org.opensearch.ml.engine.encryptor.Encryptor;
4140
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
41+
import org.opensearch.ml.engine.tools.ToolUtils;
4242
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
4343
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;
4444
import org.opensearch.remote.metadata.client.SdkClient;
4545
import org.opensearch.transport.client.Client;
4646

47+
import com.jayway.jsonpath.JsonPath;
48+
import com.jayway.jsonpath.PathNotFoundException;
49+
4750
import lombok.Data;
4851
import lombok.NoArgsConstructor;
4952
import lombok.extern.log4j.Log4j2;
@@ -104,7 +107,8 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
104107
for (int i = 0; i <= toolSpecs.size(); i++) {
105108
if (i == 0) {
106109
MLToolSpec toolSpec = toolSpecs.get(i);
107-
Tool tool = createTool(toolSpec, mlAgent.getTenantId());
110+
Map<String, String> executeParams = ToolUtils.buildToolParameters(params, toolSpec, mlAgent.getTenantId());
111+
Tool tool = ToolUtils.createTool(toolFactories, executeParams, toolSpec);
108112
firstStepListener = new StepListener<>();
109113
previousStepListener = firstStepListener;
110114
firstTool = tool;
@@ -118,7 +122,17 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
118122
String outputKey = key + ".output";
119123

120124
String outputResponse = parseResponse(output);
121-
params.put(outputKey, escapeJson(outputResponse));
125+
if (previousToolSpec.getParameters() != null && previousToolSpec.getParameters().containsKey("output_filter")) {
126+
try {
127+
Object filteredOutput = JsonPath.read(outputResponse, previousToolSpec.getParameters().get("output_filter"));
128+
params.put(outputKey, StringUtils.toJson(filteredOutput));
129+
} catch (PathNotFoundException e) {
130+
log.error("Failed to read response from path [{}]", previousToolSpec.getParameters().get("output_filter"), e);
131+
params.put(outputKey, StringUtils.escapeString(outputResponse));
132+
}
133+
} else {
134+
params.put(outputKey, StringUtils.escapeString(outputResponse));
135+
}
122136

123137
if (previousToolSpec.isIncludeOutputInAgentResponse() || finalI == toolSpecs.size()) {
124138
if (output instanceof ModelTensorOutput) {
@@ -152,9 +166,10 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
152166
}
153167

154168
MLToolSpec toolSpec = toolSpecs.get(finalI);
155-
Tool tool = createTool(toolSpec, mlAgent.getTenantId());
169+
Map<String, String> executeParams = ToolUtils.buildToolParameters(params, toolSpec, mlAgent.getTenantId());
170+
Tool tool = ToolUtils.createTool(toolFactories, executeParams, toolSpec);
156171
if (finalI < toolSpecs.size()) {
157-
tool.run(getToolExecuteParams(toolSpec, params, mlAgent.getTenantId()), nextStepListener);
172+
tool.run(getToolExecuteParams(toolSpec, executeParams, mlAgent.getTenantId()), nextStepListener);
158173
}
159174

160175
}, e -> {
@@ -256,27 +271,6 @@ String parseResponse(Object output) throws IOException {
256271
}
257272
}
258273

259-
@VisibleForTesting
260-
Tool createTool(MLToolSpec toolSpec, String tenantId) {
261-
Map<String, String> toolParams = new HashMap<>();
262-
if (toolSpec.getParameters() != null) {
263-
toolParams.putAll(toolSpec.getParameters());
264-
}
265-
toolParams.put(TENANT_ID_FIELD, tenantId);
266-
if (!toolFactories.containsKey(toolSpec.getType())) {
267-
throw new IllegalArgumentException("Tool not found: " + toolSpec.getType());
268-
}
269-
Tool tool = toolFactories.get(toolSpec.getType()).create(toolParams);
270-
if (toolSpec.getName() != null) {
271-
tool.setName(toolSpec.getName());
272-
}
273-
274-
if (toolSpec.getDescription() != null) {
275-
tool.setDescription(toolSpec.getDescription());
276-
}
277-
return tool;
278-
}
279-
280274
@VisibleForTesting
281275
Map<String, String> getToolExecuteParams(MLToolSpec toolSpec, Map<String, String> params, String tenantId) {
282276
Map<String, String> executeParams = new HashMap<>();

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
99
import static org.opensearch.ml.common.utils.StringUtils.gson;
1010

11-
import java.util.HashMap;
1211
import java.util.Map;
1312

1413
import org.opensearch.action.ActionRequest;
@@ -148,8 +147,7 @@ public String getDefaultVersion() {
148147
}
149148

150149
private Map<String, String> extractInputParameters(Map<String, String> parameters) {
151-
Map<String, String> extractedParameters = new HashMap<>();
152-
extractedParameters.putAll(parameters);
150+
Map<String, String> extractedParameters = ToolUtils.extractRequiredParameters(parameters, attributes);
153151
if (parameters.containsKey("input")) {
154152
try {
155153
Map<String, String> chatParameters = gson.fromJson(parameters.get("input"), Map.class);

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,27 @@
88
import static org.opensearch.ml.common.CommonValue.*;
99

1010
import java.io.IOException;
11+
import java.util.ArrayList;
1112
import java.util.HashMap;
13+
import java.util.List;
1214
import java.util.Map;
1315
import java.util.Objects;
1416

1517
import org.apache.commons.lang3.StringUtils;
1618
import org.opensearch.action.search.SearchRequest;
1719
import org.opensearch.action.search.SearchResponse;
1820
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
21+
import org.opensearch.common.xcontent.XContentFactory;
1922
import org.opensearch.common.xcontent.XContentType;
2023
import org.opensearch.core.action.ActionListener;
24+
import org.opensearch.core.common.bytes.BytesReference;
2125
import org.opensearch.core.xcontent.NamedXContentRegistry;
26+
import org.opensearch.core.xcontent.ToXContent;
27+
import org.opensearch.core.xcontent.XContentBuilder;
2228
import org.opensearch.core.xcontent.XContentParser;
29+
import org.opensearch.ml.common.output.model.ModelTensor;
30+
import org.opensearch.ml.common.output.model.ModelTensorOutput;
31+
import org.opensearch.ml.common.output.model.ModelTensors;
2332
import org.opensearch.ml.common.spi.tools.Tool;
2433
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
2534
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction;
@@ -127,12 +136,29 @@ private static Map<String, Object> processResponse(SearchHit hit) {
127136
return docContent;
128137
}
129138

139+
public Map<String, Object> convertSearchResponseToMap(SearchResponse searchResponse) throws IOException {
140+
XContentBuilder builder = XContentFactory.jsonBuilder();
141+
searchResponse.toXContent(builder, ToXContent.EMPTY_PARAMS);
142+
143+
// Convert to bytes and then to map
144+
BytesReference bytes = BytesReference.bytes(builder);
145+
try (
146+
XContentParser parser = XContentType.JSON
147+
.xContent()
148+
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, BytesReference.toBytes(bytes))
149+
) {
150+
return parser.map();
151+
}
152+
}
153+
130154
@Override
131-
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
155+
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
132156
try {
157+
Map<String, String> parameters = ToolUtils.extractRequiredParameters(originalParameters, attributes);
133158
String input = parameters.get(INPUT_FIELD);
134159
String index = null;
135160
String query = null;
161+
boolean returnFullResponse = Boolean.parseBoolean(parameters.getOrDefault("return_full_response", "false"));
136162
if (!StringUtils.isEmpty(input)) {
137163
try {
138164
JsonObject jsonObject = GSON.fromJson(input, JsonObject.class);
@@ -165,7 +191,15 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
165191

166192
ActionListener<SearchResponse> actionListener = ActionListener.<SearchResponse>wrap(r -> {
167193
SearchHit[] hits = r.getHits().getHits();
168-
194+
if (returnFullResponse) {
195+
List<ModelTensors> outputs = new ArrayList<>();
196+
List<ModelTensor> tensors = new ArrayList<>();
197+
tensors.add(ModelTensor.builder().name(name).dataAsMap(convertSearchResponseToMap(r)).build());
198+
outputs.add(ModelTensors.builder().mlModelTensors(tensors).build());
199+
ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(outputs).build();
200+
listener.onResponse((T) output);
201+
return;
202+
}
169203
if (hits != null && hits.length > 0) {
170204
StringBuilder contextBuilder = new StringBuilder();
171205
for (SearchHit hit : hits) {

0 commit comments

Comments
 (0)