-
Notifications
You must be signed in to change notification settings - Fork 130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch pretraining, train_step_callback
, staged training
#1447
Comments
For PyTorch I just handled the pre-training in the model defintion itself. This means of course I already construct all the layers from the beginning even if they are not used, but this did not require any extra work and allows for much more complex schemes (and also step based instead of epoch based). The only thing I am missing right now is the functionality to change e.g. batch_size dynamically per epoch. |
That's certainly also an option. But it's also limited in other ways. But you could combine it with such mechanism I was proposing here. E.g. I often also have grown the layer dimensions. This would be tricky and/or inefficient in the way you describe. Your option would also take a bit more GPU memory than necessary. E.g. my initial models were really tiny, 2 layer Conformer or so with small dimensions. Also, as you say, changing any other config option, like batch size, or the dataset itself, etc. In TF, we would always have reconstructed the whole model. The proposed API here would avoid that. Such Those callbacks would also allow to change config options, similar to what we did in TF (although the API would look different). I'm leaning towards the last option, to have such Currently I'm thinking about using such return value for the callback: class StepCallbackReturn:
"""
Return value of the step callback.
"""
def __init__(
self,
*,
updated_model: Union[rf.Module, torch.nn.Module] = None,
config_updates: Optional[Dict[str, Any]] = None,
):
"""
:param updated_model: Set this in case the model was updated.
It can be the same instance as before or a new instance.
You should set this when you modify the model in any way,
such that the engine can recreate any wrapper objects
(e.g. the DDP wrapped module, or the RF wrapped module).
:param config_updates: Can include learning_rate, batch_size, etc.
"""
self.updated_model = updated_model
self.config_updates = config_updates It could also be extended later by more logic. |
To extend a bit on I currently think: It should be right before a train step, but only if this is not the very first train step after initialization (i.e. if this is not step 0) (edit why only if not step 0? reconsider this, maybe for consistency better always...?). And I think it only makes sense in training. (Maybe we should call it Example with initialization:
Example with loading existing model.
|
step_callback
, staged training
step_callback
, staged trainingtrain_step_callback
, staged training
For other use cases (e.g. adapting the gradient accumulation or other settings), I'm also thinking about a |
(Btw, about naming: in PyTorch, the forward hook is called afterwards, so it could also have been named "post forward hook", and then there is a pre forward hook, which is called before. Maybe not the same thing, though... Also, "hook" vs "callback".) |
When it comes to settings (config updates), e.g. the example to change grad accum dynamically, or whatever else, I was also thinking, maybe doing that in one such callback function is not so easy and convenient for the user. Maybe instead, for a few supported selected settings (e.g. Advantages:
Disadvantages:
|
|
(As initially discussed in #1120.)
How to handle pretraining? The current suggested APIs (
get_model
and co) might needs to be changed, because we do not want to callget_model
every epoch? How would the APIs look like?Or better: Leave the current functions, but have a separate function
get_stage(epoch, step) -> None|str|int
or so. When the stage changes, it signals to RETURNN to reconstruct the network.Or, maybe more explicit:
train_step_callback(model, epoch, step) -> StepCallbackReturn
,StepCallbackReturn
contains optionally a new model.The text was updated successfully, but these errors were encountered: