Skip to content

Commit

Permalink
trajectory_with_inputs_and_forcing_and_stop_gradients added to model_…
Browse files Browse the repository at this point in the history
…utils.py

Does not slow down TL63. However, due to the more complex set of operations, it will be turned off by default in `experiment.py` (see http://cl/619948462). Timing
* TL63, without this, 10.6 sec/step http://screen/AtDbXzYGy4F3wXS
* TL63, with this, 10.6 sec/step  http://screen/B7juZUCGWXHafN7

See also shoyer's http://cl/522829748, which adds stop grad steps at a lower level.

PiperOrigin-RevId: 620896532
  • Loading branch information
langmore authored and NeuralGCM authors committed Apr 9, 2024
1 parent 0496d29 commit 009425a
Show file tree
Hide file tree
Showing 2 changed files with 420 additions and 1 deletion.
121 changes: 120 additions & 1 deletion neuralgcm/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import dataclasses
import functools
from typing import Any, Callable, Tuple
from typing import Any, Callable, Sequence, Tuple
from dinosaur import coordinate_systems
from dinosaur import pytree_utils
from dinosaur import typing
Expand All @@ -27,6 +27,9 @@
DynamicalSystem = Any # to prevent circular dependency on model_builder
Pytree = typing.Pytree

tree_map = jax.tree_util.tree_map
tree_leaves = jax.tree_util.tree_leaves

# Linter confused by wrapped functions
# pylint: disable=g-bare-generic

Expand Down Expand Up @@ -170,6 +173,122 @@ def _trajectory_fn(x, forcing_data, outer_steps, inner_steps=1):
return _trajectory_fn


def trajectory_with_inputs_and_forcing_and_stop_gradients(
model: DynamicalSystem,
num_init_frames: int,
start_with_input: bool = False,
stop_gradient_outer_steps: Sequence[int] = (),
) -> typing.TrajectoryFn:
"""Returns trajectory_fn that comuptes model trajectory from target data.
This extension of `trajectory_with_inputs_and_forcing` allows adding stop
gradients to the trajectory at designated steps. For example, if
`stop_gradient_outer_steps = [2]`, then gradients along the trajectory stop
at t=2. This does not mean that gradients with respect to X[2] will be zero.
It simply means that, for t > 2, gradients of X[t] with respect to X[2] will
be zero.
Wraps the default model.trajectory_fn to operate on data representation. It
corresponds to slicing `num_init_frames` from the inputs, encoding and
unrolling the trajectory.
Args:
model: model of a dynamical system used to obtain the trajectory.
num_init_frames: number of time frames used from the physics trajectory to
initialize the model state.
start_with_input: whether the firest decoded step in the output trajectory
should correspond to last input time or first future output.
stop_gradient_outer_steps: Tuple (possibly empty) indicating outer steps at
which to place stop gradients.
Returns:
Trajectory function that operates on target data trajectory by encoding
the `initial_frames` inputs and unrolls trajectory in a model space.
Decoding is not done by this function.
"""
stop_gradient_outer_steps = list(sorted(stop_gradient_outer_steps))
if num_init_frames != 1:
raise ValueError(f'{num_init_frames=} is not supported yet.')

if stop_gradient_outer_steps and min(stop_gradient_outer_steps) <= 0:
raise ValueError(
f'{stop_gradient_outer_steps=} contained non-positive values'
)

expand_dim0 = lambda tree: tree_map(lambda x_i: x_i[jnp.newaxis], tree)
concat_dim0 = lambda trees: pytree_utils.concat_along_axis(trees, axis=0)
slice_dim0 = lambda tree, idx: pytree_utils.slice_along_axis(
tree, axis=0, idx=idx
)

def concat_trajectories_with_stop_grads(
x, forcing_data, outer_steps, inner_steps=1
):
if (
stop_gradient_outer_steps
and max(stop_gradient_outer_steps) > outer_steps
):
raise ValueError(
f'{stop_gradient_outer_steps=} contained values > {outer_steps=}'
)
outer_steps_seq = list(stop_gradient_outer_steps)
if not outer_steps_seq or outer_steps_seq[-1] != outer_steps:
outer_steps_seq.append(outer_steps)

# The first leg needs to encode the input. So use
# trajectory_with_inputs_and_forcing, which does the encoding.
final_state, first_leg = trajectory_with_inputs_and_forcing(
model,
num_init_frames=num_init_frames,
start_with_input=start_with_input,
)(
x,
forcing_data=forcing_data,
outer_steps=outer_steps_seq[0],
inner_steps=inner_steps,
)

# At this point, sections contains times [0, ..., outer_steps_seq[0]]
sections = [
first_leg,
]

# Subsequent legs do not need encoding, so use model.trajectory directly.
trajectory_fn = functools.partial(
model.trajectory,
inner_steps=inner_steps,
forcing_data=forcing_data,
start_with_input=start_with_input,
)
for i in range(1, len(outer_steps_seq)):
# outer_steps_seq[-1] may or may not be in stop_gradient_outer_steps.
# The other steps will be by construction.
assert set(outer_steps_seq[:-1]).issubset(stop_gradient_outer_steps)
stop_grad_at_start = outer_steps_seq[i - 1] in stop_gradient_outer_steps

initial_state = final_state

# this_leg contains times [outer_steps_seq[0]+1, ..., outer_steps_seq[1]]
final_state, this_leg = trajectory_fn(
jax.lax.stop_gradient(initial_state)
if stop_grad_at_start
else initial_state,
outer_steps=outer_steps_seq[i] - outer_steps_seq[i - 1],
)

if stop_grad_at_start and start_with_input:
# Replace the initial point that had a stop gradient on it.
this_leg = concat_dim0([
expand_dim0(initial_state),
slice_dim0(this_leg, idx=slice(1, None)),
])
sections.append(this_leg)

return final_state, concat_dim0(sections)

return concat_trajectories_with_stop_grads


def decoded_trajectory_with_forcing(
model: DynamicalSystem,
start_with_input: bool = False,
Expand Down
Loading

0 comments on commit 009425a

Please sign in to comment.