Skip to content

Handle nesting for ConvertDType, ToArray, adapt concatenate dispatch #503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: dev
Choose a base branch
from

Conversation

vpratz
Copy link
Collaborator

@vpratz vpratz commented Jun 1, 2025

This PR generalizes the transforms from the Approximator.create_default method, so that they can be used with nested inputs as well. For ConvertDType and ToArray, it supplies a generalization that works on dictionaries.
Edit: For Concatenate, it adds a minor change in the Adapter.concatenate dispatch function to detect the special case of only one key being present also if it is nested in a sequence of length one.

The recursive structure of nested inputs requires some extra utility functions, which I put into bayesflow.utils.tree.

Concatenate can be equal to rename if only one key is supplied. By not
calling concatenate in that case, we can accept arbitrary inputs in the
transform, as long as only one is supplied. This simplifies things e.g.
in the `BasicWorkflow`, where the user passes the `summary_variables` to
concatenate, which may be a single dict, which does not need to be
concatenated.
Copy link

codecov bot commented Jun 1, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Files with missing lines Coverage Δ
bayesflow/adapters/adapter.py 86.34% <100.00%> (+0.51%) ⬆️
bayesflow/adapters/transforms/convert_dtype.py 100.00% <100.00%> (ø)
bayesflow/adapters/transforms/to_array.py 82.75% <100.00%> (+3.59%) ⬆️
bayesflow/utils/tree.py 100.00% <100.00%> (ø)

... and 11 files with indirect coverage changes

@vpratz vpratz requested review from LarsKue and stefanradev93 June 2, 2025 11:53
@LarsKue
Copy link
Contributor

LarsKue commented Jun 5, 2025

@vpratz Isn't the concat issue already addressed by this if-else in the adapter dispatch?

if isinstance(keys, str):
transform = Rename(keys, to_key=into)

@vpratz
Copy link
Collaborator Author

vpratz commented Jun 5, 2025

Thanks for the response. I didn't see that, I encountered an error in the transform, probably because I passed a list of length one instead of a string. With that additional context, I think it's better to just add that case in the dispatch as well, and revert the changes in the transform.

@vpratz vpratz changed the title Handle nesting for ConvertDType, ToArray. Relax conditions for Concatenate if only one key is present Handle nesting for ConvertDType, ToArray, adapt concatenate dispatch Jun 7, 2025
@vpratz vpratz mentioned this pull request Jun 14, 2025
Copy link
Contributor

@LarsKue LarsKue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR. I think we should discuss these changes a bit more before we move on with an implementation. See comments.

@@ -482,6 +482,9 @@ def concatenate(self, keys: str | Sequence[str], *, into: str, axis: int = -1):
axis : int, optional
Along which axis to concatenate the keys. The last axis is used by default.
"""
if isinstance(keys, Sequence) and len(keys) == 1:
# unpack string if only one key is supplied, so that Rename is used below
keys = keys[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think this would be more robust with some code duplication, since we ideally don't want to rely on the interaction of two unconnected if statements:

if isinstance(keys, str):
    transform = Rename(keys, to_key=into)
elif len(keys) == 1:
    transform = Rename(keys[0], to_key=into)

@@ -32,7 +33,7 @@ def get_config(self) -> dict:
return serialize(config)

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
return data.astype(self.to_dtype, copy=False)
return map_structure(lambda d: d.astype(self.to_dtype, copy=False), data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes no sense to me. data is typed as np.ndarray, why would you need to map_structure here? If this transform ever receives non-arrays, something else is going wrong, perhaps in the FilterTransform or MapTransform.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below for the setting, I have not (yet) adapted the type hints.

except TypeError:
pass
except ValueError:
# separate statements, as optree does not allow (KeyError | TypeError | ValueError)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is likely not an issue with optree, but your syntax. The syntax to catch any one of multiple error types is:

try:
    ...
except (ValueError, RuntimeError):
    ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, yes, I used the wrong syntax here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the added complexity to these transforms is not worth it. Why do we need to support sampling nested dictionaries?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Utilities like these are very general, and need to be extensively documented and tested. I find it somewhat opaque for now what these are trying to achieve, and why this functionality isn't already present in optree or keras.tree.

@vpratz
Copy link
Collaborator Author

vpratz commented Jun 14, 2025

Thanks a lot for the review! I'm not sure if the motivation behind the changes is totally clear yet. I'll try to describe it once more and we can exchange ideas on how (and if) we want to proceed.
The scenario I would like to improve is the following:

  • a user wants to do inference on multi-modal data using a FusionNetwork, and in their simulator directly specifies summary_variables as a dictionary that will work with the FusionNetwork.
  • The user then wants to use the BasicWorkflow and sets summary_variables=["summary_variables"].
  • This results in errors, as the transforms in the default adapter cannot handle the dictionary that was provided as summary_variables.

We already have ways to work around this using group and ungroup, but this requires manually specifying the adapter and knowing how to achieve this. The main motivation for the changes at hand is to make the defaults general enough that users do not encounter errors, and that the question what is and isn't possible with the Adapter does not get in their way if they only use the defaults.

@stefanradev93 mentioned that in the future we might see more summary networks that require dictionaries as inputs, so I figured this might become a more common scenario and worth resolving.

Most of the complexity is due to trying to keep ToArray invertible. If we do not consider this important, this could be simplified.
Other avenues (like adapting the creation of the default adapter) would be possible as well.

What do you think? Is this a change we want, and if so, what would be the best way to get there?

@stefanradev93
Copy link
Contributor

Yes, dictionary inputs and composite summary networks will be increasingly important.

@LarsKue
Copy link
Contributor

LarsKue commented Jun 14, 2025

I understand the motivation, but I don't think the Adapter should handle the increased complexity. Instead, I think we should look moreso into the Approximator and how we can facilitate having multiple summary networks there.

@vpratz
Copy link
Collaborator Author

vpratz commented Jun 14, 2025

@LarsKue This is not only about multiple summary networks, but any summary network that requires multiple inputs that do not fit in one tensor (see Stefan's comment above). Do you have any ideas regarding that?

vpratz added 4 commits June 14, 2025 21:34
- simplify map_dict to only a single structure, as we probably will not
  require the more general behavior. Add test and docstring.
- remove tree functions that were required for restoring original types
- minor cleanups to account for review comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants