Skip to content

Commit 08b90b9

Browse files
committed
feat: add verbose_filter feature and expose verbose parameter to per agent
Signed-off-by: Pavan Yekbote <pybot@amazon.com>
1 parent 5964268 commit 08b90b9

File tree

7 files changed

+316
-32
lines changed

7 files changed

+316
-32
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ public class MLChatAgentRunner implements MLAgentRunner {
122122
public static final String INJECT_DATETIME_FIELD = "inject_datetime";
123123
public static final String DATETIME_FORMAT_FIELD = "datetime_format";
124124
public static final String SYSTEM_PROMPT_FIELD = "system_prompt";
125+
public static final String VERBOSE_FILTER = "verbose_filter";
125126

126127
private static final String DEFAULT_MAX_ITERATIONS = "10";
127128
private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task";
@@ -300,6 +301,7 @@ private void runReAct(
300301
String parentInteractionId = tmpParameters.get(MLAgentExecutor.PARENT_INTERACTION_ID);
301302
boolean verbose = Boolean.parseBoolean(tmpParameters.getOrDefault(VERBOSE, "false"));
302303
boolean traceDisabled = tmpParameters.containsKey(DISABLE_TRACE) && Boolean.parseBoolean(tmpParameters.get(DISABLE_TRACE));
304+
List<String> traceFilter = parseTraceFilter(tmpParameters.get(VERBOSE_FILTER));
303305

304306
// Create root interaction.
305307
ConversationIndexMemory conversationIndexMemory = (ConversationIndexMemory) memory;
@@ -379,13 +381,15 @@ private void runReAct(
379381
lastActionInput.set(actionInput);
380382
lastToolSelectionResponse.set(thoughtResponse);
381383

382-
traceTensors
383-
.add(
384-
ModelTensors
385-
.builder()
386-
.mlModelTensors(List.of(ModelTensor.builder().name("response").result(thoughtResponse).build()))
387-
.build()
388-
);
384+
if (shouldIncludeInTrace("LLM", traceFilter)) {
385+
traceTensors
386+
.add(
387+
ModelTensors
388+
.builder()
389+
.mlModelTensors(List.of(ModelTensor.builder().name("response").result(thoughtResponse).build()))
390+
.build()
391+
);
392+
}
389393

390394
saveTraceData(
391395
conversationIndexMemory,
@@ -487,18 +491,20 @@ private void runReAct(
487491

488492
sessionMsgAnswerBuilder.append(outputToOutputString(filteredOutput));
489493
streamingWrapper.sendToolResponse(outputToOutputString(output), sessionId, parentInteractionId);
490-
traceTensors
491-
.add(
492-
ModelTensors
493-
.builder()
494-
.mlModelTensors(
495-
Collections
496-
.singletonList(
497-
ModelTensor.builder().name("response").result(sessionMsgAnswerBuilder.toString()).build()
498-
)
499-
)
500-
.build()
501-
);
494+
if (shouldIncludeInTrace(lastAction.get(), traceFilter)) {
495+
traceTensors
496+
.add(
497+
ModelTensors
498+
.builder()
499+
.mlModelTensors(
500+
Collections
501+
.singletonList(
502+
ModelTensor.builder().name("response").result(sessionMsgAnswerBuilder.toString()).build()
503+
)
504+
)
505+
.build()
506+
);
507+
}
502508

503509
if (finalI == maxIterations - 1) {
504510
handleMaxIterationsReached(
@@ -842,6 +848,21 @@ static Map<String, String> constructLLMParams(LLMSpec llm, Map<String, String> p
842848
return tmpParameters;
843849
}
844850

851+
private static List<String> parseTraceFilter(String traceFilterParam) {
852+
if (traceFilterParam == null || traceFilterParam.trim().isEmpty()) {
853+
return null;
854+
}
855+
return List.of(traceFilterParam.split(","));
856+
}
857+
858+
private static boolean shouldIncludeInTrace(String toolName, List<String> traceFilter) {
859+
if (traceFilter == null) {
860+
return true;
861+
}
862+
863+
return traceFilter.contains(toolName);
864+
}
865+
845866
public static void returnFinalResponse(
846867
String sessionId,
847868
ActionListener<Object> listener,

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ public class MLPlanExecuteAndReflectAgentRunner implements MLAgentRunner {
154154
public static final String INJECT_DATETIME_FIELD = "inject_datetime";
155155
public static final String DATETIME_FORMAT_FIELD = "datetime_format";
156156

157+
public static final String EXECUTOR_VERBOSE = "executor_verbose";
158+
public static final String EXECUTOR_VERBOSE_FILTER = "executor_verbose_filter";
159+
157160
public MLPlanExecuteAndReflectAgentRunner(
158161
Client client,
159162
Settings settings,
@@ -435,6 +438,15 @@ private void executePlanningLoop(
435438
allParams.getOrDefault(EXECUTOR_MESSAGE_HISTORY_LIMIT, DEFAULT_EXECUTOR_MESSAGE_HISTORY_LIMIT)
436439
);
437440

441+
// Pass through verbose and verbose_filter if provided
442+
if (allParams.containsKey(EXECUTOR_VERBOSE)) {
443+
reactParams.put(AgentUtils.VERBOSE, allParams.get(EXECUTOR_VERBOSE));
444+
}
445+
446+
if (allParams.containsKey(EXECUTOR_VERBOSE_FILTER)) {
447+
reactParams.put(MLChatAgentRunner.VERBOSE_FILTER, allParams.get(EXECUTOR_VERBOSE_FILTER));
448+
}
449+
438450
AgentMLInput agentInput = AgentMLInput
439451
.AgentMLInputBuilder()
440452
.agentId(reActAgentId)
@@ -449,8 +461,9 @@ private void executePlanningLoop(
449461

450462
// Navigate through the structure to get the response
451463
Map<String, String> results = new HashMap<>();
464+
List<String> allResponses = new ArrayList<>();
452465

453-
// Process tensors in a single stream
466+
// Process tensors to collect all responses
454467
reactResult.getMlModelOutputs().stream().flatMap(output -> output.getMlModelTensors().stream()).forEach(tensor -> {
455468
switch (tensor.getName()) {
456469
case MEMORY_ID_FIELD:
@@ -459,14 +472,35 @@ private void executePlanningLoop(
459472
case PARENT_INTERACTION_ID_FIELD:
460473
results.put(PARENT_INTERACTION_ID_FIELD, tensor.getResult());
461474
break;
462-
default:
463-
Map<String, ?> dataMap = tensor.getDataAsMap();
464-
if (dataMap != null && dataMap.containsKey(RESPONSE_FIELD)) {
465-
results.put(STEP_RESULT_FIELD, (String) dataMap.get(RESPONSE_FIELD));
475+
case RESPONSE_FIELD:
476+
if (tensor.getResult() != null) {
477+
allResponses.add(tensor.getResult());
478+
} else {
479+
Map<String, ?> dataMap = tensor.getDataAsMap();
480+
if (dataMap != null && dataMap.containsKey(RESPONSE_FIELD)) {
481+
allResponses.add((String) dataMap.get(RESPONSE_FIELD));
482+
}
466483
}
467484
}
468485
});
469486

487+
if (!allResponses.isEmpty()) {
488+
StringBuilder stepResult = new StringBuilder();
489+
stepResult.append(allResponses.getLast());
490+
if (allResponses.size() > 1) {
491+
stepResult.append("\n\n<step-traces>");
492+
}
493+
494+
for (int i = 0; i < allResponses.size() - 1; i++) {
495+
stepResult.append("\n\n").append(allResponses.get(i));
496+
if (i == allResponses.size() - 2) {
497+
stepResult.append("\n</step-traces>");
498+
}
499+
}
500+
501+
results.put(STEP_RESULT_FIELD, stepResult.toString());
502+
}
503+
470504
if (!results.containsKey(STEP_RESULT_FIELD)) {
471505
throw new IllegalStateException("No valid response found in ReAct agent output");
472506
}
@@ -502,8 +536,17 @@ private void executePlanningLoop(
502536
}, e -> log.error("Failed to update task {} with executor memory ID", taskId, e)));
503537
}
504538

505-
completedSteps.add(String.format("\nStep %d: %s\n", stepsExecuted + 1, stepToExecute));
506-
completedSteps.add(String.format("\nStep %d Result: %s\n", stepsExecuted + 1, results.get(STEP_RESULT_FIELD)));
539+
completedSteps.add(String.format("\n<step-%d>\n%s\n</step-%d>\n", stepsExecuted + 1, stepToExecute, stepsExecuted + 1));
540+
completedSteps
541+
.add(
542+
String
543+
.format(
544+
"\n<step-%d-result>\n%s\n</step-%d-result>\n",
545+
stepsExecuted + 1,
546+
results.get(STEP_RESULT_FIELD),
547+
stepsExecuted + 1
548+
)
549+
);
507550

508551
saveTraceData(
509552
(ConversationIndexMemory) memory,

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ public class PromptTemplate {
2828
+ "${parameters."
2929
+ PLANNER_PROMPT_FIELD
3030
+ "} \n"
31-
+ "Objective: ${parameters."
31+
+ "Objective: ```${parameters."
3232
+ USER_PROMPT_FIELD
33-
+ "} \n\nRemember: Respond only in JSON format following the required schema.";
33+
+ "}``` \n\nRemember: Respond only in JSON format following the required schema.";
3434

3535
public static final String DEFAULT_REFLECT_PROMPT_TEMPLATE = "${parameters."
3636
+ DEFAULT_PROMPT_TOOLS_FIELD
@@ -41,10 +41,10 @@ public class PromptTemplate {
4141
+ "Objective: ```${parameters."
4242
+ USER_PROMPT_FIELD
4343
+ "}```\n\n"
44-
+ "Original plan:\n[${parameters."
44+
+ "Previous plan:\n[${parameters."
4545
+ STEPS_FIELD
4646
+ "}] \n\n"
47-
+ "You have currently executed the following steps from the original plan: \n[${parameters."
47+
+ "You have currently executed the following steps: \n[${parameters."
4848
+ COMPLETED_STEPS_FIELD
4949
+ "}] \n\n"
5050
+ "${parameters."

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ public void invokeRemoteService(
106106
SdkHttpFullRequest request;
107107
switch (connector.getActionHttpMethod(action).toUpperCase(Locale.ROOT)) {
108108
case "POST":
109-
log.debug("original payload to remote model: " + payload);
109+
log.info("\n\n\noriginal payload to remote model: " + payload);
110110
request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST);
111111
break;
112112
case "GET":

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ public void invokeRemoteService(
105105
SdkHttpFullRequest request;
106106
switch (connector.getActionHttpMethod(action).toUpperCase(Locale.ROOT)) {
107107
case "POST":
108-
log.debug("original payload to remote model: " + payload);
108+
log.info("\n\n\noriginal payload to remote model: " + payload);
109109
request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST);
110110
break;
111111
case "GET":

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,4 +1171,96 @@ public void testConstructLLMParams_DefaultValues() {
11711171
Assert.assertTrue(result.containsKey(AgentUtils.RESPONSE_FORMAT_INSTRUCTION));
11721172
Assert.assertTrue(result.containsKey(AgentUtils.TOOL_RESPONSE));
11731173
}
1174+
1175+
@Test
1176+
public void testVerboseFilterWithSpecificFields() {
1177+
// Create an MLAgent and run with verbose_filter
1178+
MLAgent mlAgent = createMLAgentWithTools();
1179+
Map<String, String> params = new HashMap<>();
1180+
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
1181+
params.put("verbose", "true");
1182+
params.put("verbose_filter", "firstTool");
1183+
1184+
mlChatAgentRunner.run(mlAgent, params, agentActionListener, null);
1185+
1186+
// Capture the response
1187+
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
1188+
verify(agentActionListener).onResponse(responseCaptor.capture());
1189+
1190+
Object capturedResponse = responseCaptor.getValue();
1191+
assertTrue(capturedResponse instanceof ModelTensorOutput);
1192+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
1193+
1194+
// Count response fields across all outputs
1195+
int responseFieldCount = 0;
1196+
for (ModelTensors output : modelTensorOutput.getMlModelOutputs()) {
1197+
for (ModelTensor tensor : output.getMlModelTensors()) {
1198+
if ("response".equals(tensor.getName())) {
1199+
responseFieldCount++;
1200+
}
1201+
}
1202+
}
1203+
1204+
// Verify there is more than one response field
1205+
assertEquals(2, responseFieldCount);
1206+
}
1207+
1208+
@Test
1209+
public void testVerboseFilterWithInvalidPath() {
1210+
// Create an MLAgent and run with invalid verbose_filter
1211+
MLAgent mlAgent = createMLAgentWithTools();
1212+
Map<String, String> params = new HashMap<>();
1213+
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
1214+
params.put("verbose", "true");
1215+
params.put("verbose_filter", "RandomTool");
1216+
1217+
mlChatAgentRunner.run(mlAgent, params, agentActionListener, null);
1218+
1219+
// Should still work but filter nothing
1220+
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
1221+
verify(agentActionListener).onResponse(responseCaptor.capture());
1222+
1223+
Object capturedResponse = responseCaptor.getValue();
1224+
assertTrue(capturedResponse instanceof ModelTensorOutput);
1225+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
1226+
int responseFieldCount = 0;
1227+
for (ModelTensors output : modelTensorOutput.getMlModelOutputs()) {
1228+
for (ModelTensor tensor : output.getMlModelTensors()) {
1229+
if ("response".equals(tensor.getName())) {
1230+
responseFieldCount++;
1231+
}
1232+
}
1233+
}
1234+
1235+
assertEquals(1, responseFieldCount);
1236+
}
1237+
1238+
@Test
1239+
public void testVerboseFilterWithoutVerbose() {
1240+
// Create an MLAgent and run with verbose_filter but verbose=false
1241+
MLAgent mlAgent = createMLAgentWithTools();
1242+
Map<String, String> params = new HashMap<>();
1243+
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
1244+
params.put("verbose", "false");
1245+
1246+
mlChatAgentRunner.run(mlAgent, params, agentActionListener, null);
1247+
1248+
// verbose_filter should be ignored when verbose=false
1249+
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
1250+
verify(agentActionListener).onResponse(responseCaptor.capture());
1251+
1252+
Object capturedResponse = responseCaptor.getValue();
1253+
assertTrue(capturedResponse instanceof ModelTensorOutput);
1254+
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
1255+
int responseFieldCount = 0;
1256+
for (ModelTensors output : modelTensorOutput.getMlModelOutputs()) {
1257+
for (ModelTensor tensor : output.getMlModelTensors()) {
1258+
if ("response".equals(tensor.getName())) {
1259+
responseFieldCount++;
1260+
}
1261+
}
1262+
}
1263+
1264+
assertEquals(1, responseFieldCount);
1265+
}
11741266
}

0 commit comments

Comments
 (0)