Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recompute checkpointing #2471

Open
dham opened this issue Jun 17, 2022 · 1 comment
Open

Recompute checkpointing #2471

dham opened this issue Jun 17, 2022 · 1 comment

Comments

@dham
Copy link
Member

dham commented Jun 17, 2022

When computing the adjoint, we want revolve-style checkpointing for the use cases where the whole simulation state can't be kept in memory, and recomputing is faster than saving to disk.

There are various algorithms for executing the revolving checkpoint. See, for example, https://gitlab.inria.fr/adjoint-computation/H-Revolve/-/tree/master .

However, the basic model these follow is that the simulation is an ordered chain of operations, each of which depends only on the output of the previous one. We, by contrast, have a computation graph in which arbitrary forward dependencies are allowed. However, the normal usage pattern is that the graph encodes a sequence of timesteps, and most data is only used in the next timestep (or a bounded number of timesteps later).

In order to benefit from the previous work, we therefore need to transform our graph into a single chain of operations. The normal usage pattern would suggest that this should be possible and fairly efficient.

The first stage in this process will be to ask the user to tell us where the timesteps end, by making a function call. This will enable us to label every block variable with the timestep number in which it is created. Input block variables (i.e. those that first appear as a dependency) shouldn't have a timestep number (or should have an out of band value such as -1). We can also gather the blocks into timestep groups on the tape.

The next challenge is to determine what needs to be checkpointed. Because we can have arbitrary reuse of old values, this is not as simple as checkpointing all the variables in block $n-1$ which will be used in block $n$. Indeed, we don't even know which variables will be used in block $n$. Instead, with thanks to @jrmaddison for the idea, we defer the creation of checkpoints. While running forward, if the checkpointing algorithm instructs us to checkpoint ("takeshot") at the start of step $n$, we simply record $n$ as the last checkpoint on the tape. Now, for every input variable to every block, we perform the following algorithm. Let $n_0$ be the creation timestep of the block variable $v$, $n_b$ be the timestep of the current block (to which $v$ is an input), $n_c$ be the timestep of the last checkpoint, and let $n_{\textrm{last}}$ be the timestep of the last use of $v$ before this block:

  1. If $n_{\textrm{last}} < n_c$ , instruct $v$ to checkpoint itself. Checkpointing is idempotent, so this is safe even if $v$ was already checkpointed.
  2. For $n_{\textrm{last}} < m <= n_b$ add $v$ to the checkpointable state record of timestep $m$ (this ensures that if the tape is rerun with the checkpoints in different places, the correct data is checkpointed).
  3. Set $n_{\textrm{last}} = n_b$.
@dham
Copy link
Member Author

dham commented Jun 20, 2022

Let's consider the actions in the schedules produced by H-Revolve:

Operation Action
[F_i] Executes the i forward operation
[F_i->j] Executes the i, i+1, ..., j-1, j forward operations
[B_i] Executes the x backward operation
[WD_i] Writes the output of the (i-1) forward operation to disk
[RD_i] Reads the output of the (i-1) forward operation from disk
[WM_i] Writes the output of the (i-1) forward operation to memory
[RM_i] Reads the output of the (i-1) forward operation from memory
[DM_i] Discards the output of the (i-1) forward operation from memory

The implementation of these actions is different when recording the tape than when rerunning an existing tape. In record mode, only forward operations and stores can occur:

Operation Action
[F_i] Execute forward operation, don't automatically checkpoint block outputs
[F_i->j] As above
[B_i] N/A
[WD_i] Set $n_c$ to $i$ and switch on disk checkpointing. Switch to a new disk checkpoint file.
[RD_i] N/A
[WM_i] Set $n_c$ to $i$ and switch off disk checkpointing
[RM_i] N/A
[DM_i] N/A

In record mode, we examine the schedule each time the user increments the timestep.

In replay mode, we need to maintain a set, $l$ of currently live block_variables. These block variables have their memory checkpoints set in order to maintain the state of the simulation, but do not constitute a persistent checkpoint. The following operations occur:

Operation Action
[F_i] Execute the i forward operation, memory checkpointing as we go and adding all output block variables to $l$. At the end of the step, discard from $l$, and clear the memory checkpoints of, the block variables that aren't part of the checkpointable state.
[F_i->j] As above.
[B_i] F_i but don't discard from $l$ yet. Execute the reverse operation (adjoint or adjoint + hessian). Discard from $l$ the forward and reverse values that are not in the checkpointable state at i-1.
[WD_i] For each item in the checkpointable state at the start of step i, copy its checkpoint to disk.
[RD_i] For each item in the checkpointable state at the start of step i, read its checkpoint from disk.
[WM_i] Clear $l$ (thus turning current memory checkpoints permanent).
[RM_i] No-op (libadjoint reads memory checkpoints on demand).
[DM_i] Discard from $l$, and clear the memory checkpoints of, the block variables that are in the checkpointable state at $i$ but not at $i-1$.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant