Skip to content

Commit 6128263

Browse files
committed
replace nested normal with nested. see if ci passes
1 parent eebb50e commit 6128263

File tree

2 files changed

+5
-199
lines changed

2 files changed

+5
-199
lines changed
Lines changed: 3 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -1,146 +1,12 @@
1-
from python_workflow_definition.aiida import write_workflow_json, load_workflow_json, load_workflow_json_nested
1+
from python_workflow_definition.aiida import load_workflow_json
22

3-
from aiida_workgraph import WorkGraph, task
4-
from aiida import orm, load_profile
3+
from aiida import load_profile
54

65
load_profile()
76

87
workflow_json_filename = "main.pwd.json"
98

9+
wg = load_workflow_json(workflow_json_filename)
1010

11-
# In[2]:
12-
13-
14-
from workflow import (
15-
get_sum as _get_sum,
16-
get_prod_and_div as _get_prod_and_div,
17-
get_square as _get_square,
18-
)
19-
20-
21-
# wg = WorkGraph("nested")
22-
23-
24-
# In[4]:
25-
26-
wg = load_workflow_json_nested(workflow_json_filename)
2711
wg.to_html()
28-
29-
breakpoint()
30-
31-
pass
32-
3312
wg.run()
34-
35-
#
36-
#
37-
# get_prod_and_div_task = wg.add_task(
38-
# task(outputs=["prod", "div"])(_get_prod_and_div),
39-
# x=orm.Float(1),
40-
# y=orm.Float(2),
41-
# )
42-
#
43-
#
44-
# # In[5]:
45-
#
46-
#
47-
# get_sum_task = wg.add_task(
48-
# _get_sum,
49-
# x=get_prod_and_div_task.outputs.prod,
50-
# y=get_prod_and_div_task.outputs.div,
51-
# )
52-
#
53-
#
54-
# # In[6]:
55-
#
56-
#
57-
# get_square_task = wg.add_task(
58-
# _get_square,
59-
# x=get_sum_task.outputs.result,
60-
# )
61-
#
62-
#
63-
# # In[7]:
64-
#
65-
#
66-
# write_workflow_json(wg=wg, file_name=workflow_json_filename)
67-
#
68-
#
69-
# # In[8]:
70-
#
71-
#
72-
# get_ipython().system("cat {workflow_json_filename}")
73-
#
74-
#
75-
# # ## Load Workflow with jobflow
76-
#
77-
# # In[9]:
78-
#
79-
#
80-
# from python_workflow_definition.jobflow import load_workflow_json
81-
#
82-
#
83-
# # In[10]:
84-
#
85-
#
86-
# from jobflow.managers.local import run_locally
87-
#
88-
#
89-
# # In[11]:
90-
#
91-
#
92-
# flow = load_workflow_json(file_name=workflow_json_filename)
93-
#
94-
#
95-
# # In[12]:
96-
#
97-
#
98-
# result = run_locally(flow)
99-
# result
100-
#
101-
#
102-
# # ## Load Workflow with pyiron_base
103-
#
104-
# # In[13]:
105-
#
106-
#
107-
# from python_workflow_definition.pyiron_base import load_workflow_json
108-
#
109-
#
110-
# # In[14]:
111-
#
112-
#
113-
# delayed_object_lst = load_workflow_json(file_name=workflow_json_filename)
114-
# delayed_object_lst[-1].draw()
115-
#
116-
#
117-
# # In[15]:
118-
#
119-
#
120-
# delayed_object_lst[-1].pull()
121-
#
122-
#
123-
# # ## Load Workflow with pyiron_workflow
124-
#
125-
# # In[ ]:
126-
#
127-
#
128-
# from python_workflow_definition.pyiron_workflow import load_workflow_json
129-
#
130-
#
131-
# # In[ ]:
132-
#
133-
#
134-
# wf = load_workflow_json(file_name=workflow_json_filename)
135-
#
136-
#
137-
# # In[ ]:
138-
#
139-
#
140-
# wf.draw(size=(10, 10))
141-
#
142-
#
143-
# # In[ ]:
144-
#
145-
#
146-
# wf.run()

