Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support save&load of fsdp_optim_state #16

Open
wants to merge 48 commits into
base: main
Choose a base branch
from

Conversation

hanwen-sun
Copy link
Contributor

@hanwen-sun hanwen-sun commented Sep 3, 2024

What this pr do:

  1. suport flatten(including padding before shard) and unflatten full_optim_state_dic save and load and test with ut.
  2. support save and load of shard_optim_state_dict.

TODO:

  1. test the memory usage of checkpointing 70b model.
  2. shard_param_on_dim_0(?)

tests/distributed/test_fsdp_optim_state.py Outdated Show resolved Hide resolved
tests/distributed/test_fsdp_optim_state.py Outdated Show resolved Hide resolved
tests/distributed/test_fsdp_optim_state.py Outdated Show resolved Hide resolved
tests/distributed/test_fsdp_optim_state.py Outdated Show resolved Hide resolved
tests/distributed/test_fsdp_optim_state.py Outdated Show resolved Hide resolved
torchacc/dist/fsdp.py Show resolved Hide resolved
torchacc/dist/fsdp.py Outdated Show resolved Hide resolved
torchacc/dist/fsdp.py Outdated Show resolved Hide resolved
torchacc/dist/fsdp.py Show resolved Hide resolved
torchacc/utils/optim_utils.py Show resolved Hide resolved
torchacc/dist/fsdp.py Outdated Show resolved Hide resolved
torchacc/dist/fsdp.py Outdated Show resolved Hide resolved
torchacc/dist/fsdp.py Outdated Show resolved Hide resolved
torchacc/dist/fsdp.py Outdated Show resolved Hide resolved
torchacc/dist/fsdp.py Show resolved Hide resolved
torchacc/dist/fsdp.py Outdated Show resolved Hide resolved
torchacc/dist/fsdp.py Outdated Show resolved Hide resolved
torchacc/dist/fsdp.py Outdated Show resolved Hide resolved
tests/distributed/test_fsdp_optim_state.py Outdated Show resolved Hide resolved
torchacc/dist/fsdp.py Show resolved Hide resolved
torchacc/utils/optim_utils.py Outdated Show resolved Hide resolved
torchacc/utils/optim_utils.py Outdated Show resolved Hide resolved
torchacc/utils/optim_utils.py Outdated Show resolved Hide resolved
torchacc/dist/fsdp.py Outdated Show resolved Hide resolved
torchacc/utils/optim_utils.py Outdated Show resolved Hide resolved
We judge whether the optim_state_dict is sharded automatically

Args:
optim_state_dict (Dict[str, Any]): The optimizer states to be loaded.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some comments to explain what optim_state_dict are passed by other ranks when rank0_only is True.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no explanation here about how the optim_state_dict parameters for other ranks are passed after the optimizer state is loaded in rank 0. We can let users pass None instead of an empty dictionary (since users may not know what the keys are).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add the explanation that we will broadcast rank0's optim_state_dict info to other ranks if specified rank0_only, so it's no matter what other rank's pass in and of course other ranks can pass None.

torchacc/dist/fsdp.py Outdated Show resolved Hide resolved
tests/distributed/test_fsdp_optim_state.py Outdated Show resolved Hide resolved
tests/distributed/test_fsdp_optim_state.py Outdated Show resolved Hide resolved
torchacc/dist/fsdp.py Show resolved Hide resolved
We judge whether the optim_state_dict is sharded automatically

Args:
optim_state_dict (Dict[str, Any]): The optimizer states to be loaded.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self.model which is sharded.
"""
# for sharded optim_state, we return directly
if 'shard_metadata' in optim_state_dict.keys():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I check the world_size between the stored shard_metadata and current shard_metadata and raise a NotImplementedError(is it suitable?) error.

torchacc/dist/distributed_parallel.py Show resolved Hide resolved
torchacc/utils/optim_utils.py Show resolved Hide resolved
torchacc/utils/optim_utils.py Outdated Show resolved Hide resolved
torchacc/utils/optim_utils.py Outdated Show resolved Hide resolved
torchacc/utils/optim_utils.py Outdated Show resolved Hide resolved
torchacc/utils/optim_utils.py Outdated Show resolved Hide resolved
torchacc/utils/optim_utils.py Outdated Show resolved Hide resolved
@yitongh
Copy link
Contributor

yitongh commented Sep 14, 2024

Overall looks good to me. Pass off to @anw90 for final review.

@yitongh yitongh requested a review from anw90 September 14, 2024 09:57
@hanwen-sun hanwen-sun changed the title support fsdp_optim_state support save&load of fsdp_optim_state Sep 20, 2024
return model, optim


def _train_step(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it’s better to replace _train_step with _train because _train_step might suggest that it only includes a single training step.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

labels = torch.zeros(batch_size, dtype=torch.int64).to(device)
loss = model(data)
loss = torch.nn.functional.nll_loss(loss, labels)
loss.backward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to call loss.backward here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need, but there's no difference whether we do forward only or do forward and backward in the test case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants