Skip to content

Commit c3154c2

Browse files
committed
fix comments
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
1 parent ae9ebaf commit c3154c2

File tree

7 files changed

+71
-17
lines changed

7 files changed

+71
-17
lines changed

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import java.security.PrivilegedExceptionAction;
1818
import java.util.ArrayList;
1919
import java.util.Base64;
20+
import java.util.Collections;
2021
import java.util.HashMap;
2122
import java.util.HashSet;
2223
import java.util.List;
@@ -113,7 +114,24 @@ public static boolean isJson(String json) {
113114
}
114115
}
115116

116-
public static String escapeString(String input) {
117+
/**
118+
* Ensures that a string is properly JSON escaped.
119+
*
120+
* <p>This method examines the input string and determines whether it already represents
121+
* valid JSON content. If the input is valid JSON, it is returned unchanged. Otherwise,
122+
* the input is treated as a plain string and escaped according to JSON string literal
123+
* rules.</p>
124+
*
125+
* <p>Examples:</p>
126+
* <pre>
127+
* prepareJsonValue("hello") → "\"hello\""
128+
* prepareJsonValue("\"hello\"") → "\\\"hello\\\""
129+
* prepareJsonValue("{\"key\":123}") → {\"key\":123} (valid JSON object, unchanged)
130+
* </pre>
131+
* @param input
132+
* @return
133+
*/
134+
public static String prepareJsonValue(String input) {
117135
if (isJson(input)) {
118136
return input;
119137
}
@@ -561,7 +579,22 @@ public static boolean matchesSafePattern(String value) {
561579
return SAFE_INPUT_PATTERN.matcher(value).matches();
562580
}
563581

582+
/**
583+
* Parses a JSON array string into a List of Strings.
584+
*
585+
* @param jsonArrayString JSON array string to parse (e.g., "[\"item1\", \"item2\"]")
586+
* @return List of strings parsed from the JSON array, or an empty list if the input is
587+
* null, empty, or invalid JSON
588+
*/
564589
public static List<String> parseStringArrayToList(String jsonArrayString) {
565-
return gson.fromJson(jsonArrayString, TypeToken.getParameterized(List.class, String.class).getType());
590+
if (jsonArrayString == null || jsonArrayString.trim().isEmpty()) {
591+
return Collections.emptyList();
592+
}
593+
try {
594+
return gson.fromJson(jsonArrayString, TypeToken.getParameterized(List.class, String.class).getType());
595+
} catch (JsonSyntaxException e) {
596+
log.error("Failed to parse JSON array string: {}", jsonArrayString, e);
597+
return Collections.emptyList();
598+
}
566599
}
567600
}

common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import org.opensearch.OpenSearchParseException;
3232
import org.opensearch.action.ActionRequestValidationException;
3333

34-
import com.google.gson.JsonSyntaxException;
3534
import com.jayway.jsonpath.JsonPath;
3635

3736
public class StringUtilsTest {
@@ -856,16 +855,16 @@ public void testValidateFields_InvalidCharacterSet() {
856855
}
857856

858857
@Test
859-
public void escapeString_returnsRawIfJson() {
858+
public void prepareJsonValue_returnsRawIfJson() {
860859
String json = "{\"key\": 123}";
861-
String result = StringUtils.escapeString(json);
860+
String result = StringUtils.prepareJsonValue(json);
862861
assertSame(json, result); // branch where isJson(input)==true
863862
}
864863

865864
@Test
866-
public void escapeString_escapesBadCharsOtherwise() {
865+
public void prepareJsonValue_escapesBadCharsOtherwise() {
867866
String input = "Tom & Jerry \"<script>";
868-
String escaped = StringUtils.escapeString(input);
867+
String escaped = StringUtils.prepareJsonValue(input);
869868
assertNotEquals(input, escaped);
870869
assertFalse(StringUtils.isJson(escaped));
871870
assertEquals("Tom & Jerry \\\"<script>", escaped);
@@ -940,7 +939,14 @@ public void testParseStringArrayToList_nonArrayJson() {
940939
String nonArrayJson = "{\"key\": \"value\"}";
941940

942941
// Act & Assert
943-
assertThrows(JsonSyntaxException.class, () -> { parseStringArrayToList(nonArrayJson); });
942+
List<String> array = parseStringArrayToList(nonArrayJson);
943+
assertEquals(0, array.size());
944+
}
945+
946+
@Test
947+
public void testParseStringArrayToList_Null() {
948+
List<String> array = parseStringArrayToList(null);
949+
assertEquals(0, array.size());
944950
}
945951

946952
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ private void processOutput(
261261
String outputKey = toolName + ".output";
262262
Map<String, String> toolParameters = ToolUtils.buildToolParameters(params, previousToolSpec, tenantId);
263263
String filteredOutput = parseResponse(filterToolOutput(toolParameters, output));
264-
params.put(outputKey, StringUtils.escapeString(filteredOutput));
264+
params.put(outputKey, StringUtils.prepareJsonValue(filteredOutput));
265265
boolean traceDisabled = params.containsKey(DISABLE_TRACE) && Boolean.parseBoolean(params.get(DISABLE_TRACE));
266266

267267
if (previousToolSpec.isIncludeOutputInAgentResponse() || finalI == toolSpecs.size()) {
@@ -349,10 +349,10 @@ private void runNextStep(
349349
StepListener<Object> nextStepListener
350350
) {
351351
MLToolSpec toolSpec = toolSpecs.get(finalI);
352-
Map<String, String> executeParams = ToolUtils.buildToolParameters(params, toolSpec, tenantId);
353-
Tool tool = ToolUtils.createTool(toolFactories, executeParams, toolSpec);
352+
Map<String, String> toolExecutionParameters = ToolUtils.buildToolParameters(params, toolSpec, tenantId);
353+
Tool tool = ToolUtils.createTool(toolFactories, toolExecutionParameters, toolSpec);
354354
if (finalI < toolSpecs.size()) {
355-
tool.run(executeParams, nextStepListener);
355+
tool.run(toolExecutionParameters, nextStepListener);
356356
}
357357
}
358358

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
113113
String outputKey = toolName + ".output";
114114
Map<String, String> toolParameters = ToolUtils.buildToolParameters(params, previousToolSpec, mlAgent.getTenantId());
115115
String filteredOutput = parseResponse(filterToolOutput(toolParameters, output));
116-
params.put(outputKey, StringUtils.escapeString(filteredOutput));
116+
params.put(outputKey, StringUtils.prepareJsonValue(filteredOutput));
117117
if (previousToolSpec.isIncludeOutputInAgentResponse() || finalI == toolSpecs.size()) {
118118
if (toolParameters.containsKey(TOOL_OUTPUT_FILTERS_FIELD)) {
119119
flowAgentOutput.add(ModelTensor.builder().name(outputKey).result(filteredOutput).build());

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ public class SearchIndexTool implements Tool {
7474
private static final Gson GSON = new GsonBuilder().serializeSpecialFloatingPointValues().create();
7575

7676
public static final Map<String, Object> DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, false);
77-
public static final String RETURN_FULL_RESPONSE = "return_full_response";
77+
public static final String RETURN_RAW_RESPONSE = "return_raw_response";
7878

7979
private String name = TYPE;
8080
private Map<String, Object> attributes;
@@ -137,6 +137,13 @@ private static Map<String, Object> processResponse(SearchHit hit) {
137137
return docContent;
138138
}
139139

140+
/**
141+
* Converts a SearchResponse to a Map representation for easier processing.
142+
*
143+
* @param searchResponse The search response to convert
144+
* @return Map representation of the search response
145+
* @throws IOException if conversion fails
146+
*/
140147
public Map<String, Object> convertSearchResponseToMap(SearchResponse searchResponse) throws IOException {
141148
XContentBuilder builder = XContentFactory.jsonBuilder();
142149
searchResponse.toXContent(builder, ToXContent.EMPTY_PARAMS);
@@ -159,7 +166,7 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
159166
String input = parameters.get(INPUT_FIELD);
160167
String index = null;
161168
String query = null;
162-
boolean returnFullResponse = Boolean.parseBoolean(parameters.getOrDefault(RETURN_FULL_RESPONSE, "false"));
169+
boolean returnFullResponse = Boolean.parseBoolean(parameters.getOrDefault(RETURN_RAW_RESPONSE, "false"));
163170
if (!StringUtils.isEmpty(input)) {
164171
try {
165172
JsonObject jsonObject = GSON.fromJson(input, JsonObject.class);

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ToolUtils.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,14 @@
2424

2525
import com.google.gson.reflect.TypeToken;
2626
import com.jayway.jsonpath.JsonPath;
27+
import com.jayway.jsonpath.PathNotFoundException;
2728

2829
import lombok.extern.log4j.Log4j2;
2930

31+
/**
32+
* Utility class for tool-related operations including parameter extraction,
33+
* tool creation, and output filtering.
34+
*/
3035
@Log4j2
3136
public class ToolUtils {
3237

@@ -127,7 +132,10 @@ public static Object filterToolOutput(Map<String, String> toolParams, Object res
127132
String output = parseResponse(response);
128133
Object filteredOutput = JsonPath.read(output, toolParams.get(TOOL_OUTPUT_FILTERS_FIELD));
129134
return StringUtils.toJson(filteredOutput);
135+
} catch (PathNotFoundException e) {
136+
log.error("JSONPath not found: [{}]", toolParams.get(TOOL_OUTPUT_FILTERS_FIELD), e);
130137
} catch (Exception e) {
138+
// TODO: another option is returning error if failed to parse, need test to check which option is better.
131139
log.error("Failed to read tool response from path [{}]", toolParams.get(TOOL_OUTPUT_FILTERS_FIELD), e);
132140
}
133141
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ public void testRunWithReturnFullResponseTrue() {
298298

299299
Map<String, String> parameters = new HashMap<>();
300300
parameters.put("input", inputString);
301-
parameters.put(SearchIndexTool.RETURN_FULL_RESPONSE, "true");
301+
parameters.put(SearchIndexTool.RETURN_RAW_RESPONSE, "true");
302302

303303
mockedSearchIndexTool.run(parameters, listener);
304304

@@ -341,7 +341,7 @@ public void testRunWithReturnFullResponseFalse() {
341341

342342
Map<String, String> parameters = new HashMap<>();
343343
parameters.put("input", inputString);
344-
parameters.put(SearchIndexTool.RETURN_FULL_RESPONSE, "false");
344+
parameters.put(SearchIndexTool.RETURN_RAW_RESPONSE, "false");
345345

346346
mockedSearchIndexTool.run(parameters, listener);
347347

0 commit comments

Comments
 (0)