@@ -140,40 +140,17 @@ public void extractFactsFromConversation(
140140 }
141141
142142 try {
143- XContentBuilder messagesBuilder = jsonXContent.contentBuilder();
144- messagesBuilder.startArray();
145- Map<String, Object> strategyConfig = strategy.getStrategyConfig();
146- if (strategyConfig != null && strategyConfig.containsKey("system_prompt_message")) {
147- Object systemPromptMsg = strategyConfig.get("system_prompt_message");
148- if (systemPromptMsg != null && systemPromptMsg instanceof Map) {
149- messagesBuilder.map((Map) systemPromptMsg);
150- }
151- }
152- for (MessageInput message : messages) {
153- message.toXContent(messagesBuilder, ToXContent.EMPTY_PARAMS);
154- }
155- if (strategyConfig != null && strategyConfig.containsKey("user_prompt_message")) {
156- Object userPromptMsg = strategyConfig.get("user_prompt_message");
157- if (userPromptMsg != null && userPromptMsg instanceof Map) {
158- messagesBuilder.map((Map) userPromptMsg);
159- }
160- } else { // Add default user prompt (when strategyConfig is null or doesn't have user_prompt_message)
161- MessageInput message = getMessageInput("Please extract information from our conversation so far");
162- message.toXContent(messagesBuilder, ToXContent.EMPTY_PARAMS);
163- }
164-
165143 // Always add JSON enforcement message for fact extraction
166144 String enforcementMsg = (strategy.getType() == MemoryStrategyType.USER_PREFERENCE)
167145 ? USER_PREFERENCE_JSON_ENFORCEMENT_MESSAGE
168146 : JSON_ENFORCEMENT_MESSAGE;
169147 MessageInput enforcementMessage = getMessageInput(enforcementMsg);
170- enforcementMessage.toXContent(messagesBuilder, ToXContent.EMPTY_PARAMS);
171-
172- messagesBuilder.endArray();
173- String messagesJson = messagesBuilder.toString();
174- stringParameters.put("messages", messagesJson);
175-
176- log.debug("LLM request - processing {} messages", messages.size());
148+ // Create mutable copy to avoid UnsupportedOperationException
149+ List<MessageInput> mutableMessages = new ArrayList<>(messages);
150+ mutableMessages.add(enforcementMessage);
151+ String conversationJson = serializeMessagesToJson(mutableMessages);
152+ String userPrompt = "Analyze the following conversation and extract information:\n```json\n" + conversationJson + "\n```";
153+ stringParameters.put("user_prompt", userPrompt);
177154 } catch (Exception e) {
178155 log.error("Failed to build messages JSON", e);
179156 listener.onResponse(new ArrayList<>());
@@ -254,23 +231,12 @@ public void makeMemoryDecisions(
254231 stringParameters.put("system_prompt", DEFAULT_UPDATE_MEMORY_PROMPT);
255232
256233 String decisionRequestJson = decisionRequest.toJsonString();
234+ String userPrompt = "Analyze the following old memories and newly extracted facts, then make memory decisions:\n```json\n"
235+ + decisionRequestJson
236+ + "\n```";
237+ stringParameters.put("user_prompt", userPrompt);
257238
258239 try {
259- XContentBuilder messagesBuilder = jsonXContent.contentBuilder();
260- messagesBuilder.startArray();
261- messagesBuilder.startObject();
262- messagesBuilder.field("role", "user");
263- messagesBuilder.startArray("content");
264- messagesBuilder.startObject();
265- messagesBuilder.field("type", "text");
266- messagesBuilder.field("text", decisionRequestJson);
267- messagesBuilder.endObject();
268- messagesBuilder.endArray();
269- messagesBuilder.endObject();
270- messagesBuilder.endArray();
271-
272- String messagesJson = messagesBuilder.toString();
273- stringParameters.put("messages", messagesJson);
274240
275241 log
276242 .debug(
@@ -284,9 +250,16 @@ public void makeMemoryDecisions(
284250
285251 MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder().modelId(llmModelId).mlInput(mlInput).build();
286252
253+ String defaultLlmResultPath = memoryConfig.getParameters().getOrDefault("llm_result_path", DEFAULT_LLM_RESULT_PATH).toString();
254+ String llmResultPath = (String) Optional
255+ .ofNullable(strategy)
256+ .map(MemoryStrategy::getStrategyConfig)
257+ .map(config -> config.get("llm_result_path"))
258+ .orElse(defaultLlmResultPath);
259+
287260 client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, ActionListener.wrap(response -> {
288261 try {
289- List<MemoryDecision> decisions = parseMemoryDecisions(response);
262+ List<MemoryDecision> decisions = parseMemoryDecisions(llmResultPath, response);
290263 log.debug("LLM made {} memory decisions", decisions.size());
291264 listener.onResponse(decisions);
292265 } catch (Exception e) {
@@ -334,7 +307,6 @@ private List<String> parseFactsFromLLMResponse(MemoryStrategy strategy, MLOutput
334307 String llmResult = null;
335308 if (filterdResult != null) {
336309 llmResult = StringUtils.toJson(filterdResult);
337- llmResult = cleanMarkdownFromJson(llmResult);
338310 }
339311 if (llmResult != null) {
340312 llmResult = StringUtils.toJson(extractJsonProcessorChain.process(llmResult));
@@ -362,7 +334,7 @@ private List<String> parseFactsFromLLMResponse(MemoryStrategy strategy, MLOutput
362334 return facts;
363335 }
364336
365- private List<MemoryDecision> parseMemoryDecisions(MLTaskResponse response) {
337+ private List<MemoryDecision> parseMemoryDecisions(String llmResultPath, MLTaskResponse response) {
366338 try {
367339 MLOutput mlOutput = response.getOutput();
368340 if (!(mlOutput instanceof ModelTensorOutput)) {
@@ -375,25 +347,15 @@ private List<MemoryDecision> parseMemoryDecisions(MLTaskResponse response) {
375347 throw new IllegalStateException("No model output tensors found");
376348 }
377349
378- Map<String, ?> dataMap = tensors.get(0).getMlModelTensors().get(0).getDataAsMap();
379-
380- String responseContent = null;
381- if (dataMap.containsKey("response")) {
382- responseContent = (String) dataMap.get("response");
383- } else if (dataMap.containsKey("content")) {
384- List<Map<String, Object>> contentList = (List<Map<String, Object>>) dataMap.get("content");
385- if (contentList != null && !contentList.isEmpty()) {
386- Map<String, Object> firstContent = contentList.get(0);
387- responseContent = (String) firstContent.get("text");
388- }
389- }
390-
391- if (responseContent == null) {
350+ Map<String, ?> dataAsMap = tensors.get(0).getMlModelTensors().get(0).getDataAsMap();
351+ Object filterdResult = JsonPath.read(dataAsMap, llmResultPath);
352+ if (filterdResult == null) {
392353 throw new IllegalStateException("No response content found in LLM output");
393354 }
355+ String responseContent = StringUtils.toJson(filterdResult);
394356
395357 // Clean response content
396- responseContent = cleanMarkdownFromJson( responseContent);
358+ responseContent = StringUtils.toJson(extractJsonProcessorChain.process( responseContent) );
397359
398360 List<MemoryDecision> decisions = new ArrayList<>();
399361 try (XContentParser parser = jsonXContent.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, responseContent)) {
@@ -433,24 +395,13 @@ public void summarizeMessages(MemoryConfiguration configuration, List<MessageInp
433395 stringParameters.putIfAbsent("max_summary_size", "10");
434396
435397 try {
436- XContentBuilder messagesBuilder = jsonXContent.contentBuilder();
437- messagesBuilder.startArray();
438- for (MessageInput message : messages) {
439- message.toXContent(messagesBuilder, ToXContent.EMPTY_PARAMS);
440- }
441- if (sessionParams.containsKey("user_prompt_message")) {
442- Object userPromptMsg = sessionParams.get("user_prompt_message");
443- if (userPromptMsg != null && userPromptMsg instanceof Map) {
444- messagesBuilder.map((Map) userPromptMsg);
445- }
446- } else {
447- MessageInput message = getMessageInput(
448- "Please summarize our conversation, not exceed " + stringParameters.get("max_summary_size") + " words"
449- );
450- message.toXContent(messagesBuilder, ToXContent.EMPTY_PARAMS);
451- }
452- messagesBuilder.endArray();
453- stringParameters.put("messages", messagesBuilder.toString());
398+ String conversationJson = serializeMessagesToJson(messages);
399+ String userPrompt = "Summarize the following conversation in no more than "
400+ + stringParameters.get("max_summary_size")
401+ + " words:\n```json\n"
402+ + conversationJson
403+ + "\n```";
404+ stringParameters.put("user_prompt", userPrompt);
454405
455406 RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(stringParameters).build();
456407 MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build();
@@ -495,27 +446,14 @@ private boolean validatePromptFormat(String prompt) {
495446 return true;
496447 }
497448
498- /**
499- * Utility method to clean markdown formatting from JSON responses.
500- * Strips ```json...``` and ```...``` wrappers that LLMs commonly add.
501- */
502- private String cleanMarkdownFromJson(String response) {
503- if (response == null) {
504- return null;
505- }
506-
507- response = response.trim();
508-
509- // Remove ```json...``` wrapper
510- if (response.startsWith("```json") && response.endsWith("```")) {
511- response = response.substring(7, response.length() - 3).trim();
512- }
513- // Remove ```...``` wrapper
514- else if (response.startsWith("```") && response.endsWith("```")) {
515- response = response.substring(3, response.length() - 3).trim();
449+ private String serializeMessagesToJson(List<MessageInput> messages) throws IOException {
450+ XContentBuilder builder = jsonXContent.contentBuilder();
451+ builder.startArray();
452+ for (MessageInput message : messages) {
453+ message.toXContent(builder, ToXContent.EMPTY_PARAMS);
516454 }
517-
518- return response ;
455+ builder.endArray();
456+ return builder.toString() ;
519457 }
520458
521459 /**
0 commit comments