Skip to content

Add INT8 SDPA path for CPU #1372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Apr 18, 2025
Merged

Conversation

Valentine233
Copy link
Collaborator

@Valentine233 Valentine233 commented Dec 3, 2024

For the integration of INT8 SDPA in TorchAO, we design a feasible path by registering a customized pass of PyTorch and adding the pattern matcher and kernel in TorchAO.

Steps:

  1. Register and implement the INT8 SDPA kernel, i.e. torchao.ops.scaled_dot_product_int8, for CPU.
  2. Add the pattern matchers for INT8 SDPA., in order to replace the decomposed OPs with torchao.ops.scaled_dot_product_int8.
  3. Register a customized pass of PyTorch by defining the above patterns as torch._inductor.config.post_grad_custom_pre_pass.

Perf:
The validation is launched for int8-bf16 on GNR machine. The script is similar as the UT.

Model  Mode  E2E Speedup
BertLarge Throughput  1.13
BertLarge Realtime  1.03
VIT Throughput  1.10
VIT Realtime  1.03

Copy link

pytorch-bot bot commented Dec 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1372

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7ed497a with merge base 25034e5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 3, 2024
@Valentine233
Copy link
Collaborator Author

@drisspg @jerryzh168 @jgong5 @leslie-fang-intel Please help review for the POC, thanks!

@drisspg
Copy link
Contributor

drisspg commented Dec 3, 2024

WIll do some more review on this later today but I think we might want to make a sub folder in prototype for 'inductor_patterns" since this is a pretty particular workflow and we should add a good readme explaining how it can be used and its limitations

@vkuzo
Copy link
Contributor

vkuzo commented Dec 3, 2024

Register a customized pass of PyTorch by defining the above patterns as torch._inductor.config.joint_custom_pre_pass.

Do I understand correctly that this using torch.compile to change the numerics of the model by hooking up a quantization pass to inductor? If yes, can this live in prototype for now? I'd have concerns about using torch.compile passes to change numerics being the official API, some of the challenges here include breaking the assumption that a compiler does not meaningfully change numerics.

@jerryzh168
Copy link
Contributor

Register a customized pass of PyTorch by defining the above patterns as torch._inductor.config.joint_custom_pre_pass.

Do I understand correctly that this using torch.compile to change the numerics of the model by hooking up a quantization pass to inductor? If yes, can this live in prototype for now? I'd have concerns about using torch.compile passes to change numerics being the official API, some of the challenges here include breaking the assumption that a compiler does not meaningfully change numerics.

I believe numerics changes happens in pt2e quant api this should only do fusion

def _sfdp_init_int8():
for key, register_replacement_kwargs in _gen_sfdp_patterns_int8():
register_replacement(**register_replacement_kwargs)
config.joint_custom_pre_pass = patterns.apply
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the official API to add new fusion passes to inductor? what if we have multiple fusion passes that we need to add? i.e. we probably want to move all intel quant passes to torchao in the future as well

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py#L165, I suppose only one fusion pass could be assigned, also the same case for other customized passes. It is better to expand this to a list of passes. Maybe need more comments from @Chillee @eellison.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we report an issue to track it? If multiple libraries register the joint_custom_pre_pass, it will fail to work implicitly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Report an issue: pytorch/pytorch#151876

@vkuzo
Copy link
Contributor

vkuzo commented Dec 4, 2024

Register a customized pass of PyTorch by defining the above patterns as torch._inductor.config.joint_custom_pre_pass.

Do I understand correctly that this using torch.compile to change the numerics of the model by hooking up a quantization pass to inductor? If yes, can this live in prototype for now? I'd have concerns about using torch.compile passes to change numerics being the official API, some of the challenges here include breaking the assumption that a compiler does not meaningfully change numerics.

I believe numerics changes happens in pt2e quant api this should only do fusion

I see that rtol and atol are pretty significant when comparing the baseline to the workflow introduced in this PR, is that intended? Are you saying that with the proper setup we can do the comparison between using this inductor pass vs not using it to pass with rtol/atol near zero?

@Valentine233
Copy link
Collaborator Author

Valentine233 commented Dec 4, 2024

Register a customized pass of PyTorch by defining the above patterns as torch._inductor.config.joint_custom_pre_pass.

