Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

float8 training axiswise scaling support with per-gemm-argument configuration #940

Merged
merged 51 commits into from
Oct 7, 2024

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Sep 24, 2024

Summary:

This PR finalizes the UX for axiswise scaling for float8 training, and introduces per-gemm-argument configurability to Float8Linear to enable exploration of future recipes. Not all combinations are supported yet. Specifically, the additional combination we now support and test is a recipe from @lw , where we do the following:

output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
grad_weight_hp = input_t_hp @ grad_output_hp

Key characteristics of this recipe:

  1. increased accuracy for grad_weight, which is important for real workloads
  2. output and weight now only need to be scaled axiswise across a single dim compared to vanilla all-axiswise, which is more amenable to fast kernels

Here is how a user can configure this:

#
# short form
#

config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)

#
# or, long form
#

# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w_go = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)

# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED)

# ensure fast_accum is on to get fast kernels
gc_o = Float8GemmConfig(use_fast_accum=True)
gc_gi = Float8GemmConfig(use_fast_accum=True)
gc_gw = Float8GemmConfig(use_fast_accum=True)

config = Float8Config(
    cast_config_input = cc_i,
    cast_config_weight = cc_w,
    cast_config_grad_output = cc_go,
    cast_config_input_for_grad_weight = cc_i_gw,
    cast_config_weight_for_grad_output = cc_w_go,
    cast_config_grad_output_for_grad_weight = cc_go_gw,
    gemm_config_output=gc_o,
    gemm_config_grad_input=gc_gi,
    gemm_config_grad_weight=gc_gw,
)

performance

Below we provide basic performance characteristics of axiswise scaling in general, and the all-axiswise and lw recipes.

gemm performance of torch._scaled_mm

baseline: tensorwise scaling

> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True
    fast_accum  name      M      K      N  ref_time_s  fp8_time_s  fp8_speedup
0         True     0    256    256    256    0.000004    0.000006     0.573115
1         True     1    512    512    512    0.000005    0.000007     0.659333
2         True     2   1024   1024   1024    0.000011    0.000010     1.080664
3         True     3   2048   2048   2048    0.000028    0.000017     1.596239
4         True     4   4096   4096   4096    0.000210    0.000082     2.551705
5         True     5   8192   8192   8192    0.001671    0.000680     2.457972
6         True     6  16384  16384  16384    0.015030    0.006498     2.313032
7         True     7  32768  32768  32768    0.103236    0.048097     2.146411
8        False     0    256    256    256    0.000004    0.000006     0.630061
9        False     1    512    512    512    0.000005    0.000007     0.767236
10       False     2   1024   1024   1024    0.000012    0.000008     1.391347
11       False     3   2048   2048   2048    0.000029    0.000020     1.457922
12       False     4   4096   4096   4096    0.000211    0.000101     2.100081
13       False     5   8192   8192   8192    0.001676    0.000788     2.128628
14       False     6  16384  16384  16384    0.014933    0.006351     2.351209
15       False     7  32768  32768  32768    0.103457    0.049498     2.090134                

experiment: axiswise-scaling

> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True --scaling_granularity axiswise

    fast_accum  name      M      K      N  ref_time_s  fp8_time_s  fp8_speedup
0         True     0    256    256    256    0.000004    0.000004     0.966772
1         True     1    512    512    512    0.000005    0.000004     1.095791
2         True     2   1024   1024   1024    0.000011    0.000006     1.988363
3         True     3   2048   2048   2048    0.000027    0.000015     1.890065
4         True     4   4096   4096   4096    0.000210    0.000082     2.552356
5         True     5   8192   8192   8192    0.001674    0.001092     1.533132
6         True     6  16384  16384  16384    0.015114    0.008785     1.720480
7         True     7  32768  32768  32768    0.103286    0.071456     1.445439
8        False     0    256    256    256    0.000004    0.000004     0.899054
9        False     1    512    512    512    0.000005    0.000005     1.005340
10       False     2   1024   1024   1024    0.000011    0.000006     1.692868
11       False     3   2048   2048   2048    0.000028    0.000049     0.567655
12       False     4   4096   4096   4096    0.000210    0.000341     0.616193
13       False     5   8192   8192   8192    0.001678    0.002640     0.635541
14       False     6  16384  16384  16384    0.015051    0.021557     0.698212
15       False     7  32768  32768  32768    0.103497    0.169797     0.609533

performance on microbenchmark of ln -> linear -> sigmoid

Note: for large square shapes, performance tends to be fp8_delayed_tensorwise > fp8_dynamic_tensorwise > fp8_dynamic_axiswise > custom_recipe. For performance of fp8_dynamic_axiswise, it seems that the gap from tensorwise is mostly due to the gemm performance being behind tensorwise.

> python benchmarks/float8/float8_roofline.py ~/local/tmp/20241004_roofline.csv

   fwd_M  fwd_K  fwd_N  bf16_gemm_s  fp8_gemm_s  fp8_axs_gemm_time_s      fp8_oh_dyn_limit  ... fp8_del_s fp8_dyn_axs_s  fp8_lw_s  fp8_dyn_sp  fp8_del_sp  fp8_dyn_axs_sp  fp8_lw_sp
0    256    256    256     0.000011    0.000018             0.000012   6.50457971014493e-6  ...  0.000043      0.000049  0.000030    0.465634    0.457907        0.398357   0.643088
1    512    512    512     0.000014    0.000020             0.000013   8.01831884057971e-6  ...  0.000047      0.000054  0.000034    0.489556    0.493467        0.432643   0.685842
2   1024   1024   1024     0.000033    0.000026             0.000017   1.40732753623188e-5  ...  0.000060      0.000063  0.000050    0.734123    0.741467        0.705941   0.891199
3   2048   2048   2048     0.000081    0.000055             0.000044   3.82931014492754e-5  ...  0.000147      0.000159  0.000142    0.815678    0.800811        0.739865   0.827441
4   4096   4096   4096     0.000632    0.000274             0.000247  0.000135172405797101  ...  0.000602      0.000622  0.000662    1.236320    1.261848        1.221755   1.147678
5   8192   8192   8192     0.005027    0.002216             0.003292  0.000522689623188406  ...  0.003665      0.004776  0.005720    1.432213    1.513035        1.161130   0.969448
6  16384  16384  16384     0.045113    0.018975             0.025706   0.00207275849275362  ...  0.024664      0.032254  0.038051    1.803456    1.883291        1.440118   1.220738
7  32768  32768  32768     0.312459    0.147255             0.214492   0.00827303397101449  ...  0.182645      0.240962  0.270973    1.696376    1.766307        1.338827   1.190552

performance on torchtitan LLaMa 3 8B on 8 H100 GPUs, float8 compute only:

  • baseline (bf16 + compile): 6,294 wps
  • f8 all-tensorwise: 7,359 wps (1.17x vs baseline)
  • f8 all-axiswise: 7,135 wps (1.13x vs baseline - surprising that this is close to all-tensorwise)
  • LW_AXISWISE_WITH_GW_HP: 6,506 wps (1.03x vs baseline)

so, looks like we have performance work to do with LW_AXISWISE_WITH_GW_HP in future PRs

accuracy

I did a very quick check that loss curves on torchtitan LLaMa 3 8B pretraining with 8 H100 GPUs look good for bf16/f8_tensorwise/f8_axiswise/f8_lw on 0.5k iterations. I will leave longer accuracy verifications for future work.

Screenshot 2024-10-04 at 10 05 24 PM

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@vkuzo
Copy link
Contributor Author

vkuzo commented Sep 24, 2024

Copy link

pytorch-bot bot commented Sep 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/940

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b536435 with merge base e76db70 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 24, 2024
vkuzo added a commit that referenced this pull request Sep 24, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: de754d285a15ebc7ab7f4e963a93a8696403d70e
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 25, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: bf83d2e1d0b881777f0b6b6a457d735439a13bcc
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 25, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 0cfb3bb31bf908256b254e00317c7368434abea0
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 27, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 3eaa2df5258fc3795eaf9a86c892248889d06cb2
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 27, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c13c0ee8f000127a3beec28c90ad0467f8774a29
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 27, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 3dcb57ff71343b09b5ab6d4bc70b766c35854911
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 27, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 4b97519d234404f72ed56f2608417ff9ca43edf9
ghstack-comment-id: 2372563439
Pull Request resolved: #940
@vkuzo vkuzo changed the title [wip] make scaling configurable by gemm-argument make float8 scaling configurable by gemm-argument Sep 27, 2024
Copy link

@lw lw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking care of this! LGTM!

# If True, this tensor is not scaled to float8 and left in its original
# precision.
# TODO(ideally before this PR lands): a better name for this
keep_in_original_precision: bool = False
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternative: add a ScalingType.DISABLED option?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO if we are configuring scaling via N different knobs (type, granularity, scale, etc) and we have "disabled" for one of the knobs, we also need "disabled" for all the other knobs to avoid inconsistencies. Having the "disable" at the level of CastConfig seems like a cleaner way to configure this, although I don't love the name.

How about this:

Float8LinearConfig(...):
    # if cast config is specified, use it to scale/cast
    # if cast config is None, leave the tensor in original precision
    cast_config_input: Optional[CastConfig]
    ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Float8LinearConfig(...): # if cast config is specified, use it to scale/cast # if cast config is None, leave the tensor in original precision cast_config_input: Optional[CastConfig] ...

Realized that this conflicts with the current behavior of using None for the per-gemm-input override of input|weight|grad_output to mean "use the original config". Going to take your suggestion instead, the inconsistency is minor.

torchao/float8/config.py Outdated Show resolved Hide resolved
class _Float8LinearRecipeName(enum.Enum):
ALL_TENSORWISE = "all_tensorwise"
ALL_AXISWISE = "all_axiswise"
LW_AXISWISE_WITH_GW_HP = "lw_axiswise_with_gw_hp"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol, are you really gonna name this recipe after me? :P

torchao/float8/config.py Outdated Show resolved Hide resolved
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 4, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 71d847f5c6666ff307062f8cea25c50c398a1774
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 4, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: ce6eb7d45c669d5e1555f5cb424c54f3ef5bc534
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 4, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 816eaa1b75bc30ca6f859d70388d51635971e3f7
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 4, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 74601dd2d043dfcaab792f1d40c071959675e27e
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 4, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: a565ff43fab0421406b5634c8946136bea58ca2c
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 4, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c89617546866ee882231f66cc086bf8ae2b8951d
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 4, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: abb4fcbd1c41fafb29cc40ab7cfa296eaec60972
ghstack-comment-id: 2372563439
Pull Request resolved: #940
@vkuzo vkuzo changed the title make float8 scaling configurable by gemm-argument official axiswise scaling support with per-gemm-argument configuration Oct 5, 2024
@vkuzo vkuzo changed the title official axiswise scaling support with per-gemm-argument configuration float8 training axiswise scaling support with per-gemm-argument configuration Oct 5, 2024
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 5, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 26b3b8ff4e59cc59bf4580056575e25ca7492d4f
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 7, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: ba2f870e5e0bcda12164d81995b496b9756e86b5
ghstack-comment-id: 2372563439
Pull Request resolved: #940
[ghstack-poisoned]
@vkuzo vkuzo changed the base branch from gh/vkuzo/11/head to main October 7, 2024 20:59
vkuzo added a commit that referenced this pull request Oct 7, 2024
Summary:

