Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ public class MLChatAgentRunner implements MLAgentRunner {
public static final String SYSTEM_PROMPT_FIELD = "system_prompt";

private static final String DEFAULT_MAX_ITERATIONS = "10";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can customer configure this value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, chat agent has a parameter max_iteration

private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task";

private Client client;
private Settings settings;
Expand Down Expand Up @@ -327,7 +328,7 @@ private void runReAct(
int maxIterations = Integer.parseInt(tmpParameters.getOrDefault(MAX_ITERATION, DEFAULT_MAX_ITERATIONS));
for (int i = 0; i < maxIterations; i++) {
int finalI = i;
StepListener<?> nextStepListener = new StepListener<>();
StepListener<?> nextStepListener = (i == maxIterations - 1) ? null : new StepListener<>();

lastStepListener.whenComplete(output -> {
StringBuilder sessionMsgAnswerBuilder = new StringBuilder();
Expand Down Expand Up @@ -396,6 +397,25 @@ private void runReAct(
"LLM"
);

if (nextStepListener == null) {
handleMaxIterationsReached(
sessionId,
listener,
question,
parentInteractionId,
verbose,
traceDisabled,
traceTensors,
conversationIndexMemory,
traceNumber,
additionalInfo,
lastThought,
maxIterations,
tools
);
return;
}

if (tools.containsKey(action)) {
Map<String, String> toolParams = constructToolParams(
tools,
Expand Down Expand Up @@ -455,7 +475,7 @@ private void runReAct(
StringSubstitutor substitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder), "${parameters.", "}");
newPrompt.set(substitutor.replace(finalPrompt));
tmpParameters.put(PROMPT, newPrompt.get());
if (interactions.size() > 0) {
if (!interactions.isEmpty()) {
tmpParameters.put(INTERACTIONS, ", " + String.join(", ", interactions));
}

Expand All @@ -474,34 +494,41 @@ private void runReAct(
);

if (finalI == maxIterations - 1) {
if (verbose) {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(traceTensors).build());
} else {
List<ModelTensors> finalModelTensors = createFinalAnswerTensors(
createModelTensors(sessionId, parentInteractionId),
List.of(ModelTensor.builder().name("response").dataAsMap(Map.of("response", lastThought.get())).build())
);
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
}
} else {
ActionRequest request = new MLPredictionTaskRequest(
llm.getModelId(),
RemoteInferenceMLInput
.builder()
.algorithm(FunctionName.REMOTE)
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build())
.build(),
null,
tenantId
handleMaxIterationsReached(
sessionId,
listener,
question,
parentInteractionId,
verbose,
traceDisabled,
traceTensors,
conversationIndexMemory,
traceNumber,
additionalInfo,
lastThought,
maxIterations,
tools
);
client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener<MLTaskResponse>) nextStepListener);
return;
}

ActionRequest request = new MLPredictionTaskRequest(
llm.getModelId(),
RemoteInferenceMLInput
.builder()
.algorithm(FunctionName.REMOTE)
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build())
.build(),
null,
tenantId
);
client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener<MLTaskResponse>) nextStepListener);
}
}, e -> {
log.error("Failed to run chat agent", e);
listener.onFailure(e);
});
if (i < maxIterations - 1) {
if (nextStepListener != null) {
lastStepListener = nextStepListener;
}
}
Expand Down Expand Up @@ -837,6 +864,40 @@ private static void returnFinalResponse(
}
}

private void handleMaxIterationsReached(
String sessionId,
ActionListener<Object> listener,
String question,
String parentInteractionId,
boolean verbose,
boolean traceDisabled,
List<ModelTensors> traceTensors,
ConversationIndexMemory conversationIndexMemory,
AtomicInteger traceNumber,
Map<String, Object> additionalInfo,
AtomicReference<String> lastThought,
int maxIterations,
Map<String, Tool> tools
) {
String incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get()))
? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get())
: String.format(MAX_ITERATIONS_MESSAGE, maxIterations);
sendFinalAnswer(
sessionId,
listener,
question,
parentInteractionId,
verbose,
traceDisabled,
traceTensors,
conversationIndexMemory,
traceNumber,
additionalInfo,
incompleteResponse
);
cleanUpResource(tools);
}

private void saveMessage(
ConversationIndexMemory memory,
String question,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,10 @@ private void executePlanningLoop(
saveAndReturnFinalResult(
(ConversationIndexMemory) memory,
parentInteractionId,
finalResult,
completedSteps.get(completedSteps.size() - 2),
allParams.get(EXECUTOR_AGENT_MEMORY_ID_FIELD),
allParams.get(EXECUTOR_AGENT_PARENT_INTERACTION_ID_FIELD),
finalResult,
null,
finalListener
);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,87 @@ private Answer generateToolFailure(Exception e) {
};
}

@Test
public void testMaxIterationsReached() {
// Create LLM spec with max_iteration = 1 to force max iterations
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build();
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.type(MLAgentType.CONVERSATIONAL.name())
.llm(llmSpec)
.memory(mlMemorySpec)
.tools(Arrays.asList(firstToolSpec))
.build();

// Mock LLM response that doesn't contain final_answer to force max iterations
Mockito
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "", "action", FIRST_TOOL)))
.when(client)
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));

Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");

mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Verify response is captured
verify(agentActionListener).onResponse(objectCaptor.capture());
Object capturedResponse = objectCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);

ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
assertEquals(1, agentOutput.size());

// Verify the response contains max iterations message
String response = (String) agentOutput.get(0).getDataAsMap().get("response");
assertEquals("Agent reached maximum iterations (1) without completing the task", response);
}

@Test
public void testMaxIterationsReachedWithValidThought() {
// Create LLM spec with max_iteration = 1 to force max iterations
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build();
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.type(MLAgentType.CONVERSATIONAL.name())
.llm(llmSpec)
.memory(mlMemorySpec)
.tools(Arrays.asList(firstToolSpec))
.build();

// Mock LLM response with valid thought
Mockito
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the first tool", "action", FIRST_TOOL)))
.when(client)
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));

Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");

mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Verify response is captured
verify(agentActionListener).onResponse(objectCaptor.capture());
Object capturedResponse = objectCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);

ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
assertEquals(1, agentOutput.size());

// Verify the response contains the last valid thought instead of max iterations message
String response = (String) agentOutput.get(0).getDataAsMap().get("response");
assertEquals(
"Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the first tool",
response
);
}

@Test
public void testConstructLLMParams_WithSystemPromptAndDateTimeInjection() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Expand Down
Loading