Skip to content

float8 rowwise training: add FSDP workaround #1629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 31, 2025
Merged

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jan 27, 2025

Summary:

Adds the workaround from
pytorch/pytorch#141881 to the torchao float8 rowwise recipe, to reduce memory usage when FSDP is on.

Test Plan: tested in torchtitan, LLaMa 3 8B 8H100 training with rowwise peak memory decreased from 67GiB to 59GiB

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:

Adds the workaround from
pytorch/pytorch#141881 to the torchao float8
rowwise recipe, to reduce memory usage when FSDP is on.

Test Plan: tested in torchtitan, LLaMa 3 8B 8H100 training with rowwise
peak memory decreased from 67GiB to 59GiB

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link

pytorch-bot bot commented Jan 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1629

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 066f889 with merge base 47f96f1 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 27, 2025
@vkuzo vkuzo added topic: performance Use this tag if this PR improves the performance of a feature and removed CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. labels Jan 27, 2025
vkuzo added a commit to pytorch/torchtitan that referenced this pull request Jan 27, 2025
Summary:

This is an example of how to call float8 training with rowwise scaling
from torchao.

TODO: finalize API in torchao, and finalize how we want to expose it in
torchtitan, and optimize performance.

```
// baseline (bf16 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile
...
step: 20  loss:  8.4931  memory: 47.65GiB(50.16%)  tps: 5,760  mfu: 33.73%

// experiment (rowwise float8 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
...
// torchao main branch
step: 40  loss:  7.3818  memory: 66.81GiB(70.33%)  tps: 6,412  mfu: 37.55%
// torchao with pytorch/ao#1629
step: 20  loss:  8.3823  memory: 58.55GiB(61.63%)  tps: 6,424  mfu: 37.62%

// for comparison, tensorwise float8 with float8 all-gather (on main branch)
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
...
step: 20  loss:  8.4258  memory: 47.32GiB(49.81%)  tps: 7,186  mfu: 42.08%

```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 27, 2025
@vkuzo vkuzo merged commit 3eb18e7 into main Jan 31, 2025
23 checks passed
vkuzo added a commit to pytorch/torchtitan that referenced this pull request Feb 7, 2025
Summary:

This is an example of how to call float8 training with rowwise scaling
from torchao.

TODO: finalize API in torchao, and finalize how we want to expose it in
torchtitan, and optimize performance.

```
// baseline (bf16 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile
...
step: 20  loss:  8.4931  memory: 47.65GiB(50.16%)  tps: 5,760  mfu: 33.73%

// experiment (rowwise float8 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
...
// torchao main branch
step: 40  loss:  7.3818  memory: 66.81GiB(70.33%)  tps: 6,412  mfu: 37.55%
// torchao with pytorch/ao#1629
step: 20  loss:  8.3823  memory: 58.55GiB(61.63%)  tps: 6,424  mfu: 37.62%

// for comparison, tensorwise float8 with float8 all-gather (on main branch)
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
...
step: 20  loss:  8.4258  memory: 47.32GiB(49.81%)  tps: 7,186  mfu: 42.08%

```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
vkuzo added a commit to pytorch/torchtitan that referenced this pull request Feb 7, 2025
Summary:

This is an example of how to call float8 training with rowwise scaling
from torchao.

TODO: finalize API in torchao, and finalize how we want to expose it in
torchtitan, and optimize performance.

```
// baseline (bf16 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile
...
step: 20  loss:  8.4931  memory: 47.65GiB(50.16%)  tps: 5,760  mfu: 33.73%

// experiment (rowwise float8 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
...
// torchao main branch
step: 40  loss:  7.3818  memory: 66.81GiB(70.33%)  tps: 6,412  mfu: 37.55%
// torchao with pytorch/ao#1629
step: 20  loss:  8.3823  memory: 58.55GiB(61.63%)  tps: 6,424  mfu: 37.62%

// for comparison, tensorwise float8 with float8 all-gather (on main branch)
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
...
step: 20  loss:  8.4258  memory: 47.32GiB(49.81%)  tps: 7,186  mfu: 42.08%

```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
vkuzo added a commit to pytorch/torchtitan that referenced this pull request Feb 16, 2025
Summary:

This is an example of how to call float8 training with rowwise scaling
from torchao.

TODO: finalize API in torchao, and finalize how we want to expose it in
torchtitan, and optimize performance.

```
// baseline (bf16 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile
...
step: 20  loss:  8.4931  memory: 47.65GiB(50.16%)  tps: 5,760  mfu: 33.73%

// experiment (rowwise float8 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
...
// torchao main branch
step: 40  loss:  7.3818  memory: 66.81GiB(70.33%)  tps: 6,412  mfu: 37.55%
// torchao with pytorch/ao#1629
step: 20  loss:  8.3823  memory: 58.55GiB(61.63%)  tps: 6,424  mfu: 37.62%

// for comparison, tensorwise float8 with float8 all-gather (on main branch)
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
...
step: 20  loss:  8.4258  memory: 47.32GiB(49.81%)  tps: 7,186  mfu: 42.08%

```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
vkuzo added a commit to pytorch/torchtitan that referenced this pull request Feb 20, 2025
Summary:

