File tree 4 files changed +19
-12
lines changed
torchao/prototype/inductor/fx_passes
4 files changed +19
-12
lines changed Original file line number Diff line number Diff line change 1
1
import itertools
2
- import os
3
2
import unittest
4
3
5
4
import torch
12
11
from torch .testing ._internal .inductor_utils import HAS_CPU
13
12
14
13
import torchao
15
- from torchao .prototype .inductor .fx_passes .int8_sdpa_fusion import _int8_sdpa_init
14
+ from torchao .prototype .inductor .fx_passes .int8_sdpa_fusion import (
15
+ _int8_sdpa_init ,
16
+ custom_pass ,
17
+ )
16
18
from torchao .utils import TORCH_VERSION_AT_LEAST_2_7
17
19
18
- use_cpp_avx512 = os .getenv ("USE_AVX512" , "0" ) == "1"
19
-
20
20
21
21
class SelfAttnLikeModule (torch .nn .Module ):
22
22
def __init__ (
@@ -147,7 +147,10 @@ def _check_common(
147
147
@unittest .skipIf (
148
148
not TORCH_VERSION_AT_LEAST_2_7 , reason = "int8 sdpa requires torch 2.7 or later"
149
149
)
150
- @unittest .skipIf (not use_cpp_avx512 , reason = "cpp kernels not built" )
150
+ @unittest .skipIf (
151
+ "CPU" not in torch ._C ._dispatch_dump ("torchao::qscaled_dot_product" ),
152
+ reason = "cpp kernels not built" ,
153
+ )
151
154
@config .patch ({"freezing" : True })
152
155
def _test_sdpa_int8_rewriter (self ):
153
156
from torch .export import export_for_training
@@ -184,6 +187,7 @@ def _test_sdpa_int8_rewriter(self):
184
187
torch .amp .autocast (
185
188
self .device , enabled = enable_autocast , dtype = torch .bfloat16
186
189
),
190
+ config .patch (post_grad_custom_pre_pass = custom_pass ),
187
191
):
188
192
_int8_sdpa_init ()
189
193
quantizer = X86InductorQuantizer ()
Original file line number Diff line number Diff line change 5
5
# LICENSE file in the root directory of this source tree.
6
6
import itertools
7
7
import math
8
- import os
9
8
import sys
10
9
11
10
import pytest
32
31
compute_max_diff ,
33
32
)
34
33
35
- use_cpp_avx512 = os .getenv ("USE_AVX512" , "0" ) == "1"
36
-
37
34
if torch .version .hip is not None :
38
35
pytest .skip ("Skipping the test in ROCm" , allow_module_level = True )
39
36
@@ -160,7 +157,10 @@ def _scaled_dot_product_int8_op_ref(
160
157
not TORCH_VERSION_AT_LEAST_2_7 , reason = "int8 sdpa requires torch 2.7 or later"
161
158
)
162
159
@pytest .mark .skipif (not IS_LINUX , reason = "only support on linux" )
163
- @pytest .mark .skipif (not use_cpp_avx512 , reason = "cpp kernels not built" )
160
+ @pytest .mark .skipif (
161
+ "CPU" not in torch ._C ._dispatch_dump ("torchao::qscaled_dot_product" ),
162
+ reason = "cpp kernels not built" ,
163
+ )
164
164
@parametrize ("batch_size" , [56 , 120 ])
165
165
@parametrize ("n_head" , [2 , 16 ])
166
166
@parametrize ("q_seq_len" , [18 , 89 ])
Original file line number Diff line number Diff line change @@ -23,8 +23,9 @@ from torch._inductor.pattern_matcher import PatternMatcherPass
23
23
24
24
# Example usage
25
25
class _CustomPass (... ): # create a custom pass class
26
- config.custom_pass = _CustomPass() # define the custom pass with the patterns
27
- _register_patterns(... ) # register your own patterns
26
+ custom_pass = _CustomPass() # create an instance of custom pass
27
+ with config.patch(config.custom_pass= custom_pass):
28
+ _register_patterns(config.custom_pass) # register your own passes
28
29
29
30
```
30
31
Original file line number Diff line number Diff line change @@ -365,6 +365,7 @@ def _register_int8_sdpa_lowerings(custom_pass_dict):
365
365
)
366
366
367
367
368
+ custom_pass = None
368
369
if TORCH_VERSION_AT_LEAST_2_7 :
369
370
# TORCH_VERSION_AT_LEAST_2_7 is needed for custom graph pass
370
371
from torch ._inductor .custom_graph_pass import CustomGraphPass , get_hash_for_files
@@ -380,11 +381,12 @@ def __call__(self, g: torch.fx.graph.Graph):
380
381
def uuid (self ) -> bytes :
381
382
return get_hash_for_files ((__file__ ,))
382
383
384
+ custom_pass = _CustomPass ()
385
+
383
386
384
387
@functools .lru_cache (None )
385
388
def _int8_sdpa_init ():
386
389
if TORCH_VERSION_AT_LEAST_2_7 :
387
- config .post_grad_custom_pre_pass = _CustomPass ()
388
390
_register_int8_sdpa_lowerings (config .post_grad_custom_pre_pass )
389
391
else :
390
392
pass
You can’t perform that action at this time.
0 commit comments