Skip to content

Commit 34421b1

Browse files
authored
Add INT8 SDPA path for CPU (#1372)
* [CPU] add int8 sdpa path with an explicit OP
1 parent 25034e5 commit 34421b1

File tree

9 files changed

+2783
-1
lines changed

9 files changed

+2783
-1
lines changed

setup.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def read_version(file_path="version.txt"):
5555
and platform.system() == "Darwin"
5656
)
5757

58+
use_cpp_avx512 = os.getenv("USE_AVX512", "1") == "1" and platform.system() == "Linux"
59+
60+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
61+
5862
version_prefix = read_version()
5963
# Version is version.dev year month date if using nightlies and version if not
6064
version = (
@@ -284,6 +288,17 @@ def get_extensions():
284288
["-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always"]
285289
)
286290

291+
if use_cpp_avx512 and TORCH_VERSION_AT_LEAST_2_7:
292+
if torch._C._cpu._is_avx512_supported():
293+
extra_compile_args["cxx"].extend(
294+
[
295+
"-DCPU_CAPABILITY_AVX512",
296+
"-march=native",
297+
"-mfma",
298+
"-fopenmp",
299+
]
300+
)
301+
287302
if debug_mode:
288303
extra_compile_args["cxx"].append("-g")
289304
if "nvcc" in extra_compile_args:
@@ -305,6 +320,12 @@ def get_extensions():
305320

306321
# Collect C++ source files
307322
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))
323+
if IS_WINDOWS:
324+
# Remove csrc/cpu/*.cpp on Windows due to the link issue: unresolved external symbol PyInit__C
325+
excluded_sources = list(
326+
glob.glob(os.path.join(extensions_dir, "cpu/*.cpp"), recursive=True)
327+
)
328+
sources = [s for s in sources if s not in excluded_sources]
308329

