Skip to content

Commit 080576b

Browse files
committed
update int8 sdpa
1 parent f58fa63 commit 080576b

File tree

4 files changed

+19
-12
lines changed

4 files changed

+19
-12
lines changed

test/prototype/inductor/test_int8_sdpa_fusion.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import itertools
2-
import os
32
import unittest
43

54
import torch
@@ -12,11 +11,12 @@
1211
from torch.testing._internal.inductor_utils import HAS_CPU
1312

1413
import torchao
15-
from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import _int8_sdpa_init
14+
from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import (
15+
_int8_sdpa_init,
16+
custom_pass,
17+
)
1618
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
1719

18-
use_cpp_avx512 = os.getenv("USE_AVX512", "0") == "1"
19-
2020

2121
class SelfAttnLikeModule(torch.nn.Module):
2222
def __init__(
@@ -147,7 +147,10 @@ def _check_common(
147147
@unittest.skipIf(
148148
not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later"
149149
)
150-
@unittest.skipIf(not use_cpp_avx512, reason="cpp kernels not built")
150+
@unittest.skipIf(
151+
"CPU" not in torch._C._dispatch_dump("torchao::qscaled_dot_product"),
152+
reason="cpp kernels not built",
153+
)
151154
@config.patch({"freezing": True})
152155
def _test_sdpa_int8_rewriter(self):
153156
from torch.export import export_for_training
@@ -184,6 +187,7 @@ def _test_sdpa_int8_rewriter(self):
184187
torch.amp.autocast(
185188
self.device, enabled=enable_autocast, dtype=torch.bfloat16
186189
),
190+
config.patch(post_grad_custom_pre_pass=custom_pass),
187191
):
188192
_int8_sdpa_init()
189193
quantizer = X86InductorQuantizer()

test/test_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66
import itertools
77
import math
8-
import os
98
import sys
109

1110
import pytest
@@ -32,8 +31,6 @@
3231
compute_max_diff,
3332
)
3433

35-
use_cpp_avx512 = os.getenv("USE_AVX512", "0") == "1"
36-
3734
if torch.version.hip is not None:
3835
pytest.skip("Skipping the test in ROCm", allow_module_level=True)
3936

@@ -160,7 +157,10 @@ def _scaled_dot_product_int8_op_ref(
160157
not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later"
161158
)
162159
@pytest.mark.skipif(not IS_LINUX, reason="only support on linux")
163-
@pytest.mark.skipif(not use_cpp_avx512, reason="cpp kernels not built")
160+
@pytest.mark.skipif(
161+
"CPU" not in torch._C._dispatch_dump("torchao::qscaled_dot_product"),
162+
reason="cpp kernels not built",
163+
)
164164
@parametrize("batch_size", [56, 120])
165165
@parametrize("n_head", [2, 16])
166166
@parametrize("q_seq_len", [18, 89])

torchao/prototype/inductor/fx_passes/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ from torch._inductor.pattern_matcher import PatternMatcherPass
2323

2424
# Example usage
2525
class _CustomPass(...): # create a custom pass class
26-
config.custom_pass = _CustomPass() # define the custom pass with the patterns
27-
_register_patterns(...) # register your own patterns
26+
custom_pass = _CustomPass() # create an instance of custom pass
27+
with config.patch(config.custom_pass=custom_pass):
28+
_register_patterns(config.custom_pass) # register your own passes
2829

2930
```
3031

torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def _register_int8_sdpa_lowerings(custom_pass_dict):
365365
)
366366

367367

368+
custom_pass = None
368369
if TORCH_VERSION_AT_LEAST_2_7:
369370
# TORCH_VERSION_AT_LEAST_2_7 is needed for custom graph pass
370371
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files
@@ -380,11 +381,12 @@ def __call__(self, g: torch.fx.graph.Graph):
380381
def uuid(self) -> bytes:
381382
return get_hash_for_files((__file__,))
382383

384+
custom_pass = _CustomPass()
385+
383386

384387
@functools.lru_cache(None)
385388
def _int8_sdpa_init():
386389
if TORCH_VERSION_AT_LEAST_2_7:
387-
config.post_grad_custom_pre_pass = _CustomPass()
388390
_register_int8_sdpa_lowerings(config.post_grad_custom_pre_pass)
389391
else:
390392
pass

0 commit comments

Comments
 (0)