Skip to content

Commit a39dd33

Browse files
pyek-botdhrubo-os
andauthored
Expose message history limit for PER Agent (#4016)
* feat: expose message history limit Signed-off-by: Pavan Yekbote <pybot@amazon.com> * spotless Signed-off-by: Pavan Yekbote <pybot@amazon.com> * feat: add test case for message history limit Signed-off-by: Pavan Yekbote <pybot@amazon.com> * spotless Signed-off-by: Pavan Yekbote <pybot@amazon.com> * chore: add comment about history limit context Signed-off-by: Pavan Yekbote <pybot@amazon.com> --------- Signed-off-by: Pavan Yekbote <pybot@amazon.com> Co-authored-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 5cd4c96 commit a39dd33

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ public class MLPlanExecuteAndReflectAgentRunner implements MLAgentRunner {
107107
"You are a dedicated helper agent working as part of a plan‑execute‑reflect framework. Your role is to receive a discrete task, execute all necessary internal reasoning or tool calls, and return a single, final response that fully addresses the task. You must never return an empty response. If you are unable to complete the task or retrieve meaningful information, you must respond with a clear explanation of the issue or what was missing. Under no circumstances should you end your reply with a question or ask for more information. If you search any index, always include the raw documents in the final result instead of summarizing the content. This is critical to give visibility into what the query retrieved.";
108108
private static final String DEFAULT_NO_ESCAPE_PARAMS = "tool_configs,_tools";
109109
private static final String DEFAULT_MAX_STEPS_EXECUTED = "20";
110-
private static final int DEFAULT_MESSAGE_HISTORY_LIMIT = 10;
111110
private static final String DEFAULT_REACT_MAX_ITERATIONS = "20";
112111

113112
// fields
@@ -138,6 +137,16 @@ public class MLPlanExecuteAndReflectAgentRunner implements MLAgentRunner {
138137
public static final String REFLECT_PROMPT_TEMPLATE_FIELD = "reflect_prompt_template";
139138
public static final String PLANNER_WITH_HISTORY_TEMPLATE_FIELD = "planner_with_history_template";
140139
public static final String EXECUTOR_MAX_ITERATIONS_FIELD = "executor_max_iterations";
140+
141+
// controls how many messages (last x) from planner memory are passed as context during planning phase
142+
// these messages are added as completed steps in the reflect prompt
143+
public static final String PLANNER_MESSAGE_HISTORY_LIMIT = "message_history_limit";
144+
private static final String DEFAULT_MESSAGE_HISTORY_LIMIT = "10";
145+
146+
// controls how many messages from executor memory are passed as context during step execution
147+
public static final String EXECUTOR_MESSAGE_HISTORY_LIMIT = "executor_message_history_limit";
148+
private static final String DEFAULT_EXECUTOR_MESSAGE_HISTORY_LIMIT = "10";
149+
141150
public static final String INJECT_DATETIME_FIELD = "inject_datetime";
142151
public static final String DATETIME_FORMAT_FIELD = "datetime_format";
143152

@@ -271,7 +280,7 @@ public void run(MLAgent mlAgent, Map<String, String> apiParams, ActionListener<O
271280
String memoryId = allParams.get(MEMORY_ID_FIELD);
272281
String memoryType = mlAgent.getMemory().getType();
273282
String appType = mlAgent.getAppType();
274-
int messageHistoryLimit = DEFAULT_MESSAGE_HISTORY_LIMIT;
283+
int messageHistoryLimit = Integer.parseInt(allParams.getOrDefault(PLANNER_MESSAGE_HISTORY_LIMIT, DEFAULT_MESSAGE_HISTORY_LIMIT));
275284

276285
// todo: use chat history instead of completed steps
277286
ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType);
@@ -417,6 +426,11 @@ private void executePlanningLoop(
417426
reactParams.put(SYSTEM_PROMPT_FIELD, allParams.getOrDefault(EXECUTOR_SYSTEM_PROMPT_FIELD, DEFAULT_EXECUTOR_SYSTEM_PROMPT));
418427
reactParams.put(LLM_RESPONSE_FILTER, allParams.get(LLM_RESPONSE_FILTER));
419428
reactParams.put(MAX_ITERATION, allParams.getOrDefault(EXECUTOR_MAX_ITERATIONS_FIELD, DEFAULT_REACT_MAX_ITERATIONS));
429+
reactParams
430+
.put(
431+
MLAgentExecutor.MESSAGE_HISTORY_LIMIT,
432+
allParams.getOrDefault(EXECUTOR_MESSAGE_HISTORY_LIMIT, DEFAULT_EXECUTOR_MESSAGE_HISTORY_LIMIT)
433+
);
420434

421435
AgentMLInput agentInput = AgentMLInput
422436
.AgentMLInputBuilder()

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
import org.opensearch.ml.common.agent.MLMemorySpec;
4848
import org.opensearch.ml.common.agent.MLToolSpec;
4949
import org.opensearch.ml.common.conversation.Interaction;
50+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
51+
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
5052
import org.opensearch.ml.common.output.model.ModelTensor;
5153
import org.opensearch.ml.common.output.model.ModelTensorOutput;
5254
import org.opensearch.ml.common.output.model.ModelTensors;
@@ -327,6 +329,58 @@ public void testExecutionWithHistory() {
327329
assertEquals("final result", responseTensor.getDataAsMap().get("response"));
328330
}
329331

332+
@Test
333+
public void testMessageHistoryLimits() {
334+
MLAgent mlAgent = createMLAgentWithTools();
335+
336+
doAnswer(invocation -> {
337+
ActionListener<Object> listener = invocation.getArgument(2);
338+
ModelTensor modelTensor = ModelTensor
339+
.builder()
340+
.dataAsMap(ImmutableMap.of("response", "{\"steps\":[\"step1\"], \"result\":\"\"}"))
341+
.build();
342+
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
343+
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
344+
when(mlTaskResponse.getOutput()).thenReturn(mlModelTensorOutput);
345+
listener.onResponse(mlTaskResponse);
346+
return null;
347+
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any());
348+
349+
doAnswer(invocation -> {
350+
ActionListener<Object> listener = invocation.getArgument(1);
351+
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "tool execution result")).build();
352+
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
353+
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
354+
when(mlExecuteTaskResponse.getOutput()).thenReturn(mlModelTensorOutput);
355+
listener.onResponse(mlExecuteTaskResponse);
356+
return null;
357+
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(MLExecuteTaskRequest.class), any());
358+
359+
doAnswer(invocation -> {
360+
ActionListener<UpdateResponse> listener = invocation.getArgument(2);
361+
listener.onResponse(updateResponse);
362+
return null;
363+
}).when(mlMemoryManager).updateInteraction(any(), any(), any());
364+
365+
Map<String, String> params = new HashMap<>();
366+
params.put("question", "test question");
367+
params.put("memory_id", "test_memory_id");
368+
params.put("parent_interaction_id", "test_parent_interaction_id");
369+
params.put("message_history_limit", "5");
370+
params.put("executor_message_history_limit", "3");
371+
mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener);
372+
373+
verify(conversationIndexMemory).getMessages(any(), eq(5));
374+
375+
ArgumentCaptor<MLExecuteTaskRequest> executeCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
376+
verify(client).execute(eq(MLExecuteTaskAction.INSTANCE), executeCaptor.capture(), any());
377+
378+
AgentMLInput agentInput = (AgentMLInput) executeCaptor.getValue().getInput();
379+
RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) agentInput.getInputDataset();
380+
Map<String, String> executorParams = dataset.getParameters();
381+
assertEquals("3", executorParams.get("message_history_limit"));
382+
}
383+
330384
// ToDo: add test case for when max steps is reached
331385

332386
private MLAgent createMLAgentWithTools() {

0 commit comments

Comments
 (0)