-
Notifications
You must be signed in to change notification settings - Fork 267
Re-land the PR of "Add INT8 SDPA path for CPU" #2215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2215
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 65f7d50 with merge base 96aec6a ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import _int8_sdpa_init | ||
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 | ||
|
||
use_cpp_avx512 = os.getenv("USE_AVX512", "0") == "1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels wrong
- if user didn't set the flag during the build phase but only testing, will it cause CI failure?
- if user build the custom op, but didn't enable this flag to test, will it just skip the testing?
One way comes to my mind is to check if this custom op has been registered to the CPU dispatch key correctly, for example torch._C._dispatch_dump("torchao::qscaled_dot_product")
. Feel free to explore if any better idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks the suggestion, replace with "CPU" in torch._C._dispatch_dump("torchao::qscaled_dot_product")
.
self.device, enabled=enable_autocast, dtype=torch.bfloat16 | ||
), | ||
): | ||
_int8_sdpa_init() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For how to register the custom pass, could we follow the suggestion in pytorch/pytorch#153532 (comment)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks and modified!
test/test_ops.py
Outdated
compute_max_diff, | ||
) | ||
|
||
use_cpp_avx512 = os.getenv("USE_AVX512", "0") == "1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks the suggestion, replace with "CPU" in torch._C._dispatch_dump("torchao::qscaled_dot_product")
.
setup.py
Outdated
@@ -55,6 +55,10 @@ def read_version(file_path="version.txt"): | |||
and platform.system() == "Darwin" | |||
) | |||
|
|||
use_cpp_avx512 = os.getenv("USE_AVX512", "0") == "1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This name might not be intuitive. This flag actual decide building of CPP kernels or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Changed to use_cpp_kernels
.
080576b
to
164a8ff
Compare
164a8ff
to
65f7d50
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wheel build looks good
Re-land #1372.
Based on the original PR, there are two main modifications:
scaled_dot_product_int8
toqscaled_dot_product
, in order to reuse the API for future FP8 SDPA support.