309330
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
310331
cuda_sources = list(
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
import itertools
2+
3+
import pytest
4+
import torch
5+
import torch.utils.checkpoint
6+
from torch._dynamo.utils import counters
7+
from torch._inductor import config
8+
from torch._inductor.test_case import TestCase, run_tests
9+
from torch._inductor.utils import run_and_get_code
10+
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm
11+
from torch.testing._internal.inductor_utils import HAS_CPU
12+
from torch.utils.cpp_extension import IS_WINDOWS
13+
14+
import torchao
15+
from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import _int8_sdpa_init
16+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
17+
18+
19+
class SelfAttnLikeModule(torch.nn.Module):
20+
def __init__(
21+
self,
22+
input_dim,
23+
has_mask,
24+
num_attention_heads=None,
25+
attention_head_size=None,
26+
) -> None:
27+
super().__init__()
28+
self.input_dim = input_dim
29+
self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
30+
self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
31+
self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
32+
self.softmax = torch.nn.Softmax(dim=-1)
33+
assert num_attention_heads is not None
34+
assert attention_head_size is not None
35+
self.num_attention_heads = num_attention_heads
36+
self.attention_head_size = attention_head_size
37+
self.all_head_size = self.num_attention_heads * self.attention_head_size
38+
self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size)
39+
self.dropout = torch.nn.Dropout(0)
40+
self.has_mask = has_mask
41+
42+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
43+
new_x_shape = x.size()[:-1] + (
44+
self.num_attention_heads,
45+
self.attention_head_size,
46+
)
47+
x = x.view(new_x_shape)
48+
return x.permute([0, 2, 1, 3])
49+
50+
def forward(self, x, mask):
51+
q = self.q_proj(x)
52+
k = self.k_proj(x)
53+
v = self.v_proj(x)
54+
q = self.transpose_for_scores(q)
55+
k = self.transpose_for_scores(k)
56+
v = self.transpose_for_scores(v)
57+
scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5)
58+
if self.has_mask and mask.dtype != scores.dtype:
59+
scores = scores + mask
60+
attention = self.softmax(scores)
61+
attention = self.dropout(attention)
62+
context_layer = torch.matmul(attention, v)
63+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
64+
context_layer = context_layer.view(
65+
context_layer.size()[:-2] + (self.all_head_size,)
66+
)
67+
return self.dense(context_layer)
68+
69+
70+
class TestSDPAPatternRewriterTemplate(TestCase):
71+
def _clone_inputs(self, inputs):
72+
def clone(x):
73+
if not isinstance(x, torch.Tensor):
74+
return x
75+
return x.clone()
76+
77+
return [clone(x) for x in inputs]
78+
79+
def _check_common(
80+
self,
81+
dot_prod_attention,
82+
args1=None,
83+
contains=True,
84+
atol=1e-5,
85+
has_fuse_pattern=True,
86+
has_dropout=False,
87+
check_train=True,
88+
override_check_equal=False,
89+
dtype=torch.float,
90+
rtol=1.3e-6,
91+
):
92+
if args1 is None:
93+
tensor_shape = (4, 2, 16, 32)
94+
args1 = [
95+
torch.randn(tensor_shape, device=self.device, dtype=dtype),
96+
torch.randn(tensor_shape, device=self.device, dtype=dtype),
97+
torch.randn(tensor_shape, device=self.device, dtype=dtype),
98+
]
99+
else:
100+
args1 = list(args1)
101+
args2 = self._clone_inputs(args1)
102+
103+
for training in [False, True] if check_train else [False]:
104+
for x in itertools.chain(args1[:], args2[:]):
105+
if isinstance(x, torch.Tensor) and x.is_floating_point():
106+
x.requires_grad = training
107+
108+
dropout_arg = [training] if has_dropout else []
109+
torch.manual_seed(1234)
110+
result1 = dot_prod_attention(*(args1 + dropout_arg))
111+
112+
counters.clear()
113+
torch.manual_seed(1234)
114+
compiled_model = torch.compile(dot_prod_attention, fullgraph=True)
115+
result2, source_code = run_and_get_code(
116+
compiled_model,
117+
*(args2 + dropout_arg),
118+
)
119+
source_code = "\n".join(source_code)
120+
if has_fuse_pattern:
121+
self.assertGreaterEqual(counters["inductor"]["int8_fuse_attention"], 1)
122+
if contains:
123+
# many of the patterns get re-expanded in dispatcher
124+
self.assertIn(
125+
"torchao.scaled_dot_product_int8",
126+
source_code,
127+
)
128+
129+
# some tests configured with very low dropout where we still want to check equality
130+
if not has_dropout or override_check_equal:
131+
self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6)
132+
133+
if training:
134+
result1.sum().backward()
135+
result2.sum().backward()
136+
for arg1, arg2 in zip(args1, args2):
137+
if (
138+
isinstance(arg1, torch.Tensor)
139+
and arg1.is_floating_point()
140+
and (not has_dropout or override_check_equal)
141+
):
142+
self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol)
143+
144+
@skipIfRocm
145+
@pytest.mark.skipif(
146+
not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later"
147+
)
148+
@pytest.mark.skipif(IS_WINDOWS, reason="int8 sdpa does not support windows yet")
149+
@config.patch({"freezing": True})
150+
def _test_sdpa_int8_rewriter(self):
151+
from torch.export import export_for_training
152+
153+
import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
154+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
155+
from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import (
156+
X86InductorQuantizer,
157+
)
158+
159+
# pattern is different for bs=1
160+
torch.manual_seed(1234)
161+
for dtype, has_mask, bs in itertools.product(
162+
[torch.float32, torch.bfloat16], [True, False], [56, 1]
163+
):
164+
seqlen, numhead, headsize = 197, 16, 64
165+
mod = SelfAttnLikeModule(
166+
input_dim=headsize * numhead,
167+
has_mask=has_mask,
168+
num_attention_heads=numhead,
169+
attention_head_size=headsize,
170+
).eval()
171+
inputs = (
172+
torch.randn(
173+
(bs, seqlen, headsize * numhead), device=self.device, dtype=dtype
174+
),
175+
torch.randn((bs, 1, 1, seqlen), device=self.device)
176+
if has_mask
177+
else None,
178+
)
179+
enable_autocast = dtype == torch.bfloat16
180+
with (
181+
torch.no_grad(),
182+
torch.amp.autocast(
183+
self.device, enabled=enable_autocast, dtype=torch.bfloat16
184+
),
185+
):
186+
_int8_sdpa_init()
187+
quantizer = X86InductorQuantizer()
188+
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
189+
quantizer.set_function_type_qconfig(
190+
torch.matmul, quantizer.get_global_quantization_config()
191+
)
192+
export_model = export_for_training(
193+
mod,
194+
inputs,
195+
strict=True,
196+
).module()
197+
prepare_model = prepare_pt2e(export_model, quantizer)
198+
prepare_model(*inputs)
199+
convert_model = convert_pt2e(prepare_model)
200+
torchao.quantization.pt2e.move_exported_model_to_eval(convert_model)
201+
self._check_common(
202+
convert_model, args1=inputs, check_train=False, atol=1.0
203+
)
204+
205+
206+
if HAS_CPU:
207+
208+
class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate):
209+
device = "cpu"
210+
test_sdpa_int8_rewriter_cpu = (
211+
TestSDPAPatternRewriterTemplate._test_sdpa_int8_rewriter
212+
)
213+
214+
215+
if __name__ == "__main__":
216+
if IS_LINUX:
217+
run_tests()

0 commit comments

Comments
 (0)