Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.common.utils;

import static org.apache.commons.text.StringEscapeUtils.escapeJson;
import static org.opensearch.action.ValidateActions.addValidationError;

import java.nio.ByteBuffer;
Expand All @@ -16,6 +17,7 @@
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand All @@ -40,6 +42,7 @@
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.gson.JsonSyntaxException;
import com.google.gson.reflect.TypeToken;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.PathNotFoundException;
import com.networknt.schema.JsonSchema;
Expand Down Expand Up @@ -111,6 +114,30 @@ public static boolean isJson(String json) {
}
}

/**
* Ensures that a string is properly JSON escaped.
*
* <p>This method examines the input string and determines whether it already represents
* valid JSON content. If the input is valid JSON, it is returned unchanged. Otherwise,
* the input is treated as a plain string and escaped according to JSON string literal
* rules.</p>
*
* <p>Examples:</p>
* <pre>
* prepareJsonValue("hello") → "\"hello\""
* prepareJsonValue("\"hello\"") → "\\\"hello\\\""
* prepareJsonValue("{\"key\":123}") → {\"key\":123} (valid JSON object, unchanged)
* </pre>
* @param input
* @return
*/
public static String prepareJsonValue(String input) {
if (isJson(input)) {
return input;
}
return escapeJson(input);
}

public static String toUTF8(String rawString) {
ByteBuffer buffer = StandardCharsets.UTF_8.encode(rawString);

Expand Down Expand Up @@ -552,4 +579,22 @@ public static boolean matchesSafePattern(String value) {
return SAFE_INPUT_PATTERN.matcher(value).matches();
}

/**
* Parses a JSON array string into a List of Strings.
*
* @param jsonArrayString JSON array string to parse (e.g., "[\"item1\", \"item2\"]")
* @return List of strings parsed from the JSON array, or an empty list if the input is
* null, empty, or invalid JSON
*/
public static List<String> parseStringArrayToList(String jsonArrayString) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/**
 * Parses a JSON array string into a List of Strings.
 * 
 * @param jsonArrayString JSON array string to parse (e.g., "[\"item1\", \"item2\"]")
 * @return List of strings parsed from the JSON array
 * @throws JsonSyntaxException if the input is not a valid JSON array
 */

if (jsonArrayString == null || jsonArrayString.trim().isEmpty()) {
return Collections.emptyList();
}
try {
return gson.fromJson(jsonArrayString, TypeToken.getParameterized(List.class, String.class).getType());
} catch (JsonSyntaxException e) {
log.error("Failed to parse JSON array string: {}", jsonArrayString, e);
return Collections.emptyList();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,13 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.opensearch.ml.common.utils.StringUtils.TO_STRING_FUNCTION_NAME;
import static org.opensearch.ml.common.utils.StringUtils.collectToStringPrefixes;
import static org.opensearch.ml.common.utils.StringUtils.getJsonPath;
import static org.opensearch.ml.common.utils.StringUtils.isValidJSONPath;
import static org.opensearch.ml.common.utils.StringUtils.obtainFieldNameFromJsonPath;
import static org.opensearch.ml.common.utils.StringUtils.parseParameters;
import static org.opensearch.ml.common.utils.StringUtils.toJson;
import static org.opensearch.ml.common.utils.StringUtils.*;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -190,7 +186,7 @@ public void addDefaultMethod_NoEscape() {
public void addDefaultMethod_Escape() {
String input = "return escape(\"abc\n123\");";
String result = StringUtils.addDefaultMethod(input);
Assert.assertNotEquals(input, result);
assertNotEquals(input, result);
assertTrue(result.startsWith(StringUtils.DEFAULT_ESCAPE_FUNCTION));
}

Expand Down Expand Up @@ -858,4 +854,99 @@ public void testValidateFields_InvalidCharacterSet() {
assertTrue(exception.getMessage().contains("Field1"));
}

@Test
public void prepareJsonValue_returnsRawIfJson() {
String json = "{\"key\": 123}";
String result = StringUtils.prepareJsonValue(json);
assertSame(json, result); // branch where isJson(input)==true
}

@Test
public void prepareJsonValue_escapesBadCharsOtherwise() {
String input = "Tom & Jerry \"<script>";
String escaped = StringUtils.prepareJsonValue(input);
assertNotEquals(input, escaped);
assertFalse(StringUtils.isJson(escaped));
assertEquals("Tom & Jerry \\\"<script>", escaped);
}

@Test
public void testParseStringArrayToList_validJsonArray() {
// Arrange
String jsonArray = "[\"apple\", \"banana\", \"cherry\"]";

// Act
List<String> result = parseStringArrayToList(jsonArray);

// Assert
assertEquals(Arrays.asList("apple", "banana", "cherry"), result);
}

@Test
public void testParseStringArrayToList_emptyArray() {
// Arrange
String jsonArray = "[]";

// Act
List<String> result = parseStringArrayToList(jsonArray);

// Assert
assertTrue(result.isEmpty());
}

@Test
public void testParseStringArrayToList_withSpecialCharacters() {
// Arrange
String jsonArray = "[\"hello\", \"world!\", \"special: @#$%^&*()\"]";

// Act
List<String> result = parseStringArrayToList(jsonArray);

// Assert
assertEquals(Arrays.asList("hello", "world!", "special: @#$%^&*()"), result);
}

@Test
public void testParseStringArrayToList_withNullElement() {
// Arrange
String jsonArray = "[\"first\", null, \"third\"]";

// Act
List<String> result = parseStringArrayToList(jsonArray);

// Assert
assertEquals(3, result.size());
assertEquals("first", result.get(0));
assertNull(result.get(1));
assertEquals("third", result.get(2));
}

@Test
public void testParseStringArrayToList_jsonWithTrailingComma() {
// Arrange
String jsonWithTrailingComma = "[\"apple\", \"banana\",]"; // Invalid trailing comma

List<String> result = parseStringArrayToList(jsonWithTrailingComma);

// Assert
assertEquals(Arrays.asList("apple", "banana", null), result);
assertEquals(3, result.size());
}

@Test
public void testParseStringArrayToList_nonArrayJson() {
// Arrange
String nonArrayJson = "{\"key\": \"value\"}";

// Act & Assert
List<String> array = parseStringArrayToList(nonArrayJson);
assertEquals(0, array.size());
}

@Test
public void testParseStringArrayToList_Null() {
List<String> array = parseStringArrayToList(null);
assertEquals(0, array.size());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD;
import static org.opensearch.ml.common.CommonValue.MCP_CONNECTOR_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.isJson;
Expand All @@ -29,6 +28,7 @@
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;
import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.RESPONSE_FIELD;
import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS;
import static org.opensearch.ml.engine.tools.ToolUtils.getToolName;

import java.io.IOException;
import java.lang.reflect.Type;
Expand Down Expand Up @@ -81,6 +81,7 @@
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.function_calling.FunctionCalling;
import org.opensearch.ml.engine.tools.McpSseTool;
import org.opensearch.ml.engine.tools.ToolUtils;
import org.opensearch.remote.metadata.client.GetDataObjectRequest;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.common.SdkClientUtils;
Expand Down Expand Up @@ -646,10 +647,6 @@ public static int getMessageHistoryLimit(Map<String, String> params) {
return messageHistoryLimitStr != null ? Integer.parseInt(messageHistoryLimitStr) : LAST_N_INTERACTIONS;
}

public static String getToolName(MLToolSpec toolSpec) {
return toolSpec.getName() != null ? toolSpec.getName() : toolSpec.getType();
}

public static List<MLToolSpec> getMlToolSpecs(MLAgent mlAgent, Map<String, String> params) {
String selectedToolsStr = params.get(SELECTED_TOOLS);
List<MLToolSpec> toolSpecs = new ArrayList<>();
Expand Down Expand Up @@ -841,7 +838,8 @@ public static void createTools(
return;
}
for (MLToolSpec toolSpec : toolSpecs) {
Tool tool = createTool(toolFactories, params, toolSpec, mlAgent.getTenantId());
Map<String, String> toolParams = ToolUtils.buildToolParameters(params, toolSpec, mlAgent.getTenantId());
Tool tool = ToolUtils.createTool(toolFactories, toolParams, toolSpec);
tools.put(tool.getName(), tool);
if (toolSpec.getAttributes() != null) {
if (tool.getAttributes() == null) {
Expand All @@ -856,55 +854,6 @@ public static void createTools(
}
}

public static Tool createTool(
Map<String, Tool.Factory> toolFactories,
Map<String, String> params,
MLToolSpec toolSpec,
String tenantId
) {
if (!toolFactories.containsKey(toolSpec.getType())) {
throw new IllegalArgumentException("Tool not found: " + toolSpec.getType());
}
Map<String, String> executeParams = new HashMap<>();
if (toolSpec.getParameters() != null) {
executeParams.putAll(toolSpec.getParameters());
}
executeParams.put(TENANT_ID_FIELD, tenantId);
for (String key : params.keySet()) {
String toolNamePrefix = getToolName(toolSpec) + ".";
if (key.startsWith(toolNamePrefix)) {
executeParams.put(key.replace(toolNamePrefix, ""), params.get(key));
}
}
Map<String, Object> toolParams = new HashMap<>();
toolParams.putAll(executeParams);
Map<String, Object> runtimeResources = toolSpec.getRuntimeResources();
if (runtimeResources != null) {
toolParams.putAll(runtimeResources);
}
Tool tool = toolFactories.get(toolSpec.getType()).create(toolParams);
String toolName = getToolName(toolSpec);
tool.setName(toolName);

if (toolSpec.getDescription() != null) {
tool.setDescription(toolSpec.getDescription());
}
if (params.containsKey(toolName + ".description")) {
tool.setDescription(params.get(toolName + ".description"));
}

return tool;
}

public static List<String> getToolNames(Map<String, Tool> tools) {
final List<String> inputTools = new ArrayList<>();
for (Map.Entry<String, Tool> entry : tools.entrySet()) {
String toolName = entry.getValue().getName();
inputTools.add(toolName);
}
return inputTools;
}

public static Map<String, String> constructToolParams(
Map<String, Tool> tools,
Map<String, MLToolSpec> toolSpecMap,
Expand All @@ -916,8 +865,15 @@ public static Map<String, String> constructToolParams(
Map<String, String> toolParams = new HashMap<>();
Map<String, String> toolSpecParams = toolSpecMap.get(action).getParameters();
Map<String, String> toolSpecConfigMap = toolSpecMap.get(action).getConfigMap();
MLToolSpec toolSpec = toolSpecMap.get(action);
if (toolSpecParams != null) {
toolParams.putAll(toolSpecParams);
for (String key : toolSpecParams.keySet()) {
String toolNamePrefix = getToolName(toolSpec) + ".";
if (key.startsWith(toolNamePrefix)) {
toolParams.put(key.replace(toolNamePrefix, ""), toolSpecParams.get(key));
}
}
}
if (toolSpecConfigMap != null) {
toolParams.putAll(toolSpecConfigMap);
Expand Down
Loading
Loading