My brain hurts from so many long identifiers...

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c461f25c0725c6882b4d884bbe0cd08f159579ac
ghstack-comment-id: 2372563439
Pull Request resolved: #940
@vkuzo vkuzo merged commit dec0313 into main Oct 7, 2024
43 checks passed
jainapurva pushed a commit that referenced this pull request Oct 9, 2024
…guration (#940)

Summary:

This PR finalizes the UX for axiswise scaling for float8 training, and introduces per-gemm-argument configurability to Float8Linear to enable exploration of future recipes. Not all combinations are supported yet.  Specifically, the additional combination we now support and test is a recipe from @lw , where we do the following:

```
output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
grad_weight_hp = input_t_hp @ grad_output_hp
```

Key characteristics of this recipe:
1. increased accuracy for `grad_weight`, which is important for real workloads
2. `output` and `weight` now only need to be scaled axiswise across a single dim compared to vanilla all-axiswise, which is more amenable to fast kernels

Here is how a user can configure this:

```python
#
# short form
#

config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)

#
# or, long form
#

# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w_go = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)

# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED)

# ensure fast_accum is on to get fast kernels
gc_o = Float8GemmConfig(use_fast_accum=True)
gc_gi = Float8GemmConfig(use_fast_accum=True)
gc_gw = Float8GemmConfig(use_fast_accum=True)

config = Float8Config(
    cast_config_input = cc_i,
    cast_config_weight = cc_w,
    cast_config_grad_output = cc_go,
    cast_config_input_for_grad_weight = cc_i_gw,
    cast_config_weight_for_grad_output = cc_w_go,
    cast_config_grad_output_for_grad_weight = cc_go_gw,
    gemm_config_output=gc_o,
    gemm_config_grad_input=gc_gi,
    gemm_config_grad_weight=gc_gw,
)
```

# performance

Below we provide basic performance characteristics of axiswise scaling in general, and the all-axiswise and lw recipes.

## gemm performance of torch._scaled_mm

baseline: tensorwise scaling

```
> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True
    fast_accum  name      M      K      N  ref_time_s  fp8_time_s  fp8_speedup
0         True     0    256    256    256    0.000004    0.000006     0.573115
1         True     1    512    512    512    0.000005    0.000007     0.659333
2         True     2   1024   1024   1024    0.000011    0.000010     1.080664
3         True     3   2048   2048   2048    0.000028    0.000017     1.596239
4         True     4   4096   4096   4096    0.000210    0.000082     2.551705
5         True     5   8192   8192   8192    0.001671    0.000680     2.457972
6         True     6  16384  16384  16384    0.015030    0.006498     2.313032
7         True     7  32768  32768  32768    0.103236    0.048097     2.146411
8        False     0    256    256    256    0.000004    0.000006     0.630061
9        False     1    512    512    512    0.000005    0.000007     0.767236
10       False     2   1024   1024   1024    0.000012    0.000008     1.391347
11       False     3   2048   2048   2048    0.000029    0.000020     1.457922
12       False     4   4096   4096   4096    0.000211    0.000101     2.100081
13       False     5   8192   8192   8192    0.001676    0.000788     2.128628
14       False     6  16384  16384  16384    0.014933    0.006351     2.351209
15       False     7  32768  32768  32768    0.103457    0.049498     2.090134                
```

experiment: axiswise-scaling

```
> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True --scaling_granularity axiswise

    fast_accum  name      M      K      N  ref_time_s  fp8_time_s  fp8_speedup
0         True     0    256    256    256    0.000004    0.000004     0.966772
1         True     1    512    512    512    0.000005    0.000004     1.095791
2         True     2   1024   1024   1024    0.000011    0.000006     1.988363
3         True     3   2048   2048   2048    0.000027    0.000015     1.890065
4         True     4   4096   4096   4096    0.000210    0.000082     2.552356
5         True     5   8192   8192   8192    0.001674    0.001092     1.533132
6         True     6  16384  16384  16384    0.015114    0.008785     1.720480
7         True     7  32768  32768  32768    0.103286    0.071456     1.445439
8        False     0    256    256    256    0.000004    0.000004     0.899054
9        False     1    512    512    512    0.000005    0.000005     1.005340
10       False     2   1024   1024   1024    0.000011    0.000006     1.692868
11       False     3   2048   2048   2048    0.000028    0.000049     0.567655
12       False     4   4096   4096   4096    0.000210    0.000341     0.616193
13       False     5   8192   8192   8192    0.001678    0.002640     0.635541
14       False     6  16384  16384  16384    0.015051    0.021557     0.698212
15       False     7  32768  32768  32768    0.103497    0.169797     0.609533

```

## performance on microbenchmark of ln -> linear -> sigmoid

Note: for large square shapes, performance tends to be fp8_delayed_tensorwise > fp8_dynamic_tensorwise > fp8_dynamic_axiswise > custom_recipe.  For performance of fp8_dynamic_axiswise, it seems that the gap from tensorwise is mostly due to the gemm performance being behind tensorwise.

```
> python benchmarks/float8/float8_roofline.py ~/local/tmp/20241004_roofline.csv

   fwd_M  fwd_K  fwd_N  bf16_gemm_s  fp8_gemm_s  fp8_axs_gemm_time_s      fp8_oh_dyn_limit  ... fp8_del_s fp8_dyn_axs_s  fp8_lw_s  fp8_dyn_sp  fp8_del_sp  fp8_dyn_axs_sp  fp8_lw_sp
0    256    256    256     0.000011    0.000018             0.000012   6.50457971014493e-6  ...  0.000043      0.000049  0.000030    0.465634    0.457907        0.398357   0.643088
1    512    512    512     0.000014    0.000020             0.000013   8.01831884057971e-6  ...  0.000047      0.000054  0.000034    0.489556    0.493467        0.432643   0.685842
2   1024   1024   1024     0.000033    0.000026             0.000017   1.40732753623188e-5  ...  0.000060      0.000063  0.000050    0.734123    0.741467        0.705941   0.891199
3   2048   2048   2048     0.000081    0.000055             0.000044   3.82931014492754e-5  ...  0.000147      0.000159  0.000142    0.815678    0.800811        0.739865   0.827441
4   4096   4096   4096     0.000632    0.000274             0.000247  0.000135172405797101  ...  0.000602      0.000622  0.000662    1.236320    1.261848        1.221755   1.147678
5   8192   8192   8192     0.005027    0.002216             0.003292  0.000522689623188406  ...  0.003665      0.004776  0.005720    1.432213    1.513035        1.161130   0.969448
6  16384  16384  16384     0.045113    0.018975             0.025706   0.00207275849275362  ...  0.024664      0.032254  0.038051    1.803456    1.883291        1.440118   1.220738
7  32768  32768  32768     0.312459    0.147255             0.214492   0.00827303397101449  ...  0.182645      0.240962  0.270973    1.696376    1.766307        1.338827   1.190552

```

## performance on torchtitan LLaMa 3 8B on 8 H100 GPUs, float8 compute only:

* baseline (bf16 + compile): 6,294 wps
* f8 all-tensorwise: 7,359 wps (1.17x vs baseline)
* f8 all-axiswise: 7,135 wps (1.13x vs baseline - surprising that this is close to all-tensorwise)
* LW_AXISWISE_WITH_GW_HP: 6,506 wps (1.03x vs baseline)

so, looks like we have performance work to do with `LW_AXISWISE_WITH_GW_HP` in future PRs

# accuracy

I did a very quick check that loss curves on torchtitan LLaMa 3 8B pretraining with 8 H100 GPUs look good for bf16/f8_tensorwise/f8_axiswise/f8_lw on 0.5k iterations.  I will leave longer accuracy verifications for future work.

<img width="973" alt="Screenshot 2024-10-04 at 10 05 24 PM" src="https://github.com/user-attachments/assets/0d682183-41ef-4f04-992f-cd0d0fc8a65c">


Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
jainapurva pushed a commit that referenced this pull request Oct 15, 2024
…guration (#940)

Summary:

This PR finalizes the UX for axiswise scaling for float8 training, and introduces per-gemm-argument configurability to Float8Linear to enable exploration of future recipes. Not all combinations are supported yet.  Specifically, the additional combination we now support and test is a recipe from @lw , where we do the following:

```
output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
grad_weight_hp = input_t_hp @ grad_output_hp
```

Key characteristics of this recipe:
1. increased accuracy for `grad_weight`, which is important for real workloads
2. `output` and `weight` now only need to be scaled axiswise across a single dim compared to vanilla all-axiswise, which is more amenable to fast kernels

Here is how a user can configure this:

```python
#
# short form
#

config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)

#
# or, long form
#

# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w_go = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)

# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED)

# ensure fast_accum is on to get fast kernels
gc_o = Float8GemmConfig(use_fast_accum=True)
gc_gi = Float8GemmConfig(use_fast_accum=True)
gc_gw = Float8GemmConfig(use_fast_accum=True)

config = Float8Config(
    cast_config_input = cc_i,
    cast_config_weight = cc_w,
    cast_config_grad_output = cc_go,
    cast_config_input_for_grad_weight = cc_i_gw,
    cast_config_weight_for_grad_output = cc_w_go,
    cast_config_grad_output_for_grad_weight = cc_go_gw,
    gemm_config_output=gc_o,
    gemm_config_grad_input=gc_gi,
    gemm_config_grad_weight=gc_gw,
)
```

# performance

Below we provide basic performance characteristics of axiswise scaling in general, and the all-axiswise and lw recipes.

## gemm performance of torch._scaled_mm

baseline: tensorwise scaling

```
> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True
    fast_accum  name      M      K      N  ref_time_s  fp8_time_s  fp8_speedup
0         True     0    256    256    256    0.000004    0.000006     0.573115
1         True     1    512    512    512    0.000005    0.000007     0.659333
2         True     2   1024   1024   1024    0.000011    0.000010     1.080664
3         True     3   2048   2048   2048    0.000028    0.000017     1.596239
4         True     4   4096   4096   4096    0.000210    0.000082     2.551705
5         True     5   8192   8192   8192    0.001671    0.000680     2.457972
6         True     6  16384  16384  16384    0.015030    0.006498     2.313032
7         True     7  32768  32768  32768    0.103236    0.048097     2.146411
8        False     0    256    256    256    0.000004    0.000006     0.630061
9        False     1    512    512    512    0.000005    0.000007     0.767236
10       False     2   1024   1024   1024    0.000012    0.000008     1.391347
11       False     3   2048   2048   2048    0.000029    0.000020     1.457922
12       False     4   4096   4096   4096    0.000211    0.000101     2.100081
13       False     5   8192   8192   8192    0.001676    0.000788     2.128628
14       False     6  16384  16384  16384    0.014933    0.006351     2.351209
15       False     7  32768  32768  32768    0.103457    0.049498     2.090134                
```

experiment: axiswise-scaling

```
> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True --scaling_granularity axiswise

    fast_accum  name      M      K      N  ref_time_s  fp8_time_s  fp8_speedup
0         True     0    256    256    256    0.000004    0.000004     0.966772
1         True     1    512    512    512    0.000005    0.000004     1.095791
2         True     2   1024   1024   1024    0.000011    0.000006     1.988363
3         True     3   2048   2048   2048    0.000027    0.000015     1.890065
4         True     4   4096   4096   4096    0.000210    0.000082     2.552356
5         True     5   8192   8192   8192    0.001674    0.001092     1.533132
6         True     6  16384  16384  16384    0.015114    0.008785     1.720480
7         True     7  32768  32768  32768    0.103286    0.071456     1.445439
8        False     0    256    256    256    0.000004    0.000004     0.899054
9        False     1    512    512    512    0.000005    0.000005     1.005340
10       False     2   1024   1024   1024    0.000011    0.000006     1.692868
11       False     3   2048   2048   2048    0.000028    0.000049     0.567655
12       False     4   4096   4096   4096    0.000210    0.000341     0.616193
13       False     5   8192   8192   8192    0.001678    0.002640     0.635541
14       False     6  16384  16384  16384    0.015051    0.021557     0.698212
15       False     7  32768  32768  32768    0.103497    0.169797     0.609533

```

## performance on microbenchmark of ln -> linear -> sigmoid

Note: for large square shapes, performance tends to be fp8_delayed_tensorwise > fp8_dynamic_tensorwise > fp8_dynamic_axiswise > custom_recipe.  For performance of fp8_dynamic_axiswise, it seems that the gap from tensorwise is mostly due to the gemm performance being behind tensorwise.

```
> python benchmarks/float8/float8_roofline.py ~/local/tmp/20241004_roofline.csv

   fwd_M  fwd_K  fwd_N  bf16_gemm_s  fp8_gemm_s  fp8_axs_gemm_time_s      fp8_oh_dyn_limit  ... fp8_del_s fp8_dyn_axs_s  fp8_lw_s  fp8_dyn_sp  fp8_del_sp  fp8_dyn_axs_sp  fp8_lw_sp
0    256    256    256     0.000011    0.000018             0.000012   6.50457971014493e-6  ...  0.000043      0.000049  0.000030    0.465634    0.457907        0.398357   0.643088
1    512    512    512     0.000014    0.000020             0.000013   8.01831884057971e-6  ...  0.000047      0.000054  0.000034    0.489556    0.493467        0.432643   0.685842
2   1024   1024   1024     0.000033    0.000026             0.000017   1.40732753623188e-5  ...  0.000060      0.000063  0.000050    0.734123    0.741467        0.705941   0.891199
3   2048   2048   2048     0.000081    0.000055             0.000044   3.82931014492754e-5  ...  0.000147      0.000159  0.000142    0.815678    0.800811        0.739865   0.827441
4   4096   4096   4096     0.000632    0.000274             0.000247  0.000135172405797101  ...  0.000602      0.000622  0.000662    1.236320    1.261848        1.221755   1.147678
5   8192   8192   8192     0.005027    0.002216             0.003292  0.000522689623188406  ...  0.003665      0.004776  0.005720    1.432213    1.513035        1.161130   0.969448
6  16384  16384  16384     0.045113    0.018975             0.025706   0.00207275849275362  ...  0.024664      0.032254  0.038051    1.803456    1.883291        1.440118   1.220738
7  32768  32768  32768     0.312459    0.147255             0.214492   0.00827303397101449  ...  0.182645      0.240962  0.270973    1.696376    1.766307        1.338827   1.190552

```

## performance on torchtitan LLaMa 3 8B on 8 H100 GPUs, float8 compute only:

* baseline (bf16 + compile): 6,294 wps
* f8 all-tensorwise: 7,359 wps (1.17x vs baseline)
* f8 all-axiswise: 7,135 wps (1.13x vs baseline - surprising that this is close to all-tensorwise)
* LW_AXISWISE_WITH_GW_HP: 6,506 wps (1.03x vs baseline)

so, looks like we have performance work to do with `LW_AXISWISE_WITH_GW_HP` in future PRs

# accuracy

I did a very quick check that loss curves on torchtitan LLaMa 3 8B pretraining with 8 H100 GPUs look good for bf16/f8_tensorwise/f8_axiswise/f8_lw on 0.5k iterations.  I will leave longer accuracy verifications for future work.

<img width="973" alt="Screenshot 2024-10-04 at 10 05 24 PM" src="https://github.com/user-attachments/assets/0d682183-41ef-4f04-992f-cd0d0fc8a65c">


Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants