You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
Summary:
This adds two changes:
1. `Float8LinearMixin` saves a constructor for FSDP to use to construct the `Float8Tensor` for `w_fp8`. This is needed for FSDP to manage the unsharded gradient and since FSDP prefers to own the underlying data/storage for all-gather.
2. `Float8Linear.forward()` checks if `self._w_fp8` has been set (by FSDP) and skips the cast itself if so.
I have tested this with P865757339 (not cleaned up), but I do not think we need to land yet. (This does mean there is a chance for changes to this repo to break FSDP fp8 all-gather, but I think it is fine for now.)
Pull Request resolved: #130
Reviewed By: awgu
Differential Revision: D50754666
Pulled By: drisspg
fbshipit-source-id: 9f7a9bfc9f2b3cb7455cd8b8642f0c4e4a55ee64
0 commit comments