This repository was archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 19
enable autocast + compile + FSDP + Float8Linear #172
Closed
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4fa2654
to
e3ab9c9
Compare
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
vkuzo
commented
Dec 31, 2023
@@ -335,7 +342,7 @@ def forward(self, x): | |||
y = self.cast_y_to_float8_in_bw(y, self.emulate) | |||
|
|||
if self.bias is not None: | |||
y = y + self.bias.to(self.bias_dtype) | |||
y = y + self.bias.to(y.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this change is not called out in the configs, but this just removes the need to store a module attribute, which also makes this code more full-graph compile friendly. Modifying module attributes seems to graph break if autocast and FSDP are both enabled.
drisspg
approved these changes
Jan 2, 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay this makes sense
Summary: This adds a couple of config options to unbreak autocast + compile + FSDP + Float8Linear. To enable these options, the user needs to do: ``` config.enable_amax_init = False config.enable_pre_and_post_forward = False ``` The `enable_amax_init` config adds the option to disable amax initialization. The reason this is currently broken is: 1. FSDP is not full-graph friendly (regardless of compile) 2. the amax init function has a graph break in distributed code because it uses inplace distributed collectives. I did try to use functional collectives, but that ran into numerical issues with compile, so for now just working around it. 3. graph breaks in Float8Linear code are not supported because of the issue documented in #166 4. so, as a workaround for all of the above, we just skip amax init for now. We do know from NVIDIA that this path is not needed for model convergence, and TE does not support this at all. It was nice for testing but not necessary for training jobs. The second config option disables pre-forward and post-forward. I don't have a repro in a unit test for now, but this does unbreak LLaMa 7B on 8 GPUs with FSDP + compile. Specifically, the thing which is broken in pre-forward/post-forward is assignment on module attributes. My hunch is that this graph breaks if autocast + FSDP are on, and graph breaks are not supported due to (3) above. Test Plan: ``` // unit / integration tests with-proxy test/test_everything.sh // run the LLaMa 7b trainer on 8 GPUs with autocast + compile + FSDP + Float8Linear, no compile errors ``` Reviewers: Subscribers: Tasks: Tags:
e3ab9c9
to
2d97dde
Compare
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Labels
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
This adds a couple of config options to unbreak autocast + compile + FSDP + Float8Linear. To enable these options, the user needs to do:
The
enable_amax_init
config adds the option to disable amax initialization. The reason this is currently broken is:The second config option disables pre-forward and post-forward. I don't have a repro in a unit test for now, but this does unbreak LLaMa 7B on 8 GPUs with FSDP + compile. Specifically, the thing which is broken in pre-forward/post-forward is assignment on module attributes. My hunch is that this graph breaks if autocast + FSDP are on, and graph breaks are not supported due to (3) above.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: