Skip to content

Commit

Permalink
Updated jax.tree_map to jax.tree_util.tree_map. As of [JAX 0.4.26](
Browse files Browse the repository at this point in the history
…https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-26-april-3-2024), `jax.tree_map` is deprecated. We should switch to `jax.tree.map` if/when the JAX minimum version is >=0.4.26 for CLU.

Additional context:
* JAX PR that added the `jax.tree` module: jax-ml/jax#19588
* JAX PR that deprecated `jax.tree_map`: jax-ml/jax#19930

PiperOrigin-RevId: 622092167
  • Loading branch information
chiamp authored and copybara-github committed Apr 5, 2024
1 parent c50acb7 commit 873ad23
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 28 deletions.
3 changes: 1 addition & 2 deletions clu/deterministic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@
)
ds_iter = iter(ds)
for _ in range(num_train_steps):
batch = jax.tree_map(lambda x: x._numpy(), next(ds_iter)
batch = jax.tree_util.tree_map(lambda x: x._numpy(), next(ds_iter)
# (training step)
"""

import enum
Expand Down
8 changes: 4 additions & 4 deletions clu/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ def reduce_step(reduced: M, metric: M) -> tuple[M, None]:
# pylint: disable-next=protected-access
return reduced._reduce_merge(metric), None

first = jax.tree_map(lambda x: x[0], self)
remainder = jax.tree_map(lambda x: x[1:], self)
first = jax.tree_util.tree_map(lambda x: x[0], self)
remainder = jax.tree_util.tree_map(lambda x: x[1:], self)
# According to b/160868467#comment4, usage of `jax.lax.scan` does not add a
# significant computational cost for simple metrics where e.g. `jnp.sum`
# could be used instead.
Expand Down Expand Up @@ -365,7 +365,7 @@ def evaluate(params):
@pool
def copy_to_host(update):
return jax.tree_map(np.asarray, update)
return jax.tree_util.tree_map(np.asarray, update)
futures = []
for batch in eval_ds:
Expand Down Expand Up @@ -397,7 +397,7 @@ def merge(self, other: CollectingMetric) -> CollectingMetric:
return other
if self.values and not other.values:
return self
return type(self)(jax.tree_map(np.asarray, values))
return type(self)(jax.tree_util.tree_map(np.asarray, values))

def reduce(self) -> CollectingMetric:
# Note that this is usually called from inside a `pmap()` via
Expand Down
56 changes: 37 additions & 19 deletions clu/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,12 @@ def setUp(self):
}

# Stack all values. Can for example be used in a pmap().
self.model_outputs_stacked = jax.tree_map(lambda *args: jnp.stack(args),
*self.model_outputs)
self.model_outputs_masked_stacked = jax.tree_map(
lambda *args: jnp.stack(args), *self.model_outputs_masked)
self.model_outputs_stacked = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), *self.model_outputs
)
self.model_outputs_masked_stacked = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), *self.model_outputs_masked
)

def make_compute_metric(self, metric_class, reduce, jit=True):
"""Returns a jitted function to compute metrics.
Expand All @@ -131,8 +133,9 @@ def compute_metric(model_outputs):
metric_class.from_model_output(**model_output)
for model_output in model_outputs
]
metric_stacked = jax.tree_map(lambda *args: jnp.stack(args),
*metric_list)
metric_stacked = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), *metric_list
)
metric = metric_stacked.reduce()
else:
metric = metric_class.empty()
Expand All @@ -150,15 +153,23 @@ def test_metric_last_value_reduce(self):
metric2 = metrics.LastValue.from_model_output(jnp.array([3, 4]))
metric3 = metrics.LastValue.from_model_output(jnp.array([3, 4]),
jnp.array([0, 0]))
metric12 = jax.tree_map(lambda *args: jnp.stack(args), metric1, metric2)
metric21 = jax.tree_map(lambda *args: jnp.stack(args), metric2, metric1)
metric12 = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), metric1, metric2
)
metric21 = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), metric2, metric1
)
self.assertEqual(metric12.reduce().value, 2.5)

chex.assert_trees_all_equal(metric12.reduce().compute(),
metric21.reduce().compute())

