You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-14-34a99b71cb82>](https://colab.corp.google.com/drive/1M9-A_LXuCfjEkmDmhVOqaLlhVkl4BGfJ#) in <module>()
----> 1 jax.tree_util.tree_map(lambda x: x + 1, {1: 7, "y": 42})
1 frames
[google3/third_party/py/jax/_src/tree_util.py](https://colab.corp.google.com/drive/1M9-A_LXuCfjEkmDmhVOqaLlhVkl4BGfJ#) in tree_map(f, tree, is_leaf, *rest)
205 [[5, 7, 9], [6, 1, 2]]
206 """
--> 207 leaves, treedef = tree_flatten(tree, is_leaf)
208 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
209 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
[google3/third_party/py/jax/_src/tree_util.py](https://colab.corp.google.com/drive/1M9-A_LXuCfjEkmDmhVOqaLlhVkl4BGfJ#) in tree_flatten(tree, is_leaf)
58 element is a treedef representing the structure of the flattened tree.
59 """
---> 60 return pytree.flatten(tree, is_leaf)
61
62
TypeError: '<' not supported between instances of 'str' and 'int'
What jax/jaxlib version are you using?
No response
Which accelerator(s) are you using?
No response
Additional system info
No response
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered:
Indeed dicts must have sortable key sets to work with JAX's pytree machinery. The reason is that we want dict metadata equality (for hashing/caching purposes) to depend on the set of keys only and not the insertion order; if it depended on insertion order, we'd get too many surprising cache misses.
That said, we should document that requirement, and perhaps raise a clearer error.
If you have dicts with unsortable key sets and you want to use them with pytrees, one approach would be to make a new pytree node and specify the flattening/unflattening (and metadata) logic yourself, e.g.
Description
The following code currently raises TypeError probably due to key sorting.
Related issue: #4085
Error message:
What jax/jaxlib version are you using?
No response
Which accelerator(s) are you using?
No response
Additional system info
No response
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: