Skip to content

Commit 5a7f4cf

Browse files
authored
Update interaction with failure message on agent execution failure (#4198)
* fix: update message with agent failure message Signed-off-by: Pavan Yekbote <pybot@amazon.com> * fix test cases Signed-off-by: Pavan Yekbote <pybot@amazon.com> * add more tests Signed-off-by: Pavan Yekbote <pybot@amazon.com> * add more test cases Signed-off-by: Pavan Yekbote <pybot@amazon.com> * feat: add more test cases Signed-off-by: Pavan Yekbote <pybot@amazon.com> * fix: return error if memory not used Signed-off-by: Pavan Yekbote <pybot@amazon.com> * fix: test case Signed-off-by: Pavan Yekbote <pybot@amazon.com> --------- Signed-off-by: Pavan Yekbote <pybot@amazon.com>
1 parent 97ad298 commit 5a7f4cf

File tree

2 files changed

+699
-9
lines changed

2 files changed

+699
-9
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 107 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,12 @@ public void execute(Input input, ActionListener<Output> listener) {
222222
String appType = mlAgent.getAppType();
223223
String question = inputDataSet.getParameters().get(QUESTION);
224224

225+
if (parentInteractionId != null && regenerateInteractionId != null) {
226+
throw new IllegalArgumentException(
227+
"Provide either `parent_interaction_id` to update an existing interaction, or `regenerate_interaction_id` to create a new one."
228+
);
229+
}
230+
225231
MLTask mlTask = MLTask
226232
.builder()
227233
.taskType(MLTaskType.AGENT_EXECUTION)
@@ -289,7 +295,52 @@ public void execute(Input input, ActionListener<Output> listener) {
289295
listener.onFailure(ex);
290296
}));
291297
} else {
292-
executeAgent(inputDataSet, mlTask, isAsync, memoryId, mlAgent, outputs, modelTensors, listener);
298+
// For existing conversations, create memory instance using factory
299+
if (memorySpec != null && memorySpec.getType() != null) {
300+
ConversationIndexMemory.Factory factory = (ConversationIndexMemory.Factory) memoryFactoryMap
301+
.get(memorySpec.getType());
302+
if (factory != null) {
303+
// memoryId exists, so create returns an object with existing memory, therefore name can
304+
// be null
305+
factory
306+
.create(
307+
null,
308+
memoryId,
309+
appType,
310+
ActionListener
311+
.wrap(
312+
createdMemory -> executeAgent(
313+
inputDataSet,
314+
mlTask,
315+
isAsync,
316+
memoryId,
317+
mlAgent,
318+
outputs,
319+
modelTensors,
320+
listener,
321+
createdMemory
322+
),
323+
ex -> {
324+
log.error("Failed to find memory with memory_id: {}", memoryId, ex);
325+
listener.onFailure(ex);
326+
}
327+
)
328+
);
329+
return;
330+
}
331+
}
332+
333+
executeAgent(
334+
inputDataSet,
335+
mlTask,
336+
isAsync,
337+
memoryId,
338+
mlAgent,
339+
outputs,
340+
modelTensors,
341+
listener,
342+
null
343+
);
293344
}
294345
} catch (Exception e) {
295346
log.error("Failed to parse ml agent {}", agentId, e);
@@ -364,7 +415,8 @@ private void saveRootInteractionAndExecute(
364415
mlAgent,
365416
outputs,
366417
modelTensors,
367-
listener
418+
listener,
419+
memory
368420
),
369421
e -> {
370422
log.error("Failed to regenerate for interaction {}", regenerateInteractionId, e);
@@ -373,7 +425,7 @@ private void saveRootInteractionAndExecute(
373425
)
374426
);
375427
} else {
376-
executeAgent(inputDataSet, mlTask, isAsync, memory.getConversationId(), mlAgent, outputs, modelTensors, listener);
428+
executeAgent(inputDataSet, mlTask, isAsync, memory.getConversationId(), mlAgent, outputs, modelTensors, listener, memory);
377429
}
378430
}, ex -> {
379431
log.error("Failed to create parent interaction", ex);
@@ -389,7 +441,8 @@ private void executeAgent(
389441
MLAgent mlAgent,
390442
List<ModelTensors> outputs,
391443
List<ModelTensor> modelTensors,
392-
ActionListener<Output> listener
444+
ActionListener<Output> listener,
445+
ConversationIndexMemory memory
393446
) {
394447
String mcpConnectorConfigJSON = (mlAgent.getParameters() != null) ? mlAgent.getParameters().get(MCP_CONNECTORS_FIELD) : null;
395448
if (mcpConnectorConfigJSON != null && !mlFeatureEnabledSetting.isMcpConnectorEnabled()) {
@@ -408,14 +461,15 @@ private void executeAgent(
408461
}
409462

410463
MLAgentRunner mlAgentRunner = getAgentRunner(mlAgent);
464+
String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID);
465+
411466
// If async is true, index ML task and return the taskID. Also add memoryID to the task if it exists
412467
if (isAsync) {
413468
Map<String, Object> agentResponse = new HashMap<>();
414469
if (memoryId != null && !memoryId.isEmpty()) {
415470
agentResponse.put(MEMORY_ID, memoryId);
416471
}
417472

418-
String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID);
419473
if (parentInteractionId != null && !parentInteractionId.isEmpty()) {
420474
agentResponse.put(PARENT_INTERACTION_ID, parentInteractionId);
421475
}
@@ -432,15 +486,28 @@ private void executeAgent(
432486
outputBuilder.setResponse(agentResponse);
433487
}
434488
listener.onResponse(outputBuilder);
435-
ActionListener<Object> agentActionListener = createAsyncTaskUpdater(mlTask, outputs, modelTensors);
489+
ActionListener<Object> agentActionListener = createAsyncTaskUpdater(
490+
mlTask,
491+
outputs,
492+
modelTensors,
493+
parentInteractionId,
494+
memory
495+
);
436496
inputDataSet.getParameters().put(TASK_ID_FIELD, taskId);
437497
mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
438498
}, e -> {
439499
log.error("Failed to create task for agent async execution", e);
440500
listener.onFailure(e);
441501
}));
442502
} else {
443-
ActionListener<Object> agentActionListener = createAgentActionListener(listener, outputs, modelTensors, mlAgent.getType());
503+
ActionListener<Object> agentActionListener = createAgentActionListener(
504+
listener,
505+
outputs,
506+
modelTensors,
507+
mlAgent.getType(),
508+
parentInteractionId,
509+
memory
510+
);
444511
mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
445512
}
446513
}
@@ -450,7 +517,9 @@ private ActionListener<Object> createAgentActionListener(
450517
ActionListener<Output> listener,
451518
List<ModelTensors> outputs,
452519
List<ModelTensor> modelTensors,
453-
String agentType
520+
String agentType,
521+
String parentInteractionId,
522+
ConversationIndexMemory memory
454523
) {
455524
return ActionListener.wrap(output -> {
456525
if (output != null) {
@@ -461,11 +530,18 @@ private ActionListener<Object> createAgentActionListener(
461530
}
462531
}, ex -> {
463532
log.error("Failed to run {} agent", agentType, ex);
533+
updateInteractionWithFailure(parentInteractionId, memory, ex.getMessage());
464534
listener.onFailure(ex);
465535
});
466536
}
467537

468-
private ActionListener<Object> createAsyncTaskUpdater(MLTask mlTask, List<ModelTensors> outputs, List<ModelTensor> modelTensors) {
538+
private ActionListener<Object> createAsyncTaskUpdater(
539+
MLTask mlTask,
540+
List<ModelTensors> outputs,
541+
List<ModelTensor> modelTensors,
542+
String parentInteractionId,
543+
ConversationIndexMemory memory
544+
) {
469545
String taskId = mlTask.getTaskId();
470546
Map<String, Object> agentResponse = new HashMap<>();
471547
Map<String, Object> updatedTask = new HashMap<>();
@@ -508,6 +584,8 @@ private ActionListener<Object> createAsyncTaskUpdater(MLTask mlTask, List<ModelT
508584
e -> log.error("Failed to update ML task {} with agent execution results", taskId)
509585
)
510586
);
587+
588+
updateInteractionWithFailure(parentInteractionId, memory, ex.getMessage());
511589
});
512590
}
513591

@@ -616,4 +694,24 @@ public void indexMLTask(MLTask mlTask, ActionListener<IndexResponse> listener) {
616694
listener.onFailure(e);
617695
}
618696
}
697+
698+
private void updateInteractionWithFailure(String interactionId, ConversationIndexMemory memory, String errorMessage) {
699+
if (interactionId != null && memory != null) {
700+
String failureMessage = "Agent execution failed: " + errorMessage;
701+
Map<String, Object> updateContent = new HashMap<>();
702+
updateContent.put(RESPONSE_FIELD, failureMessage);
703+
704+
memory
705+
.getMemoryManager()
706+
.updateInteraction(
707+
interactionId,
708+
updateContent,
709+
ActionListener
710+
.wrap(
711+
res -> log.info("Updated interaction {} with failure message", interactionId),
712+
e -> log.warn("Failed to update interaction {} with failure message", interactionId, e)
713+
)
714+
);
715+
}
716+
}
619717
}

0 commit comments

Comments
 (0)