Skip to content

Commit

Permalink
Add workaround to recover the perf for quantized vit in torch.compile (
Browse files Browse the repository at this point in the history
…#926)

Add temporary workaround to recover the perf for quantized vit under torch.compile

Summary:
Recently we found a perf drop in quantized vit due to #898 (comment)
This PR add a temp fix until we figure out the longer term fix.

I think ideally we should figure out why the tensor subclass check failed in torch.compile (https://github.com/pytorch/pytorch/blob/e4d294221b140fdbb49a64f297bc60c9fcc2f80e/torch/nn/modules/activation.py#L1286) and fix that

Test Plan:
python tutorials/quantize_vit/run_vit_b_quant.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored and jainapurva committed Sep 25, 2024
1 parent 1074f9f commit 96cee95
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tutorials/quantize_vit/run_vit_b_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(model)

# temporary workaround to recover the perf with quantized model under torch.compile
torch.backends.mha.set_fastpath_enabled(False)

model = torch.compile(model, mode='max-autotune')

# Must run with no_grad when optimizing for inference
Expand Down

0 comments on commit 96cee95

Please sign in to comment.