-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[TPU] Add Trainer support for PyTorch XLA FSDP #18746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
@carmocca please take a look when you get the chance |
for more information, see https://pre-commit.ci
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
def lightning_module_state_dict(self) -> Dict[str, Any]: | ||
assert self.model is not None | ||
# this format is defined by: | ||
# https://github.com/pytorch/xla/blob/v2.1.0/torch_xla/distributed/fsdp/state_dict_utils.py#L122-L125 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What will be the format after consolidationg the state? It would be a major caveat if you end up with a checkpoint that can't be loaded back into Lightning.
I strongly suggest we provide our own consolidate_sharded_model_checkpoints
like for deepspeed, where we make it compatible with Lightning checkpoints. This way, this method here can return just self.model.state_dict()
and the sharding data can be added to the checkpoint in save_checkpoint()
as an additional entry. After consolidation, we would end up with a checkpoint that is in the standard format and can be loaded by any strategy or even without the Trainer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/pytorch/xla/blob/master/docs/fsdp.md?plain=1#L41
consolidate_sharded_model_checkpoints
combines the model shards and each shard's metadata to create the full model state dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO if we wanted to create a Lightning-specific consolidate_sharded_model_checkpoints
, this should be done in a separate PR
src/lightning/pytorch/trainer/connectors/accelerator_connector.py
Outdated
Show resolved
Hide resolved
converted = obj.state_dict() | ||
# add shard_metadata to state | ||
# add shard_metadata to state. this format is defined by | ||
# https://github.com/pytorch/xla/blob/v2.1.0/torch_xla/distributed/fsdp/state_dict_utils.py#L122-L125 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder what happens if you save a state where the key for the model is not "model"
fabric.save(path, {"banana": model})
This would be totally valid in any other setting, but I think here it would fail since the XLA format expects these keys.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iirc this won't work when trying to consolidate the checkpoint using consolidate_sharded_model_checkpoints
https://github.com/Lightning-AI/lightning/pull/18746/files/e5b695aff64bc37ae1a67fba4aac4981200eecfd#diff-3908a573abf00ae5f37061f214f2a3c2616b6591e0c96206b9f48b4c7ab49ea4R457. I'm not sure how this works for individual shards.
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we extend the tests in plugins/precision/test_xla.py in a meaningful way?
|
||
@property | ||
def restore_checkpoint_after_setup(self) -> bool: | ||
return self._state_dict_type == "sharded" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the "full" case here, we need that the user puts their entire model definition in init. So I suggest to put an error if configure_model is overridden and we're loading a checkpoint of type full state dict.
Please also note that the meaning of the state_dict_type
argument falls out of sync with the regular FSDP strategy. There it is only used to define the type of checkpoint we're saving. This argument does NOT determine what type of checkpoint we're loading, because we're detecting that automatically.
So I strongly recommend to do
@property
def restore_checkpoint_after_setup(self) -> bool:
return False
and forbid using configure_model()
until XLA's FSDP has a streamlined support for loading checkpoints into a sharded model with meta device init support.
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #18746 +/- ##
==========================================
- Coverage 83% 54% -29%
==========================================
Files 439 435 -4
Lines 34469 34671 +202
==========================================
- Hits 28706 18687 -10019
- Misses 5763 15984 +10221 |
|
GitGuardian id | Secret | Commit | Filename | |
---|---|---|---|---|
- | Generic High Entropy Secret | 78fa3af | tests/tests_app/utilities/test_login.py | View secret |
- | Base64 Basic Authentication | 78fa3af | tests/tests_app/utilities/test_login.py | View secret |
🛠 Guidelines to remediate hardcoded secrets
- Understand the implications of revoking this secret by investigating where it is used in your code.
- Replace and store your secret safely. Learn here the best practices.
- Revoke and rotate this secret.
- If possible, rewrite git history. Rewriting git history is not a trivial act. You might completely break other contributing developers' workflow and you risk accidentally deleting legitimate data.
To avoid such incidents in the future consider
- following these best practices for managing and storing secrets including API keys and other credentials
- install secret detection on pre-commit to catch secret before it leaves your machine and ease remediation.
🦉 GitGuardian detects secrets in your source code to help developers and security teams secure the modern development process. You are seeing this because you or someone else with access to this repository has authorized GitGuardian to scan your pull request.
Our GitHub checks need improvements? Share your feedbacks!
Implements Pytorch XLA FSDP for TPUs. This PR is based on #17421, but focuses only on Trainer related changes. Use the
XLAFSDPStrategy
to use FSDP on TPUs.Fixes #13209