|
5 | 5 |
|
6 | 6 | package org.opensearch.ml.action.memorycontainer.memory; |
7 | 7 |
|
| 8 | +import static org.junit.Assert.assertTrue; |
8 | 9 | import static org.mockito.ArgumentMatchers.any; |
9 | 10 | import static org.mockito.ArgumentMatchers.eq; |
10 | 11 | import static org.mockito.Mockito.doAnswer; |
|
25 | 26 | import org.mockito.MockitoAnnotations; |
26 | 27 | import org.opensearch.core.action.ActionListener; |
27 | 28 | import org.opensearch.core.xcontent.NamedXContentRegistry; |
| 29 | +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; |
28 | 30 | import org.opensearch.ml.common.memorycontainer.MemoryConfiguration; |
29 | 31 | import org.opensearch.ml.common.memorycontainer.MemoryDecision; |
30 | 32 | import org.opensearch.ml.common.memorycontainer.MemoryStrategy; |
|
36 | 38 | import org.opensearch.ml.common.transport.MLTaskResponse; |
37 | 39 | import org.opensearch.ml.common.transport.memorycontainer.memory.MessageInput; |
38 | 40 | import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; |
| 41 | +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; |
39 | 42 | import org.opensearch.transport.client.Client; |
40 | 43 |
|
41 | 44 | public class MemoryProcessingServiceTests { |
@@ -933,4 +936,48 @@ public void testExtractFactsFromConversation_ValidCustomPrompt() { |
933 | 936 |
|
934 | 937 | verify(client).execute(any(), any(), any()); |
935 | 938 | } |
| 939 | + |
| 940 | + @Test |
| 941 | + public void testExtractFactsFromConversation_JsonEnforcementMessageAppended() { |
| 942 | + // Test that JSON enforcement message is always appended to fact extraction requests |
| 943 | + Map<String, Object> strategyConfig = new HashMap<>(); |
| 944 | + MemoryStrategy strategy = new MemoryStrategy("id", true, MemoryStrategyType.SEMANTIC, Arrays.asList("user_id"), strategyConfig); |
| 945 | + |
| 946 | + List<MessageInput> messages = Arrays.asList(MessageInput.builder().content(testContent).role("user").build()); |
| 947 | + MemoryConfiguration storageConfig = mock(MemoryConfiguration.class); |
| 948 | + when(storageConfig.getLlmId()).thenReturn("llm-model-123"); |
| 949 | + |
| 950 | + // Capture the request to verify JSON enforcement message is included |
| 951 | + doAnswer(invocation -> { |
| 952 | + MLPredictionTaskRequest request = invocation.getArgument(1); |
| 953 | + RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) request.getMlInput().getInputDataset(); |
| 954 | + Map<String, String> parameters = dataset.getParameters(); |
| 955 | + String messagesJson = parameters.get("messages"); |
| 956 | + |
| 957 | + // Verify that the JSON enforcement message is included in the messages |
| 958 | + assertTrue( |
| 959 | + "JSON enforcement message should be included", |
| 960 | + messagesJson.contains("Respond NOW with ONE LINE of valid JSON ONLY") |
| 961 | + ); |
| 962 | + |
| 963 | + // Mock successful response |
| 964 | + ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2); |
| 965 | + List<ModelTensors> mlModelOutputs = new ArrayList<>(); |
| 966 | + List<ModelTensor> tensors = new ArrayList<>(); |
| 967 | + Map<String, Object> contents = new HashMap<>(); |
| 968 | + contents.put("content", List.of(Map.of("text", "{\"facts\":[\"Test fact\"]}"))); |
| 969 | + tensors.add(ModelTensor.builder().name("response").dataAsMap(contents).build()); |
| 970 | + mlModelOutputs.add(ModelTensors.builder().mlModelTensors(tensors).build()); |
| 971 | + MLTaskResponse output = MLTaskResponse |
| 972 | + .builder() |
| 973 | + .output(ModelTensorOutput.builder().mlModelOutputs(mlModelOutputs).build()) |
| 974 | + .build(); |
| 975 | + actionListener.onResponse(output); |
| 976 | + return null; |
| 977 | + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); |
| 978 | + |
| 979 | + memoryProcessingService.extractFactsFromConversation(messages, strategy, storageConfig, factsListener); |
| 980 | + |
| 981 | + verify(client).execute(any(), any(), any()); |
| 982 | + } |
936 | 983 | } |
0 commit comments