Do I understand correctly that this using torch.compile to change the numerics of the model by hooking up a quantization pass to inductor? If yes, can this live in prototype for now? I'd have concerns about using torch.compile passes to change numerics being the official API, some of the challenges here include breaking the assumption that a compiler does not meaningfully change numerics.

I believe numerics changes happens in pt2e quant api this should only do fusion

I see that rtol and atol are pretty significant when comparing the baseline to the workflow introduced in this PR, is that intended? Are you saying that with the proper setup we can do the comparison between using this inductor pass vs not using it to pass with rtol/atol near zero?

There exists some numeric issues for the kernel currently, and I would update it finally. Ideally for quantized dtype, the atol is expected to be 1.5 (refer to https://github.com/pytorch/pytorch/blob/main/test/quantization/core/test_quantized_op.py#L5167), and the rol is the default value.

// dropout_p, is_causal, attn_mask, scale,
// q_zp, q_scale,
// k_zp, k_scale,
// v_zp, v_scale,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decent amount of commented code here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the unused codes are removed.

@vkuzo
Copy link
Contributor

vkuzo commented Dec 4, 2024

@jerryzh168 or @drisspg , could you help make sure the PR summary includes the high level description of the flow and explains why the inductor pass is used to swap to the final kernel, this isn't expected, it's somewhat makes sense to me but it would be good to just explain what the plan is in a way that's easily discoverable

@vkuzo
Copy link
Contributor

vkuzo commented Dec 4, 2024

Register a customized pass of PyTorch by defining the above patterns as torch._inductor.config.joint_custom_pre_pass.

Do I understand correctly that this using torch.compile to change the numerics of the model by hooking up a quantization pass to inductor? If yes, can this live in prototype for now? I'd have concerns about using torch.compile passes to change numerics being the official API, some of the challenges here include breaking the assumption that a compiler does not meaningfully change numerics.

I believe numerics changes happens in pt2e quant api this should only do fusion

I see that rtol and atol are pretty significant when comparing the baseline to the workflow introduced in this PR, is that intended? Are you saying that with the proper setup we can do the comparison between using this inductor pass vs not using it to pass with rtol/atol near zero?

There exists some numeric issues for the kernel currently, and I would update it finally. Ideally for quantized dtype, the atol is expected to be 1.5 (refer to https://github.com/pytorch/pytorch/blob/main/test/quantization/core/test_quantized_op.py#L5167), and the rol is the default value.

cc @jerryzh168 . The atol linked above is AFAIK for comparisons between high precision and low precision, which is numerics changing. If we're saying that this doesn't change numerics, I'd expect atol|rtol to be near 0.

@jerryzh168
Copy link
Contributor

jerryzh168 commented Dec 4, 2024

@vkuzo are you referring to atol=1.0 (https://github.com/pytorch/ao/pull/1372/files#diff-fe0aa67bd65dd5da118abe44f45104170ba9871cb14bd446a448d056599d462aR189)? this does look a bit large. I think there is some expected difference when we fuse dq - op - q patterns to real quantized ops, it's probably expected to have slight changes to numerics, so it's a bit different compared to floating point fusions

@@ -71,6 +72,56 @@ def _(
return _in_feats.new_empty((BS, OC))


def scaled_dot_product_int8(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add for_cpu to the op name, since we will likely add some gpu ops as well

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is supposed to be used by all the backends, and each backend could register its own implementation. For example, for CPU path https://github.com/pytorch/ao/pull/1372/files#diff-eaf2387d03cf16395487f5f4162420a8e84bb89e9f1221a01474d2f87ff449ddR2080-R2082.

Or do you think the API for CPU and GPU could be different?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its fine as is, we are planning to add a variant for FAv3 like fp8 attention which would be slightly different I imagine, but it will be prototype

cc @jbschlosser

yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
…torch#1372)

Let's gracefully fail if no model is given to the `download` command.

Signed-off-by: Sébastien Han <seb@redhat.com>
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jan 3, 2025
### Description

During the support of INT8 SDPA pytorch/ao#1372, we find that `at::vec::vec_reduce_all<int32_t>` would go  into slow scalar path when doing sum and max. So here, we support the two reduce-related ops `reduce_add` and `reduce_max` for `vec512` and `vec256`, using the Sequence instructions.

### Details
- Support vectorized `reduce_add` and `reduce_max` for dtypes `int32` and `float32`, using the Sequence instructions;
- Implement the scalar version for fallback path in vec base;
- Add the operator `reduce` in vec base, in order to simplify the codes.

Pull Request resolved: #144065
Approved by: https://github.com/mingfeima
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jan 4, 2025
In lowering, support the parameter `out_dtype` for `dequant_per_tensor` and `dequant_per_channel`.

Fix the following runtime error issue found in pytorch/ao#1372:

```
File "/home/liaoxuan/pytorch_ao/torch/_inductor/lowering.py", line 452, in wrapped
    out = decomp_fn(*args, **kwargs)
torch._dynamo.exc.BackendCompilerFailed: backend='compile_fx_wrapper' raised:
LoweringException: TypeError: quantized_decomposed_dequantize_per_tensor_default() got an unexpected keyword argument 'out_dtype'
  target: quantized_decomposed.dequantize_per_tensor.default
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cpu', torch.uint8, size=[1, 7, 7, 9], stride=[441, 63, 9, 1]))
  ))
  args[1]: 0.01
  args[2]: 100
  args[3]: 0
  args[4]: 255
  args[5]: torch.uint8
  kwargs: {'out_dtype': torch.bfloat16}
```

Pull Request resolved: #143845
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
@msaroufim
Copy link
Member

msaroufim commented Feb 15, 2025

@Valentine233 do you have any perf benchmarks? Naively I would have imagined that SDPA isn't as crucial to implement on CPU because the caches are are larger and hardware managed but perhaps my mathematical intuition of what CPU SDPA is lacking so anything you could share would be super helpful

@Valentine233
Copy link
Collaborator Author

do you have any perf benchmarks? Naively I would have imagined that SDPA isn't as crucial to implement on CPU because the caches are are larger and hardware managed but perhaps my mathematical intuition of what CPU SDPA is lacking so anything you could share would be super helpful

We also encounter cache problems on CPU for long sequence lengths, so the fused SDPA does have perf benefits. I would share the results in the descriptions, as soon as they are prepared.

@Xia-Weiwen
Copy link
Collaborator

@leslie-fang-intel @Xia-Weiwen do you want to move other inductor quant passes here as well

I think so since the q and dq ops are deprecated in PyTorch core #1372 (comment). We may need to think more about it, such as the pace to do it and how to keep backward compatibility. We may also need to add more registration points for custom passes in Inductor to accommodate all passes. Leslie may comment more if needed. Thanks.

@leslie-fang-intel
Copy link
Collaborator

leslie-fang-intel commented Apr 18, 2025

@leslie-fang-intel @Xia-Weiwen do you want to move other inductor quant passes here as well

I think so since the q and dq ops are deprecated in PyTorch core #1372 (comment). We may need to think more about it, such as the pace to do it and how to keep backward compatibility. We may also need to add more registration points for custom passes in Inductor to accommodate all passes. Leslie may comment more if needed. Thanks.

Yean, since the new q and dq ops like torch.ops.torchao.dequantize_affine will be registered in TorchAO and be used in TorchAO PT2E flow, I feel we need to register patterns with these new ops in TorchAO.

@Valentine233 Valentine233 merged commit 34421b1 into pytorch:main Apr 18, 2025
34 checks passed
@jerryzh168
Copy link
Contributor

looks like there is a build issue? ao/torchao/csrc/cpu/int8_sdpa.cpp:1:9: error: #pragma once in main file [-Werror,-Wpragma-once-outside-header]
1 | #pragma once
| ^
1 error generated.

@jerryzh168
Copy link
Contributor

will need to revert this one due to internal build errors, please land again

jerryzh168 added a commit that referenced this pull request Apr 21, 2025
jerryzh168 added a commit that referenced this pull request Apr 22, 2025
Revert "Add INT8 SDPA path for CPU (#1372)"

This reverts commit 34421b1.
lisjin pushed a commit to lisjin/ao that referenced this pull request Apr 22, 2025
Revert "Add INT8 SDPA path for CPU (pytorch#1372)"

This reverts commit 34421b1.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants