Skip to content

Commit

Permalink
add memory id and interation id for non-verbose (#2004)
Browse files Browse the repository at this point in the history
Signed-off-by: Jing Zhang <jngz@amazon.com>
  • Loading branch information
jngz-es authored Feb 3, 2024
1 parent a62ecc1 commit b84b130
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,23 @@ private void runReAct(
);

List<ModelTensors> finalModelTensors = new ArrayList<>();
finalModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
List
.of(
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
ModelTensor
.builder()
.name(MLAgentExecutor.PARENT_INTERACTION_ID)
.result(parentInteractionId)
.build()
)
)
.build()
);
finalModelTensors
.add(
ModelTensors
Expand Down Expand Up @@ -603,6 +620,23 @@ private void runReAct(
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build());
} else {
List<ModelTensors> finalModelTensors = new ArrayList<>();
finalModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
List
.of(
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
ModelTensor
.builder()
.name(MLAgentExecutor.PARENT_INTERACTION_ID)
.result(parentInteractionId)
.build()
)
)
.build()
);
finalModelTensors
.add(
ModelTensors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,46 @@ public void testParsingJsonBlockFromResponse3() {
assertEquals("parsed final answer", modelTensor2.getResult());
}

@Test
public void testParsingJsonBlockFromResponse4() {
// Prepare the response with JSON block
String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", "
+ "\"action_input\":\"parsed action input\", \"final_answer\":\"parsed final answer\"}";
String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text";

// Mock LLM response to not contain "thought" but contain "response" with JSON block
Map<String, String> llmResponse = new HashMap<>();
llmResponse.put("response", responseWithJsonBlock);
doAnswer(getLLMAnswer(llmResponse))
.when(client)
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));

// Create an MLAgent and run the MLChatAgentRunner
MLAgent mlAgent = createMLAgentWithTools();
Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
params.put("verbose", "false");
mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Capture the response passed to the listener
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
verify(agentActionListener).onResponse(responseCaptor.capture());

// Extract the captured response
Object capturedResponse = responseCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;

ModelTensor memoryIdModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0);
ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1);

// Verify that the parsed values from JSON block are correctly set
assertEquals("memory_id", memoryIdModelTensor.getName());
assertEquals("conversation_id", memoryIdModelTensor.getResult());
assertEquals("parent_interaction_id", parentInteractionModelTensor.getName());
assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult());
}

@Test
public void testRunWithIncludeOutputNotSet() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Expand All @@ -293,7 +333,7 @@ public void testRunWithIncludeOutputNotSet() {
mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener);
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors();
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
assertEquals(1, agentOutput.size());
// Respond with last tool output
assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));
Expand Down Expand Up @@ -322,7 +362,7 @@ public void testRunWithIncludeOutputMLModel() {
mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener);
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors();
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
assertEquals(1, agentOutput.size());
// Respond with last tool output
assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));
Expand Down Expand Up @@ -356,7 +396,7 @@ public void testRunWithIncludeOutputSet() {
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors();
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
assertEquals(1, agentOutput.size());
// Respond with last tool output
assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));
Expand Down

0 comments on commit b84b130

Please sign in to comment.