Skip to content

Commit

Permalink
Use prng_key from RandomnessState instead of passing it explicitly
Browse files Browse the repository at this point in the history
This will allow for using NeuralGCM models will less boilerplate.

PiperOrigin-RevId: 621323220
  • Loading branch information
shoyer authored and NeuralGCM authors committed Apr 3, 2024
1 parent b317577 commit fb7701b
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 109 deletions.
44 changes: 22 additions & 22 deletions neuralgcm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Public API for NeuralGCM models."""
from __future__ import annotations

import datetime
import functools
from typing import Any
Expand Down Expand Up @@ -75,10 +76,12 @@ def _prepend_dummy_time_axis(state: typing.Pytree) -> typing.Pytree:

def _static_gin_config(method):
"""Decorator to add static gin config to a method."""

@functools.wraps(method)
def _method(self, *args, **kwargs):
with gin_utils.specific_config(self.gin_config):
return method(self, *args, **kwargs)

return _method


Expand Down Expand Up @@ -185,7 +188,7 @@ def data_to_xarray(
'geopotential': 'z',
'temperature': 't',
'longitude': 'lon',
'latitude': 'lat'
'latitude': 'lat',
}
return xarray_utils.data_to_xarray_with_renaming(
data,
Expand All @@ -198,7 +201,10 @@ def data_to_xarray(
@jax.jit
@_static_gin_config
def encode(
self, inputs: Inputs, forcings: Forcings, rng_key: typing.PRNGKeyArray
self,
inputs: Inputs,
forcings: Forcings,
rng_key: typing.PRNGKeyArray | None = None,
) -> State:
"""Encode from pressure-level inputs & forcings to model state.
Expand All @@ -209,7 +215,8 @@ def encode(
an array with shape `[level, longitude, latitude]` matching
`data_coords`. Single level data (e.g., sea surface temperature) should
have a `level` dimension of size 1.
rng_key: JAX RNG key to use for encoding the state.
rng_key: optional JAX RNG key to use for encoding the state. Required if
using stochastic models, otherwise ignored.
Returns:
Dynamical core state on sigma levels, where all arrays have dimensions
Expand All @@ -224,9 +231,7 @@ def encode(

@jax.jit
@_static_gin_config
def advance(
self, state: State, forcings: Forcings, rng_key: typing.PRNGKeyArray
) -> State:
def advance(self, state: State, forcings: Forcings) -> State:
"""Advance model state one timestep forward.
Args:
Expand All @@ -237,7 +242,6 @@ def advance(
an array with shape `[level, longitude, latitude]` matching
`data_coords`. Single level data (e.g., sea surface temperature) should
have a `level` dimension of size 1.
rng_key: JAX RNG key to use for advancing the state.
Returns:
State advanced one time-step forward.
Expand All @@ -246,7 +250,7 @@ def advance(
sim_time = _sim_time_from_state(state)
forcings = _prepend_dummy_time_axis(forcings)
f = self._structure.forcing_fn(self.params, None, forcings, sim_time)
state = self._structure.advance_fn(self.params, rng_key, state, f)
state = self._structure.advance_fn(self.params, None, state, f)
return state

@jax.jit
Expand Down Expand Up @@ -282,7 +286,6 @@ def unroll(
self,
state: State,
forcings: Forcings,
rng_key: typing.PRNGKeyArray,
*,
steps: int,
timedelta: TimedeltaLike | None = None,
Expand All @@ -292,9 +295,7 @@ def unroll(
Usage:
advanced_state, outputs = model.unroll(
state, forcings, rng_key, steps=N, post_process_fn=model.decode
)
advanced_state, outputs = model.unroll(state, forcings, steps=N)
where `advanced_state` is the advanced model state after `N` steps and
`outputs` is a trajectory of decoded states on pressure-levels with a
Expand All @@ -306,13 +307,12 @@ def unroll(
trajectory. Should include a leading time-axis, but times can be at any
desired granularity compatible with the model (e.g., it should be fine
to supply daily forcing data, even if producing hourly outputs).
rng_key: random key to use for advancing state.
steps: number of time-steps to take.
timedelta: size of each time-step to take, which must be a multiple of the
internal model timestep. By default uses the internal model timestep.
start_with_input: if `True`, outputs are at times `[0, ..., (steps
- 1) * timestep]` relative to the initial time; if `False`, outputs
are at times `[timestep, ..., steps * timestep]`.
start_with_input: if `True`, outputs are at times `[0, ..., (steps - 1) *
timestep]` relative to the initial time; if `False`, outputs are at
times `[timestep, ..., steps * timestep]`.
Returns:
A tuple of the advanced state at time `steps * timestamp`, and outputs
Expand Down Expand Up @@ -340,17 +340,16 @@ def compute_slice_fwd(state, forcings):
)

compute_slice = hk.transform(compute_slice_fwd)
return compute_slice.apply(self.params, rng_key, state, forcings)
return compute_slice.apply(self.params, None, state, forcings)

@classmethod
def from_checkpoint(cls, checkpoint: Any) -> PressureLevelModel:
"""Creates a PressureLevelModel from a checkpoint.
Args:
checkpoint: dictionary with keys "model_config_str", "aux_ds_dict"
and "params" that specifies model gin configuration, supplemental
xarray dataset with model-specific static features, and model
parameters.
checkpoint: dictionary with keys "model_config_str", "aux_ds_dict" and
"params" that specifies model gin configuration, supplemental xarray
dataset with model-specific static features, and model parameters.
Returns:
Instance of a `PressureLevelModel` with weights and configuration
Expand All @@ -361,7 +360,8 @@ def from_checkpoint(cls, checkpoint: Any) -> PressureLevelModel:
aux_ds = xarray.Dataset.from_dict(checkpoint['aux_ds_dict'])
data_coords = model_builder.coordinate_system_from_dataset(aux_ds)
model_specs = model_builder.get_model_specs(
data_coords, physics_specs, {xarray_utils.XARRAY_DS_KEY: aux_ds})
data_coords, physics_specs, {xarray_utils.XARRAY_DS_KEY: aux_ds}
)
whirl_model = model_builder.WhirlModel(
coords=model_specs.coords,
dt=model_specs.dt,
Expand Down
7 changes: 6 additions & 1 deletion neuralgcm/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import functools
from typing import Any, Callable, Dict, Optional, Tuple
import zlib

from dinosaur import coordinate_systems
from dinosaur import primitive_equations
from dinosaur import pytree_utils
Expand Down Expand Up @@ -500,6 +502,9 @@ def __call__(
return self.transform_fn(wb_on_sigma.asdict())


_DECODER_SALT = zlib.crc32(b'decoder') # arbitrary int32 value


@gin.register
class LearnedPrimitiveToWeatherbenchDecoder(PrimitiveToWeatherbenchDecoder):
"""Similar to `PrimitiveToWeatherbenchDecoder` with learned interpolation."""
Expand Down Expand Up @@ -567,7 +572,7 @@ def __call__(
self, inputs: ModelState, forcing: Forcing
) -> DataState:
randomness = self.randomness_fn.unconditional_sample(
hk.next_rng_key()
jax.random.fold_in(inputs.randomness.prng_key, _DECODER_SALT)
)
prognostics = self.perturbation_fn(
inputs=self.coords.with_dycore_sharding(inputs.state),
Expand Down
4 changes: 3 additions & 1 deletion neuralgcm/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,9 @@ def __call__(
inputs: DataState,
forcing: Forcing,
) -> ModelState:
randomness = self.randomness_fn.unconditional_sample(hk.next_rng_key())
randomness = self.randomness_fn.unconditional_sample(
hk.maybe_next_rng_key()
)
wb_state = self.coords.with_physics_sharding(
weatherbench_utils.State(**self.slice_fn(inputs))
)
Expand Down
7 changes: 4 additions & 3 deletions neuralgcm/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def finalize_state(
Returns:
Initialized model state.
"""
x.randomness = self.randomness_fn.unconditional_sample(hk.next_rng_key())
x.randomness = self.randomness_fn.unconditional_sample(
hk.maybe_next_rng_key()
)
x.diagnostics = self.diagnostics_fn(
x, physics_tendencies=None, forcing=forcing)
return x
Expand Down Expand Up @@ -288,8 +290,7 @@ def step_fn(x):

next_state = self.corrector_fn(x.state, pp_tendency, forcing)
# TODO(dkochkov) update stochastic modules to take optional state.
next_randomness = self.randomness_fn.advance(
x.randomness, rng=hk.next_rng_key())
next_randomness = self.randomness_fn.advance(x.randomness)
next_memory = x.state if x.memory is not None else None
next_diagnostics = self.diagnostics_fn(x, pp_tendency)
x_next = ModelState(
Expand Down
Loading

0 comments on commit fb7701b

Please sign in to comment.