|  | 
| 25 | 25 | ) | 
| 26 | 26 | 
 | 
| 27 | 27 | 
 | 
| 28 |  | -def load_workflow_json_nested(file_name: str) -> WorkGraph: | 
|  | 28 | +def load_workflow_json(file_name: str) -> WorkGraph: | 
| 29 | 29 |     """Load a workflow from JSON with support for nested workflows. | 
| 30 | 30 | 
 | 
| 31 | 31 |     This function recursively loads workflows, properly exposing inputs/outputs | 
| @@ -86,7 +86,7 @@ def load_workflow_json_nested(file_name: str) -> WorkGraph: | 
| 86 | 86 |             workflow_path = parent_dir / workflow_file | 
| 87 | 87 | 
 | 
| 88 | 88 |             # Recursively load the sub-workflow with proper input/output exposure | 
| 89 |  | -            sub_wg = load_workflow_json_nested(file_name=str(workflow_path)) | 
|  | 89 | +            sub_wg = load_workflow_json(file_name=str(workflow_path)) | 
| 90 | 90 | 
 | 
| 91 | 91 |             # Add the sub-workflow as a task - it will automatically have the right inputs/outputs | 
| 92 | 92 |             workflow_task = wg.add_task(sub_wg) | 
| @@ -203,66 +203,6 @@ def load_workflow_json_nested(file_name: str) -> WorkGraph: | 
| 203 | 203 |     return wg | 
| 204 | 204 | 
 | 
| 205 | 205 | 
 | 
| 206 |  | -def load_workflow_json(file_name: str) -> WorkGraph: | 
| 207 |  | - | 
| 208 |  | -    data = PythonWorkflowDefinitionWorkflow.load_json_file(file_name=file_name) | 
| 209 |  | - | 
| 210 |  | -    wg = WorkGraph() | 
| 211 |  | -    task_name_mapping = {} | 
| 212 |  | - | 
| 213 |  | -    for id, identifier in convert_nodes_list_to_dict( | 
| 214 |  | -        nodes_list=data[NODES_LABEL] | 
| 215 |  | -    ).items(): | 
| 216 |  | -        if isinstance(identifier, str) and "." in identifier: | 
| 217 |  | -            p, m = identifier.rsplit(".", 1) | 
| 218 |  | -            mod = import_module(p) | 
| 219 |  | -            func = getattr(mod, m) | 
| 220 |  | -            decorated_func = task(outputs=namespace())(func) | 
| 221 |  | -            new_task = wg.add_task(decorated_func) | 
| 222 |  | -            new_task.spec = replace(new_task.spec, schema_source=SchemaSource.EMBEDDED) | 
| 223 |  | -            task_name_mapping[id] = new_task | 
| 224 |  | -        else: | 
| 225 |  | -            # data task | 
| 226 |  | -            data_node = general_serializer(identifier) | 
| 227 |  | -            task_name_mapping[id] = data_node | 
| 228 |  | - | 
| 229 |  | -    # add links | 
| 230 |  | -    for link in data[EDGES_LABEL]: | 
| 231 |  | -        # TODO: continue here | 
| 232 |  | -        to_task = task_name_mapping[str(link[TARGET_LABEL])] | 
| 233 |  | -        # if the input is not exit, it means we pass the data into to the kwargs | 
| 234 |  | -        # in this case, we add the input socket | 
| 235 |  | -        if isinstance(to_task, Task): | 
| 236 |  | -            if link[TARGET_PORT_LABEL] not in to_task.inputs: | 
| 237 |  | -                to_socket = to_task.add_input_spec( | 
| 238 |  | -                    "workgraph.any", name=link[TARGET_PORT_LABEL] | 
| 239 |  | -                ) | 
| 240 |  | -            else: | 
| 241 |  | -                to_socket = to_task.inputs[link[TARGET_PORT_LABEL]] | 
| 242 |  | -        from_task = task_name_mapping[str(link[SOURCE_LABEL])] | 
| 243 |  | -        if isinstance(from_task, orm.Data): | 
| 244 |  | -            to_socket.value = from_task | 
| 245 |  | -        else: | 
| 246 |  | -            try: | 
| 247 |  | -                if link[SOURCE_PORT_LABEL] is None: | 
| 248 |  | -                    link[SOURCE_PORT_LABEL] = "result" | 
| 249 |  | -                # because we are not define the outputs explicitly during the pythonjob creation | 
| 250 |  | -                # we add it here, and assume the output exit | 
| 251 |  | -                if link[SOURCE_PORT_LABEL] not in from_task.outputs: | 
| 252 |  | -                    # if str(link["sourcePort"]) not in from_task.outputs: | 
| 253 |  | -                    from_socket = from_task.add_output_spec( | 
| 254 |  | -                        "workgraph.any", | 
| 255 |  | -                        name=link[SOURCE_PORT_LABEL], | 
| 256 |  | -                    ) | 
| 257 |  | -                else: | 
| 258 |  | -                    from_socket = from_task.outputs[link[SOURCE_PORT_LABEL]] | 
| 259 |  | -                if isinstance(to_task, Task): | 
| 260 |  | -                    wg.add_link(from_socket, to_socket) | 
| 261 |  | -            except Exception as e: | 
| 262 |  | -                traceback.print_exc() | 
| 263 |  | -                print("Failed to link", link, "with error:", e) | 
| 264 |  | -    return wg | 
| 265 |  | - | 
| 266 | 206 | 
 | 
| 267 | 207 | def write_workflow_json(wg: WorkGraph, file_name: str) -> dict: | 
| 268 | 208 |     data = {NODES_LABEL: [], EDGES_LABEL: []} | 
|  | 
0 commit comments