src/python_workflow_definition/aiida.py

Lines changed: 2 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626

2727

28-
def load_workflow_json_nested(file_name: str) -> WorkGraph:
28+
def load_workflow_json(file_name: str) -> WorkGraph:
2929
"""Load a workflow from JSON with support for nested workflows.
3030
3131
This function recursively loads workflows, properly exposing inputs/outputs
@@ -86,7 +86,7 @@ def load_workflow_json_nested(file_name: str) -> WorkGraph:
8686
workflow_path = parent_dir / workflow_file
8787

8888
# Recursively load the sub-workflow with proper input/output exposure
89-
sub_wg = load_workflow_json_nested(file_name=str(workflow_path))
89+
sub_wg = load_workflow_json(file_name=str(workflow_path))
9090

9191
# Add the sub-workflow as a task - it will automatically have the right inputs/outputs
9292
workflow_task = wg.add_task(sub_wg)
@@ -203,66 +203,6 @@ def load_workflow_json_nested(file_name: str) -> WorkGraph:
203203
return wg
204204

205205

206-
def load_workflow_json(file_name: str) -> WorkGraph:
207-
208-
data = PythonWorkflowDefinitionWorkflow.load_json_file(file_name=file_name)
209-
210-
wg = WorkGraph()
211-
task_name_mapping = {}
212-
213-
for id, identifier in convert_nodes_list_to_dict(
214-
nodes_list=data[NODES_LABEL]
215-
).items():
216-
if isinstance(identifier, str) and "." in identifier:
217-
p, m = identifier.rsplit(".", 1)
218-
mod = import_module(p)
219-
func = getattr(mod, m)
220-
decorated_func = task(outputs=namespace())(func)
221-
new_task = wg.add_task(decorated_func)
222-
new_task.spec = replace(new_task.spec, schema_source=SchemaSource.EMBEDDED)
223-
task_name_mapping[id] = new_task
224-
else:
225-
# data task
226-
data_node = general_serializer(identifier)
227-
task_name_mapping[id] = data_node
228-
229-
# add links
230-
for link in data[EDGES_LABEL]:
231-
# TODO: continue here
232-
to_task = task_name_mapping[str(link[TARGET_LABEL])]
233-
# if the input is not exit, it means we pass the data into to the kwargs
234-
# in this case, we add the input socket
235-
if isinstance(to_task, Task):
236-
if link[TARGET_PORT_LABEL] not in to_task.inputs:
237-
to_socket = to_task.add_input_spec(
238-
"workgraph.any", name=link[TARGET_PORT_LABEL]
239-
)
240-
else:
241-
to_socket = to_task.inputs[link[TARGET_PORT_LABEL]]
242-
from_task = task_name_mapping[str(link[SOURCE_LABEL])]
243-
if isinstance(from_task, orm.Data):
244-
to_socket.value = from_task
245-
else:
246-
try:
247-
if link[SOURCE_PORT_LABEL] is None:
248-
link[SOURCE_PORT_LABEL] = "result"
249-
# because we are not define the outputs explicitly during the pythonjob creation
250-
# we add it here, and assume the output exit
251-
if link[SOURCE_PORT_LABEL] not in from_task.outputs:
252-
# if str(link["sourcePort"]) not in from_task.outputs:
253-
from_socket = from_task.add_output_spec(
254-
"workgraph.any",
255-
name=link[SOURCE_PORT_LABEL],
256-
)
257-
else:
258-
from_socket = from_task.outputs[link[SOURCE_PORT_LABEL]]
259-
if isinstance(to_task, Task):
260-
wg.add_link(from_socket, to_socket)
261-
except Exception as e:
262-
traceback.print_exc()
263-
print("Failed to link", link, "with error:", e)
264-
return wg
265-
266206

267207
def write_workflow_json(wg: WorkGraph, file_name: str) -> dict:
268208
data = {NODES_LABEL: [], EDGES_LABEL: []}

0 commit comments

Comments
 (0)