16
16
from torch .testing ._internal .inductor_utils import HAS_CPU , HAS_CUDA
17
17
18
18
import torch .ao .quantization .quantizer .x86_inductor_quantizer as xiq
19
- from torch ._export import capture_pre_autograd_graph
19
+ from torch .export import export_for_training
20
20
from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
21
21
from torch .ao .quantization .quantizer .x86_inductor_quantizer import (
22
22
X86InductorQuantizer ,
@@ -65,7 +65,7 @@ def forward(self, x, mask):
65
65
if self .has_mask :
66
66
scores = scores + mask
67
67
attention = self .softmax (scores )
68
- # attention = self.dropout(attention)
68
+ attention = self .dropout (attention )
69
69
context_layer = torch .matmul (attention , v )
70
70
context_layer = context_layer .permute (0 , 2 , 1 , 3 ).contiguous ()
71
71
context_layer = context_layer .view (
@@ -75,7 +75,7 @@ def forward(self, x, mask):
75
75
76
76
def _generate_qdq_quantized_model (mod , inputs , quantizer ):
77
77
with torch .no_grad ():
78
- export_model = capture_pre_autograd_graph (mod , inputs )
78
+ export_model = export_for_training (mod , inputs ). module ( )
79
79
prepare_model = prepare_pt2e (export_model , quantizer )
80
80
prepare_model (* inputs )
81
81
convert_model = convert_pt2e (prepare_model )
@@ -173,10 +173,10 @@ def _test_sdpa_rewriter_int8_1_to_4(self):
173
173
if dtype == torch .bfloat16
174
174
else contextlib .nullcontext ()
175
175
)
176
- inputs = [
176
+ inputs = (
177
177
torch .randn ((bs , 384 , 64 * 16 ), device = self .device , dtype = dtype ),
178
178
torch .randn ((bs , 1 , 1 , 384 ), device = self .device ) if has_mask else None ,
179
- ]
179
+ )
180
180
with torch .no_grad (), maybe_autocast :
181
181
_sfdp_init_int8 ()
182
182
quantizer = X86InductorQuantizer ()
0 commit comments