-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support save&load of fsdp_optim_state #16
base: main
Are you sure you want to change the base?
Conversation
We judge whether the optim_state_dict is sharded automatically | ||
|
||
Args: | ||
optim_state_dict (Dict[str, Any]): The optimizer states to be loaded. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add some comments to explain what optim_state_dict are passed by other ranks when rank0_only is True.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no explanation here about how the optim_state_dict
parameters for other ranks are passed after the optimizer state is loaded in rank 0. We can let users pass None
instead of an empty dictionary (since users may not know what the keys are).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I add the explanation that we will broadcast rank0's optim_state_dict info to other ranks if specified rank0_only, so it's no matter what other rank's pass in and of course other ranks can pass None.
We judge whether the optim_state_dict is sharded automatically | ||
|
||
Args: | ||
optim_state_dict (Dict[str, Any]): The optimizer states to be loaded. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
torchacc/dist/fsdp.py
Outdated
self.model which is sharded. | ||
""" | ||
# for sharded optim_state, we return directly | ||
if 'shard_metadata' in optim_state_dict.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I check the world_size between the stored shard_metadata and current shard_metadata and raise a NotImplementedError(is it suitable?) error.
Overall looks good to me. Pass off to @anw90 for final review. |
return model, optim | ||
|
||
|
||
def _train_step( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it’s better to replace _train_step with _train because _train_step might suggest that it only includes a single training step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
labels = torch.zeros(batch_size, dtype=torch.int64).to(device) | ||
loss = model(data) | ||
loss = torch.nn.functional.nll_loss(loss, labels) | ||
loss.backward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to call loss.backward here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need, but there's no difference whether we do forward only or do forward and backward in the test case.
What this pr do:
TODO: