-
Notifications
You must be signed in to change notification settings - Fork 257
Add INT8 SDPA path for CPU #1372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1372
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7ed497a with merge base 25034e5 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@drisspg @jerryzh168 @jgong5 @leslie-fang-intel Please help review for the POC, thanks! |
WIll do some more review on this later today but I think we might want to make a sub folder in prototype for 'inductor_patterns" since this is a pretty particular workflow and we should add a good readme explaining how it can be used and its limitations |
Do I understand correctly that this using torch.compile to change the numerics of the model by hooking up a quantization pass to inductor? If yes, can this live in prototype for now? I'd have concerns about using torch.compile passes to change numerics being the official API, some of the challenges here include breaking the assumption that a compiler does not meaningfully change numerics. |
I believe numerics changes happens in pt2e quant api this should only do fusion |
def _sfdp_init_int8(): | ||
for key, register_replacement_kwargs in _gen_sfdp_patterns_int8(): | ||
register_replacement(**register_replacement_kwargs) | ||
config.joint_custom_pre_pass = patterns.apply |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this the official API to add new fusion passes to inductor? what if we have multiple fusion passes that we need to add? i.e. we probably want to move all intel quant passes to torchao in the future as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py#L165, I suppose only one fusion pass could be assigned, also the same case for other customized passes. It is better to expand this to a list of passes. Maybe need more comments from @Chillee @eellison.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we report an issue to track it? If multiple libraries register the joint_custom_pre_pass
, it will fail to work implicitly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Report an issue: pytorch/pytorch#151876
I see that |
There exists some numeric issues for the kernel currently, and I would update it finally. Ideally for quantized dtype, the atol is expected to be 1.5 (refer to https://github.com/pytorch/pytorch/blob/main/test/quantization/core/test_quantized_op.py#L5167), and the rol is the default value. |
torchao/csrc/cpu/sdpa.cpp
Outdated
// dropout_p, is_causal, attn_mask, scale, | ||
// q_zp, q_scale, | ||
// k_zp, k_scale, | ||
// v_zp, v_scale, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
decent amount of commented code here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, the unused codes are removed.
@jerryzh168 or @drisspg , could you help make sure the PR summary includes the high level description of the flow and explains why the inductor pass is used to swap to the final kernel, this isn't expected, it's somewhat makes sense to me but it would be good to just explain what the plan is in a way that's easily discoverable |
cc @jerryzh168 . The atol linked above is AFAIK for comparisons between high precision and low precision, which is numerics changing. If we're saying that this doesn't change numerics, I'd expect atol|rtol to be near 0. |
@vkuzo are you referring to |
@@ -71,6 +72,56 @@ def _( | |||
return _in_feats.new_empty((BS, OC)) | |||
|
|||
|
|||
def scaled_dot_product_int8( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add for_cpu
to the op name, since we will likely add some gpu ops as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is supposed to be used by all the backends, and each backend could register its own implementation. For example, for CPU path https://github.com/pytorch/ao/pull/1372/files#diff-eaf2387d03cf16395487f5f4162420a8e84bb89e9f1221a01474d2f87ff449ddR2080-R2082.
Or do you think the API for CPU and GPU could be different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think its fine as is, we are planning to add a variant for FAv3 like fp8 attention which would be slightly different I imagine, but it will be prototype
cc @jbschlosser
…torch#1372) Let's gracefully fail if no model is given to the `download` command. Signed-off-by: Sébastien Han <seb@redhat.com>
### Description During the support of INT8 SDPA pytorch/ao#1372, we find that `at::vec::vec_reduce_all<int32_t>` would go into slow scalar path when doing sum and max. So here, we support the two reduce-related ops `reduce_add` and `reduce_max` for `vec512` and `vec256`, using the Sequence instructions. ### Details - Support vectorized `reduce_add` and `reduce_max` for dtypes `int32` and `float32`, using the Sequence instructions; - Implement the scalar version for fallback path in vec base; - Add the operator `reduce` in vec base, in order to simplify the codes. Pull Request resolved: #144065 Approved by: https://github.com/mingfeima
In lowering, support the parameter `out_dtype` for `dequant_per_tensor` and `dequant_per_channel`. Fix the following runtime error issue found in pytorch/ao#1372: ``` File "/home/liaoxuan/pytorch_ao/torch/_inductor/lowering.py", line 452, in wrapped out = decomp_fn(*args, **kwargs) torch._dynamo.exc.BackendCompilerFailed: backend='compile_fx_wrapper' raised: LoweringException: TypeError: quantized_decomposed_dequantize_per_tensor_default() got an unexpected keyword argument 'out_dtype' target: quantized_decomposed.dequantize_per_tensor.default args[0]: TensorBox(StorageBox( InputBuffer(name='arg0_1', layout=FixedLayout('cpu', torch.uint8, size=[1, 7, 7, 9], stride=[441, 63, 9, 1])) )) args[1]: 0.01 args[2]: 100 args[3]: 0 args[4]: 255 args[5]: torch.uint8 kwargs: {'out_dtype': torch.bfloat16} ``` Pull Request resolved: #143845 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
7318d2a
to
526ae54
Compare
@Valentine233 do you have any perf benchmarks? Naively I would have imagined that SDPA isn't as crucial to implement on CPU because the caches are are larger and hardware managed but perhaps my mathematical intuition of what CPU SDPA is lacking so anything you could share would be super helpful |
We also encounter cache problems on CPU for long sequence lengths, so the fused SDPA does have perf benefits. I would share the results in the descriptions, as soon as they are prepared. |
8933c53
to
ecc66e4
Compare
434dce2
to
7ed497a
Compare
I think so since the |
Yean, since the new |
looks like there is a build issue? ao/torchao/csrc/cpu/int8_sdpa.cpp:1:9: error: #pragma once in main file [-Werror,-Wpragma-once-outside-header] |
will need to revert this one due to internal build errors, please land again |
This reverts commit 34421b1.
Revert "Add INT8 SDPA path for CPU (pytorch#1372)" This reverts commit 34421b1.
For the integration of INT8 SDPA in TorchAO, we design a feasible path by registering a customized pass of PyTorch and adding the pattern matcher and kernel in TorchAO.
Steps:
torchao.ops.scaled_dot_product_int8
, for CPU.torchao.ops.scaled_dot_product_int8
.torch._inductor.config.post_grad_custom_pre_pass
.Perf:
The validation is launched for int8-bf16 on GNR machine. The script is similar as the UT.