Skip to content

Commit ecc66e4

Browse files
committed
update
1 parent 84ceeb6 commit ecc66e4

File tree

2 files changed

+5
-31
lines changed

2 files changed

+5
-31
lines changed

test/prototype/inductor/test_int8_sdpa_fusion.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -147,55 +147,30 @@ def _check_common(
147147
):
148148
self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol)
149149

150-
iter_n = 20
151-
with torch.profiler.profile(
152-
activities=[torch.profiler.ProfilerActivity.CPU],
153-
schedule=torch.profiler.schedule(wait=2, warmup=iter_n, active=20),
154-
) as prof:
155-
for _ in range(iter_n + 22):
156-
r = compiled_model(*(args2 + dropout_arg))
157-
prof.step()
158-
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
159-
160150
@skipIfRocm
161151
@config.patch({"freezing": True})
162-
def _test_sdpa_rewriter_int8_1_to_4(self):
152+
def _test_sdpa_int8_rewriter(self):
163153
# pattern is different for bs=1
164154
for dtype, has_mask, bs in itertools.product(
165155
[torch.float32, torch.bfloat16], [True, False], [56, 1]
166156
):
167157
seqlen, numhead, headsize = 197, 16, 64
168-
# dtype = torch.bfloat16
169-
# has_mask = True
170-
# is_bs_1 = 0
171-
# if is_bs_1:
172-
# candidates = [[1, 384, 16, 64], [1, 197, 12, 64]]
173-
# else:
174-
# candidates = [[120, 384, 16, 64], [224, 197, 12, 64]]
175-
# candidates = [[120, 384, 16, 64]]
176-
# for bs, seqlen, numhead, headsize in candidates:
177158
mod = SelfAttnLikeModule(
178159
input_dim=headsize * numhead,
179160
has_mask=has_mask,
180161
num_attention_heads=numhead,
181162
attention_head_size=headsize,
182163
).eval()
183-
maybe_autocast = (
184-
torch.cpu.amp.autocast()
185-
if dtype == torch.bfloat16
186-
else contextlib.nullcontext()
187-
)
188-
print("\nTEST shape", bs, numhead, seqlen, headsize)
189164
inputs = (
190165
torch.randn(
191166
(bs, seqlen, headsize * numhead), device=self.device, dtype=dtype
192-
)
193-
* 10,
167+
) * 10,
194168
torch.randn((bs, 1, 1, seqlen), device=self.device) * 10
195169
if has_mask
196170
else None,
197171
)
198-
with torch.no_grad(), maybe_autocast:
172+
enable_autocast = (dtype == torch.bfloat16)
173+
with torch.no_grad(), torch.amp.autocast("cpu", enabled=enable_autocast, dtype=torch.bfloat16):
199174
_int8_sdpa_init()
200175
quantizer = X86InductorQuantizer()
201176
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
@@ -217,7 +192,7 @@ def _test_sdpa_rewriter_int8_1_to_4(self):
217192
if HAS_CPU:
218193
class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate):
219194
device = "cpu"
220-
test_sdpa_rewriter_int8_1_to_4_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_int8_1_to_4
195+
test_sdpa_int8_rewriter_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_int8_rewriter
221196

222197
if __name__ == "__main__":
223198
if IS_LINUX:

torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def _register_int8_sdpa_pattern(pattern):
4545
extra_check=_is_valid_int8_sdpa_pattern(),
4646
)
4747
def int8_sdpa(match: Match, *args, **kwargs):
48-
print("\n***hit int8_sdpa_pattern***\n")
4948
query = kwargs["query"]
5049
key = kwargs["key"]
5150
value = kwargs["value"]

0 commit comments

Comments
 (0)