Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: issue #10596 by making the iteration node outputs right #11394

Merged
merged 5 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions api/core/app/entities/queue_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum, StrEnum
from typing import Any, Optional

from pydantic import BaseModel, field_validator
from pydantic import BaseModel

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.node_entities import NodeRunMetadataKey
Expand Down Expand Up @@ -113,18 +113,6 @@ class QueueIterationNextEvent(AppQueueEvent):
output: Optional[Any] = None # output for the current iteration
duration: Optional[float] = None

@field_validator("output", mode="before")
@classmethod
def set_output(cls, v):
"""
Set output
"""
if v is None:
return None
if isinstance(v, int | float | str | bool | dict | list):
return v
raise ValueError("output must be a valid type")


class QueueIterationCompletedEvent(AppQueueEvent):
"""
Expand Down
76 changes: 46 additions & 30 deletions api/core/workflow/nodes/iteration/iteration_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from flask import Flask, current_app

from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.variables import IntegerVariable
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeRunResult,
Expand Down Expand Up @@ -155,18 +155,19 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
iteration_node_data=self.node_data,
index=0,
pre_iteration_output=None,
duration=None,
)
iter_run_map: dict[str, float] = {}
outputs: list[Any] = [None] * len(iterator_list_value)
try:
if self.node_data.is_parallel:
futures: list[Future] = []
q = Queue()
q: Queue = Queue()
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
self._run_single_iter_parallel,
current_app._get_current_object(),
current_app._get_current_object(), # type: ignore
q,
iterator_list_value,
inputs,
Expand All @@ -181,6 +182,7 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
future.add_done_callback(thread_pool.task_done_callback)
futures.append(future)
succeeded_count = 0
empty_count = 0
while True:
try:
event = q.get(timeout=1)
Expand Down Expand Up @@ -208,33 +210,38 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
else:
for _ in range(len(iterator_list_value)):
yield from self._run_single_iter(
iterator_list_value,
variable_pool,
inputs,
outputs,
start_at,
graph_engine,
iteration_graph,
iter_run_map,
iterator_list_value=iterator_list_value,
variable_pool=variable_pool,
inputs=inputs,
outputs=outputs,
start_at=start_at,
graph_engine=graph_engine,
iteration_graph=iteration_graph,
iter_run_map=iter_run_map,
)
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs = [output for output in outputs if output is not None]

# Flatten the list of lists
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
outputs = [item for sublist in outputs for item in sublist]

yield IterationRunSucceededEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
)

yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": jsonable_encoder(outputs)},
outputs={"output": outputs},
metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map},
)
)
Expand All @@ -248,7 +255,7 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=str(e),
Expand Down Expand Up @@ -280,7 +287,7 @@ def _extract_variable_selector_to_variable_mapping(
:param node_data: node data
:return:
"""
variable_mapping = {
variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": node_data.iterator_selector,
}

Expand Down Expand Up @@ -308,7 +315,7 @@ def _extract_variable_selector_to_variable_mapping(
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config
)
sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
except NotImplementedError:
sub_node_variable_mapping = {}

Expand All @@ -329,8 +336,12 @@ def _extract_variable_selector_to_variable_mapping(
return variable_mapping

def _handle_event_metadata(
self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str
) -> NodeRunStartedEvent | BaseNodeEvent:
self,
*,
event: BaseNodeEvent | InNodeEvent,
iter_run_index: int,
parallel_mode_run_id: str | None,
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
"""
add iteration metadata to event.
"""
Expand All @@ -355,6 +366,7 @@ def _handle_event_metadata(

def _run_single_iter(
self,
*,
iterator_list_value: list[str],
variable_pool: VariablePool,
inputs: dict[str, list],
Expand All @@ -373,12 +385,12 @@ def _run_single_iter(
try:
rst = graph_engine.run()
# get current iteration index
current_index = variable_pool.get([self.node_id, "index"]).value
index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(index_variable, IntegerVariable):
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
current_index = index_variable.value
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
next_index = int(current_index) + 1

if current_index is None:
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
Expand All @@ -391,7 +403,9 @@ def _run_single_iter(
continue

if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata(event, current_index, parallel_mode_run_id)
yield self._handle_event_metadata(
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
Expand All @@ -404,7 +418,7 @@ def _run_single_iter(
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
Expand All @@ -417,7 +431,7 @@ def _run_single_iter(
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
Expand All @@ -429,9 +443,11 @@ def _run_single_iter(
)
)
return
else:
event = cast(InNodeEvent, event)
metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id)
elif isinstance(event, InNodeEvent):
# event = cast(InNodeEvent, event)
metadata_event = self._handle_event_metadata(
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
if isinstance(event, NodeRunFailedEvent):
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
yield NodeInIterationFailedEvent(
Expand Down Expand Up @@ -513,7 +529,7 @@ def _run_single_iter(
iteration_node_data=self.node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
pre_iteration_output=current_iteration_output or None,
duration=duration,
)

Expand Down Expand Up @@ -551,7 +567,7 @@ def _run_single_iter_parallel(
index: int,
item: Any,
iter_run_map: dict[str, float],
) -> Generator[NodeEvent | InNodeEvent, None, None]:
):
"""
run single iteration in parallel mode
"""
Expand Down
Loading