Skip to content

Commit 168c496

Browse files
committed
add more UT
Signed-off-by: Mingshi Liu <mingshl@amazon.com>
1 parent 89e11ce commit 168c496

File tree

3 files changed

+75
-46
lines changed

3 files changed

+75
-46
lines changed

common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,4 +1225,46 @@ public void testDeserializeNullFloat_ToNull() {
12251225
assertTrue(m.get("fPrim").isJsonPrimitive());
12261226
assertEquals(1.0f, m.get("fPrim").getAsFloat(), 1e-9f);
12271227
}
1228+
1229+
@Test
1230+
public void testProcessTextDoc_ExceptionMessageEscaping() {
1231+
// Test the problematic exception message from the error
1232+
String problematicMessage = "Invalid payload: { \"system\": [{\"text\": \"You are a precise...\"}], \"messages\": [...] }\n"
1233+
+ "See https://github.com/google/gson/blob/main/Troubleshooting.md#unexpected-json-structure";
1234+
1235+
String escapedMessage = StringUtils.processTextDoc(problematicMessage);
1236+
1237+
// Verify that problematic characters are escaped
1238+
assertFalse(
1239+
"Escaped message should not contain unescaped newlines",
1240+
escapedMessage.contains("\n") && !escapedMessage.contains("\\n")
1241+
);
1242+
assertFalse(
1243+
"Escaped message should not contain unescaped quotes",
1244+
escapedMessage.contains("\"") && !escapedMessage.contains("\\\"")
1245+
);
1246+
}
1247+
1248+
@Test
1249+
public void testProcessTextDoc_GsonParsingErrorMessageEscaping() {
1250+
// Test the specific Gson error message pattern
1251+
String gsonError = "Expected BEGIN_ARRAY but was STRING at line 1 column 1 path $\n"
1252+
+ "See https://github.com/google/gson/blob/main/Troubleshooting.md#unexpected-json-structure";
1253+
1254+
String escapedMessage = StringUtils.processTextDoc(gsonError);
1255+
1256+
// The escaped message should be safe for JSON inclusion
1257+
assertTrue("Escaped message should be safe for JSON", !escapedMessage.contains("\n") || escapedMessage.contains("\\n"));
1258+
}
1259+
1260+
@Test
1261+
public void testProcessTextDoc_NormalMessagePassthrough() {
1262+
// Test that normal messages without special characters pass through unchanged
1263+
String normalMessage = "Tool execution failed with normal error";
1264+
1265+
String escapedMessage = StringUtils.processTextDoc(normalMessage);
1266+
1267+
// Normal messages should be handled properly
1268+
assertTrue("Normal messages should be handled properly", escapedMessage.length() > 0);
1269+
}
12281270
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ private static void runTool(
630630
TOOL_CALL_ID,
631631
toolCallId,
632632
"tool_response",
633-
"Tool " + action + " failed: " + StringUtils.processTextDoc(e.getMessage())
633+
"Tool " + action + " failed: " + processTextDoc(e.getMessage())
634634
),
635635
INTERACTIONS_PREFIX
636636
)

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
import org.opensearch.ml.common.spi.memory.Memory;
5757
import org.opensearch.ml.common.spi.tools.Tool;
5858
import org.opensearch.ml.common.transport.MLTaskResponse;
59-
import org.opensearch.ml.common.utils.StringUtils;
6059
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
6160
import org.opensearch.ml.engine.memory.MLMemoryManager;
6261
import org.opensearch.ml.engine.tools.ReadFromScratchPadTool;
@@ -692,6 +691,38 @@ public void testToolThrowException() {
692691
assertNotNull(modelTensorOutput);
693692
}
694693

694+
@Test
695+
public void testToolExceptionMessageEscaping() {
696+
// Mock tool validation to return true
697+
when(firstTool.validate(any())).thenReturn(true);
698+
699+
// Create an MLAgent with tools
700+
MLAgent mlAgent = createMLAgentWithTools();
701+
702+
// Create parameters for the agent
703+
Map<String, String> params = createAgentParamsWithAction(FIRST_TOOL, "someInput");
704+
705+
// Mock tool to throw exception with problematic characters (quotes, newlines)
706+
String problematicMessage = "Invalid payload: { \"system\": [{\"text\": \"You are a precise...\"}] }\n" +
707+
"See https://github.com/google/gson/blob/main/Troubleshooting.md#unexpected-json-structure";
708+
709+
Mockito
710+
.doThrow(new IllegalArgumentException(problematicMessage))
711+
.when(firstTool)
712+
.run(Mockito.anyMap(), toolListenerCaptor.capture());
713+
714+
// Run the MLChatAgentRunner
715+
mlChatAgentRunner.run(mlAgent, params, agentActionListener, null);
716+
717+
// Verify that the tool's run method was called
718+
verify(firstTool).run(any(), any());
719+
720+
// Verify that the agent completes without throwing JSON parsing exceptions
721+
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
722+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
723+
assertNotNull("Agent should complete successfully even with problematic exception messages", modelTensorOutput);
724+
}
725+
695726
@Test
696727
public void testToolParameters() {
697728
// Mock tool validation to return false.
@@ -1172,48 +1203,4 @@ public void testConstructLLMParams_DefaultValues() {
11721203
Assert.assertTrue(result.containsKey(AgentUtils.RESPONSE_FORMAT_INSTRUCTION));
11731204
Assert.assertTrue(result.containsKey(AgentUtils.TOOL_RESPONSE));
11741205
}
1175-
1176-
@Test
1177-
public void testExceptionMessageEscaping() {
1178-
// Test the problematic exception message from the error
1179-
String problematicMessage = "Invalid payload: { \"system\": [{\"text\": \"You are a precise...\"}], \"messages\": [...] }\n"
1180-
+ "See https://github.com/google/gson/blob/main/Troubleshooting.md#unexpected-json-structure";
1181-
1182-
String escapedMessage = StringUtils.processTextDoc(problematicMessage);
1183-
1184-
// Verify that problematic characters are escaped
1185-
Assert
1186-
.assertFalse(
1187-
"Escaped message should not contain unescaped newlines",
1188-
escapedMessage.contains("\n") && !escapedMessage.contains("\\n")
1189-
);
1190-
Assert
1191-
.assertFalse(
1192-
"Escaped message should not contain unescaped quotes",
1193-
escapedMessage.contains("\"") && !escapedMessage.contains("\\\"")
1194-
);
1195-
}
1196-
1197-
@Test
1198-
public void testGsonParsingErrorMessageEscaping() {
1199-
// Test the specific Gson error message pattern
1200-
String gsonError = "Expected BEGIN_ARRAY but was STRING at line 1 column 1 path $\n"
1201-
+ "See https://github.com/google/gson/blob/main/Troubleshooting.md#unexpected-json-structure";
1202-
1203-
String escapedMessage = StringUtils.processTextDoc(gsonError);
1204-
1205-
// The escaped message should be safe for JSON inclusion
1206-
Assert.assertTrue("Escaped message should be safe for JSON", !escapedMessage.contains("\n") || escapedMessage.contains("\\n"));
1207-
}
1208-
1209-
@Test
1210-
public void testNormalMessagePassthrough() {
1211-
// Test that normal messages without special characters pass through unchanged
1212-
String normalMessage = "Tool execution failed with normal error";
1213-
1214-
String escapedMessage = StringUtils.processTextDoc(normalMessage);
1215-
1216-
// Normal messages should be handled properly
1217-
Assert.assertTrue("Normal messages should be handled properly", escapedMessage.length() > 0);
1218-
}
12191206
}

0 commit comments

Comments
 (0)