Skip to content

Commit 665dac0

Browse files
authored
Fix unwrap_tensor_subclass (#2062)
init
1 parent 379cb75 commit 665dac0

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

torchao/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@ def unwrap_tensor_subclass(model, filter_fn=None):
314314
and type(child.weight) is not torch.nn.Parameter
315315
and isinstance(child.weight, torch.Tensor)
316316
and issubclass(type(child.weight), torch.Tensor)
317+
and isinstance(child.weight, TorchAOBaseTensor)
318+
and not parametrize.is_parametrized(child)
317319
):
318320
parametrize.register_parametrization(
319321
child, "weight", UnwrapTensorSubclass()

0 commit comments

Comments
 (0)