-
Notifications
You must be signed in to change notification settings - Fork 69
Handle nesting for ConvertDType, ToArray, adapt concatenate dispatch #503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
Concatenate can be equal to rename if only one key is supplied. By not calling concatenate in that case, we can accept arbitrary inputs in the transform, as long as only one is supplied. This simplifies things e.g. in the `BasicWorkflow`, where the user passes the `summary_variables` to concatenate, which may be a single dict, which does not need to be concatenated.
Codecov ReportAll modified and coverable lines are covered by tests ✅
|
@vpratz Isn't the concat issue already addressed by this if-else in the adapter dispatch? bayesflow/bayesflow/adapters/adapter.py Lines 485 to 486 in 449a79a
|
Thanks for the response. I didn't see that, I encountered an error in the transform, probably because I passed a list of length one instead of a string. With that additional context, I think it's better to just add that case in the dispatch as well, and revert the changes in the transform. |
Moves the fix from the Concatenate transform to the concatenate method of the adapter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR. I think we should discuss these changes a bit more before we move on with an implementation. See comments.
bayesflow/adapters/adapter.py
Outdated
@@ -482,6 +482,9 @@ def concatenate(self, keys: str | Sequence[str], *, into: str, axis: int = -1): | |||
axis : int, optional | |||
Along which axis to concatenate the keys. The last axis is used by default. | |||
""" | |||
if isinstance(keys, Sequence) and len(keys) == 1: | |||
# unpack string if only one key is supplied, so that Rename is used below | |||
keys = keys[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I think this would be more robust with some code duplication, since we ideally don't want to rely on the interaction of two unconnected if statements:
if isinstance(keys, str):
transform = Rename(keys, to_key=into)
elif len(keys) == 1:
transform = Rename(keys[0], to_key=into)
@@ -32,7 +33,7 @@ def get_config(self) -> dict: | |||
return serialize(config) | |||
|
|||
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: | |||
return data.astype(self.to_dtype, copy=False) | |||
return map_structure(lambda d: d.astype(self.to_dtype, copy=False), data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes no sense to me. data
is typed as np.ndarray
, why would you need to map_structure
here? If this transform ever receives non-arrays, something else is going wrong, perhaps in the FilterTransform
or MapTransform
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See below for the setting, I have not (yet) adapted the type hints.
except TypeError: | ||
pass | ||
except ValueError: | ||
# separate statements, as optree does not allow (KeyError | TypeError | ValueError) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is likely not an issue with optree, but your syntax. The syntax to catch any one of multiple error types is:
try:
...
except (ValueError, RuntimeError):
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, yes, I used the wrong syntax here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the added complexity to these transforms is not worth it. Why do we need to support sampling nested dictionaries?
bayesflow/utils/tree.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Utilities like these are very general, and need to be extensively documented and tested. I find it somewhat opaque for now what these are trying to achieve, and why this functionality isn't already present in optree
or keras.tree
.
Thanks a lot for the review! I'm not sure if the motivation behind the changes is totally clear yet. I'll try to describe it once more and we can exchange ideas on how (and if) we want to proceed.
We already have ways to work around this using @stefanradev93 mentioned that in the future we might see more summary networks that require dictionaries as inputs, so I figured this might become a more common scenario and worth resolving. Most of the complexity is due to trying to keep What do you think? Is this a change we want, and if so, what would be the best way to get there? |
Yes, dictionary inputs and composite summary networks will be increasingly important. |
I understand the motivation, but I don't think the Adapter should handle the increased complexity. Instead, I think we should look moreso into the Approximator and how we can facilitate having multiple summary networks there. |
@LarsKue This is not only about multiple summary networks, but any summary network that requires multiple inputs that do not fit in one tensor (see Stefan's comment above). Do you have any ideas regarding that? |
- simplify map_dict to only a single structure, as we probably will not require the more general behavior. Add test and docstring. - remove tree functions that were required for restoring original types - minor cleanups to account for review comments
This PR generalizes the transforms from the
Approximator.create_default
method, so that they can be used with nested inputs as well. For ConvertDType and ToArray, it supplies a generalization that works on dictionaries.Edit: For Concatenate, it adds a minor change in the Adapter.concatenate dispatch function to detect the special case of only one key being present also if it is nested in a sequence of length one.
The recursive structure of nested inputs requires some extra utility functions, which I put into
bayesflow.utils.tree
.