Skip to content

Commit 7318d2a

Browse files
committed
update int8 sdpa cpu
1 parent 545b2b8 commit 7318d2a

File tree

5 files changed

+922
-1305
lines changed

5 files changed

+922
-1305
lines changed

setup.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,20 @@ def get_extensions():
7070
"cxx": [
7171
"-O3" if not debug_mode else "-O0",
7272
"-fdiagnostics-color=always",
73+
# ## AVX2
74+
# "-DCPU_CAPABILITY=AVX2",
75+
# "-DCPU_CAPABILITY_AVX2",
76+
# "-mavx2",
77+
# "-mfma",
78+
# "-mf16c",
79+
## AVX512
80+
"-DCPU_CAPABILITY=AVX512",
81+
"-DCPU_CAPABILITY_AVX512",
82+
"-mavx512f",
83+
"-mavx512bw",
84+
"-mavx512vl",
85+
"-mavx512dq",
86+
"-mfma",
7387
],
7488
"nvcc": [
7589
"-O3" if not debug_mode else "-O0",

test/quantization/test_sfdp_int8_fx_pass.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
1717

1818
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
19-
from torch._export import capture_pre_autograd_graph
19+
from torch.export import export_for_training
2020
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
2121
from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
2222
X86InductorQuantizer,
@@ -65,7 +65,7 @@ def forward(self, x, mask):
6565
if self.has_mask:
6666
scores = scores + mask
6767
attention = self.softmax(scores)
68-
# attention = self.dropout(attention)
68+
attention = self.dropout(attention)
6969
context_layer = torch.matmul(attention, v)
7070
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
7171
context_layer = context_layer.view(
@@ -75,7 +75,7 @@ def forward(self, x, mask):
7575

7676
def _generate_qdq_quantized_model(mod, inputs, quantizer):
7777
with torch.no_grad():
78-
export_model = capture_pre_autograd_graph(mod, inputs)
78+
export_model = export_for_training(mod, inputs).module()
7979
prepare_model = prepare_pt2e(export_model, quantizer)
8080
prepare_model(*inputs)
8181
convert_model = convert_pt2e(prepare_model)
@@ -173,10 +173,10 @@ def _test_sdpa_rewriter_int8_1_to_4(self):
173173
if dtype == torch.bfloat16
174174
else contextlib.nullcontext()
175175
)
176-
inputs = [
176+
inputs = (
177177
torch.randn((bs, 384, 64 * 16), device=self.device, dtype=dtype),
178178
torch.randn((bs, 1, 1, 384), device=self.device) if has_mask else None,
179-
]
179+
)
180180
with torch.no_grad(), maybe_autocast:
181181
_sfdp_init_int8()
182182
quantizer = X86InductorQuantizer()

0 commit comments

Comments
 (0)