Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Parallel Execution of Nodes in Workflows #8192

Merged
merged 148 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
148 commits
Select commit Hold shift + click to select a range
8217c46
add new graph structure
takatost Jun 24, 2024
fe27c97
add runtime graph
takatost Jun 25, 2024
216910a
add runtime state of graph
takatost Jun 25, 2024
aaa98c7
optimize
takatost Jun 26, 2024
1d8ecac
save
takatost Jun 26, 2024
8375517
save
takatost Jun 29, 2024
0f19b2a
optimize graph
takatost Jul 2, 2024
1b6cd97
completed graph init test
takatost Jul 4, 2024
03f56a0
refactor graph
takatost Jul 5, 2024
fed068a
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
takatost Jul 7, 2024
1adaf42
refactor graph
takatost Jul 7, 2024
0e885a3
refactor runtime
takatost Jul 8, 2024
d77b689
completed parallel tests
takatost Jul 10, 2024
821e09b
add run logics
takatost Jul 12, 2024
00fb23d
graph engine implement
takatost Jul 15, 2024
00ec36d
add graph engine test
takatost Jul 16, 2024
775e52d
merge
takatost Jul 16, 2024
4ef3d4e
optimize
takatost Jul 16, 2024
16e2d00
optimize
takatost Jul 16, 2024
cc96acd
fix bugs
takatost Jul 17, 2024
90e518b
fix bugs
takatost Jul 17, 2024
f67a88f
fix test
takatost Jul 17, 2024
7ad77e9
fix test
takatost Jul 18, 2024
dad1a96
finished answer stream output
takatost Jul 19, 2024
beaac50
fix bug
takatost Jul 19, 2024
a603e01
fix bug
takatost Jul 22, 2024
2c695de
fix bugs
takatost Jul 22, 2024
0fe5165
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
takatost Jul 23, 2024
7303b53
fix bug
takatost Jul 23, 2024
e9bfeda
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
takatost Jul 23, 2024
ec77607
save
takatost Jul 23, 2024
833584b
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
takatost Jul 24, 2024
f4eb7cd
add end stream output test
takatost Jul 24, 2024
4097f7c
add parallel branch output
takatost Jul 25, 2024
7c67ba8
remove threadpool
takatost Jul 25, 2024
df13316
fix lint
takatost Jul 25, 2024
ae351bd
add iteration support
takatost Jul 25, 2024
a31feac
fix iteration
takatost Jul 25, 2024
38f8c45
add events in interation node
takatost Jul 26, 2024
beea1e1
fix lint
takatost Jul 26, 2024
483f71f
fix logging
takatost Jul 26, 2024
63addf8
add parallel branch events
takatost Jul 26, 2024
88dcd7b
fix bug
takatost Jul 26, 2024
0818b7b
remove iteration special logic
takatost Jul 26, 2024
917aacb
add chatflow app event convert
takatost Jul 30, 2024
c9bb366
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
takatost Jul 30, 2024
8d27ec3
fix bug
takatost Jul 30, 2024
8401a11
feat(workflow): integrate workflow entry with advanced chat app
takatost Aug 13, 2024
14d020f
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
takatost Aug 13, 2024
2980e31
fix issues when merging from main
takatost Aug 13, 2024
674af04
fix migration version depends
takatost Aug 13, 2024
6f6b32e
feat(workflow): integrate workflow entry with workflow app
takatost Aug 14, 2024
1da5862
feat(workflow): fix iteration single debug
takatost Aug 14, 2024
702df31
fix(workflow): fix generate issues in workflow
takatost Aug 15, 2024
c519265
fix: unit tests in workflow
takatost Aug 15, 2024
db9b0ee
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
takatost Aug 15, 2024
91e51ce
fix(workflow): issues by merging main branch
takatost Aug 15, 2024
90221c0
fix: unit tests
takatost Aug 15, 2024
5b5e6e3
fix: answer node unit tests
takatost Aug 15, 2024
1973f50
feat: frontend support parallel
zxhlyh Aug 16, 2024
352c45c
feat(workflow): integrate parallel into workflow apps
takatost Aug 16, 2024
5d78657
fix(workflow): issues in workflow parallels
takatost Aug 16, 2024
755a965
fix(workflow): add parallel id into published events
takatost Aug 18, 2024
617ea4b
fix(workflow): fix parallel bug
takatost Aug 20, 2024
1d88b62
fix(workflow): fix node link to previous node issue
takatost Aug 20, 2024
412be6d
fix bug
takatost Aug 21, 2024
35be41b
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
takatost Aug 21, 2024
e34497d
fix: merge issues
takatost Aug 21, 2024
92072e2
fix: ruff issues
takatost Aug 21, 2024
d6da7b0
fix dialogue_count
takatost Aug 22, 2024
ec4fc78
fix iteration start node
takatost Aug 22, 2024
fe2b300
fix lint
takatost Aug 22, 2024
5b22d8f
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
takatost Aug 22, 2024
42899fb
fix bug
takatost Aug 22, 2024
85d3197
fix end node bug
takatost Aug 24, 2024
4771e85
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
takatost Aug 24, 2024
6c61776
fix test
takatost Aug 25, 2024
1016db1
feat: parallel hover
zxhlyh Aug 26, 2024
76bb8d1
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
takatost Aug 26, 2024
9c8144e
feat: parallel hover
zxhlyh Aug 26, 2024
b9f34f6
fix: iteration start node id
takatost Aug 26, 2024
4e3dc36
fix: workflow run edge status
zxhlyh Aug 27, 2024
4256e9d
chore(iteration): keep start_node_id using in parallel start nodes
takatost Aug 27, 2024
cd52633
fix(graph_engine): parent_parallel_id missing
takatost Aug 27, 2024
b0a81c6
fix(workflow): parallel execution after if-else that only one branch …
takatost Aug 28, 2024
8ba5673
feat: iteration support parallel
zxhlyh Aug 28, 2024
c2bb114
fix(workflow): parallel not yield
takatost Aug 28, 2024
4418fa1
fix: bug
zxhlyh Aug 28, 2024
74c8004
fix(graph_engine): fix execute loops in parallel
takatost Aug 28, 2024
6b6750b
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
takatost Aug 28, 2024
5d34e08
fix: migration
takatost Aug 28, 2024
790dd3b
fix(workflow): duplicate nodes in parallel
takatost Aug 28, 2024
ae22015
fix(workflow): loop check
takatost Aug 28, 2024
f43596f
fix: parallel branch limit
zxhlyh Aug 29, 2024
3e257ae
update the workflow parallel log
YIXIAO0 Aug 29, 2024
32a11cb
update the parallel workflow log for iteration and chatflow preview
YIXIAO0 Aug 29, 2024
1bde57e
delete console logs
YIXIAO0 Aug 29, 2024
7c9081a
fix
zxhlyh Aug 30, 2024
708256e
Merge branch 'feat/workflow-parallel-support' of github.com:langgeniu…
YIXIAO0 Aug 30, 2024
e3ae529
update the onNodeFinished method for nodes being passed through more …
YIXIAO0 Aug 30, 2024
2b5b856
solve the branch issue
YIXIAO0 Aug 30, 2024
e329518
fix a typo
YIXIAO0 Aug 30, 2024
77e62f7
fix(workflow): run node in multi parallel bugs
takatost Aug 30, 2024
162e967
fix(workflow): missing parallel event in workflow app
takatost Aug 30, 2024
d7c0ca8
feat: inner parallels will be added to its corresponding branch
YIXIAO0 Aug 30, 2024
ee1587c
fix: make the End node always nested in the root
YIXIAO0 Aug 30, 2024
71a7d89
fix styling
YIXIAO0 Aug 30, 2024
29b1ce7
fix: node end status
YIXIAO0 Sep 1, 2024
0dabf79
fix(workflow): fix merge branch node id err
takatost Sep 2, 2024
52b4623
fix(workflow): fix merge branch node id err
takatost Sep 2, 2024
43240fc
fix
zxhlyh Sep 2, 2024
5bda3a3
fix(workflow): bugs
takatost Sep 2, 2024
81d09d4
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
takatost Sep 2, 2024
7035f64
fix: next step
zxhlyh Sep 2, 2024
bbc922d
merge main
takatost Sep 2, 2024
35d9c59
Merge remote-tracking branch 'origin/feat/workflow-parallel-support' …
takatost Sep 2, 2024
70aced0
fix
zxhlyh Sep 2, 2024
166365a
feat(workflow): add thread pool
takatost Sep 2, 2024
5ca9df6
feat(workflow): add thread pool
takatost Sep 2, 2024
955884b
chore(workflow): max thread submit count
takatost Sep 2, 2024
f71c51c
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
takatost Sep 2, 2024
d929665
fix: migration
takatost Sep 2, 2024
83343ee
update parallel log
YIXIAO0 Sep 3, 2024
b28c7b1
Merge branch 'feat/workflow-parallel-support' of github.com:langgeniu…
YIXIAO0 Sep 3, 2024
3431b19
update styling and iteration log
YIXIAO0 Sep 3, 2024
36d95e4
fix(iteration): iterator_length not correct
takatost Sep 3, 2024
6bee121
update log in web app
YIXIAO0 Sep 3, 2024
78fa1f6
fix(workflow): detached session issues
takatost Sep 3, 2024
cd42dbd
update the log for iteration nodes
YIXIAO0 Sep 4, 2024
4962b2c
check node edge
zxhlyh Sep 4, 2024
5cb018e
update the method to check if a node is in iteration
YIXIAO0 Sep 4, 2024
4f5dc82
fix
zxhlyh Sep 4, 2024
c625f42
Merge branch 'main' into feat/workflow-parallel-support
zxhlyh Sep 4, 2024
44038b9
fix: iteration copy
zxhlyh Sep 4, 2024
4663463
fix: refine the "isInIteration" for workflow
YIXIAO0 Sep 4, 2024
7e30487
feat: update dsl version
takatost Sep 4, 2024
94432a0
chore: update package versions to 0.8.0-beta1 (#7979)
laipz8200 Sep 4, 2024
4b0d2bf
chore: update build-push.yml to remove unnecessary tags
laipz8200 Sep 4, 2024
ab5bb18
fix(workflow): parent parallel logic
takatost Sep 9, 2024
351702d
fix: iteration log in debug
zxhlyh Sep 9, 2024
4071ea4
fix(workflow): add iteration_id in node start/finished event
takatost Sep 9, 2024
6140f22
fix: iteration log in webapp
zxhlyh Sep 9, 2024
aaebe8f
update the log for the parallel iterations case
YIXIAO0 Sep 10, 2024
8fb6e87
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
takatost Sep 10, 2024
9b7a321
fix merge
takatost Sep 10, 2024
d531827
remove if regex
zxhlyh Sep 10, 2024
706c00c
prepare for 0.8.0
takatost Sep 10, 2024
aaae6bd
parallel tip
zxhlyh Sep 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/advanced_chat/generate_task_pipeline.py
#	api/core/app/apps/workflow/app_runner.py
#	api/core/app/apps/workflow/generate_task_pipeline.py
#	api/core/app/task_pipeline/workflow_cycle_state_manager.py
#	api/core/workflow/entities/variable_pool.py
#	api/core/workflow/nodes/code/code_node.py
#	api/core/workflow/nodes/llm/llm_node.py
#	api/core/workflow/nodes/start/start_node.py
#	api/core/workflow/nodes/variable_assigner/__init__.py
#	api/tests/integration_tests/workflow/nodes/test_llm.py
#	api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
#	api/tests/unit_tests/core/workflow/nodes/test_answer.py
#	api/tests/unit_tests/core/workflow/nodes/test_if_else.py
#	api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
  • Loading branch information
takatost committed Aug 21, 2024
commit 35be41b337f5d07f9c29552c9d20c40a20947bb2
65 changes: 62 additions & 3 deletions api/core/app/apps/advanced_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
Expand Down Expand Up @@ -67,8 +69,9 @@ def generate(

# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id', ''), user)
conversation_id = args.get('conversation_id')
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)

# parse files
files = args['files'] if args.get('files') else []
Expand Down Expand Up @@ -225,6 +228,62 @@ def _generate(self, *,
message_id=message.id
)

# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]

session.commit()

# Increment dialogue count.
conversation.dialogue_count += 1

conversation_id = conversation.id
conversation_dialogue_count = conversation.dialogue_count
db.session.commit()
db.session.refresh(conversation)

inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files

user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id

# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: conversation_id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
contexts.workflow_variable_pool.set(variable_pool)

# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore
Expand Down Expand Up @@ -296,7 +355,7 @@ def _generate_worker(self, flask_app: Flask,
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true':
if os.environ.get("DEBUG", "false").lower() == 'true':
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
Expand Down
8 changes: 4 additions & 4 deletions api/core/app/apps/advanced_chat/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from events.message_event import message_was_created
from extensions.ext_database import db
Expand All @@ -69,7 +69,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
_user: Union[Account, EndUser]
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariableKey, Any]

def __init__(
self,
Expand Down Expand Up @@ -312,7 +312,7 @@ def _process_stream_response(
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')

yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
Expand All @@ -321,7 +321,7 @@ def _process_stream_response(
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')

yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
Expand Down
12 changes: 6 additions & 6 deletions api/core/app/apps/workflow/app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App, EndUser
Expand Down Expand Up @@ -79,14 +79,14 @@ def run(self) -> None:
user_inputs=self.application_generate_entity.single_iteration_run.inputs
)
else:

inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files

# Create a variable pool.
system_inputs = {
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id,
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
}

variable_pool = VariablePool(
Expand All @@ -98,7 +98,7 @@ def run(self) -> None:

# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)

# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
Expand Down
6 changes: 4 additions & 2 deletions api/core/app/apps/workflow/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariable
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser
Expand All @@ -64,7 +66,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariableKey, Any]

def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
Expand Down
2 changes: 1 addition & 1 deletion api/core/workflow/entities/variable_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class VariablePool(BaseModel):
description='User inputs',
)

system_variables: Mapping[SystemVariable, Any] = Field(
system_variables: Mapping[SystemVariableKey, Any] = Field(
description='System variables',
)

Expand Down
4 changes: 2 additions & 2 deletions api/core/workflow/nodes/code/code_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ def _transform_result(self, result: dict, output_schema: Optional[dict[str, Code

@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: CodeNodeData
) -> Mapping[str, Sequence[str]]:
Expand Down
4 changes: 2 additions & 2 deletions api/core/workflow/nodes/llm/llm_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
Expand Down Expand Up @@ -370,7 +370,7 @@ def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) ->

inputs[variable_selector.variable] = variable_value

return inputs # type: ignore
return inputs

def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]:
"""
Expand Down
13 changes: 7 additions & 6 deletions api/core/workflow/nodes/start/start_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any

from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus
Expand All @@ -17,11 +18,11 @@ def _run(self) -> NodeRunResult:
Run node
:return:
"""
# Get cleaned inputs
cleaned_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables

for var in self.graph_runtime_state.variable_pool.system_variables:
cleaned_inputs['sys.' + var.value] = self.graph_runtime_state.variable_pool.system_variables[var]
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]

return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
Expand All @@ -31,8 +32,8 @@ def _run(self) -> NodeRunResult:

@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: StartNodeData
) -> Mapping[str, Sequence[str]]:
Expand Down
110 changes: 5 additions & 105 deletions api/core/workflow/nodes/variable_assigner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,8 @@
from .node import VariableAssignerNode
from .node_data import VariableAssignerData, WriteMode

from sqlalchemy import select
from sqlalchemy.orm import Session

from core.app.segments import SegmentType, Variable, factory
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from extensions.ext_database import db
from models import ConversationVariable, WorkflowNodeExecutionStatus


class VariableAssignerNodeError(Exception):
pass


class WriteMode(str, Enum):
OVER_WRITE = 'over-write'
APPEND = 'append'
CLEAR = 'clear'


class VariableAssignerData(BaseNodeData):
title: str = 'Variable Assigner'
desc: Optional[str] = 'Assign a value to a variable'
assigned_variable_selector: Sequence[str]
write_mode: WriteMode
input_variable_selector: Sequence[str]


class VariableAssignerNode(BaseNode):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER

def _run(self) -> NodeRunResult:
data = cast(VariableAssignerData, self.node_data)

# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError('assigned variable not found')

match data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_variable = original_variable.model_copy(update={'value': income_value.value})

case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={'value': updated_value})

case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})

case _:
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')

# Over write the variable.
self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable)

# Update conversation variable.
# TODO: Find a better way to use the database.
conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id'])
if not conversation_id:
raise VariableAssignerNodeError('conversation_id not found')
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)

return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
'value': income_value.to_object(),
},
)


def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableAssignerNodeError('conversation variable not found in the database')
row.data = variable.model_dump_json()
session.commit()


def get_zero_value(t: SegmentType):
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return factory.build_segment([])
case SegmentType.OBJECT:
return factory.build_segment({})
case SegmentType.STRING:
return factory.build_segment('')
case SegmentType.NUMBER:
return factory.build_segment(0)
case _:
raise VariableAssignerNodeError(f'unsupported variable type: {t}')
__all__ = [
'VariableAssignerNode',
'VariableAssignerData',
'WriteMode',
]
2 changes: 1 addition & 1 deletion api/tests/integration_tests/workflow/nodes/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from core.model_runtime.model_providers import ModelProviderFactory
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.llm_node import LLMNode
from extensions.ext_database import db
from models.provider import ProviderType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from extensions.ext_database import db
from models.provider import ProviderType
Expand Down
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.