@@ -125,6 +125,7 @@ public class MLChatAgentRunner implements MLAgentRunner {
125125 public static final String SYSTEM_PROMPT_FIELD = "system_prompt" ;
126126
127127 private static final String DEFAULT_MAX_ITERATIONS = "10" ;
128+ private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task" ;
128129
129130 private Client client ;
130131 private Settings settings ;
@@ -327,7 +328,7 @@ private void runReAct(
327328 int maxIterations = Integer .parseInt (tmpParameters .getOrDefault (MAX_ITERATION , DEFAULT_MAX_ITERATIONS ));
328329 for (int i = 0 ; i < maxIterations ; i ++) {
329330 int finalI = i ;
330- StepListener <?> nextStepListener = new StepListener <>();
331+ StepListener <?> nextStepListener = ( i == maxIterations - 1 ) ? null : new StepListener <>();
331332
332333 lastStepListener .whenComplete (output -> {
333334 StringBuilder sessionMsgAnswerBuilder = new StringBuilder ();
@@ -396,6 +397,25 @@ private void runReAct(
396397 "LLM"
397398 );
398399
400+ if (nextStepListener == null ) {
401+ handleMaxIterationsReached (
402+ sessionId ,
403+ listener ,
404+ question ,
405+ parentInteractionId ,
406+ verbose ,
407+ traceDisabled ,
408+ traceTensors ,
409+ conversationIndexMemory ,
410+ traceNumber ,
411+ additionalInfo ,
412+ lastThought ,
413+ maxIterations ,
414+ tools
415+ );
416+ return ;
417+ }
418+
399419 if (tools .containsKey (action )) {
400420 Map <String , String > toolParams = constructToolParams (
401421 tools ,
@@ -455,7 +475,7 @@ private void runReAct(
455475 StringSubstitutor substitutor = new StringSubstitutor (Map .of (SCRATCHPAD , scratchpadBuilder ), "${parameters." , "}" );
456476 newPrompt .set (substitutor .replace (finalPrompt ));
457477 tmpParameters .put (PROMPT , newPrompt .get ());
458- if (interactions .size () > 0 ) {
478+ if (! interactions .isEmpty () ) {
459479 tmpParameters .put (INTERACTIONS , ", " + String .join (", " , interactions ));
460480 }
461481
@@ -474,34 +494,41 @@ private void runReAct(
474494 );
475495
476496 if (finalI == maxIterations - 1 ) {
477- if (verbose ) {
478- listener .onResponse (ModelTensorOutput .builder ().mlModelOutputs (traceTensors ).build ());
479- } else {
480- List <ModelTensors > finalModelTensors = createFinalAnswerTensors (
481- createModelTensors (sessionId , parentInteractionId ),
482- List .of (ModelTensor .builder ().name ("response" ).dataAsMap (Map .of ("response" , lastThought .get ())).build ())
483- );
484- listener .onResponse (ModelTensorOutput .builder ().mlModelOutputs (finalModelTensors ).build ());
485- }
486- } else {
487- ActionRequest request = new MLPredictionTaskRequest (
488- llm .getModelId (),
489- RemoteInferenceMLInput
490- .builder ()
491- .algorithm (FunctionName .REMOTE )
492- .inputDataset (RemoteInferenceInputDataSet .builder ().parameters (tmpParameters ).build ())
493- .build (),
494- null ,
495- tenantId
497+ handleMaxIterationsReached (
498+ sessionId ,
499+ listener ,
500+ question ,
501+ parentInteractionId ,
502+ verbose ,
503+ traceDisabled ,
504+ traceTensors ,
505+ conversationIndexMemory ,
506+ traceNumber ,
507+ additionalInfo ,
508+ lastThought ,
509+ maxIterations ,
510+ tools
496511 );
497- client . execute ( MLPredictionTaskAction . INSTANCE , request , ( ActionListener < MLTaskResponse >) nextStepListener ) ;
512+ return ;
498513 }
514+
515+ ActionRequest request = new MLPredictionTaskRequest (
516+ llm .getModelId (),
517+ RemoteInferenceMLInput
518+ .builder ()
519+ .algorithm (FunctionName .REMOTE )
520+ .inputDataset (RemoteInferenceInputDataSet .builder ().parameters (tmpParameters ).build ())
521+ .build (),
522+ null ,
523+ tenantId
524+ );
525+ client .execute (MLPredictionTaskAction .INSTANCE , request , (ActionListener <MLTaskResponse >) nextStepListener );
499526 }
500527 }, e -> {
501528 log .error ("Failed to run chat agent" , e );
502529 listener .onFailure (e );
503530 });
504- if (i < maxIterations - 1 ) {
531+ if (nextStepListener != null ) {
505532 lastStepListener = nextStepListener ;
506533 }
507534 }
@@ -837,6 +864,40 @@ private static void returnFinalResponse(
837864 }
838865 }
839866
867+ private void handleMaxIterationsReached (
868+ String sessionId ,
869+ ActionListener <Object > listener ,
870+ String question ,
871+ String parentInteractionId ,
872+ boolean verbose ,
873+ boolean traceDisabled ,
874+ List <ModelTensors > traceTensors ,
875+ ConversationIndexMemory conversationIndexMemory ,
876+ AtomicInteger traceNumber ,
877+ Map <String , Object > additionalInfo ,
878+ AtomicReference <String > lastThought ,
879+ int maxIterations ,
880+ Map <String , Tool > tools
881+ ) {
882+ String incompleteResponse = (lastThought .get () != null && !lastThought .get ().isEmpty () && !"null" .equals (lastThought .get ()))
883+ ? String .format ("%s. Last thought: %s" , String .format (MAX_ITERATIONS_MESSAGE , maxIterations ), lastThought .get ())
884+ : String .format (MAX_ITERATIONS_MESSAGE , maxIterations );
885+ sendFinalAnswer (
886+ sessionId ,
887+ listener ,
888+ question ,
889+ parentInteractionId ,
890+ verbose ,
891+ traceDisabled ,
892+ traceTensors ,
893+ conversationIndexMemory ,
894+ traceNumber ,
895+ additionalInfo ,
896+ incompleteResponse
897+ );
898+ cleanUpResource (tools );
899+ }
900+
840901 private void saveMessage (
841902 ConversationIndexMemory memory ,
842903 String question ,
0 commit comments