Skip to content

Commit

Permalink
[dinosaur] Add prng_step to RandomnessState
Browse files Browse the repository at this point in the history
This facilitates avoiding the pattern of iterative splitting the same key, which has poor statistical properties. The recommended pattern for generating a new PRNG key is `jax.random.fold_in(state.prng_key, state.prng_step)`.

PiperOrigin-RevId: 621884598
  • Loading branch information
shoyer authored and NeuralGCM authors committed Apr 4, 2024
1 parent b317577 commit c4ec1bb
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion neuralgcm/encoders_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def encode_fwd(inputs):
params = encoder_model.init(rng, inputs)
encoded_state = (encoder_model.apply(params, rng, inputs).state).asdict()

get_shape_fn = lambda tree: jax.tree_map(lambda x: x.shape, tree)
get_shape_fn = lambda tree: jax.tree_map(np.shape, tree)
expected_shapes = primitive_equations.StateWithTime(
divergence=coords.modal_shape,
vorticity=coords.modal_shape,
Expand Down
4 changes: 2 additions & 2 deletions neuralgcm/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_shapes_of_model_function(
rng_key = jax.random.PRNGKey(42)
params = model.init_params(rng_key, data_trajectory, forcing_data)

get_shape_fn = lambda tree: jax.tree_util.tree_map(lambda x: x.shape, tree)
get_shape_fn = lambda tree: jax.tree_util.tree_map(np.shape, tree)

with self.subTest('encoder'):
actual = jax.tree_util.tree_leaves(
Expand Down Expand Up @@ -386,7 +386,7 @@ def test_shapes_of_model_function(

rng_key = jax.random.PRNGKey(42)
params = model.init_params(rng_key, data_trajectory, forcing_data)
get_shape_fn = lambda tree: jax.tree_util.tree_map(lambda x: x.shape, tree)
get_shape_fn = lambda tree: jax.tree_util.tree_map(np.shape, tree)

with self.subTest('encoder'):
actual = jax.tree_util.tree_leaves(
Expand Down
2 changes: 1 addition & 1 deletion neuralgcm/reference_code/metrics_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def assert_compliant(self, trajectory: typing.Pytree, is_nodal: bool) -> None:
self.n_total_wavenumbers,
)

is_compliant = tree_map(lambda x: x.shape == expected_shape, trajectory)
is_compliant = tree_map(lambda x: np.shape(x) == expected_shape, trajectory)
if not all(tree_leaves(is_compliant)):
shapes = tree_map(np.shape, trajectory)
raise ShapeError(
Expand Down
2 changes: 1 addition & 1 deletion neuralgcm/reference_code/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def add_noise_to_input_frame(
"""
del kwargs # unused.
time_zero_slice = pytree_utils.slice_along_axis(batch, 1, 0)
shapes = jax.tree.map(lambda x: x.shape, time_zero_slice)
shapes = jax.tree.map(np.shape, time_zero_slice)
rngs = jax.random.split(rng, len(jax.tree.leaves(time_zero_slice)))
rngs = jax.tree.unflatten(jax.tree.structure(time_zero_slice), rngs)

Expand Down

0 comments on commit c4ec1bb

Please sign in to comment.