-
Notifications
You must be signed in to change notification settings - Fork 253
enable torch.compile for mxfp8_cublas recipe #1841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Stack from ghstack (oldest at bottom): |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1841
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: This PR enables `MXLinear` with `mxfp8_cublas` recipe to use torch.compile. The current approach is a short term workaround until pytorch/pytorch#148461 is done. Since we can't use e8m0 in torchinductor or triton yet, we create a custom op wrapper around `torch._scaled_mm` which takes `uint8` scales and does the cast to e8m0 inside the wrapper, where torchinductor can't see it. Test Plan: ``` // this now works (although performance is not ideal due to #1788) python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250305_test --mx_recipe_name mxfp8_cublas // we can also uncomment the hardware check and run the unit test pytest test/prototype/mx_formats -s -k test_linear_compile ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 033d817549f80d7d0d8cf549f748411cc1f3ac6a ghstack-comment-id: 2701679811 Pull Request resolved: #1841
Summary: This PR enables `MXLinear` with `mxfp8_cublas` recipe to use torch.compile. The current approach is a short term workaround until pytorch/pytorch#147873 is done. Since we can't use e8m0 in torchinductor or triton yet, we create a custom op wrapper around `torch._scaled_mm` which takes `uint8` scales and does the cast to e8m0 inside the wrapper, where torchinductor can't see it. Test Plan: ``` // this now works (although performance is not ideal due to #1788) python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250305_test --mx_recipe_name mxfp8_cublas // we can also uncomment the hardware check and run the unit test pytest test/prototype/mx_formats -s -k test_linear_compile ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f3ebd12edcb746b8abf992d00711ce2bdbb7fcf2 ghstack-comment-id: 2701679811 Pull Request resolved: #1841
Summary: This PR enables `MXLinear` with `mxfp8_cublas` recipe to use torch.compile. The current approach is a short term workaround until pytorch/pytorch#147873 is done. Since we can't use e8m0 in torchinductor or triton yet, we create a custom op wrapper around `torch._scaled_mm` which takes `uint8` scales and does the cast to e8m0 inside the wrapper, where torchinductor can't see it. Test Plan: ``` // this now works (although performance is not ideal due to #1788) python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250305_test --mx_recipe_name mxfp8_cublas // we can also uncomment the hardware check and run the unit test pytest test/prototype/mx_formats -s -k test_linear_compile ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e5687e308db0a54c6083c58cfec5cc49626622f1 ghstack-comment-id: 2701679811 Pull Request resolved: #1841
is_sm_at_least_100(), | ||
reason="triton does not work yet on CUDA capability 10.0", | ||
) | ||
@pytest.mark.skipif( | ||
not is_sm_at_least_100(), | ||
reason="MX gemms require CUDA capability 10.0", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Combining skip if is_sm_at_least_100()
with skip if not is_sm_at_least_100()
will prevent the test from ever running, so I just want to confirm, is this test intentionally being skipped until the new release of pytorch (with triton that supports compute capability 10.0) is part of CI?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, that's corrrect. It's skipped in CI because we don't have B200s in CI, and it's skipped locally because it requires building triton from source. I uncomment these tests if I need to run them, for now.
Summary:
This PR enables
MXLinear
withmxfp8_cublas
recipe to usetorch.compile.
The current approach is a short term workaround until
pytorch/pytorch#148461 is done. Since we can't
use e8m0 in torchinductor or triton yet, we create a custom op wrapper
around
torch._scaled_mm
which takesuint8
scales and does the cast toe8m0 inside the wrapper, where torchinductor can't see it.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: