33
44from aiida import orm
55from aiida_pythonjob .data .serializer import general_serializer
6- from aiida_workgraph import WorkGraph , task
6+ from aiida_workgraph import WorkGraph , task , Task , namespace
77from aiida_workgraph .socket import TaskSocketNamespace
8-
8+ from dataclasses import replace
9+ from node_graph .node_spec import SchemaSource
910from python_workflow_definition .models import PythonWorkflowDefinitionWorkflow
1011from python_workflow_definition .shared import (
1112 convert_nodes_list_to_dict ,
1213 update_node_names ,
13- remove_result ,
1414 set_result_node ,
1515 NODES_LABEL ,
1616 EDGES_LABEL ,
2424
2525
2626def load_workflow_json (file_name : str ) -> WorkGraph :
27- data = remove_result (
28- workflow_dict = PythonWorkflowDefinitionWorkflow .load_json_file (
29- file_name = file_name
30- )
31- )
27+
28+ data = PythonWorkflowDefinitionWorkflow .load_json_file (file_name = file_name )
3229
3330 wg = WorkGraph ()
3431 task_name_mapping = {}
@@ -40,24 +37,28 @@ def load_workflow_json(file_name: str) -> WorkGraph:
4037 p , m = identifier .rsplit ("." , 1 )
4138 mod = import_module (p )
4239 func = getattr (mod , m )
43- wg . add_task (func )
44- # Remove the default result output, because we will add the outputs later from the data in the link
45- del wg . tasks [ - 1 ]. outputs [ "result" ]
46- task_name_mapping [id ] = wg . tasks [ - 1 ]
40+ decorated_func = task ( outputs = namespace ()) (func )
41+ new_task = wg . add_task ( decorated_func )
42+ new_task . spec = replace ( new_task . spec , schema_source = SchemaSource . EMBEDDED )
43+ task_name_mapping [id ] = new_task
4744 else :
4845 # data task
4946 data_node = general_serializer (identifier )
5047 task_name_mapping [id ] = data_node
5148
5249 # add links
5350 for link in data [EDGES_LABEL ]:
51+ # TODO: continue here
5452 to_task = task_name_mapping [str (link [TARGET_LABEL ])]
5553 # if the input is not exit, it means we pass the data into to the kwargs
5654 # in this case, we add the input socket
57- if link [TARGET_PORT_LABEL ] not in to_task .inputs :
58- to_socket = to_task .add_input ("workgraph.any" , name = link [TARGET_PORT_LABEL ])
59- else :
60- to_socket = to_task .inputs [link [TARGET_PORT_LABEL ]]
55+ if isinstance (to_task , Task ):
56+ if link [TARGET_PORT_LABEL ] not in to_task .inputs :
57+ to_socket = to_task .add_input_spec (
58+ "workgraph.any" , name = link [TARGET_PORT_LABEL ]
59+ )
60+ else :
61+ to_socket = to_task .inputs [link [TARGET_PORT_LABEL ]]
6162 from_task = task_name_mapping [str (link [SOURCE_LABEL ])]
6263 if isinstance (from_task , orm .Data ):
6364 to_socket .value = from_task
@@ -69,16 +70,14 @@ def load_workflow_json(file_name: str) -> WorkGraph:
6970 # we add it here, and assume the output exit
7071 if link [SOURCE_PORT_LABEL ] not in from_task .outputs :
7172 # if str(link["sourcePort"]) not in from_task.outputs:
72- from_socket = from_task .add_output (
73+ from_socket = from_task .add_output_spec (
7374 "workgraph.any" ,
7475 name = link [SOURCE_PORT_LABEL ],
75- # name=str(link["sourcePort"]),
76- metadata = {"is_function_output" : True },
7776 )
7877 else :
7978 from_socket = from_task .outputs [link [SOURCE_PORT_LABEL ]]
80-
81- wg .add_link (from_socket , to_socket )
79+ if isinstance ( to_task , Task ):
80+ wg .add_link (from_socket , to_socket )
8281 except Exception as e :
8382 traceback .print_exc ()
8483 print ("Failed to link" , link , "with error:" , e )
@@ -90,12 +89,18 @@ def write_workflow_json(wg: WorkGraph, file_name: str) -> dict:
9089 node_name_mapping = {}
9190 data_node_name_mapping = {}
9291 i = 0
92+ GRAPH_LEVEL_NAMES = ["graph_inputs" , "graph_outputs" , "graph_ctx" ]
93+
9394 for node in wg .tasks :
94- executor = node .get_executor ()
95+
96+ if node .name in GRAPH_LEVEL_NAMES :
97+ continue
98+
9599 node_name_mapping [node .name ] = i
96100
97- callable_name = executor ["callable_name" ]
98- callable_name = f"{ executor ['module_path' ]} .{ callable_name } "
101+ executor = node .get_executor ()
102+ callable_name = f"{ executor .module_path } .{ executor .callable_name } "
103+
99104 data [NODES_LABEL ].append ({"id" : i , "type" : "function" , "value" : callable_name })
100105 i += 1
101106
@@ -141,6 +146,7 @@ def write_workflow_json(wg: WorkGraph, file_name: str) -> dict:
141146 SOURCE_PORT_LABEL : None ,
142147 }
143148 )
149+
144150 data [VERSION_LABEL ] = VERSION_NUMBER
145151 PythonWorkflowDefinitionWorkflow (
146152 ** set_result_node (workflow_dict = update_node_names (workflow_dict = data ))
0 commit comments