This is an example of how to call float8 training with rowwise scaling
from torchao.

TODO: finalize API in torchao, and finalize how we want to expose it in
torchtitan, and optimize performance.

```
// baseline (bf16 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile
...
step: 20  loss:  8.4931  memory: 47.65GiB(50.16%)  tps: 5,760  mfu: 33.73%

// experiment (rowwise float8 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
...
// torchao main branch
step: 40  loss:  7.3818  memory: 66.81GiB(70.33%)  tps: 6,412  mfu: 37.55%
// torchao with pytorch/ao#1629
step: 20  loss:  8.3823  memory: 58.55GiB(61.63%)  tps: 6,424  mfu: 37.62%

// for comparison, tensorwise float8 with float8 all-gather (on main branch)
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
...
step: 20  loss:  8.4258  memory: 47.32GiB(49.81%)  tps: 7,186  mfu: 42.08%

```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
vkuzo added a commit to pytorch/torchtitan that referenced this pull request Feb 26, 2025
Summary:

This is an example of how to call float8 training with rowwise scaling
from torchao.

TODO: finalize API in torchao, and finalize how we want to expose it in
torchtitan, and optimize performance.

```
// baseline (bf16 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile
...
step: 20  loss:  8.4931  memory: 47.65GiB(50.16%)  tps: 5,760  mfu: 33.73%

// experiment (rowwise float8 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
...
// torchao main branch
step: 40  loss:  7.3818  memory: 66.81GiB(70.33%)  tps: 6,412  mfu: 37.55%
// torchao with pytorch/ao#1629
step: 20  loss:  8.3823  memory: 58.55GiB(61.63%)  tps: 6,424  mfu: 37.62%

// for comparison, tensorwise float8 with float8 all-gather (on main branch)
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
...
step: 20  loss:  8.4258  memory: 47.32GiB(49.81%)  tps: 7,186  mfu: 42.08%

```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
vkuzo added a commit to pytorch/torchtitan that referenced this pull request Feb 26, 2025
Summary:

This is an example of how to call float8 training with rowwise scaling
from torchao.

TODO: finalize API in torchao, and finalize how we want to expose it in
torchtitan, and optimize performance.

```
// baseline (bf16 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile
...
step: 20  loss:  8.4931  memory: 47.65GiB(50.16%)  tps: 5,760  mfu: 33.73%

// experiment (rowwise float8 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
...
// torchao main branch
step: 40  loss:  7.3818  memory: 66.81GiB(70.33%)  tps: 6,412  mfu: 37.55%
// torchao with pytorch/ao#1629
step: 20  loss:  8.3823  memory: 58.55GiB(61.63%)  tps: 6,424  mfu: 37.62%

// for comparison, tensorwise float8 with float8 all-gather (on main branch)
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
...
step: 20  loss:  8.4258  memory: 47.32GiB(49.81%)  tps: 7,186  mfu: 42.08%

```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
vkuzo added a commit to pytorch/torchtitan that referenced this pull request Feb 27, 2025
Summary:

This is an example of how to call float8 training with rowwise scaling
from torchao.

TODO: finalize API in torchao, and finalize how we want to expose it in
torchtitan, and optimize performance.

```
// baseline (bf16 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile
...
step: 20  loss:  8.4931  memory: 47.65GiB(50.16%)  tps: 5,760  mfu: 33.73%

// experiment (rowwise float8 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
...
// torchao main branch
step: 40  loss:  7.3818  memory: 66.81GiB(70.33%)  tps: 6,412  mfu: 37.55%
// torchao with pytorch/ao#1629
step: 20  loss:  8.3823  memory: 58.55GiB(61.63%)  tps: 6,424  mfu: 37.62%

// for comparison, tensorwise float8 with float8 all-gather (on main branch)
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
...
step: 20  loss:  8.4258  memory: 47.32GiB(49.81%)  tps: 7,186  mfu: 42.08%

```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
vkuzo added a commit to pytorch/torchtitan that referenced this pull request Feb 27, 2025
Summary:

This is an example of how to call float8 training with rowwise scaling
from torchao.

TODO: finalize API in torchao, and finalize how we want to expose it in
torchtitan, and optimize performance.

```
// baseline (bf16 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile
...
step: 20  loss:  8.4931  memory: 47.65GiB(50.16%)  tps: 5,760  mfu: 33.73%

// experiment (rowwise float8 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
...
// torchao main branch
step: 40  loss:  7.3818  memory: 66.81GiB(70.33%)  tps: 6,412  mfu: 37.55%
// torchao with pytorch/ao#1629
step: 20  loss:  8.3823  memory: 58.55GiB(61.63%)  tps: 6,424  mfu: 37.62%

// for comparison, tensorwise float8 with float8 all-gather (on main branch)
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
...
step: 20  loss:  8.4258  memory: 47.32GiB(49.81%)  tps: 7,186  mfu: 42.08%

```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: performance Use this tag if this PR improves the performance of a feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants