Skip to content

Commit d0792dc

Browse files
pyek-botdhrubo-os
andcommitted
Ensure chat agent returns response when max iterations are reached (opensearch-project#4031)
* fix: mlchatagentrunner max iterations bug Signed-off-by: Pavan Yekbote <pybot@amazon.com> * fix: gradle spotless Signed-off-by: Pavan Yekbote <pybot@amazon.com> * feat: refactor some code and add test cases Signed-off-by: Pavan Yekbote <pybot@amazon.com> * refactor: move the last iteration check earlier Signed-off-by: Pavan Yekbote <pybot@amazon.com> --------- Signed-off-by: Pavan Yekbote <pybot@amazon.com> Co-authored-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent 06e20ac commit d0792dc

File tree

3 files changed

+166
-25
lines changed

3 files changed

+166
-25
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 84 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ public class MLChatAgentRunner implements MLAgentRunner {
119119
public static final String LLM_INTERFACE = "_llm_interface";
120120

121121
private static final String DEFAULT_MAX_ITERATIONS = "10";
122+
private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task";
122123

123124
private Client client;
124125
private Settings settings;
@@ -321,7 +322,7 @@ private void runReAct(
321322
int maxIterations = Integer.parseInt(tmpParameters.getOrDefault(MAX_ITERATION, DEFAULT_MAX_ITERATIONS));
322323
for (int i = 0; i < maxIterations; i++) {
323324
int finalI = i;
324-
StepListener<?> nextStepListener = new StepListener<>();
325+
StepListener<?> nextStepListener = (i == maxIterations - 1) ? null : new StepListener<>();
325326

326327
lastStepListener.whenComplete(output -> {
327328
StringBuilder sessionMsgAnswerBuilder = new StringBuilder();
@@ -390,6 +391,25 @@ private void runReAct(
390391
"LLM"
391392
);
392393

394+
if (nextStepListener == null) {
395+
handleMaxIterationsReached(
396+
sessionId,
397+
listener,
398+
question,
399+
parentInteractionId,
400+
verbose,
401+
traceDisabled,
402+
traceTensors,
403+
conversationIndexMemory,
404+
traceNumber,
405+
additionalInfo,
406+
lastThought,
407+
maxIterations,
408+
tools
409+
);
410+
return;
411+
}
412+
393413
if (tools.containsKey(action)) {
394414
Map<String, String> toolParams = constructToolParams(
395415
tools,
@@ -449,7 +469,7 @@ private void runReAct(
449469
StringSubstitutor substitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder), "${parameters.", "}");
450470
newPrompt.set(substitutor.replace(finalPrompt));
451471
tmpParameters.put(PROMPT, newPrompt.get());
452-
if (interactions.size() > 0) {
472+
if (!interactions.isEmpty()) {
453473
tmpParameters.put(INTERACTIONS, ", " + String.join(", ", interactions));
454474
}
455475

@@ -468,34 +488,41 @@ private void runReAct(
468488
);
469489

470490
if (finalI == maxIterations - 1) {
471-
if (verbose) {
472-
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(traceTensors).build());
473-
} else {
474-
List<ModelTensors> finalModelTensors = createFinalAnswerTensors(
475-
createModelTensors(sessionId, parentInteractionId),
476-
List.of(ModelTensor.builder().name("response").dataAsMap(Map.of("response", lastThought.get())).build())
477-
);
478-
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
479-
}
480-
} else {
481-
ActionRequest request = new MLPredictionTaskRequest(
482-
llm.getModelId(),
483-
RemoteInferenceMLInput
484-
.builder()
485-
.algorithm(FunctionName.REMOTE)
486-
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build())
487-
.build(),
488-
null,
489-
tenantId
491+
handleMaxIterationsReached(
492+
sessionId,
493+
listener,
494+
question,
495+
parentInteractionId,
496+
verbose,
497+
traceDisabled,
498+
traceTensors,
499+
conversationIndexMemory,
500+
traceNumber,
501+
additionalInfo,
502+
lastThought,
503+
maxIterations,
504+
tools
490505
);
491-
client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener<MLTaskResponse>) nextStepListener);
506+
return;
492507
}
508+
509+
ActionRequest request = new MLPredictionTaskRequest(
510+
llm.getModelId(),
511+
RemoteInferenceMLInput
512+
.builder()
513+
.algorithm(FunctionName.REMOTE)
514+
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build())
515+
.build(),
516+
null,
517+
tenantId
518+
);
519+
client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener<MLTaskResponse>) nextStepListener);
493520
}
494521
}, e -> {
495522
log.error("Failed to run chat agent", e);
496523
listener.onFailure(e);
497524
});
498-
if (i < maxIterations - 1) {
525+
if (nextStepListener != null) {
499526
lastStepListener = nextStepListener;
500527
}
501528
}
@@ -813,6 +840,40 @@ private static void returnFinalResponse(
813840
}
814841
}
815842

843+
private void handleMaxIterationsReached(
844+
String sessionId,
845+
ActionListener<Object> listener,
846+
String question,
847+
String parentInteractionId,
848+
boolean verbose,
849+
boolean traceDisabled,
850+
List<ModelTensors> traceTensors,
851+
ConversationIndexMemory conversationIndexMemory,
852+
AtomicInteger traceNumber,
853+
Map<String, Object> additionalInfo,
854+
AtomicReference<String> lastThought,
855+
int maxIterations,
856+
Map<String, Tool> tools
857+
) {
858+
String incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get()))
859+
? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get())
860+
: String.format(MAX_ITERATIONS_MESSAGE, maxIterations);
861+
sendFinalAnswer(
862+
sessionId,
863+
listener,
864+
question,
865+
parentInteractionId,
866+
verbose,
867+
traceDisabled,
868+
traceTensors,
869+
conversationIndexMemory,
870+
traceNumber,
871+
additionalInfo,
872+
incompleteResponse
873+
);
874+
cleanUpResource(tools);
875+
}
876+
816877
private void saveMessage(
817878
ConversationIndexMemory memory,
818879
String question,

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,10 +345,10 @@ private void executePlanningLoop(
345345
saveAndReturnFinalResult(
346346
(ConversationIndexMemory) memory,
347347
parentInteractionId,
348-
finalResult,
349-
completedSteps.get(completedSteps.size() - 2),
350348
allParams.get(EXECUTOR_AGENT_MEMORY_ID_FIELD),
351349
allParams.get(EXECUTOR_AGENT_PARENT_INTERACTION_ID_FIELD),
350+
finalResult,
351+
null,
352352
finalListener
353353
);
354354
return;

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,4 +977,84 @@ private Answer generateToolFailure(Exception e) {
977977
};
978978
}
979979

980+
@Test
981+
public void testMaxIterationsReached() {
982+
// Create LLM spec with max_iteration = 1 to force max iterations
983+
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build();
984+
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
985+
final MLAgent mlAgent = MLAgent
986+
.builder()
987+
.name("TestAgent")
988+
.type(MLAgentType.CONVERSATIONAL.name())
989+
.llm(llmSpec)
990+
.memory(mlMemorySpec)
991+
.tools(Arrays.asList(firstToolSpec))
992+
.build();
993+
994+
// Mock LLM response that doesn't contain final_answer to force max iterations
995+
Mockito
996+
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "", "action", FIRST_TOOL)))
997+
.when(client)
998+
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));
999+
1000+
Map<String, String> params = new HashMap<>();
1001+
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
1002+
1003+
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
1004+
1005+
// Verify response is captured
1006+
verify(agentActionListener).onResponse(objectCaptor.capture());
1007+
Object capturedResponse = objectCaptor.getValue();
1008+
assertTrue(capturedResponse instanceof ModelTensorOutput);
1009+
1010+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
1011+
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
1012+
assertEquals(1, agentOutput.size());
1013+
1014+
// Verify the response contains max iterations message
1015+
String response = (String) agentOutput.get(0).getDataAsMap().get("response");
1016+
assertEquals("Agent reached maximum iterations (1) without completing the task", response);
1017+
}
1018+
1019+
@Test
1020+
public void testMaxIterationsReachedWithValidThought() {
1021+
// Create LLM spec with max_iteration = 1 to force max iterations
1022+
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build();
1023+
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
1024+
final MLAgent mlAgent = MLAgent
1025+
.builder()
1026+
.name("TestAgent")
1027+
.type(MLAgentType.CONVERSATIONAL.name())
1028+
.llm(llmSpec)
1029+
.memory(mlMemorySpec)
1030+
.tools(Arrays.asList(firstToolSpec))
1031+
.build();
1032+
1033+
// Mock LLM response with valid thought
1034+
Mockito
1035+
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the first tool", "action", FIRST_TOOL)))
1036+
.when(client)
1037+
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));
1038+
1039+
Map<String, String> params = new HashMap<>();
1040+
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
1041+
1042+
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
1043+
1044+
// Verify response is captured
1045+
verify(agentActionListener).onResponse(objectCaptor.capture());
1046+
Object capturedResponse = objectCaptor.getValue();
1047+
assertTrue(capturedResponse instanceof ModelTensorOutput);
1048+
1049+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
1050+
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
1051+
assertEquals(1, agentOutput.size());
1052+
1053+
// Verify the response contains the last valid thought instead of max iterations message
1054+
String response = (String) agentOutput.get(0).getDataAsMap().get("response");
1055+
assertEquals(
1056+
"Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the first tool",
1057+
response
1058+
);
1059+
}
9801060
}

0 commit comments

Comments
 (0)