|
56 | 56 | import org.opensearch.ml.common.spi.memory.Memory; |
57 | 57 | import org.opensearch.ml.common.spi.tools.Tool; |
58 | 58 | import org.opensearch.ml.common.transport.MLTaskResponse; |
59 | | -import org.opensearch.ml.common.utils.StringUtils; |
60 | 59 | import org.opensearch.ml.engine.memory.ConversationIndexMemory; |
61 | 60 | import org.opensearch.ml.engine.memory.MLMemoryManager; |
62 | 61 | import org.opensearch.ml.engine.tools.ReadFromScratchPadTool; |
@@ -692,6 +691,38 @@ public void testToolThrowException() { |
692 | 691 | assertNotNull(modelTensorOutput); |
693 | 692 | } |
694 | 693 |
|
| 694 | + @Test |
| 695 | + public void testToolExceptionMessageEscaping() { |
| 696 | + // Mock tool validation to return true |
| 697 | + when(firstTool.validate(any())).thenReturn(true); |
| 698 | + |
| 699 | + // Create an MLAgent with tools |
| 700 | + MLAgent mlAgent = createMLAgentWithTools(); |
| 701 | + |
| 702 | + // Create parameters for the agent |
| 703 | + Map<String, String> params = createAgentParamsWithAction(FIRST_TOOL, "someInput"); |
| 704 | + |
| 705 | + // Mock tool to throw exception with problematic characters (quotes, newlines) |
| 706 | + String problematicMessage = "Invalid payload: { \"system\": [{\"text\": \"You are a precise...\"}] }\n" + |
| 707 | + "See https://github.com/google/gson/blob/main/Troubleshooting.md#unexpected-json-structure"; |
| 708 | + |
| 709 | + Mockito |
| 710 | + .doThrow(new IllegalArgumentException(problematicMessage)) |
| 711 | + .when(firstTool) |
| 712 | + .run(Mockito.anyMap(), toolListenerCaptor.capture()); |
| 713 | + |
| 714 | + // Run the MLChatAgentRunner |
| 715 | + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); |
| 716 | + |
| 717 | + // Verify that the tool's run method was called |
| 718 | + verify(firstTool).run(any(), any()); |
| 719 | + |
| 720 | + // Verify that the agent completes without throwing JSON parsing exceptions |
| 721 | + Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); |
| 722 | + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); |
| 723 | + assertNotNull("Agent should complete successfully even with problematic exception messages", modelTensorOutput); |
| 724 | + } |
| 725 | + |
695 | 726 | @Test |
696 | 727 | public void testToolParameters() { |
697 | 728 | // Mock tool validation to return false. |
@@ -1172,48 +1203,4 @@ public void testConstructLLMParams_DefaultValues() { |
1172 | 1203 | Assert.assertTrue(result.containsKey(AgentUtils.RESPONSE_FORMAT_INSTRUCTION)); |
1173 | 1204 | Assert.assertTrue(result.containsKey(AgentUtils.TOOL_RESPONSE)); |
1174 | 1205 | } |
1175 | | - |
1176 | | - @Test |
1177 | | - public void testExceptionMessageEscaping() { |
1178 | | - // Test the problematic exception message from the error |
1179 | | - String problematicMessage = "Invalid payload: { \"system\": [{\"text\": \"You are a precise...\"}], \"messages\": [...] }\n" |
1180 | | - + "See https://github.com/google/gson/blob/main/Troubleshooting.md#unexpected-json-structure"; |
1181 | | - |
1182 | | - String escapedMessage = StringUtils.processTextDoc(problematicMessage); |
1183 | | - |
1184 | | - // Verify that problematic characters are escaped |
1185 | | - Assert |
1186 | | - .assertFalse( |
1187 | | - "Escaped message should not contain unescaped newlines", |
1188 | | - escapedMessage.contains("\n") && !escapedMessage.contains("\\n") |
1189 | | - ); |
1190 | | - Assert |
1191 | | - .assertFalse( |
1192 | | - "Escaped message should not contain unescaped quotes", |
1193 | | - escapedMessage.contains("\"") && !escapedMessage.contains("\\\"") |
1194 | | - ); |
1195 | | - } |
1196 | | - |
1197 | | - @Test |
1198 | | - public void testGsonParsingErrorMessageEscaping() { |
1199 | | - // Test the specific Gson error message pattern |
1200 | | - String gsonError = "Expected BEGIN_ARRAY but was STRING at line 1 column 1 path $\n" |
1201 | | - + "See https://github.com/google/gson/blob/main/Troubleshooting.md#unexpected-json-structure"; |
1202 | | - |
1203 | | - String escapedMessage = StringUtils.processTextDoc(gsonError); |
1204 | | - |
1205 | | - // The escaped message should be safe for JSON inclusion |
1206 | | - Assert.assertTrue("Escaped message should be safe for JSON", !escapedMessage.contains("\n") || escapedMessage.contains("\\n")); |
1207 | | - } |
1208 | | - |
1209 | | - @Test |
1210 | | - public void testNormalMessagePassthrough() { |
1211 | | - // Test that normal messages without special characters pass through unchanged |
1212 | | - String normalMessage = "Tool execution failed with normal error"; |
1213 | | - |
1214 | | - String escapedMessage = StringUtils.processTextDoc(normalMessage); |
1215 | | - |
1216 | | - // Normal messages should be handled properly |
1217 | | - Assert.assertTrue("Normal messages should be handled properly", escapedMessage.length() > 0); |
1218 | | - } |
1219 | 1206 | } |
0 commit comments