Skip to content

Conversation

gkroiz
Copy link
Contributor

@gkroiz gkroiz commented Oct 8, 2023

Implements Pytorch XLA FSDP for TPUs. This PR is based on #17421, but focuses only on Trainer related changes. Use the XLAFSDPStrategy to use FSDP on TPUs.

Fixes #13209

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Oct 8, 2023
@gkroiz
Copy link
Contributor Author

gkroiz commented Oct 8, 2023

@carmocca please take a look when you get the chance

@carmocca carmocca added this to the 2.1 milestone Oct 9, 2023
@carmocca carmocca self-assigned this Oct 9, 2023
@carmocca carmocca added feature Is an improvement or enhancement strategy: fsdp Fully Sharded Data Parallel strategy: xla labels Oct 9, 2023
@github-actions github-actions bot added the fabric lightning.fabric.Fabric label Oct 9, 2023
@carmocca carmocca requested a review from awaelchli October 11, 2023 13:28
@carmocca carmocca marked this pull request as ready for review October 11, 2023 13:44
gkroiz and others added 2 commits October 11, 2023 07:09
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
def lightning_module_state_dict(self) -> Dict[str, Any]:
assert self.model is not None
# this format is defined by:
# https://github.com/pytorch/xla/blob/v2.1.0/torch_xla/distributed/fsdp/state_dict_utils.py#L122-L125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will be the format after consolidationg the state? It would be a major caveat if you end up with a checkpoint that can't be loaded back into Lightning.

I strongly suggest we provide our own consolidate_sharded_model_checkpoints like for deepspeed, where we make it compatible with Lightning checkpoints. This way, this method here can return just self.model.state_dict() and the sharding data can be added to the checkpoint in save_checkpoint() as an additional entry. After consolidation, we would end up with a checkpoint that is in the standard format and can be loaded by any strategy or even without the Trainer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/pytorch/xla/blob/master/docs/fsdp.md?plain=1#L41
consolidate_sharded_model_checkpoints combines the model shards and each shard's metadata to create the full model state dict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO if we wanted to create a Lightning-specific consolidate_sharded_model_checkpoints, this should be done in a separate PR

converted = obj.state_dict()
# add shard_metadata to state
# add shard_metadata to state. this format is defined by
# https://github.com/pytorch/xla/blob/v2.1.0/torch_xla/distributed/fsdp/state_dict_utils.py#L122-L125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder what happens if you save a state where the key for the model is not "model"

fabric.save(path, {"banana": model})

This would be totally valid in any other setting, but I think here it would fail since the XLA format expects these keys.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc this won't work when trying to consolidate the checkpoint using consolidate_sharded_model_checkpoints https://github.com/Lightning-AI/lightning/pull/18746/files/e5b695aff64bc37ae1a67fba4aac4981200eecfd#diff-3908a573abf00ae5f37061f214f2a3c2616b6591e0c96206b9f48b4c7ab49ea4R457. I'm not sure how this works for individual shards.

# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we extend the tests in plugins/precision/test_xla.py in a meaningful way?


@property
def restore_checkpoint_after_setup(self) -> bool:
return self._state_dict_type == "sharded"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the "full" case here, we need that the user puts their entire model definition in init. So I suggest to put an error if configure_model is overridden and we're loading a checkpoint of type full state dict.

Please also note that the meaning of the state_dict_type argument falls out of sync with the regular FSDP strategy. There it is only used to define the type of checkpoint we're saving. This argument does NOT determine what type of checkpoint we're loading, because we're detecting that automatically.

So I strongly recommend to do

@property
def restore_checkpoint_after_setup(self) -> bool:
     return False

and forbid using configure_model() until XLA's FSDP has a streamlined support for loading checkpoints into a sharded model with meta device init support.

@codecov
Copy link

codecov bot commented Oct 11, 2023

Codecov Report

Merging #18746 (904b50d) into master (7434c47) will decrease coverage by 29%.
Report is 1 commits behind head on master.
The diff coverage is 43%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #18746      +/-   ##
==========================================
- Coverage      83%      54%     -29%     
==========================================
  Files         439      435       -4     
  Lines       34469    34671     +202     
==========================================
- Hits        28706    18687   -10019     
- Misses       5763    15984   +10221     

@Borda Borda modified the milestones: 2.1, 2.1.x Oct 12, 2023
Copy link

gitguardian bot commented Jan 16, 2024

⚠️ GitGuardian has uncovered 2 secrets following the scan of your pull request.

Please consider investigating the findings and remediating the incidents. Failure to do so may lead to compromising the associated services or software components.

🔎 Detected hardcoded secrets in your pull request
GitGuardian id Secret Commit Filename
- Generic High Entropy Secret 78fa3af tests/tests_app/utilities/test_login.py View secret
- Base64 Basic Authentication 78fa3af tests/tests_app/utilities/test_login.py View secret
🛠 Guidelines to remediate hardcoded secrets
  1. Understand the implications of revoking this secret by investigating where it is used in your code.
  2. Replace and store your secret safely. Learn here the best practices.
  3. Revoke and rotate this secret.
  4. If possible, rewrite git history. Rewriting git history is not a trivial act. You might completely break other contributing developers' workflow and you risk accidentally deleting legitimate data.

To avoid such incidents in the future consider


🦉 GitGuardian detects secrets in your source code to help developers and security teams secure the modern development process. You are seeing this because you or someone else with access to this repository has authorized GitGuardian to scan your pull request.

Our GitHub checks need improvements? Share your feedbacks!

@awaelchli awaelchli modified the milestones: 2.1.x, 2.2.x Feb 8, 2024
@carmocca carmocca removed their assignment May 6, 2024
@carmocca carmocca modified the milestones: 2.2.x, future May 6, 2024
@Borda Borda added the run TPU label Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

fabric lightning.fabric.Fabric feature Is an improvement or enhancement has conflicts pl Generic label for PyTorch Lightning package run TPU strategy: fsdp Fully Sharded Data Parallel strategy: xla

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature request: FSDP native strategy for TPUs

4 participants