diff --git a/neuralgcm/encoders_test.py b/neuralgcm/encoders_test.py index 30d8622..dcea400 100644 --- a/neuralgcm/encoders_test.py +++ b/neuralgcm/encoders_test.py @@ -130,7 +130,7 @@ def encode_fwd(inputs): params = encoder_model.init(rng, inputs) encoded_state = (encoder_model.apply(params, rng, inputs).state).asdict() - get_shape_fn = lambda tree: jax.tree_map(lambda x: x.shape, tree) + get_shape_fn = lambda tree: jax.tree_map(np.shape, tree) expected_shapes = primitive_equations.StateWithTime( divergence=coords.modal_shape, vorticity=coords.modal_shape, diff --git a/neuralgcm/models_test.py b/neuralgcm/models_test.py index 6273459..10e42ba 100644 --- a/neuralgcm/models_test.py +++ b/neuralgcm/models_test.py @@ -127,7 +127,7 @@ def test_shapes_of_model_function( rng_key = jax.random.PRNGKey(42) params = model.init_params(rng_key, data_trajectory, forcing_data) - get_shape_fn = lambda tree: jax.tree_util.tree_map(lambda x: x.shape, tree) + get_shape_fn = lambda tree: jax.tree_util.tree_map(np.shape, tree) with self.subTest('encoder'): actual = jax.tree_util.tree_leaves( @@ -386,7 +386,7 @@ def test_shapes_of_model_function( rng_key = jax.random.PRNGKey(42) params = model.init_params(rng_key, data_trajectory, forcing_data) - get_shape_fn = lambda tree: jax.tree_util.tree_map(lambda x: x.shape, tree) + get_shape_fn = lambda tree: jax.tree_util.tree_map(np.shape, tree) with self.subTest('encoder'): actual = jax.tree_util.tree_leaves( diff --git a/neuralgcm/reference_code/metrics_util.py b/neuralgcm/reference_code/metrics_util.py index 258c343..b7e9bd1 100644 --- a/neuralgcm/reference_code/metrics_util.py +++ b/neuralgcm/reference_code/metrics_util.py @@ -130,7 +130,7 @@ def assert_compliant(self, trajectory: typing.Pytree, is_nodal: bool) -> None: self.n_total_wavenumbers, ) - is_compliant = tree_map(lambda x: x.shape == expected_shape, trajectory) + is_compliant = tree_map(lambda x: np.shape(x) == expected_shape, trajectory) if not all(tree_leaves(is_compliant)): shapes = tree_map(np.shape, trajectory) raise ShapeError( diff --git a/neuralgcm/reference_code/train_utils.py b/neuralgcm/reference_code/train_utils.py index a976f29..deeeb91 100644 --- a/neuralgcm/reference_code/train_utils.py +++ b/neuralgcm/reference_code/train_utils.py @@ -265,7 +265,7 @@ def add_noise_to_input_frame( """ del kwargs # unused. time_zero_slice = pytree_utils.slice_along_axis(batch, 1, 0) - shapes = jax.tree.map(lambda x: x.shape, time_zero_slice) + shapes = jax.tree.map(np.shape, time_zero_slice) rngs = jax.random.split(rng, len(jax.tree.leaves(time_zero_slice))) rngs = jax.tree.unflatten(jax.tree.structure(time_zero_slice), rngs)