Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

[wip] make Float8Linear amax init more FSDP+compile friendly #171

Closed
wants to merge 1 commit into from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Dec 28, 2023

Summary:

Need to use functional collectives to help torch.compile trace through distributed code
(https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py)

Numerics are off, debugging

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:

Need to use functional collectives to help torch.compile trace
through distributed code
(https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py)

Numerics are off, debugging

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 28, 2023
facebook-github-bot pushed a commit that referenced this pull request Jan 3, 2024
Summary:
This adds a couple of config options to unbreak autocast + compile + FSDP + Float8Linear. To enable these options, the user needs to do:

```
config.enable_amax_init = False
config.enable_pre_and_post_forward = False
```

The `enable_amax_init` config adds the option to disable amax initialization. The reason this is currently broken is:
1. FSDP is not full-graph friendly (regardless of compile)
2. the amax init function has a graph break in distributed code because it uses inplace distributed collectives.  I did try to use functional collectives (#171), but that ran into numerical issues with compile, so for now just working around it.
3. graph breaks in Float8Linear code are not supported because of the issue documented in #166
4. so, as a workaround for all of the above, we just skip amax init for now.  We do know from NVIDIA that this path is not needed for model convergence, and TE does not support this at all. It was nice for testing but not necessary for training jobs.

The second config option disables pre-forward and post-forward. I don't have a repro in a unit test for now, but this does unbreak LLaMa 7B on 8 GPUs with FSDP + compile. Specifically, the thing which is broken in pre-forward/post-forward is assignment on module attributes. My hunch is that this graph breaks if autocast + FSDP are on, and graph breaks are not supported due to (3) above.

Pull Request resolved: #172

Test Plan:
```
// unit / integration tests
with-proxy test/test_everything.sh

// run the LLaMa 7b trainer on 8 GPUs with autocast + compile + FSDP + Float8Linear, no compile errors
```

Reviewed By: drisspg

Differential Revision: D52468625

Pulled By: vkuzo

fbshipit-source-id: be4fac927b8520602ed018e96d7a49056e9c6e06
@drisspg drisspg closed this Apr 3, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants