Skip to content

Commit eebae78

Browse files
authored
fix llm result path; convert message to user prompt string (#4283)
* fix llm result path; convert message to user prompt string Signed-off-by: Yaliang Wu <ylwu@amazon.com> * run spotlessApply Signed-off-by: Yaliang Wu <ylwu@amazon.com> * fix failed ut Signed-off-by: Yaliang Wu <ylwu@amazon.com> --------- Signed-off-by: Yaliang Wu <ylwu@amazon.com>
1 parent 74d9ab4 commit eebae78

File tree

2 files changed

+51
-108
lines changed

2 files changed

+51
-108
lines changed

plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingService.java

Lines changed: 38 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -140,40 +140,17 @@ public void extractFactsFromConversation(
140140
}
141141

142142
try {
143-
XContentBuilder messagesBuilder = jsonXContent.contentBuilder();
144-
messagesBuilder.startArray();
145-
Map<String, Object> strategyConfig = strategy.getStrategyConfig();
146-
if (strategyConfig != null && strategyConfig.containsKey("system_prompt_message")) {
147-
Object systemPromptMsg = strategyConfig.get("system_prompt_message");
148-
if (systemPromptMsg != null && systemPromptMsg instanceof Map) {
149-
messagesBuilder.map((Map) systemPromptMsg);
150-
}
151-
}
152-
for (MessageInput message : messages) {
153-
message.toXContent(messagesBuilder, ToXContent.EMPTY_PARAMS);
154-
}
155-
if (strategyConfig != null && strategyConfig.containsKey("user_prompt_message")) {
156-
Object userPromptMsg = strategyConfig.get("user_prompt_message");
157-
if (userPromptMsg != null && userPromptMsg instanceof Map) {
158-
messagesBuilder.map((Map) userPromptMsg);
159-
}
160-
} else { // Add default user prompt (when strategyConfig is null or doesn't have user_prompt_message)
161-
MessageInput message = getMessageInput("Please extract information from our conversation so far");
162-
message.toXContent(messagesBuilder, ToXContent.EMPTY_PARAMS);
163-
}
164-
165143
// Always add JSON enforcement message for fact extraction
166144
String enforcementMsg = (strategy.getType() == MemoryStrategyType.USER_PREFERENCE)
167145
? USER_PREFERENCE_JSON_ENFORCEMENT_MESSAGE
168146
: JSON_ENFORCEMENT_MESSAGE;
169147
MessageInput enforcementMessage = getMessageInput(enforcementMsg);
170-
enforcementMessage.toXContent(messagesBuilder, ToXContent.EMPTY_PARAMS);
171-
172-
messagesBuilder.endArray();
173-
String messagesJson = messagesBuilder.toString();
174-
stringParameters.put("messages", messagesJson);
175-
176-
log.debug("LLM request - processing {} messages", messages.size());
148+
// Create mutable copy to avoid UnsupportedOperationException
149+
List<MessageInput> mutableMessages = new ArrayList<>(messages);
150+
mutableMessages.add(enforcementMessage);
151+
String conversationJson = serializeMessagesToJson(mutableMessages);
152+
String userPrompt = "Analyze the following conversation and extract information:\n```json\n" + conversationJson + "\n```";
153+
stringParameters.put("user_prompt", userPrompt);
177154
} catch (Exception e) {
178155
log.error("Failed to build messages JSON", e);
179156
listener.onResponse(new ArrayList<>());
@@ -254,23 +231,12 @@ public void makeMemoryDecisions(
254231
stringParameters.put("system_prompt", DEFAULT_UPDATE_MEMORY_PROMPT);
255232

256233
String decisionRequestJson = decisionRequest.toJsonString();
234+
String userPrompt = "Analyze the following old memories and newly extracted facts, then make memory decisions:\n```json\n"
235+
+ decisionRequestJson
236+
+ "\n```";
237+
stringParameters.put("user_prompt", userPrompt);
257238

258239
try {
259-
XContentBuilder messagesBuilder = jsonXContent.contentBuilder();
260-
messagesBuilder.startArray();
261-
messagesBuilder.startObject();
262-
messagesBuilder.field("role", "user");
263-
messagesBuilder.startArray("content");
264-
messagesBuilder.startObject();
265-
messagesBuilder.field("type", "text");
266-
messagesBuilder.field("text", decisionRequestJson);
267-
messagesBuilder.endObject();
268-
messagesBuilder.endArray();
269-
messagesBuilder.endObject();
270-
messagesBuilder.endArray();
271-
272-
String messagesJson = messagesBuilder.toString();
273-
stringParameters.put("messages", messagesJson);
274240

275241
log
276242
.debug(
@@ -284,9 +250,16 @@ public void makeMemoryDecisions(
284250

285251
MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder().modelId(llmModelId).mlInput(mlInput).build();
286252

253+
String defaultLlmResultPath = memoryConfig.getParameters().getOrDefault("llm_result_path", DEFAULT_LLM_RESULT_PATH).toString();
254+
String llmResultPath = (String) Optional
255+
.ofNullable(strategy)
256+
.map(MemoryStrategy::getStrategyConfig)
257+
.map(config -> config.get("llm_result_path"))
258+
.orElse(defaultLlmResultPath);
259+
287260
client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, ActionListener.wrap(response -> {
288261
try {
289-
List<MemoryDecision> decisions = parseMemoryDecisions(response);
262+
List<MemoryDecision> decisions = parseMemoryDecisions(llmResultPath, response);
290263
log.debug("LLM made {} memory decisions", decisions.size());
291264
listener.onResponse(decisions);
292265
} catch (Exception e) {
@@ -334,7 +307,6 @@ private List<String> parseFactsFromLLMResponse(MemoryStrategy strategy, MLOutput
334307
String llmResult = null;
335308
if (filterdResult != null) {
336309
llmResult = StringUtils.toJson(filterdResult);
337-
llmResult = cleanMarkdownFromJson(llmResult);
338310
}
339311
if (llmResult != null) {
340312
llmResult = StringUtils.toJson(extractJsonProcessorChain.process(llmResult));
@@ -362,7 +334,7 @@ private List<String> parseFactsFromLLMResponse(MemoryStrategy strategy, MLOutput
362334
return facts;
363335
}
364336

365-
private List<MemoryDecision> parseMemoryDecisions(MLTaskResponse response) {
337+
private List<MemoryDecision> parseMemoryDecisions(String llmResultPath, MLTaskResponse response) {
366338
try {
367339
MLOutput mlOutput = response.getOutput();
368340
if (!(mlOutput instanceof ModelTensorOutput)) {
@@ -375,25 +347,15 @@ private List<MemoryDecision> parseMemoryDecisions(MLTaskResponse response) {
375347
throw new IllegalStateException("No model output tensors found");
376348
}
377349

378-
Map<String, ?> dataMap = tensors.get(0).getMlModelTensors().get(0).getDataAsMap();
379-
380-
String responseContent = null;
381-
if (dataMap.containsKey("response")) {
382-
responseContent = (String) dataMap.get("response");
383-
} else if (dataMap.containsKey("content")) {
384-
List<Map<String, Object>> contentList = (List<Map<String, Object>>) dataMap.get("content");
385-
if (contentList != null && !contentList.isEmpty()) {
386-
Map<String, Object> firstContent = contentList.get(0);
387-
responseContent = (String) firstContent.get("text");
388-
}
389-
}
390-
391-
if (responseContent == null) {
350+
Map<String, ?> dataAsMap = tensors.get(0).getMlModelTensors().get(0).getDataAsMap();
351+
Object filterdResult = JsonPath.read(dataAsMap, llmResultPath);
352+
if (filterdResult == null) {
392353
throw new IllegalStateException("No response content found in LLM output");
393354
}
355+
String responseContent = StringUtils.toJson(filterdResult);
394356

395357
// Clean response content
396-
responseContent = cleanMarkdownFromJson(responseContent);
358+
responseContent = StringUtils.toJson(extractJsonProcessorChain.process(responseContent));
397359

398360
List<MemoryDecision> decisions = new ArrayList<>();
399361
try (XContentParser parser = jsonXContent.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, responseContent)) {
@@ -433,24 +395,13 @@ public void summarizeMessages(MemoryConfiguration configuration, List<MessageInp
433395
stringParameters.putIfAbsent("max_summary_size", "10");
434396

435397
try {
436-
XContentBuilder messagesBuilder = jsonXContent.contentBuilder();
437-
messagesBuilder.startArray();
438-
for (MessageInput message : messages) {
439-
message.toXContent(messagesBuilder, ToXContent.EMPTY_PARAMS);
440-
}
441-
if (sessionParams.containsKey("user_prompt_message")) {
442-
Object userPromptMsg = sessionParams.get("user_prompt_message");
443-
if (userPromptMsg != null && userPromptMsg instanceof Map) {
444-
messagesBuilder.map((Map) userPromptMsg);
445-
}
446-
} else {
447-
MessageInput message = getMessageInput(
448-
"Please summarize our conversation, not exceed " + stringParameters.get("max_summary_size") + " words"
449-
);
450-
message.toXContent(messagesBuilder, ToXContent.EMPTY_PARAMS);
451-
}
452-
messagesBuilder.endArray();
453-
stringParameters.put("messages", messagesBuilder.toString());
398+
String conversationJson = serializeMessagesToJson(messages);
399+
String userPrompt = "Summarize the following conversation in no more than "
400+
+ stringParameters.get("max_summary_size")
401+
+ " words:\n```json\n"
402+
+ conversationJson
403+
+ "\n```";
404+
stringParameters.put("user_prompt", userPrompt);
454405

455406
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(stringParameters).build();
456407
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build();
@@ -495,27 +446,14 @@ private boolean validatePromptFormat(String prompt) {
495446
return true;
496447
}
497448

498-
/**
499-
* Utility method to clean markdown formatting from JSON responses.
500-
* Strips ```json...``` and ```...``` wrappers that LLMs commonly add.
501-
*/
502-
private String cleanMarkdownFromJson(String response) {
503-
if (response == null) {
504-
return null;
505-
}
506-
507-
response = response.trim();
508-
509-
// Remove ```json...``` wrapper
510-
if (response.startsWith("```json") && response.endsWith("```")) {
511-
response = response.substring(7, response.length() - 3).trim();
512-
}
513-
// Remove ```...``` wrapper
514-
else if (response.startsWith("```") && response.endsWith("```")) {
515-
response = response.substring(3, response.length() - 3).trim();
449+
private String serializeMessagesToJson(List<MessageInput> messages) throws IOException {
450+
XContentBuilder builder = jsonXContent.contentBuilder();
451+
builder.startArray();
452+
for (MessageInput message : messages) {
453+
message.toXContent(builder, ToXContent.EMPTY_PARAMS);
516454
}
517-
518-
return response;
455+
builder.endArray();
456+
return builder.toString();
519457
}
520458

521459
/**

plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingServiceTests.java

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import static org.junit.Assert.assertFalse;
99
import static org.junit.Assert.assertNotEquals;
10+
import static org.junit.Assert.assertNotNull;
1011
import static org.junit.Assert.assertTrue;
1112
import static org.mockito.ArgumentMatchers.any;
1213
import static org.mockito.ArgumentMatchers.eq;
@@ -540,14 +541,17 @@ public void testMakeMemoryDecisions_JsonCodeBlock() {
540541
List<FactSearchResult> searchResults = Arrays.asList();
541542
MemoryConfiguration storageConfig = mock(MemoryConfiguration.class);
542543
when(storageConfig.getLlmId()).thenReturn("llm-model-123");
544+
when(storageConfig.getParameters()).thenReturn(new HashMap<>());
543545

544546
MLTaskResponse mockResponse = mock(MLTaskResponse.class);
545547
ModelTensorOutput mockOutput = mock(ModelTensorOutput.class);
546548
ModelTensors mockTensors = mock(ModelTensors.class);
547549
ModelTensor mockTensor = mock(ModelTensor.class);
548550

549551
Map<String, Object> dataMap = new HashMap<>();
550-
dataMap.put("response", "```json\n{\"memory_decisions\": []}\n```");
552+
Map<String, Object> contentItem = new HashMap<>();
553+
contentItem.put("text", "```json\n{\"memory_decisions\": []}\n```");
554+
dataMap.put("content", Arrays.asList(contentItem));
551555

552556
when(mockResponse.getOutput()).thenReturn(mockOutput);
553557
when(mockOutput.getMlModelOutputs()).thenReturn(Arrays.asList(mockTensors));
@@ -571,14 +575,17 @@ public void testMakeMemoryDecisions_PlainCodeBlock() {
571575
List<FactSearchResult> searchResults = Arrays.asList();
572576
MemoryConfiguration storageConfig = mock(MemoryConfiguration.class);
573577
when(storageConfig.getLlmId()).thenReturn("llm-model-123");
578+
when(storageConfig.getParameters()).thenReturn(new HashMap<>());
574579

575580
MLTaskResponse mockResponse = mock(MLTaskResponse.class);
576581
ModelTensorOutput mockOutput = mock(ModelTensorOutput.class);
577582
ModelTensors mockTensors = mock(ModelTensors.class);
578583
ModelTensor mockTensor = mock(ModelTensor.class);
579584

580585
Map<String, Object> dataMap = new HashMap<>();
581-
dataMap.put("response", "```\n{\"memory_decisions\": []}\n```");
586+
Map<String, Object> contentItem = new HashMap<>();
587+
contentItem.put("text", "```\n{\"memory_decisions\": []}\n```");
588+
dataMap.put("content", Arrays.asList(contentItem));
582589

583590
when(mockResponse.getOutput()).thenReturn(mockOutput);
584591
when(mockOutput.getMlModelOutputs()).thenReturn(Arrays.asList(mockTensors));
@@ -957,13 +964,11 @@ public void testExtractFactsFromConversation_JsonEnforcementMessageAppended() {
957964
MLPredictionTaskRequest request = invocation.getArgument(1);
958965
RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) request.getMlInput().getInputDataset();
959966
Map<String, String> parameters = dataset.getParameters();
960-
String messagesJson = parameters.get("messages");
967+
String userPrompt = parameters.get("user_prompt");
961968

962-
// Verify that the JSON enforcement message is included in the messages
963-
assertTrue(
964-
"JSON enforcement message should be included",
965-
messagesJson.contains("Respond NOW with ONE LINE of valid JSON ONLY")
966-
);
969+
// Verify that the JSON enforcement message is included in the user_prompt
970+
assertNotNull("user_prompt should not be null", userPrompt);
971+
assertTrue("JSON enforcement message should be included", userPrompt.contains("Respond NOW with ONE LINE of valid JSON ONLY"));
967972

968973
// Mock successful response
969974
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);

0 commit comments

Comments
 (0)