Skip to content

Commit 07f2695

Browse files
committed
feat: allow per to process additional info
Signed-off-by: Pavan Yekbote <pybot@amazon.com>
1 parent 5964268 commit 07f2695

File tree

3 files changed

+81
-9
lines changed

3 files changed

+81
-9
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
99
import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD;
10+
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD;
1011
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD;
1112
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD;
1213
import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly;
@@ -460,9 +461,9 @@ private void executePlanningLoop(
460461
results.put(PARENT_INTERACTION_ID_FIELD, tensor.getResult());
461462
break;
462463
default:
463-
Map<String, ?> dataMap = tensor.getDataAsMap();
464-
if (dataMap != null && dataMap.containsKey(RESPONSE_FIELD)) {
465-
results.put(STEP_RESULT_FIELD, (String) dataMap.get(RESPONSE_FIELD));
464+
String stepResult = parseTensorDataMap(tensor);
465+
if (stepResult != null) {
466+
results.put(STEP_RESULT_FIELD, stepResult);
466467
}
467468
}
468469
});
@@ -502,8 +503,17 @@ private void executePlanningLoop(
502503
}, e -> log.error("Failed to update task {} with executor memory ID", taskId, e)));
503504
}
504505

505-
completedSteps.add(String.format("\nStep %d: %s\n", stepsExecuted + 1, stepToExecute));
506-
completedSteps.add(String.format("\nStep %d Result: %s\n", stepsExecuted + 1, results.get(STEP_RESULT_FIELD)));
506+
completedSteps.add(String.format("\n<step-%d>\n%s\n</step-%d>\n", stepsExecuted + 1, stepToExecute, stepsExecuted + 1));
507+
completedSteps
508+
.add(
509+
String
510+
.format(
511+
"\n<step-%d-result>\n%s\n</step-%d-result>\n",
512+
stepsExecuted + 1,
513+
results.get(STEP_RESULT_FIELD),
514+
stepsExecuted + 1
515+
)
516+
);
507517

508518
saveTraceData(
509519
(ConversationIndexMemory) memory,
@@ -544,6 +554,39 @@ private void executePlanningLoop(
544554
client.execute(MLPredictionTaskAction.INSTANCE, request, planListener);
545555
}
546556

557+
@VisibleForTesting
558+
String parseTensorDataMap(ModelTensor tensor) {
559+
Map<String, ?> dataMap = tensor.getDataAsMap();
560+
if (dataMap == null) {
561+
return null;
562+
}
563+
564+
StringBuilder stepResult = new StringBuilder();
565+
if (dataMap.containsKey(RESPONSE_FIELD)) {
566+
stepResult.append((String) dataMap.get(RESPONSE_FIELD));
567+
}
568+
569+
if (dataMap.containsKey(INTERACTIONS_ADDITIONAL_INFO_FIELD)) {
570+
stepResult.append("\n<step-traces>\n");
571+
((Map<String, Object>) dataMap.get(INTERACTIONS_ADDITIONAL_INFO_FIELD))
572+
.forEach(
573+
(key, value) -> stepResult
574+
.append("<")
575+
.append(key)
576+
.append(">")
577+
.append("\n")
578+
.append(value)
579+
.append("\n")
580+
.append("</")
581+
.append(key)
582+
.append(">")
583+
);
584+
stepResult.append("\n</step-traces>\n");
585+
}
586+
587+
return stepResult.toString();
588+
}
589+
547590
@VisibleForTesting
548591
Map<String, Object> parseLLMOutput(Map<String, String> allParams, ModelTensorOutput modelTensorOutput) {
549592
Map<String, Object> modelOutput = new HashMap<>();

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ public class PromptTemplate {
2828
+ "${parameters."
2929
+ PLANNER_PROMPT_FIELD
3030
+ "} \n"
31-
+ "Objective: ${parameters."
31+
+ "Objective: ```${parameters."
3232
+ USER_PROMPT_FIELD
33-
+ "} \n\nRemember: Respond only in JSON format following the required schema.";
33+
+ "}``` \n\nRemember: Respond only in JSON format following the required schema.";
3434

3535
public static final String DEFAULT_REFLECT_PROMPT_TEMPLATE = "${parameters."
3636
+ DEFAULT_PROMPT_TOOLS_FIELD
@@ -41,10 +41,10 @@ public class PromptTemplate {
4141
+ "Objective: ```${parameters."
4242
+ USER_PROMPT_FIELD
4343
+ "}```\n\n"
44-
+ "Original plan:\n[${parameters."
44+
+ "Previous plan:\n[${parameters."
4545
+ STEPS_FIELD
4646
+ "}] \n\n"
47-
+ "You have currently executed the following steps from the original plan: \n[${parameters."
47+
+ "You have currently executed the following steps: \n[${parameters."
4848
+ COMPLETED_STEPS_FIELD
4949
+ "}] \n\n"
5050
+ "${parameters."

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static org.junit.Assert.assertEquals;
99
import static org.junit.Assert.assertFalse;
1010
import static org.junit.Assert.assertNotNull;
11+
import static org.junit.Assert.assertNull;
1112
import static org.junit.Assert.assertThrows;
1213
import static org.junit.Assert.assertTrue;
1314
import static org.mockito.ArgumentMatchers.any;
@@ -677,6 +678,34 @@ public void testSaveAndReturnFinalResult() {
677678
assertEquals(finalResult, secondModelTensorList.get(0).getDataAsMap().get("response"));
678679
}
679680

681+
@Test
682+
public void testParseTensorDataMap() {
683+
// Test with response only
684+
Map<String, Object> dataMap = new HashMap<>();
685+
dataMap.put("response", "test response");
686+
ModelTensor tensor = ModelTensor.builder().dataAsMap(dataMap).build();
687+
688+
String result = mlPlanExecuteAndReflectAgentRunner.parseTensorDataMap(tensor);
689+
assertEquals("test response", result);
690+
691+
// Test with additional info
692+
Map<String, Object> additionalInfo = new HashMap<>();
693+
additionalInfo.put("trace1", "content1");
694+
additionalInfo.put("trace2", "content2");
695+
dataMap.put("additional_info", additionalInfo);
696+
697+
result = mlPlanExecuteAndReflectAgentRunner.parseTensorDataMap(tensor);
698+
assertTrue(result.contains("test response"));
699+
assertTrue(result.contains("<step-traces>"));
700+
assertTrue(result.contains("<trace1>\ncontent1\n</trace1>"));
701+
assertTrue(result.contains("<trace2>\ncontent2\n</trace2>"));
702+
assertTrue(result.contains("</step-traces>"));
703+
704+
// Test with null dataMap
705+
ModelTensor nullTensor = ModelTensor.builder().build();
706+
assertNull(mlPlanExecuteAndReflectAgentRunner.parseTensorDataMap(nullTensor));
707+
}
708+
680709
@Test
681710
public void testUpdateTaskWithExecutorAgentInfo() {
682711
MLAgent mlAgent = createMLAgentWithTools();

0 commit comments

Comments
 (0)