Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model checkpoint load and store logic #93

Open
albertz opened this issue Jan 18, 2022 · 0 comments
Open

Model checkpoint load and store logic #93

albertz opened this issue Jan 18, 2022 · 0 comments
Milestone

Comments

@albertz
Copy link
Member

albertz commented Jan 18, 2022

So far this was not really considered, and I guess the idea was to keep this logic separate, and use the standard RETURNN mechanisms (model, load, preload_from_files, etc).

Together with (potentially custom) parameter initialization (#59, #92), we might unify this.
Custom param updates or assignments (#90) might be related, but maybe not.

Specifically, in PyTorch, each Module has functions load_state_dict and state_dict, and those are serialized to disk via torch.load and torch.save (tutorial).

In current RETURNN, this is complicated to do. model just specifies the name for the root module. This implies that there must be one specific root module. Loading other sub modules with other parameters can be done via preload_from_files by name prefix. For saving, it will save the whole module (including all sub modules) in the given checkpoint.

Related is also how the root module is defined (#44).


Potential API: Module.checkpoint_filename

A checkpoint_filename attribute in a Module. This would specify the checkpoint filename for load and save, for all the params in the module and all submodules. It would otherwise rely on the standard RETURNN logic.

Submodules could also define checkpoint_filename, which would overwrite the parent filename. Those parameters would then be excluded from the parent module as well.

This also does not require one single root module anymore.

However, it is probably not always wanted that all the modules will get saved after an epoch. Maybe some (sub) module was only loaded for initialization (either directly, or indirectly like #92 (comment)). This API would not really allow that.

Discussion

All the iterative-looking logic currently being implemented, like

y = ...
loss = cross_entropy(...)
loss.mark_as_loss()

this defines what is being calculated per-step. So it assumes one outer loop over the steps.
Actually, you can also see this more as a model definition. What actually is done per-step is the model update in training via the optimizer, which is separated here. Actually the optimizer is also not really part of returnn-common yet.

In any case, now mixing this with logic which is done at initialization (also depending on whether this is a new run, starting in epoch 1, or a continued run, loading a previous checkpoint), and with logic done per-epoch, can maybe cause confusion.

Maybe we should make this logic more explicit, the definition what part the calculation is executed in what context. Maybe similar to our Loop (#16) logic with a context manager, like:

model = Model(...)

with nn.init_ctx():
  ...

with nn.step_loop():
  ...

with nn.epoch_loop():
  ...

Edit This proposal on the explicit training loop and stages was moved to a separate issue: #96

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant