@@ -147,55 +147,30 @@ def _check_common(
147
147
):
148
148
self .assertEqual (arg1 .grad , arg2 .grad , atol = atol , rtol = rtol )
149
149
150
- iter_n = 20
151
- with torch .profiler .profile (
152
- activities = [torch .profiler .ProfilerActivity .CPU ],
153
- schedule = torch .profiler .schedule (wait = 2 , warmup = iter_n , active = 20 ),
154
- ) as prof :
155
- for _ in range (iter_n + 22 ):
156
- r = compiled_model (* (args2 + dropout_arg ))
157
- prof .step ()
158
- print (prof .key_averages ().table (sort_by = "self_cpu_time_total" ))
159
-
160
150
@skipIfRocm
161
151
@config .patch ({"freezing" : True })
162
- def _test_sdpa_rewriter_int8_1_to_4 (self ):
152
+ def _test_sdpa_int8_rewriter (self ):
163
153
# pattern is different for bs=1
164
154
for dtype , has_mask , bs in itertools .product (
165
155
[torch .float32 , torch .bfloat16 ], [True , False ], [56 , 1 ]
166
156
):
167
157
seqlen , numhead , headsize = 197 , 16 , 64
168
- # dtype = torch.bfloat16
169
- # has_mask = True
170
- # is_bs_1 = 0
171
- # if is_bs_1:
172
- # candidates = [[1, 384, 16, 64], [1, 197, 12, 64]]
173
- # else:
174
- # candidates = [[120, 384, 16, 64], [224, 197, 12, 64]]
175
- # candidates = [[120, 384, 16, 64]]
176
- # for bs, seqlen, numhead, headsize in candidates:
177
158
mod = SelfAttnLikeModule (
178
159
input_dim = headsize * numhead ,
179
160
has_mask = has_mask ,
180
161
num_attention_heads = numhead ,
181
162
attention_head_size = headsize ,
182
163
).eval ()
183
- maybe_autocast = (
184
- torch .cpu .amp .autocast ()
185
- if dtype == torch .bfloat16
186
- else contextlib .nullcontext ()
187
- )
188
- print ("\n TEST shape" , bs , numhead , seqlen , headsize )
189
164
inputs = (
190
165
torch .randn (
191
166
(bs , seqlen , headsize * numhead ), device = self .device , dtype = dtype
192
- )
193
- * 10 ,
167
+ ) * 10 ,
194
168
torch .randn ((bs , 1 , 1 , seqlen ), device = self .device ) * 10
195
169
if has_mask
196
170
else None ,
197
171
)
198
- with torch .no_grad (), maybe_autocast :
172
+ enable_autocast = (dtype == torch .bfloat16 )
173
+ with torch .no_grad (), torch .amp .autocast ("cpu" , enabled = enable_autocast , dtype = torch .bfloat16 ):
199
174
_int8_sdpa_init ()
200
175
quantizer = X86InductorQuantizer ()
201
176
quantizer .set_global (xiq .get_default_x86_inductor_quantization_config ())
@@ -217,7 +192,7 @@ def _test_sdpa_rewriter_int8_1_to_4(self):
217
192
if HAS_CPU :
218
193
class SDPAPatternRewriterCpuTests (TestSDPAPatternRewriterTemplate ):
219
194
device = "cpu"
220
- test_sdpa_rewriter_int8_1_to_4_cpu = TestSDPAPatternRewriterTemplate ._test_sdpa_rewriter_int8_1_to_4
195
+ test_sdpa_int8_rewriter_cpu = TestSDPAPatternRewriterTemplate ._test_sdpa_int8_rewriter
221
196
222
197
if __name__ == "__main__" :
223
198
if IS_LINUX :
0 commit comments