diff --git a/neuralgcm/api.py b/neuralgcm/api.py index 0a38947..442cd85 100644 --- a/neuralgcm/api.py +++ b/neuralgcm/api.py @@ -13,6 +13,7 @@ # limitations under the License. """Public API for NeuralGCM models.""" from __future__ import annotations + import datetime import functools from typing import Any @@ -75,10 +76,12 @@ def _prepend_dummy_time_axis(state: typing.Pytree) -> typing.Pytree: def _static_gin_config(method): """Decorator to add static gin config to a method.""" + @functools.wraps(method) def _method(self, *args, **kwargs): with gin_utils.specific_config(self.gin_config): return method(self, *args, **kwargs) + return _method @@ -185,7 +188,7 @@ def data_to_xarray( 'geopotential': 'z', 'temperature': 't', 'longitude': 'lon', - 'latitude': 'lat' + 'latitude': 'lat', } return xarray_utils.data_to_xarray_with_renaming( data, @@ -198,7 +201,10 @@ def data_to_xarray( @jax.jit @_static_gin_config def encode( - self, inputs: Inputs, forcings: Forcings, rng_key: typing.PRNGKeyArray + self, + inputs: Inputs, + forcings: Forcings, + rng_key: typing.PRNGKeyArray | None = None, ) -> State: """Encode from pressure-level inputs & forcings to model state. @@ -209,7 +215,8 @@ def encode( an array with shape `[level, longitude, latitude]` matching `data_coords`. Single level data (e.g., sea surface temperature) should have a `level` dimension of size 1. - rng_key: JAX RNG key to use for encoding the state. + rng_key: optional JAX RNG key to use for encoding the state. Required if + using stochastic models, otherwise ignored. Returns: Dynamical core state on sigma levels, where all arrays have dimensions @@ -224,9 +231,7 @@ def encode( @jax.jit @_static_gin_config - def advance( - self, state: State, forcings: Forcings, rng_key: typing.PRNGKeyArray - ) -> State: + def advance(self, state: State, forcings: Forcings) -> State: """Advance model state one timestep forward. Args: @@ -237,7 +242,6 @@ def advance( an array with shape `[level, longitude, latitude]` matching `data_coords`. Single level data (e.g., sea surface temperature) should have a `level` dimension of size 1. - rng_key: JAX RNG key to use for advancing the state. Returns: State advanced one time-step forward. @@ -246,7 +250,7 @@ def advance( sim_time = _sim_time_from_state(state) forcings = _prepend_dummy_time_axis(forcings) f = self._structure.forcing_fn(self.params, None, forcings, sim_time) - state = self._structure.advance_fn(self.params, rng_key, state, f) + state = self._structure.advance_fn(self.params, None, state, f) return state @jax.jit @@ -282,7 +286,6 @@ def unroll( self, state: State, forcings: Forcings, - rng_key: typing.PRNGKeyArray, *, steps: int, timedelta: TimedeltaLike | None = None, @@ -292,9 +295,7 @@ def unroll( Usage: - advanced_state, outputs = model.unroll( - state, forcings, rng_key, steps=N, post_process_fn=model.decode - ) + advanced_state, outputs = model.unroll(state, forcings, steps=N) where `advanced_state` is the advanced model state after `N` steps and `outputs` is a trajectory of decoded states on pressure-levels with a @@ -306,13 +307,12 @@ def unroll( trajectory. Should include a leading time-axis, but times can be at any desired granularity compatible with the model (e.g., it should be fine to supply daily forcing data, even if producing hourly outputs). - rng_key: random key to use for advancing state. steps: number of time-steps to take. timedelta: size of each time-step to take, which must be a multiple of the internal model timestep. By default uses the internal model timestep. - start_with_input: if `True`, outputs are at times `[0, ..., (steps - - 1) * timestep]` relative to the initial time; if `False`, outputs - are at times `[timestep, ..., steps * timestep]`. + start_with_input: if `True`, outputs are at times `[0, ..., (steps - 1) * + timestep]` relative to the initial time; if `False`, outputs are at + times `[timestep, ..., steps * timestep]`. Returns: A tuple of the advanced state at time `steps * timestamp`, and outputs @@ -340,17 +340,16 @@ def compute_slice_fwd(state, forcings): ) compute_slice = hk.transform(compute_slice_fwd) - return compute_slice.apply(self.params, rng_key, state, forcings) + return compute_slice.apply(self.params, None, state, forcings) @classmethod def from_checkpoint(cls, checkpoint: Any) -> PressureLevelModel: """Creates a PressureLevelModel from a checkpoint. Args: - checkpoint: dictionary with keys "model_config_str", "aux_ds_dict" - and "params" that specifies model gin configuration, supplemental - xarray dataset with model-specific static features, and model - parameters. + checkpoint: dictionary with keys "model_config_str", "aux_ds_dict" and + "params" that specifies model gin configuration, supplemental xarray + dataset with model-specific static features, and model parameters. Returns: Instance of a `PressureLevelModel` with weights and configuration @@ -361,7 +360,8 @@ def from_checkpoint(cls, checkpoint: Any) -> PressureLevelModel: aux_ds = xarray.Dataset.from_dict(checkpoint['aux_ds_dict']) data_coords = model_builder.coordinate_system_from_dataset(aux_ds) model_specs = model_builder.get_model_specs( - data_coords, physics_specs, {xarray_utils.XARRAY_DS_KEY: aux_ds}) + data_coords, physics_specs, {xarray_utils.XARRAY_DS_KEY: aux_ds} + ) whirl_model = model_builder.WhirlModel( coords=model_specs.coords, dt=model_specs.dt, diff --git a/neuralgcm/decoders.py b/neuralgcm/decoders.py index 13c4a1e..97b95fc 100644 --- a/neuralgcm/decoders.py +++ b/neuralgcm/decoders.py @@ -15,6 +15,8 @@ import functools from typing import Any, Callable, Dict, Optional, Tuple +import zlib + from dinosaur import coordinate_systems from dinosaur import primitive_equations from dinosaur import pytree_utils @@ -500,6 +502,9 @@ def __call__( return self.transform_fn(wb_on_sigma.asdict()) +_DECODER_SALT = zlib.crc32(b'decoder') # arbitrary int32 value + + @gin.register class LearnedPrimitiveToWeatherbenchDecoder(PrimitiveToWeatherbenchDecoder): """Similar to `PrimitiveToWeatherbenchDecoder` with learned interpolation.""" @@ -567,7 +572,7 @@ def __call__( self, inputs: ModelState, forcing: Forcing ) -> DataState: randomness = self.randomness_fn.unconditional_sample( - hk.next_rng_key() + jax.random.fold_in(inputs.randomness.prng_key, _DECODER_SALT) ) prognostics = self.perturbation_fn( inputs=self.coords.with_dycore_sharding(inputs.state), diff --git a/neuralgcm/encoders.py b/neuralgcm/encoders.py index e0e4a0c..316d308 100644 --- a/neuralgcm/encoders.py +++ b/neuralgcm/encoders.py @@ -599,7 +599,9 @@ def __call__( inputs: DataState, forcing: Forcing, ) -> ModelState: - randomness = self.randomness_fn.unconditional_sample(hk.next_rng_key()) + randomness = self.randomness_fn.unconditional_sample( + hk.maybe_next_rng_key() + ) wb_state = self.coords.with_physics_sharding( weatherbench_utils.State(**self.slice_fn(inputs)) ) diff --git a/neuralgcm/steps.py b/neuralgcm/steps.py index 0968f2e..54f69b7 100644 --- a/neuralgcm/steps.py +++ b/neuralgcm/steps.py @@ -83,7 +83,9 @@ def finalize_state( Returns: Initialized model state. """ - x.randomness = self.randomness_fn.unconditional_sample(hk.next_rng_key()) + x.randomness = self.randomness_fn.unconditional_sample( + hk.maybe_next_rng_key() + ) x.diagnostics = self.diagnostics_fn( x, physics_tendencies=None, forcing=forcing) return x @@ -288,8 +290,7 @@ def step_fn(x): next_state = self.corrector_fn(x.state, pp_tendency, forcing) # TODO(dkochkov) update stochastic modules to take optional state. - next_randomness = self.randomness_fn.advance( - x.randomness, rng=hk.next_rng_key()) + next_randomness = self.randomness_fn.advance(x.randomness) next_memory = x.state if x.memory is not None else None next_diagnostics = self.diagnostics_fn(x, pp_tendency) x_next = ModelState( diff --git a/neuralgcm/stochastic.py b/neuralgcm/stochastic.py index b2d87f1..a12db97 100644 --- a/neuralgcm/stochastic.py +++ b/neuralgcm/stochastic.py @@ -16,7 +16,8 @@ import abc import enum import logging -from typing import Any, Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, TypeVar, Union +import zlib from dinosaur import coordinate_systems from dinosaur import typing @@ -83,19 +84,11 @@ def preferred_representation(self) -> PreferredRepresentation | None: """The PreferredRepresentation for this field, or None if no preference.""" @abc.abstractmethod - def unconditional_sample( - self, - rng: typing.PRNGKeyArray, - ) -> RandomnessState: + def unconditional_sample(self, rng: typing.PRNGKeyArray) -> RandomnessState: """Sample the random field unconditionally.""" @abc.abstractmethod - def advance( - self, - state: RandomnessState, - *, - rng: typing.PRNGKeyArray, - ) -> RandomnessState: + def advance(self, state: RandomnessState) -> RandomnessState: """Updates the core state of a random field.""" @abc.abstractmethod @@ -110,6 +103,18 @@ def to_nodal_values(self, core_state: CoreRandomState) -> typing.Array | None: RandomnessModule = Callable[..., RandomField] +_RANDOM_FIELD_SALT = zlib.crc32(b'random_field') # arbitrary int32 value + + +T = TypeVar('T', typing.PRNGKeyArray, None) + + +def _next_prng_key(key: T) -> T: + if key is None: + return None + return jax.random.fold_in(key, _RANDOM_FIELD_SALT) + + @gin.register class NoRandomField(RandomField): """Module that disables randomness in a given module returning `None`.""" @@ -140,22 +145,14 @@ def preferred_representation(self) -> PreferredRepresentation | None: return None def unconditional_sample( - self, - rng: typing.PRNGKeyArray, + self, rng: typing.PRNGKeyArray | None ) -> RandomnessState: """Returns a zeros initialized state.""" - del rng # unused - return RandomnessState() # default to None elements. + return RandomnessState(prng_key=rng) - def advance( - self, - state: RandomnessState, - *, - rng: typing.PRNGKeyArray, - ) -> RandomnessState: + def advance(self, state: RandomnessState) -> RandomnessState: """Updates the state of a random gaussian field.""" - del state, rng # unused. - return RandomnessState() # defaults to None elements. + return RandomnessState(prng_key=_next_prng_key(state.prng_key)) def to_nodal_values(self, core_state: CoreRandomState) -> typing.Array | None: del core_state # unused. @@ -203,10 +200,9 @@ def preferred_representation(self) -> PreferredRepresentation | None: def unconditional_sample( self, - rng: typing.PRNGKeyArray, + rng: typing.PRNGKeyArray | None ) -> RandomnessState: """Returns a zeros initialized state.""" - del rng # unused if self._prefer_nodal: core = jnp.zeros(self.coords.horizontal.nodal_shape) else: @@ -215,21 +211,17 @@ def unconditional_sample( core=core, nodal_value=jnp.zeros(self.coords.horizontal.nodal_shape), modal_value=jnp.zeros(self.coords.horizontal.modal_shape), + prng_key=rng, ) - def advance( - self, - state: RandomnessState, - *, - rng: typing.PRNGKeyArray, - ) -> RandomnessState: + def advance(self, state: RandomnessState) -> RandomnessState: """Updates the state of a random gaussian field.""" - del rng # unused _validate_randomness_state(state) return RandomnessState( core=jnp.zeros_like(state.core), nodal_value=jnp.zeros(self.coords.horizontal.nodal_shape), modal_value=jnp.zeros(self.coords.horizontal.modal_shape), + prng_key=_next_prng_key(state.prng_key), ) def to_nodal_values(self, core_state: CoreRandomState) -> typing.Array | None: @@ -355,17 +347,16 @@ def _sigma_array(self) -> jax.Array: # have L2 norm = radius. See http://screen/9FYVXZ5cMHoGDZk return normalization * sigmas_unnormed / self.coords.horizontal.radius - def unconditional_sample( - self, - rng: typing.PRNGKeyArray, - ) -> RandomnessState: + def unconditional_sample(self, rng: typing.PRNGKeyArray) -> RandomnessState: """Returns a randomly initialized state for the autoregressive process.""" modal_shape = self.coords.horizontal.modal_shape + rng, next_rng = jax.random.split(rng) if self.variance is None: return RandomnessState( core=jnp.zeros(modal_shape), nodal_value=jnp.zeros(self.coords.horizontal.nodal_shape), modal_value=jnp.zeros(modal_shape), + prng_key=next_rng, ) sigmas = self._sigma_array() weights = jnp.where( @@ -378,21 +369,19 @@ def unconditional_sample( core=core, nodal_value=self.to_nodal_values(core), modal_value=self.to_modal_values(core), + prng_key=next_rng, ) - def advance( - self, - state: RandomnessState, - *, - rng: typing.PRNGKeyArray, - ) -> RandomnessState: + def advance(self, state: RandomnessState) -> RandomnessState: """Updates the CoreRandomState of a random gaussian field.""" _validate_randomness_state(state) + rng, next_rng = jax.random.split(rng) if self.variance is None: return RandomnessState( core=jnp.zeros_like(state.core), nodal_value=jnp.zeros(self.coords.horizontal.nodal_shape), modal_value=jnp.zeros(self.coords.horizontal.modal_shape), + prng_key=next_rng, ) modal_shape = self.coords.horizontal.modal_shape eta = jax.random.truncated_normal(rng, -self.clip, self.clip, modal_shape) @@ -403,6 +392,7 @@ def advance( core=next_core, nodal_value=self.to_nodal_values(next_core), modal_value=self.to_modal_values(next_core), + prng_key=next_rng, ) @property @@ -722,10 +712,7 @@ def make_rf(correlation_time, correlation_length, variance): def n_fields(self) -> int: return self._n_fields - def unconditional_sample( - self, - rng: typing.PRNGKeyArray, - ) -> RandomnessState: + def unconditional_sample(self, rng: typing.PRNGKeyArray) -> RandomnessState: """Sample the batch GRFs unconditionally.""" logging.info( '[NGCM] Calling BatchGaussianRandomFieldModule.unconditional_sample' @@ -744,22 +731,16 @@ def _unconditional_sample_one_rf( self._variances, ) - def advance( - self, - state: RandomnessState, - *, - rng: typing.PRNGKeyArray, - ) -> RandomnessState: + def advance(self, state: RandomnessState) -> RandomnessState: """Updates the state of the batch of GRFs.""" logging.info('[NGCM] Calling BatchGaussianRandomFieldModule.advance') - def _advance_one_rf(s, key, correlation_time, correlation_length, variance): + def _advance_one_rf(state, correlation_time, correlation_length, variance): rf = self._make_rf(correlation_time, correlation_length, variance) - return rf.advance(state=s, rng=key) + return rf.advance(state) return jax.vmap(_advance_one_rf)( state, - jax.random.split(rng, self.n_fields), self._correlation_times, self._correlation_lengths, self._variances, @@ -886,9 +867,8 @@ def unconditional_sample( core = {} nodal_values = {} modal_values = {} - for (name, rf), sample_key in zip( - self._random_fields.items(), jax.random.split(rng, self.n_fields) - ): + *rngs, next_rng = jax.random.split(rng, self.n_fields + 1) + for (name, rf), sample_key in zip(self._random_fields.items(), rngs): rvs = rf.unconditional_sample(sample_key) core[name] = rvs.core nodal_values[name] = rvs.nodal_value @@ -897,21 +877,16 @@ def unconditional_sample( core=core, nodal_value=nodal_values, modal_value=modal_values, + prng_key=next_rng, ) - def advance( - self, - state: RandomnessState, - *, - rng: typing.PRNGKeyArray, - ) -> RandomnessState: + def advance(self, state: RandomnessState) -> RandomnessState: """Updates the core state of a random field.""" core = {} nodal_values = {} modal_values = {} - for (name, rf), sample_key in zip( - self._random_fields.items(), jax.random.split(rng, self.n_fields) - ): + *rngs, next_rng = jax.random.split(state.prng_key, self.n_fields + 1) + for (name, rf), sample_key in zip(self._random_fields.items(), rngs): # rvs is a RandomnessState. rvs = rf.advance(RandomnessState(state.core[name]), rng=sample_key) core[name] = rvs.core @@ -921,6 +896,7 @@ def advance( core=core, nodal_value=nodal_values, modal_value=modal_values, + prng_key=next_rng, ) @@ -965,17 +941,11 @@ def unconditional_sample( modal_value=self.to_modal_values(rvs), ) - def advance( - self, - state: RandomnessState, - *, - rng: typing.PRNGKeyArray, - ) -> RandomnessState: + def advance(self, state: RandomnessState) -> RandomnessState: """Updates the core state of a random field.""" rvs = [] for rf, s in zip(self._random_fields, state.core, strict=True): - rng, sample_key = jax.random.split(rng) - rvs.append(rf.advance(RandomnessState(s), rng=sample_key).core) + rvs.append(rf.advance(RandomnessState(s)).core) return RandomnessState( core=rvs, nodal_value=self.to_nodal_values(rvs), diff --git a/neuralgcm/stochastic_test.py b/neuralgcm/stochastic_test.py index 8fa920c..9f3e78b 100644 --- a/neuralgcm/stochastic_test.py +++ b/neuralgcm/stochastic_test.py @@ -505,7 +505,6 @@ def test_stats( @hk.transform def make_field_trajectory(key): - init_key, key = jax.random.split(key) grf = self._make_grf( variances=variances, initial_correlation_lengths=initial_correlation_lengths, @@ -513,16 +512,14 @@ def make_field_trajectory(key): # Do not specify the field names... Let the default naming happen. field_subset=field_subset, ) - sample = grf.unconditional_sample(init_key) + sample = grf.unconditional_sample(key) - step_keys = jax.random.split(key, unroll_length) - - def step_fn(c, x): - next_c = grf.advance(state=c, rng=x) + def step_fn(c, _): + next_c = grf.advance(c) next_output = next_c.nodal_value return (next_c, next_output) - _, trajectory = jax.lax.scan(step_fn, sample, xs=step_keys) + _, trajectory = jax.lax.scan(step_fn, sample, length=unroll_length) return sample, jax.device_get(trajectory) n_samples = 2000