metric13 = jax.tree_map(lambda *args: jnp.stack(args), metric1, metric3)
metric31 = jax.tree_map(lambda *args: jnp.stack(args), metric1, metric3)
metric13 = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), metric1, metric3
)
metric31 = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), metric1, metric3
)
self.assertEqual(metric13.reduce().value, 1.5)
chex.assert_trees_all_equal(metric13.reduce().compute(),
metric31.reduce().compute())
Expand Down Expand Up @@ -187,15 +198,15 @@ def test_metric_last_value_legacy_kwarg_value(self):
def test_metric_last_value_tree_manipulation(self):
# Test mapping leaves to other non array values (e.g.: None).
metric = metrics.LastValue(value=2.0)
metric = jax.tree_map(lambda x: None, metric)
metric = jax.tree_util.tree_map(lambda x: None, metric)
self.assertIsNone(metric.total, None)
self.assertIsNone(metric.count, None)
metric = metrics.LastValue(value=2.0, count=3)
metric = jax.tree_map(lambda x: None, metric)
metric = jax.tree_util.tree_map(lambda x: None, metric)
self.assertIsNone(metric.total, None)
self.assertIsNone(metric.count, None)
metric = metrics.LastValue(2.0)
metric = jax.tree_map(lambda x: None, metric)
metric = jax.tree_util.tree_map(lambda x: None, metric)
self.assertIsNone(metric.total, None)
self.assertIsNone(metric.count, None)

Expand Down Expand Up @@ -272,7 +283,9 @@ def rename_mask(**kwargs):
)
def test_merge_asserts_shape(self, metric_cls):
metric1 = metric_cls.from_model_output(jnp.arange(3.))
metric2 = jax.tree_map(lambda *args: jnp.stack(args), metric1, metric1)
metric2 = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), metric1, metric1
)
with self.assertRaisesRegex(ValueError, r"^Expected same shape"):
metric1.merge(metric2)

Expand All @@ -287,7 +300,9 @@ def test_accuracy(self, reduce):

def test_last_value_asserts_shape(self):
metric1 = metrics.LastValue.from_model_output(jnp.arange(3.))
metric2 = jax.tree_map(lambda *args: jnp.stack(args), metric1, metric1)
metric2 = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), metric1, metric1
)
with self.assertRaisesRegex(ValueError, r"^Expected same shape"):
metric1.merge(metric2)

Expand Down Expand Up @@ -407,8 +422,9 @@ def test_collection_gather(self, masked, all_gather_mock):
for model_output in (model_outputs)

]
all_gather_mock.return_value = jax.tree_map(lambda *args: jnp.stack(args),
*collections)
all_gather_mock.return_value = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), *collections
)

def compute_collection(model_outputs):
collection = Collection.gather_from_model_output(**model_outputs[0])
Expand Down Expand Up @@ -441,7 +457,9 @@ def test_collection_asserts_replication(self):
Collection.single_from_model_output(**model_output)
for model_output in self.model_outputs
]
collection = jax.tree_map(lambda *args: jnp.stack(args), *collections)
collection = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), *collections
)
with self.assertRaisesRegex(ValueError, r"^Collection is still replicated"):
collection.compute()

Expand Down Expand Up @@ -470,7 +488,7 @@ def test_collecting_metric_async(self):

@pool
def copy_to_host(update):
return jax.tree_map(np.asarray, update)
return jax.tree_util.tree_map(np.asarray, update)

futures = []
from_model_output = jax.jit(CollectingMetricAccuracy.from_model_output)
Expand Down
2 changes: 1 addition & 1 deletion clu/parameter_overview_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_get_parameter_overview(self):
rng = jax.random.PRNGKey(42)
# Weights of a 2D convolution with 2 filters.
variables = CNN().init(rng, jnp.zeros((2, 5, 5, 3)))
variables = jax.tree_map(jnp.ones_like, variables)
variables = jax.tree_util.tree_map(jnp.ones_like, variables)
self.assertEqual(
FLAX_CONV2D_PARAMETER_OVERVIEW,
parameter_overview.get_parameter_overview(
Expand Down
4 changes: 2 additions & 2 deletions clu_synopsis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1480,7 +1480,7 @@
"\n",
"my_metrics = None\n",
"for batch in test_ds:\n",
" batch = jax.tree_map(np.asarray, batch)\n",
" batch = jax.tree_util.tree_map(np.asarray, batch)\n",
" update = eval_step_p(batch).unreplicate()\n",
" my_metrics = update if my_metrics is None else my_metrics.merge(update)\n",
"\n",
Expand Down Expand Up @@ -1657,7 +1657,7 @@
"\n",
"my_metrics = None\n",
"for batch in test_ds:\n",
" batch = jax.tree_map(np.asarray, batch)\n",
" batch = jax.tree_util.tree_map(np.asarray, batch)\n",
" update = eval_step_p(batch).unreplicate()\n",
" my_metrics = update if my_metrics is None else my_metrics.merge(update)\n",
"\n",
Expand Down

0 comments on commit 873ad23

Please sign in to comment.