Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.tree_utils.tree_map raises errors for dictionary with mixed types of keys #15358

Open
zi-w opened this issue Apr 1, 2023 · 2 comments
Open
Labels
better_errors Improve the error reporting documentation

Comments

@zi-w
Copy link

zi-w commented Apr 1, 2023

Description

The following code currently raises TypeError probably due to key sorting.
Related issue: #4085

jax.tree_util.tree_map(lambda x: x + 1, {1: 7, "y": 42})

Error message:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-14-34a99b71cb82>](https://colab.corp.google.com/drive/1M9-A_LXuCfjEkmDmhVOqaLlhVkl4BGfJ#) in <module>()
----> 1 jax.tree_util.tree_map(lambda x: x + 1, {1: 7, "y": 42})

1 frames
[google3/third_party/py/jax/_src/tree_util.py](https://colab.corp.google.com/drive/1M9-A_LXuCfjEkmDmhVOqaLlhVkl4BGfJ#) in tree_map(f, tree, is_leaf, *rest)
    205     [[5, 7, 9], [6, 1, 2]]
    206   """
--> 207   leaves, treedef = tree_flatten(tree, is_leaf)
    208   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    209   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

[google3/third_party/py/jax/_src/tree_util.py](https://colab.corp.google.com/drive/1M9-A_LXuCfjEkmDmhVOqaLlhVkl4BGfJ#) in tree_flatten(tree, is_leaf)
     58     element is a treedef representing the structure of the flattened tree.
     59   """
---> 60   return pytree.flatten(tree, is_leaf)
     61 
     62 

TypeError: '<' not supported between instances of 'str' and 'int'

What jax/jaxlib version are you using?

No response

Which accelerator(s) are you using?

No response

Additional system info

No response

NVIDIA GPU info

No response

@zi-w zi-w added the bug Something isn't working label Apr 1, 2023
@mattjj
Copy link
Collaborator

mattjj commented Apr 9, 2023

Thanks for opening this!

Indeed dicts must have sortable key sets to work with JAX's pytree machinery. The reason is that we want dict metadata equality (for hashing/caching purposes) to depend on the set of keys only and not the insertion order; if it depended on insertion order, we'd get too many surprising cache misses.

That said, we should document that requirement, and perhaps raise a clearer error.

If you have dicts with unsortable key sets and you want to use them with pytrees, one approach would be to make a new pytree node and specify the flattening/unflattening (and metadata) logic yourself, e.g.

@jax.tree_util.register_pytreee_node_class
class MyDict(dict):
  def tree_flatten(self):
    ...

  @classmethod
  def tree_unflatten(cls, ...):
    ...

What do you think?

@mattjj mattjj added documentation better_errors Improve the error reporting and removed bug Something isn't working labels Apr 9, 2023
@zi-w
Copy link
Author

zi-w commented Apr 14, 2023

Thank you for the detailed explanations! Yes, I think a better error message would be sufficient.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better_errors Improve the error reporting documentation
Projects
None yet
Development

No branches or pull requests

2 participants