Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

enable autocast + compile + FSDP + Float8Linear #172

Closed
wants to merge 1 commit into from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Dec 31, 2023

Summary:

This adds a couple of config options to unbreak autocast + compile + FSDP + Float8Linear. To enable these options, the user needs to do:

config.enable_amax_init = False
config.enable_pre_and_post_forward = False

The enable_amax_init config adds the option to disable amax initialization. The reason this is currently broken is:

  1. FSDP is not full-graph friendly (regardless of compile)
  2. the amax init function has a graph break in distributed code because it uses inplace distributed collectives. I did try to use functional collectives ([wip] make Float8Linear amax init more FSDP+compile friendly #171), but that ran into numerical issues with compile, so for now just working around it.
  3. graph breaks in Float8Linear code are not supported because of the issue documented in [wip] enable Float8Tensor as subgraph boundary #166
  4. so, as a workaround for all of the above, we just skip amax init for now. We do know from NVIDIA that this path is not needed for model convergence, and TE does not support this at all. It was nice for testing but not necessary for training jobs.

The second config option disables pre-forward and post-forward. I don't have a repro in a unit test for now, but this does unbreak LLaMa 7B on 8 GPUs with FSDP + compile. Specifically, the thing which is broken in pre-forward/post-forward is assignment on module attributes. My hunch is that this graph breaks if autocast + FSDP are on, and graph breaks are not supported due to (3) above.

Test Plan:

// unit / integration tests
with-proxy test/test_everything.sh

// run the LLaMa 7b trainer on 8 GPUs with autocast + compile + FSDP + Float8Linear, no compile errors

Reviewers:

Subscribers:

Tasks:

Tags:

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 31, 2023
@vkuzo vkuzo force-pushed the 20231229_fsdp_autocast_compile_test branch from 4fa2654 to e3ab9c9 Compare December 31, 2023 19:41
@facebook-github-bot
Copy link
Contributor

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@@ -335,7 +342,7 @@ def forward(self, x):
y = self.cast_y_to_float8_in_bw(y, self.emulate)

if self.bias is not None:
y = y + self.bias.to(self.bias_dtype)
y = y + self.bias.to(y.dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is not called out in the configs, but this just removes the need to store a module attribute, which also makes this code more full-graph compile friendly. Modifying module attributes seems to graph break if autocast and FSDP are both enabled.

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay this makes sense

Summary:

This adds a couple of config options to unbreak autocast + compile +
FSDP + Float8Linear. To enable these options, the user needs to do:

```
config.enable_amax_init = False
config.enable_pre_and_post_forward = False
```

The `enable_amax_init` config adds the option to disable amax initialization.
The reason this is currently broken is:
1. FSDP is not full-graph friendly (regardless of compile)
2. the amax init function has a graph break in distributed code because
   it uses inplace distributed collectives.  I did try to use functional
   collectives, but that ran into numerical issues with compile, so for
   now just working around it.
3. graph breaks in Float8Linear code are not supported because of the
   issue documented in
   #166
4. so, as a workaround for all of the above, we just skip amax init for
   now.  We do know from NVIDIA that this path is not needed for model
   convergence, and TE does not support this at all. It was nice for
   testing but not necessary for training jobs.

The second config option disables pre-forward and post-forward. I don't
have a repro in a unit test for now, but this does unbreak LLaMa 7B on 8
GPUs with FSDP + compile. Specifically, the thing which is broken in
pre-forward/post-forward is assignment on module attributes. My hunch is
that this graph breaks if autocast + FSDP are on, and graph breaks are
not supported due to (3) above.

Test Plan:

```
// unit / integration tests
with-proxy test/test_everything.sh

// run the LLaMa 7b trainer on 8 GPUs with autocast + compile + FSDP + Float8Linear, no compile errors
```

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo force-pushed the 20231229_fsdp_autocast_compile_test branch from e3ab9c9 to 2d97dde Compare January 2, 2024 23:09
@facebook-github-bot
Copy link
Contributor

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@vkuzo merged this pull request in 120e752.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants