@@ -119,6 +119,7 @@ public class MLChatAgentRunner implements MLAgentRunner {
119119 public static final String LLM_INTERFACE = "_llm_interface" ;
120120
121121 private static final String DEFAULT_MAX_ITERATIONS = "10" ;
122+ private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task" ;
122123
123124 private Client client ;
124125 private Settings settings ;
@@ -321,7 +322,7 @@ private void runReAct(
321322 int maxIterations = Integer .parseInt (tmpParameters .getOrDefault (MAX_ITERATION , DEFAULT_MAX_ITERATIONS ));
322323 for (int i = 0 ; i < maxIterations ; i ++) {
323324 int finalI = i ;
324- StepListener <?> nextStepListener = new StepListener <>();
325+ StepListener <?> nextStepListener = ( i == maxIterations - 1 ) ? null : new StepListener <>();
325326
326327 lastStepListener .whenComplete (output -> {
327328 StringBuilder sessionMsgAnswerBuilder = new StringBuilder ();
@@ -390,6 +391,25 @@ private void runReAct(
390391 "LLM"
391392 );
392393
394+ if (nextStepListener == null ) {
395+ handleMaxIterationsReached (
396+ sessionId ,
397+ listener ,
398+ question ,
399+ parentInteractionId ,
400+ verbose ,
401+ traceDisabled ,
402+ traceTensors ,
403+ conversationIndexMemory ,
404+ traceNumber ,
405+ additionalInfo ,
406+ lastThought ,
407+ maxIterations ,
408+ tools
409+ );
410+ return ;
411+ }
412+
393413 if (tools .containsKey (action )) {
394414 Map <String , String > toolParams = constructToolParams (
395415 tools ,
@@ -449,7 +469,7 @@ private void runReAct(
449469 StringSubstitutor substitutor = new StringSubstitutor (Map .of (SCRATCHPAD , scratchpadBuilder ), "${parameters." , "}" );
450470 newPrompt .set (substitutor .replace (finalPrompt ));
451471 tmpParameters .put (PROMPT , newPrompt .get ());
452- if (interactions .size () > 0 ) {
472+ if (! interactions .isEmpty () ) {
453473 tmpParameters .put (INTERACTIONS , ", " + String .join (", " , interactions ));
454474 }
455475
@@ -468,34 +488,41 @@ private void runReAct(
468488 );
469489
470490 if (finalI == maxIterations - 1 ) {
471- if (verbose ) {
472- listener .onResponse (ModelTensorOutput .builder ().mlModelOutputs (traceTensors ).build ());
473- } else {
474- List <ModelTensors > finalModelTensors = createFinalAnswerTensors (
475- createModelTensors (sessionId , parentInteractionId ),
476- List .of (ModelTensor .builder ().name ("response" ).dataAsMap (Map .of ("response" , lastThought .get ())).build ())
477- );
478- listener .onResponse (ModelTensorOutput .builder ().mlModelOutputs (finalModelTensors ).build ());
479- }
480- } else {
481- ActionRequest request = new MLPredictionTaskRequest (
482- llm .getModelId (),
483- RemoteInferenceMLInput
484- .builder ()
485- .algorithm (FunctionName .REMOTE )
486- .inputDataset (RemoteInferenceInputDataSet .builder ().parameters (tmpParameters ).build ())
487- .build (),
488- null ,
489- tenantId
491+ handleMaxIterationsReached (
492+ sessionId ,
493+ listener ,
494+ question ,
495+ parentInteractionId ,
496+ verbose ,
497+ traceDisabled ,
498+ traceTensors ,
499+ conversationIndexMemory ,
500+ traceNumber ,
501+ additionalInfo ,
502+ lastThought ,
503+ maxIterations ,
504+ tools
490505 );
491- client . execute ( MLPredictionTaskAction . INSTANCE , request , ( ActionListener < MLTaskResponse >) nextStepListener ) ;
506+ return ;
492507 }
508+
509+ ActionRequest request = new MLPredictionTaskRequest (
510+ llm .getModelId (),
511+ RemoteInferenceMLInput
512+ .builder ()
513+ .algorithm (FunctionName .REMOTE )
514+ .inputDataset (RemoteInferenceInputDataSet .builder ().parameters (tmpParameters ).build ())
515+ .build (),
516+ null ,
517+ tenantId
518+ );
519+ client .execute (MLPredictionTaskAction .INSTANCE , request , (ActionListener <MLTaskResponse >) nextStepListener );
493520 }
494521 }, e -> {
495522 log .error ("Failed to run chat agent" , e );
496523 listener .onFailure (e );
497524 });
498- if (i < maxIterations - 1 ) {
525+ if (nextStepListener != null ) {
499526 lastStepListener = nextStepListener ;
500527 }
501528 }
@@ -813,6 +840,40 @@ private static void returnFinalResponse(
813840 }
814841 }
815842
843+ private void handleMaxIterationsReached (
844+ String sessionId ,
845+ ActionListener <Object > listener ,
846+ String question ,
847+ String parentInteractionId ,
848+ boolean verbose ,
849+ boolean traceDisabled ,
850+ List <ModelTensors > traceTensors ,
851+ ConversationIndexMemory conversationIndexMemory ,
852+ AtomicInteger traceNumber ,
853+ Map <String , Object > additionalInfo ,
854+ AtomicReference <String > lastThought ,
855+ int maxIterations ,
856+ Map <String , Tool > tools
857+ ) {
858+ String incompleteResponse = (lastThought .get () != null && !lastThought .get ().isEmpty () && !"null" .equals (lastThought .get ()))
859+ ? String .format ("%s. Last thought: %s" , String .format (MAX_ITERATIONS_MESSAGE , maxIterations ), lastThought .get ())
860+ : String .format (MAX_ITERATIONS_MESSAGE , maxIterations );
861+ sendFinalAnswer (
862+ sessionId ,
863+ listener ,
864+ question ,
865+ parentInteractionId ,
866+ verbose ,
867+ traceDisabled ,
868+ traceTensors ,
869+ conversationIndexMemory ,
870+ traceNumber ,
871+ additionalInfo ,
872+ incompleteResponse
873+ );
874+ cleanUpResource (tools );
875+ }
876+
816877 private void saveMessage (
817878 ConversationIndexMemory memory ,
818879 String question ,
0 commit comments