Skip to content

Commit f6a68bf

Browse files
pyek-botdhrubo-os
andauthored
Ensure chat agent returns response when max iterations are reached (#4031)
* fix: mlchatagentrunner max iterations bug Signed-off-by: Pavan Yekbote <pybot@amazon.com> * fix: gradle spotless Signed-off-by: Pavan Yekbote <pybot@amazon.com> * feat: refactor some code and add test cases Signed-off-by: Pavan Yekbote <pybot@amazon.com> * refactor: move the last iteration check earlier Signed-off-by: Pavan Yekbote <pybot@amazon.com> --------- Signed-off-by: Pavan Yekbote <pybot@amazon.com> Co-authored-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 6e3656d commit f6a68bf

File tree

3 files changed

+167
-25
lines changed

3 files changed

+167
-25
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 84 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ public class MLChatAgentRunner implements MLAgentRunner {
125125
public static final String SYSTEM_PROMPT_FIELD = "system_prompt";
126126

127127
private static final String DEFAULT_MAX_ITERATIONS = "10";
128+
private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task";
128129

129130
private Client client;
130131
private Settings settings;
@@ -327,7 +328,7 @@ private void runReAct(
327328
int maxIterations = Integer.parseInt(tmpParameters.getOrDefault(MAX_ITERATION, DEFAULT_MAX_ITERATIONS));
328329
for (int i = 0; i < maxIterations; i++) {
329330
int finalI = i;
330-
StepListener<?> nextStepListener = new StepListener<>();
331+
StepListener<?> nextStepListener = (i == maxIterations - 1) ? null : new StepListener<>();
331332

332333
lastStepListener.whenComplete(output -> {
333334
StringBuilder sessionMsgAnswerBuilder = new StringBuilder();
@@ -396,6 +397,25 @@ private void runReAct(
396397
"LLM"
397398
);
398399

400+
if (nextStepListener == null) {
401+
handleMaxIterationsReached(
402+
sessionId,
403+
listener,
404+
question,
405+
parentInteractionId,
406+
verbose,
407+
traceDisabled,
408+
traceTensors,
409+
conversationIndexMemory,
410+
traceNumber,
411+
additionalInfo,
412+
lastThought,
413+
maxIterations,
414+
tools
415+
);
416+
return;
417+
}
418+
399419
if (tools.containsKey(action)) {
400420
Map<String, String> toolParams = constructToolParams(
401421
tools,
@@ -455,7 +475,7 @@ private void runReAct(
455475
StringSubstitutor substitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder), "${parameters.", "}");
456476
newPrompt.set(substitutor.replace(finalPrompt));
457477
tmpParameters.put(PROMPT, newPrompt.get());
458-
if (interactions.size() > 0) {
478+
if (!interactions.isEmpty()) {
459479
tmpParameters.put(INTERACTIONS, ", " + String.join(", ", interactions));
460480
}
461481

@@ -474,34 +494,41 @@ private void runReAct(
474494
);
475495

476496
if (finalI == maxIterations - 1) {
477-
if (verbose) {
478-
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(traceTensors).build());
479-
} else {
480-
List<ModelTensors> finalModelTensors = createFinalAnswerTensors(
481-
createModelTensors(sessionId, parentInteractionId),
482-
List.of(ModelTensor.builder().name("response").dataAsMap(Map.of("response", lastThought.get())).build())
483-
);
484-
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
485-
}
486-
} else {
487-
ActionRequest request = new MLPredictionTaskRequest(
488-
llm.getModelId(),
489-
RemoteInferenceMLInput
490-
.builder()
491-
.algorithm(FunctionName.REMOTE)
492-
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build())
493-
.build(),
494-
null,
495-
tenantId
497+
handleMaxIterationsReached(
498+
sessionId,
499+
listener,
500+
question,
501+
parentInteractionId,
502+
verbose,
503+
traceDisabled,
504+
traceTensors,
505+
conversationIndexMemory,
506+
traceNumber,
507+
additionalInfo,
508+
lastThought,
509+
maxIterations,
510+
tools
496511
);
497-
client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener<MLTaskResponse>) nextStepListener);
512+
return;
498513
}
514+
515+
ActionRequest request = new MLPredictionTaskRequest(
516+
llm.getModelId(),
517+
RemoteInferenceMLInput
518+
.builder()
519+
.algorithm(FunctionName.REMOTE)
520+
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build())
521+
.build(),
522+
null,
523+
tenantId
524+
);
525+
client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener<MLTaskResponse>) nextStepListener);
499526
}
500527
}, e -> {
501528
log.error("Failed to run chat agent", e);
502529
listener.onFailure(e);
503530
});
504-
if (i < maxIterations - 1) {
531+
if (nextStepListener != null) {
505532
lastStepListener = nextStepListener;
506533
}
507534
}
@@ -837,6 +864,40 @@ private static void returnFinalResponse(
837864
}
838865
}
839866

867+
private void handleMaxIterationsReached(
868+
String sessionId,
869+
ActionListener<Object> listener,
870+
String question,
871+
String parentInteractionId,
872+
boolean verbose,
873+
boolean traceDisabled,
874+
List<ModelTensors> traceTensors,
875+
ConversationIndexMemory conversationIndexMemory,
876+
AtomicInteger traceNumber,
877+
Map<String, Object> additionalInfo,
878+
AtomicReference<String> lastThought,
879+
int maxIterations,
880+
Map<String, Tool> tools
881+
) {
882+
String incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get()))
883+
? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get())
884+
: String.format(MAX_ITERATIONS_MESSAGE, maxIterations);
885+
sendFinalAnswer(
886+
sessionId,
887+
listener,
888+
question,
889+
parentInteractionId,
890+
verbose,
891+
traceDisabled,
892+
traceTensors,
893+
conversationIndexMemory,
894+
traceNumber,
895+
additionalInfo,
896+
incompleteResponse
897+
);
898+
cleanUpResource(tools);
899+
}
900+
840901
private void saveMessage(
841902
ConversationIndexMemory memory,
842903
String question,

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,10 @@ private void executePlanningLoop(
364364
saveAndReturnFinalResult(
365365
(ConversationIndexMemory) memory,
366366
parentInteractionId,
367-
finalResult,
368-
completedSteps.get(completedSteps.size() - 2),
369367
allParams.get(EXECUTOR_AGENT_MEMORY_ID_FIELD),
370368
allParams.get(EXECUTOR_AGENT_PARENT_INTERACTION_ID_FIELD),
369+
finalResult,
370+
null,
371371
finalListener
372372
);
373373
return;

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,87 @@ private Answer generateToolFailure(Exception e) {
978978
};
979979
}
980980

981+
@Test
982+
public void testMaxIterationsReached() {
983+
// Create LLM spec with max_iteration = 1 to force max iterations
984+
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build();
985+
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
986+
final MLAgent mlAgent = MLAgent
987+
.builder()
988+
.name("TestAgent")
989+
.type(MLAgentType.CONVERSATIONAL.name())
990+
.llm(llmSpec)
991+
.memory(mlMemorySpec)
992+
.tools(Arrays.asList(firstToolSpec))
993+
.build();
994+
995+
// Mock LLM response that doesn't contain final_answer to force max iterations
996+
Mockito
997+
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "", "action", FIRST_TOOL)))
998+
.when(client)
999+
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));
1000+
1001+
Map<String, String> params = new HashMap<>();
1002+
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
1003+
1004+
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
1005+
1006+
// Verify response is captured
1007+
verify(agentActionListener).onResponse(objectCaptor.capture());
1008+
Object capturedResponse = objectCaptor.getValue();
1009+
assertTrue(capturedResponse instanceof ModelTensorOutput);
1010+
1011+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
1012+
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
1013+
assertEquals(1, agentOutput.size());
1014+
1015+
// Verify the response contains max iterations message
1016+
String response = (String) agentOutput.get(0).getDataAsMap().get("response");
1017+
assertEquals("Agent reached maximum iterations (1) without completing the task", response);
1018+
}
1019+
1020+
@Test
1021+
public void testMaxIterationsReachedWithValidThought() {
1022+
// Create LLM spec with max_iteration = 1 to force max iterations
1023+
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build();
1024+
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
1025+
final MLAgent mlAgent = MLAgent
1026+
.builder()
1027+
.name("TestAgent")
1028+
.type(MLAgentType.CONVERSATIONAL.name())
1029+
.llm(llmSpec)
1030+
.memory(mlMemorySpec)
1031+
.tools(Arrays.asList(firstToolSpec))
1032+
.build();
1033+
1034+
// Mock LLM response with valid thought
1035+
Mockito
1036+
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the first tool", "action", FIRST_TOOL)))
1037+
.when(client)
1038+
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));
1039+
1040+
Map<String, String> params = new HashMap<>();
1041+
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
1042+
1043+
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
1044+
1045+
// Verify response is captured
1046+
verify(agentActionListener).onResponse(objectCaptor.capture());
1047+
Object capturedResponse = objectCaptor.getValue();
1048+
assertTrue(capturedResponse instanceof ModelTensorOutput);
1049+
1050+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
1051+
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
1052+
assertEquals(1, agentOutput.size());
1053+
1054+
// Verify the response contains the last valid thought instead of max iterations message
1055+
String response = (String) agentOutput.get(0).getDataAsMap().get("response");
1056+
assertEquals(
1057+
"Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the first tool",
1058+
response
1059+
);
1060+
}
1061+
9811062
@Test
9821063
public void testConstructLLMParams_WithSystemPromptAndDateTimeInjection() {
9831064
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();

0 commit comments

Comments
 (0)