Skip to content

Skip calling unwrap_tensor_subclass for torch 2.7+ #1531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
TORCH_VERSION_AT_LEAST_2_7,
benchmark_model,
is_fbcode,
is_sm_at_least_90,
Expand Down Expand Up @@ -1749,7 +1750,10 @@ def test_autoquant_min_sqnr(self, device, dtype):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.")
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
@unittest.skip("AOTI tests are failing right now")
@unittest.skip(
"AOTI tests are failing right now, repro by commenting out the skip and run:"
"python test/integration/test_integration.py -k TestAOTI.test_aoti_06"
)
class TestAOTI(unittest.TestCase):
@parameterized.expand(
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
Expand Down Expand Up @@ -1792,7 +1796,8 @@ def forward(self, x):
model(x)

api(model)
unwrap_tensor_subclass(model)
if not TORCH_VERSION_AT_LEAST_2_7:
unwrap_tensor_subclass(model)

# running model
model(x)
Expand All @@ -1802,7 +1807,7 @@ def forward(self, x):

example_inputs = (x,)
torch._inductor.aoti_compile_and_package(
torch.export.export(model, example_inputs, strict=True), example_inputs
torch.export.export(model, example_inputs, strict=True)
)


Expand Down Expand Up @@ -1851,7 +1856,8 @@ def forward(self, x):
model(x)

api(model)
unwrap_tensor_subclass(model)
if not TORCH_VERSION_AT_LEAST_2_7:
unwrap_tensor_subclass(model)

# running model
ref = model(x)
Expand Down
11 changes: 4 additions & 7 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,9 @@ for module, name in model.named_modules():
module.weight = nn.Parameter(to_linear_activation_quantized(module.weight, input_quant_func))
```

#### Workaround with `unwrap_tensor_subclass` for `export`, `AOTI` and `torch.compile` (pytorch 2.4 and before only)
The model/tensor subclass should also be compatible with AOTI and torch.export, currently we can support
`torch.export.export` and `torch.aot_compile` with the following workaround:
#### Workaround with `unwrap_tensor_subclass` for `export`, `AOTI` and `torch.compile`

If you are using pytorch 2.6 or before, you need to call `unwrap_tensor_subclass` before `torch.export.export` and `aot_compile`:
```
from torchao.utils import unwrap_tensor_subclass
m_unwrapped = unwrap_tensor_subclass(m)
Expand All @@ -311,10 +311,7 @@ m = torch.export.export(m_unwrapped, example_inputs).module()
torch._export.aot_compile(m_unwrapped, example_inputs)
```

For `torch.compile`, if you are using pytorch nightly or pytorch 2.5+, you won't need to use `unwrap_tensor_subclass` in order to be compatible with `torch.compile`,
but if you use 2.4 or before, you'll need to use `unwrap_tensor_subclass` as well to be able to run `torch.compile` on the quantized model.

Note that the workaround will not be needed after https://github.com/pytorch/pytorch/issues/129682 is fixed.
If you are using pytorch 2.4 or before, you'll also need `unwrap_tensor_subclass` before calling `torch.compile` as well.

Note that the workaround is also required for `torch.compile` with `freezing` (`torch._inductor.config.freezing=True`) until https://github.com/pytorch/pytorch/pull/136265 is fixed.

Expand Down
2 changes: 2 additions & 0 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"TORCH_VERSION_AT_LEAST_2_4",
"TORCH_VERSION_AT_LEAST_2_5",
"TORCH_VERSION_AT_LEAST_2_6",
"TORCH_VERSION_AT_LEAST_2_7",
# Needs to be deprecated in the future
"TORCH_VERSION_AFTER_2_2",
"TORCH_VERSION_AFTER_2_3",
Expand Down Expand Up @@ -367,6 +368,7 @@ def torch_version_at_least(min_version):
return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0


TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0")
TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0")
TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0")
TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0")
Expand Down
Loading