Skip to content

Handle nesting for ConvertDType, ToArray, adapt concatenate dispatch #503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: dev
Choose a base branch
from
2 changes: 2 additions & 0 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ def concatenate(self, keys: str | Sequence[str], *, into: str, axis: int = -1):
"""
if isinstance(keys, str):
transform = Rename(keys, to_key=into)
elif len(keys) == 1:
transform = Rename(keys[0], to_key=into)
else:
transform = Concatenate(keys, into=into, axis=axis)
self.transforms.append(transform)
Expand Down
9 changes: 5 additions & 4 deletions bayesflow/adapters/transforms/convert_dtype.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from keras.tree import map_structure

from bayesflow.utils.serialization import serializable, serialize

Expand Down Expand Up @@ -31,8 +32,8 @@ def get_config(self) -> dict:
}
return serialize(config)

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
return data.astype(self.to_dtype, copy=False)
def forward(self, data: np.ndarray | dict, **kwargs) -> np.ndarray | dict:
return map_structure(lambda d: d.astype(self.to_dtype, copy=False), data)

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return data.astype(self.from_dtype, copy=False)
def inverse(self, data: np.ndarray | dict, **kwargs) -> np.ndarray | dict:
return map_structure(lambda d: d.astype(self.from_dtype, copy=False), data)
11 changes: 10 additions & 1 deletion bayesflow/adapters/transforms/to_array.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the added complexity to these transforms is not worth it. Why do we need to support sampling nested dictionaries?

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

from bayesflow.utils.tree import map_dict
from bayesflow.utils.serialization import serializable, serialize

from .elementwise_transform import ElementwiseTransform
Expand Down Expand Up @@ -34,12 +35,20 @@ def get_config(self) -> dict:
return serialize({"original_type": self.original_type})

def forward(self, data: any, **kwargs) -> np.ndarray:
if isinstance(data, dict):
# no invertiblity for dict, do not store original type
return map_dict(np.asarray, data)

if self.original_type is None:
self.original_type = type(data)

return np.asarray(data)

def inverse(self, data: np.ndarray, **kwargs) -> any:
def inverse(self, data: np.ndarray | dict, **kwargs) -> any:
if isinstance(data, dict):
# no invertibility for dict to keep complexity low
return data

if self.original_type is None:
raise RuntimeError("Cannot call `inverse` before calling `forward` at least once.")

Expand Down
29 changes: 29 additions & 0 deletions bayesflow/utils/tree.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import optree
from typing import Callable


def flatten_shape(structure):
Expand All @@ -12,3 +13,31 @@ def is_shape_tuple(x):
namespace="keras",
)
return leaves


def map_dict(func: Callable, dictionary: dict) -> dict:
"""Applies a function to all leaves of a (possibly nested) dictionary.

Parameters
----------
func : Callable
The function to apply to the leaves.
dictionary : dict
The input dictionary.

Returns
-------
dict
A dictionary with the outputs of `func` as leaves.
"""

def is_not_dict(x):
return not isinstance(x, dict)

return optree.tree_map(
func,
dictionary,
is_leaf=is_not_dict,
none_is_leaf=True,
namespace="keras",
)
4 changes: 2 additions & 2 deletions tests/test_adapters/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ def serializable_fn(x):

return (
Adapter()
.group(["p1", "p2"], into="ps", prefix="p")
.to_array()
.ungroup("ps", prefix="p")
.as_set(["s1", "s2"])
.broadcast("t1", to="t2")
.as_time_series(["t1", "t2"])
Expand All @@ -37,8 +39,6 @@ def serializable_fn(x):
.rename("o1", "o2")
.random_subsample("s3", sample_size=33, axis=0)
.take("s3", indices=np.arange(0, 32), axis=0)
.group(["p1", "p2"], into="ps", prefix="p")
.ungroup("ps", prefix="p")
)


Expand Down
13 changes: 13 additions & 0 deletions tests/test_adapters/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,16 @@ def test_nnpe(random_data):
# Both should assign noise to high-variance dimension
assert std_dim[1] > 0
assert std_glob[1] > 0


def test_single_concatenate_to_rename():
# test that single-element concatenate is converted to rename
from bayesflow import Adapter
from bayesflow.adapters.transforms import Rename, Concatenate

ad = Adapter().concatenate("a", into="b")
assert isinstance(ad[0], Rename)
ad = Adapter().concatenate(["a"], into="b")
assert isinstance(ad[0], Rename)
ad = Adapter().concatenate(["a", "b"], into="c")
assert isinstance(ad[0], Concatenate)
16 changes: 16 additions & 0 deletions tests/test_utils/test_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
def test_map_dict():
from bayesflow.utils.tree import map_dict

input = {
"a": {
"x": [0, 1, 2],
},
"b": [0, 1],
"c": "foo",
}
output = map_dict(len, input)
for key, value in output.items():
if key == "a":
assert value["x"] == len(input["a"]["x"])
continue
assert value == len(input[key])
9 changes: 1 addition & 8 deletions tests/test_workflows/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,6 @@ def sample(self, batch_shape: Shape, num_observations: int = 4) -> dict[str, Ten

x = mean[:, None] + noise

return dict(mean=mean, a=x, b=x)
return dict(mean=mean, observables=dict(a=x, b=x))

return FusionSimulator()


@pytest.fixture
def fusion_adapter():
from bayesflow import Adapter

return Adapter.create_default(["mean"]).group(["a", "b"], "summary_variables")
7 changes: 3 additions & 4 deletions tests/test_workflows/test_basic_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,13 @@ def test_basic_workflow(tmp_path, inference_network, summary_network):
assert samples["parameters"].shape == (5, 3, 2)


def test_basic_workflow_fusion(
tmp_path, fusion_inference_network, fusion_summary_network, fusion_simulator, fusion_adapter
):
def test_basic_workflow_fusion(tmp_path, fusion_inference_network, fusion_summary_network, fusion_simulator):
workflow = bf.BasicWorkflow(
adapter=fusion_adapter,
inference_network=fusion_inference_network,
summary_network=fusion_summary_network,
simulator=fusion_simulator,
inference_variables=["mean"],
summary_variables=["observables"],
checkpoint_filepath=str(tmp_path),
)

Expand Down