|
7 | 7 |
|
8 | 8 | import static org.opensearch.ml.common.MLTask.STATE_FIELD; |
9 | 9 | import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD; |
| 10 | +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; |
10 | 11 | import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD; |
11 | 12 | import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD; |
12 | 13 | import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly; |
@@ -460,9 +461,9 @@ private void executePlanningLoop( |
460 | 461 | results.put(PARENT_INTERACTION_ID_FIELD, tensor.getResult()); |
461 | 462 | break; |
462 | 463 | default: |
463 | | - Map<String, ?> dataMap = tensor.getDataAsMap(); |
464 | | - if (dataMap != null && dataMap.containsKey(RESPONSE_FIELD)) { |
465 | | - results.put(STEP_RESULT_FIELD, (String) dataMap.get(RESPONSE_FIELD)); |
| 464 | + String stepResult = parseTensorDataMap(tensor); |
| 465 | + if (stepResult != null) { |
| 466 | + results.put(STEP_RESULT_FIELD, stepResult); |
466 | 467 | } |
467 | 468 | } |
468 | 469 | }); |
@@ -502,8 +503,17 @@ private void executePlanningLoop( |
502 | 503 | }, e -> log.error("Failed to update task {} with executor memory ID", taskId, e))); |
503 | 504 | } |
504 | 505 |
|
505 | | - completedSteps.add(String.format("\nStep %d: %s\n", stepsExecuted + 1, stepToExecute)); |
506 | | - completedSteps.add(String.format("\nStep %d Result: %s\n", stepsExecuted + 1, results.get(STEP_RESULT_FIELD))); |
| 506 | + completedSteps.add(String.format("\n<step-%d>\n%s\n</step-%d>\n", stepsExecuted + 1, stepToExecute, stepsExecuted + 1)); |
| 507 | + completedSteps |
| 508 | + .add( |
| 509 | + String |
| 510 | + .format( |
| 511 | + "\n<step-%d-result>\n%s\n</step-%d-result>\n", |
| 512 | + stepsExecuted + 1, |
| 513 | + results.get(STEP_RESULT_FIELD), |
| 514 | + stepsExecuted + 1 |
| 515 | + ) |
| 516 | + ); |
507 | 517 |
|
508 | 518 | saveTraceData( |
509 | 519 | (ConversationIndexMemory) memory, |
@@ -544,6 +554,39 @@ private void executePlanningLoop( |
544 | 554 | client.execute(MLPredictionTaskAction.INSTANCE, request, planListener); |
545 | 555 | } |
546 | 556 |
|
| 557 | + @VisibleForTesting |
| 558 | + String parseTensorDataMap(ModelTensor tensor) { |
| 559 | + Map<String, ?> dataMap = tensor.getDataAsMap(); |
| 560 | + if (dataMap == null) { |
| 561 | + return null; |
| 562 | + } |
| 563 | + |
| 564 | + StringBuilder stepResult = new StringBuilder(); |
| 565 | + if (dataMap.containsKey(RESPONSE_FIELD)) { |
| 566 | + stepResult.append((String) dataMap.get(RESPONSE_FIELD)); |
| 567 | + } |
| 568 | + |
| 569 | + if (dataMap.containsKey(INTERACTIONS_ADDITIONAL_INFO_FIELD)) { |
| 570 | + stepResult.append("\n<step-traces>\n"); |
| 571 | + ((Map<String, Object>) dataMap.get(INTERACTIONS_ADDITIONAL_INFO_FIELD)) |
| 572 | + .forEach( |
| 573 | + (key, value) -> stepResult |
| 574 | + .append("<") |
| 575 | + .append(key) |
| 576 | + .append(">") |
| 577 | + .append("\n") |
| 578 | + .append(value) |
| 579 | + .append("\n") |
| 580 | + .append("</") |
| 581 | + .append(key) |
| 582 | + .append(">") |
| 583 | + ); |
| 584 | + stepResult.append("\n</step-traces>\n"); |
| 585 | + } |
| 586 | + |
| 587 | + return stepResult.toString(); |
| 588 | + } |
| 589 | + |
547 | 590 | @VisibleForTesting |
548 | 591 | Map<String, Object> parseLLMOutput(Map<String, String> allParams, ModelTensorOutput modelTensorOutput) { |
549 | 592 | Map<String, Object> modelOutput = new HashMap<>(); |
|
0 commit comments