Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion docs/gallery/autogen/common_concepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,35 @@ def add_multiply(
print(node.outputs.sum)
print(node.outputs.product)

# %%
# Pydantic models as annotations
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# You can annotate inputs/outputs with Pydantic models. The model defines the socket schema:
# inputs are validated, and outputs are stored as typed AiiDA nodes per field. Runtime results
# are therefore a dict of nodes (not a Pydantic instance). Use the node mapping for provenance,
# and rebuild a model only for convenience.
#
# See the node-graph guide for more details on structured models and leaf blobs.
#
# .. code-block:: python
#
# from pydantic import BaseModel
#
# class Inputs(BaseModel):
# x: int
# y: int
#
# class Outputs(BaseModel):
# sum: int
# product: int
#
# @pyfunction()
# def add_multiply(data: Inputs) -> Outputs:
# return Outputs(sum=data.x + data.y, product=data.x * data.y)
#
# result, node = run_get_node(add_multiply, data=Inputs(x=2, y=3))
# # result is {"sum": Int(...), "product": Int(...)}
#
# %%
# Dynamic namespaces
# ~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -202,7 +231,11 @@ def nested_dict_task(x, y):
# 1. The system first searches for an AiiDA data entry point matching the
# object's type (e.g., ``ase.atoms.Atoms``).
# 2. If no specific serializer is found, it attempts to store the data using
# ``JsonableData``.
# ``JsonableData``. This includes Pydantic models and dataclasses (they are
# converted to JSON-friendly dicts).
# 3. When a Pydantic model is used as an *output schema*, results are stored as
# typed AiiDA nodes per field and returned as a dict of nodes (not a model).
# Use the node mapping for provenance; rebuild a Pydantic instance only for display.
# 3. If the data is not JSON-serializable, it will raise an error.
#
# Registering a custom serializer
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ requires-python = ">=3.10"
dependencies = [
"aiida-core>=2.7.1,<3",
"ase",
"node-graph~=0.6.0",
"node-graph~=0.6.1",
]

[project.optional-dependencies]
Expand Down
8 changes: 8 additions & 0 deletions src/aiida_pythonjob/calculations/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from aiida_pythonjob.data.deserializer import deserialize_to_raw_python_data

# Attribute keys stored on ProcessNode.base.attributes
ATTR_INPUTS_SPEC = "inputs_spec"
ATTR_OUTPUTS_SPEC = "outputs_spec"
ATTR_SERIALIZERS = "serializers"
ATTR_DESERIALIZERS = "deserializers"
Expand All @@ -21,6 +22,12 @@ def add_common_function_io(spec) -> None:
:class:`~aiida.engine.CalcJobProcessSpec`.
"""
spec.input_namespace("function_data", dynamic=True, required=True)
spec.input(
"metadata.inputs_spec",
valid_type=dict,
required=False,
help="Specification for the inputs.",
)
spec.input(
"metadata.outputs_spec",
valid_type=dict,
Expand Down Expand Up @@ -107,6 +114,7 @@ def _build_process_label(self) -> str: # called by AiiDA engine

def _setup_metadata(self, metadata: dict) -> None: # type: ignore[override]
"""Store common metadata on the ProcessNode and forward the rest."""
self.node.base.attributes.set(ATTR_INPUTS_SPEC, metadata.pop("inputs_spec", {}))
self.node.base.attributes.set(ATTR_OUTPUTS_SPEC, metadata.pop("outputs_spec", {}))
self.node.base.attributes.set(ATTR_SERIALIZERS, metadata.pop("serializers", {}))
self.node.base.attributes.set(ATTR_DESERIALIZERS, metadata.pop("deserializers", {}))
Expand Down
5 changes: 5 additions & 0 deletions src/aiida_pythonjob/calculations/pyfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
from aiida.orm import CalcFunctionNode, Float
from aiida.orm.nodes.data.base import to_aiida_type
from node_graph.socket_spec import SocketSpec
from node_graph.utils.struct_utils import coerce_inputs_from_spec

from aiida_pythonjob.calculations.common import (
ATTR_DESERIALIZERS,
ATTR_INPUTS_SPEC,
ATTR_OUTPUTS_SPEC,
ATTR_SERIALIZERS,
FunctionProcessMixin,
Expand Down Expand Up @@ -113,6 +115,9 @@ async def run(self) -> t.Union[plumpy.process_states.Stop, int, plumpy.process_s
inputs = dict(self.inputs.function_inputs or {})
deserializers = self.node.base.attributes.get(ATTR_DESERIALIZERS, {})
inputs = deserialize_to_raw_python_data(inputs, deserializers=deserializers)
inputs_spec = self.node.base.attributes.get(ATTR_INPUTS_SPEC, {})
if inputs_spec:
inputs = coerce_inputs_from_spec(inputs, SocketSpec.from_dict(inputs_spec))
except Exception as exception:
return self.exit_codes.ERROR_DESERIALIZE_INPUTS_FAILED.format(
exception=str(exception), traceback=traceback.format_exc()
Expand Down
4 changes: 4 additions & 0 deletions src/aiida_pythonjob/calculations/pythonjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from aiida_pythonjob.calculations.common import (
ATTR_DESERIALIZERS,
ATTR_INPUTS_SPEC,
FunctionProcessMixin,
add_common_function_io,
)
Expand Down Expand Up @@ -271,6 +272,9 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:

# Create a pickle file for the user input values
filename = "inputs.pickle"
inputs_spec = self.node.base.attributes.get(ATTR_INPUTS_SPEC, {})
if inputs_spec:
input_values = {"__inputs_spec__": inputs_spec, "__inputs__": input_values}
with folder.open(filename, "wb") as handle:
pickle.dump(input_values, handle)

Expand Down
10 changes: 9 additions & 1 deletion src/aiida_pythonjob/calculations/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import plumpy.process_states
from aiida.engine.processes.process import Process, ProcessState
from aiida.engine.utils import InterruptableFuture, interruptable_task
from node_graph.socket_spec import SocketSpec
from node_graph.utils.struct_utils import coerce_inputs_from_spec

from aiida_pythonjob.calculations.common import ATTR_DESERIALIZERS
from aiida_pythonjob.calculations.common import ATTR_DESERIALIZERS, ATTR_INPUTS_SPEC
from aiida_pythonjob.data.deserializer import deserialize_to_raw_python_data

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -43,6 +45,9 @@ async def task_run_job(process: Process, *args, **kwargs) -> Any:
inputs = dict(process.inputs.function_inputs or {})
deserializers = node.base.attributes.get(ATTR_DESERIALIZERS, {})
inputs = deserialize_to_raw_python_data(inputs, deserializers=deserializers)
inputs_spec = node.base.attributes.get(ATTR_INPUTS_SPEC, {})
if inputs_spec:
inputs = coerce_inputs_from_spec(inputs, SocketSpec.from_dict(inputs_spec))

try:
logger.info(f"scheduled request to run the function<{node.pk}>")
Expand All @@ -65,6 +70,9 @@ async def task_run_monitor_job(process: Process, *args, **kwargs) -> Any:
inputs = dict(process.inputs.function_inputs or {})
deserializers = node.base.attributes.get(ATTR_DESERIALIZERS, {})
inputs = deserialize_to_raw_python_data(inputs, deserializers=deserializers)
inputs_spec = node.base.attributes.get(ATTR_INPUTS_SPEC, {})
if inputs_spec:
inputs = coerce_inputs_from_spec(inputs, SocketSpec.from_dict(inputs_spec))

try:
logger.info(f"scheduled request to run the function<{node.pk}>")
Expand Down
13 changes: 13 additions & 0 deletions src/aiida_pythonjob/calculations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ def generate_script_py(
" write_error_file('UNPICKLE_INPUTS_FAILED', e, traceback.format_exc())",
" sys.exit(1)",
"",
" # 2b) Optional: rehydrate structured inputs based on inputs_spec",
" try:",
" if isinstance(inputs, dict) and '__inputs_spec__' in inputs:",
" inputs_spec = inputs.get('__inputs_spec__')",
" inputs = inputs.get('__inputs__', {})",
" if inputs_spec:",
" from node_graph.socket_spec import SocketSpec",
" from node_graph.utils.struct_utils import coerce_inputs_from_spec",
" inputs = coerce_inputs_from_spec(inputs, SocketSpec.from_dict(inputs_spec))",
" except Exception as e:",
" write_error_file('COERCE_INPUTS_FAILED', e, traceback.format_exc())",
" sys.exit(1)",
"",
]

if pickled_function:
Expand Down
36 changes: 36 additions & 0 deletions src/aiida_pythonjob/data/jsonable_data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import importlib
import json
import typing
from dataclasses import asdict, is_dataclass

import numpy as np
from aiida import orm

try:
from pydantic import BaseModel
except Exception:
BaseModel = ()

__all__ = ("JsonableData",)


Expand Down Expand Up @@ -46,6 +52,8 @@ def __init__(self, obj: typing.Any, *args, **kwargs):
self._obj = obj

def _validate_method(self, obj: typing.Any):
if self._is_pydantic_instance(obj) or self._is_dataclass_instance(obj):
return
if not any(hasattr(obj, method) for method in self._DICT_METHODS):
raise ValueError(f"The object must have at least one of the following methods: {self._DICT_METHODS}")
if not any(hasattr(obj, method) for method in self._FROM_DICT_METHODS):
Expand All @@ -55,6 +63,10 @@ def _extract_dict(self, obj: typing.Any) -> dict:
"""
Attempt to call one of the recognized "to-dict" style methods on `obj` in sequence.
"""
if self._is_pydantic_instance(obj):
return obj.model_dump(exclude_none=False)
if self._is_dataclass_instance(obj):
return asdict(obj)
for method_name in self._DICT_METHODS:
method = getattr(obj, method_name, None)
if callable(method):
Expand Down Expand Up @@ -138,6 +150,14 @@ def _rebuild_object(self, cls_: typing.Any, attributes: dict) -> typing.Any:
"""
Attempt to reconstruct an object of type `cls_` from `attributes`.
"""
if self._is_pydantic_type(cls_):
if hasattr(cls_, "model_validate"):
return cls_.model_validate(attributes)
if hasattr(cls_, "parse_obj"):
return cls_.parse_obj(attributes)
return cls_(**attributes)
if self._is_dataclass_type(cls_):
return cls_(**attributes)
for method_name in self._FROM_DICT_METHODS:
fromdict_method = getattr(cls_, method_name, None)
if callable(fromdict_method):
Expand All @@ -151,6 +171,22 @@ def _rebuild_object(self, cls_: typing.Any, attributes: dict) -> typing.Any:
f"({self._FROM_DICT_METHODS}) nor constructor that accepts these attributes."
)

@staticmethod
def _is_pydantic_instance(obj: typing.Any) -> bool:
return isinstance(obj, BaseModel)

@staticmethod
def _is_pydantic_type(cls_: typing.Any) -> bool:
return isinstance(cls_, type) and issubclass(cls_, BaseModel)

@staticmethod
def _is_dataclass_instance(obj: typing.Any) -> bool:
return is_dataclass(obj) and not isinstance(obj, type)

@staticmethod
def _is_dataclass_type(cls_: typing.Any) -> bool:
return is_dataclass(cls_)

@property
def obj(self) -> typing.Any:
"""Return the wrapped Python object."""
Expand Down
1 change: 1 addition & 0 deletions src/aiida_pythonjob/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def _prepare_common(
_validate_inputs_against_signature(fn, serialized_inputs)

metadata = {
"inputs_spec": in_spec.to_dict(),
"outputs_spec": out_spec.to_dict(),
"serializers": merged_serializers,
"deserializers": merged_deserializers,
Expand Down
3 changes: 3 additions & 0 deletions src/aiida_pythonjob/parsers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from aiida import orm
from aiida.engine import ExitCode
from node_graph.socket_spec import SocketSpec
from node_graph.utils.struct_utils import is_structured_instance, structured_to_dict

from ..utils import _ensure_spec, serialize_ports

Expand Down Expand Up @@ -46,6 +47,8 @@ def parse_outputs(

fields = spec.fields or {}
is_dyn = spec.meta.dynamic
if is_structured_instance(results) and (fields or is_dyn):
results = structured_to_dict(results)

if already_serialized(results):
return {"result": results}, None
Expand Down
3 changes: 3 additions & 0 deletions src/aiida_pythonjob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from aiida.orm import Computer, InstalledCode, Str, User, load_code, load_computer
from node_graph.socket_meta import SocketMeta
from node_graph.socket_spec import SocketSpec
from node_graph.utils.struct_utils import is_structured_instance, structured_to_dict

from aiida_pythonjob.data.serializer import general_serializer

Expand Down Expand Up @@ -295,6 +296,8 @@ def serialize_ports(
# Namespace
if spec.is_namespace():
name = getattr(spec.meta, "help", None) or "<namespace>"
if is_structured_instance(python_data):
python_data = structured_to_dict(python_data)
if not isinstance(python_data, dict):
raise ValueError(f"Expected dict for namespace '{name}', got {type(python_data)}")

Expand Down
35 changes: 35 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,31 @@
from dataclasses import dataclass

import aiida
import pytest

from aiida_pythonjob.data.jsonable_data import JsonableData
from aiida_pythonjob.data.serializer import all_serializers

try:
from pydantic import BaseModel as _BaseModel
except Exception: # pragma: no cover - optional dependency
_BaseModel = None

if _BaseModel is not None:

class BlobModel(_BaseModel):
a: int
b: int

else:
BlobModel = None


@dataclass
class BlobDC:
a: int
b: int


def test_typing():
"""Test function with typing."""
Expand Down Expand Up @@ -90,3 +113,15 @@ def test_datetime_data():

with pytest.raises(TypeError, match="Expected datetime.datetime"):
DateTimeData("2024-06-01")


def test_jsonable_data_pydantic_and_dataclass():
pytest.importorskip("pydantic")

model = BlobModel(a=1, b=2)
node = JsonableData(model)
assert node.value.model_dump() == model.model_dump()

dc = BlobDC(a=3, b=4)
node_dc = JsonableData(dc)
assert node_dc.value == dc
33 changes: 33 additions & 0 deletions tests/test_pyfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,25 @@

from aiida_pythonjob import PyFunction, prepare_pyfunction_inputs, pyfunction

try:
from pydantic import BaseModel as _BaseModel
except Exception: # pragma: no cover - optional dependency
_BaseModel = None

if _BaseModel is not None:

class PydanticInputs(_BaseModel):
x: int
y: int

class PydanticOutputs(_BaseModel):
sum: int
product: int

else:
PydanticInputs = None
PydanticOutputs = None


@pyfunction()
def add(x, y):
Expand Down Expand Up @@ -136,6 +155,20 @@ def myfunc(x, y):
assert result["add_multiply"]["multiply"].value == 2


def test_pydantic_inputs_outputs():
pytest.importorskip("pydantic")

@pyfunction()
def add_multiply(data: PydanticInputs) -> PydanticOutputs:
return PydanticOutputs(sum=data.x + data.y, product=data.x * data.y)

result, node = run_get_node(add_multiply, data=PydanticInputs(x=2, y=3))

assert result["sum"].value == 5
assert result["product"].value == 6
assert node.is_finished_ok


def test_override_outputs():
"""Test function with namespace output and input."""

Expand Down