Skip to content

Commit b2601ef

Browse files
jan-janssenGeigerJ2superstar54
authored
Upgrade to aiida-workgraph 0.7.4 (#136)
--------- Co-authored-by: Julian Geiger <julian.geiger@gmx.net> Co-authored-by: superstar54 <xingwang1991@gmail.com>
1 parent f8ba185 commit b2601ef

File tree

8 files changed

+22587
-31
lines changed

8 files changed

+22587
-31
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,4 +199,4 @@ jobflow_to_aiida_qe.json
199199
aiida_to_jobflow_qe.json
200200
pyiron_base_to_aiida_simple.json
201201
pyiron_base_to_jobflow_qe.json
202-
202+
**/*.h5

binder/environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ dependencies:
1010
- pyiron_base =0.12.0
1111
- pyiron_workflow =0.13.0
1212
- pygraphviz =1.14
13-
- aiida-workgraph =0.5.2
13+
- aiida-workgraph =0.7.4
1414
- plumpy =0.25.0
1515
- conda_subprocess =0.0.7
1616
- networkx =3.5

example_workflows/quantum_espresso/aiida.ipynb

Lines changed: 5799 additions & 1 deletion
Large diffs are not rendered by default.

example_workflows/quantum_espresso/jobflow.ipynb

Lines changed: 5747 additions & 1 deletion
Large diffs are not rendered by default.

example_workflows/quantum_espresso/pyiron_base.ipynb

Lines changed: 4993 additions & 1 deletion
Large diffs are not rendered by default.

example_workflows/quantum_espresso/pyiron_workflow.ipynb

Lines changed: 6015 additions & 1 deletion
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies = [
2222

2323
[project.optional-dependencies]
2424
aiida = [
25-
"aiida-workgraph>=0.5.1,<=0.5.2",
25+
"aiida-workgraph>=0.5.1,<=0.7.4",
2626
]
2727
jobflow = [
2828
"jobflow>=0.1.18,<=0.2.0",

src/python_workflow_definition/aiida.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33

44
from aiida import orm
55
from aiida_pythonjob.data.serializer import general_serializer
6-
from aiida_workgraph import WorkGraph, task
6+
from aiida_workgraph import WorkGraph, task, Task, namespace
77
from aiida_workgraph.socket import TaskSocketNamespace
8-
8+
from dataclasses import replace
9+
from node_graph.node_spec import SchemaSource
910
from python_workflow_definition.models import PythonWorkflowDefinitionWorkflow
1011
from python_workflow_definition.shared import (
1112
convert_nodes_list_to_dict,
1213
update_node_names,
13-
remove_result,
1414
set_result_node,
1515
NODES_LABEL,
1616
EDGES_LABEL,
@@ -24,11 +24,8 @@
2424

2525

2626
def load_workflow_json(file_name: str) -> WorkGraph:
27-
data = remove_result(
28-
workflow_dict=PythonWorkflowDefinitionWorkflow.load_json_file(
29-
file_name=file_name
30-
)
31-
)
27+
28+
data = PythonWorkflowDefinitionWorkflow.load_json_file(file_name=file_name)
3229

3330
wg = WorkGraph()
3431
task_name_mapping = {}
@@ -40,24 +37,28 @@ def load_workflow_json(file_name: str) -> WorkGraph:
4037
p, m = identifier.rsplit(".", 1)
4138
mod = import_module(p)
4239
func = getattr(mod, m)
43-
wg.add_task(func)
44-
# Remove the default result output, because we will add the outputs later from the data in the link
45-
del wg.tasks[-1].outputs["result"]
46-
task_name_mapping[id] = wg.tasks[-1]
40+
decorated_func = task(outputs=namespace())(func)
41+
new_task = wg.add_task(decorated_func)
42+
new_task.spec = replace(new_task.spec, schema_source=SchemaSource.EMBEDDED)
43+
task_name_mapping[id] = new_task
4744
else:
4845
# data task
4946
data_node = general_serializer(identifier)
5047
task_name_mapping[id] = data_node
5148

5249
# add links
5350
for link in data[EDGES_LABEL]:
51+
# TODO: continue here
5452
to_task = task_name_mapping[str(link[TARGET_LABEL])]
5553
# if the input is not exit, it means we pass the data into to the kwargs
5654
# in this case, we add the input socket
57-
if link[TARGET_PORT_LABEL] not in to_task.inputs:
58-
to_socket = to_task.add_input("workgraph.any", name=link[TARGET_PORT_LABEL])
59-
else:
60-
to_socket = to_task.inputs[link[TARGET_PORT_LABEL]]
55+
if isinstance(to_task, Task):
56+
if link[TARGET_PORT_LABEL] not in to_task.inputs:
57+
to_socket = to_task.add_input_spec(
58+
"workgraph.any", name=link[TARGET_PORT_LABEL]
59+
)
60+
else:
61+
to_socket = to_task.inputs[link[TARGET_PORT_LABEL]]
6162
from_task = task_name_mapping[str(link[SOURCE_LABEL])]
6263
if isinstance(from_task, orm.Data):
6364
to_socket.value = from_task
@@ -69,16 +70,14 @@ def load_workflow_json(file_name: str) -> WorkGraph:
6970
# we add it here, and assume the output exit
7071
if link[SOURCE_PORT_LABEL] not in from_task.outputs:
7172
# if str(link["sourcePort"]) not in from_task.outputs:
72-
from_socket = from_task.add_output(
73+
from_socket = from_task.add_output_spec(
7374
"workgraph.any",
7475
name=link[SOURCE_PORT_LABEL],
75-
# name=str(link["sourcePort"]),
76-
metadata={"is_function_output": True},
7776
)
7877
else:
7978
from_socket = from_task.outputs[link[SOURCE_PORT_LABEL]]
80-
81-
wg.add_link(from_socket, to_socket)
79+
if isinstance(to_task, Task):
80+
wg.add_link(from_socket, to_socket)
8281
except Exception as e:
8382
traceback.print_exc()
8483
print("Failed to link", link, "with error:", e)
@@ -90,12 +89,18 @@ def write_workflow_json(wg: WorkGraph, file_name: str) -> dict:
9089
node_name_mapping = {}
9190
data_node_name_mapping = {}
9291
i = 0
92+
GRAPH_LEVEL_NAMES = ["graph_inputs", "graph_outputs", "graph_ctx"]
93+
9394
for node in wg.tasks:
94-
executor = node.get_executor()
95+
96+
if node.name in GRAPH_LEVEL_NAMES:
97+
continue
98+
9599
node_name_mapping[node.name] = i
96100

97-
callable_name = executor["callable_name"]
98-
callable_name = f"{executor['module_path']}.{callable_name}"
101+
executor = node.get_executor()
102+
callable_name = f"{executor.module_path}.{executor.callable_name}"
103+
99104
data[NODES_LABEL].append({"id": i, "type": "function", "value": callable_name})
100105
i += 1
101106

@@ -141,6 +146,7 @@ def write_workflow_json(wg: WorkGraph, file_name: str) -> dict:
141146
SOURCE_PORT_LABEL: None,
142147
}
143148
)
149+
144150
data[VERSION_LABEL] = VERSION_NUMBER
145151
PythonWorkflowDefinitionWorkflow(
146152
**set_result_node(workflow_dict=update_node_names(workflow_dict=data))

0 commit comments

Comments
 (0)