Closed
Description
I'm noticing strange behavior when trying to create a tensor_subclass which holds another tensor_sub class.
Here is a minified repro: (add this to the bottom of torchao/dtypes/affine_quantized_tensor.py
@dataclass(frozen=True)
class TestLayout(LayoutType):
scales: torch.Tensor
zeros: torch.Tensor
def post_process(self, input: torch.Tensor) -> torch.Tensor:
return PlainAQTLayout.from_plain(input, self.scales, self.zeros, PlainLayoutType())
@register_layout_cls(TestLayout)
class TestAQTLayout(PlainAQTLayout):
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return self.int_data.get_plain()[0], self.scale, self.zero_point
@classmethod
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
layout_type: LayoutType,
):
assert isinstance(layout_type, TestLayout)
return cls(int_data, scale, zero_point, layout_type)
if __name__ == "__main__":
from torchao.quantization.quant_api import quantize_, _get_linear_subclass_inserter
from torchao.utils import unwrap_tensor_subclass
def test_quant():
def apply_test_quant(weight):
layout_type = TestLayout(torch.tensor([1.0]), torch.tensor([0.0]) )
mapping_type = MappingType.ASYMMETRIC
block_size = (1, 2)
quant_min = 0
quant_max = 8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int32
zero_point_domain = ZeroPointDomain.INT
return to_affine_quantized(
weight, mapping_type, block_size, torch.uint8, quant_min = quant_min,
quant_max = quant_max, eps = eps,
zero_point_dtype=zero_point_dtype,
zero_point_domain=zero_point_domain,
layout_type=layout_type,
)
return _get_linear_subclass_inserter (apply_test_quant)
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
self.linear = torch.nn.Linear(32, 10)
def forward(self, x):
return self.linear(x)
test_input = torch.randn(32)
m =LinearModel()
m.forward(test_input)
quantize_(m, test_quant())
m.forward(test_input)
m = unwrap_tensor_subclass(m)
m = torch.compile(m, fullgraph=True)
m.forward(test_input)
When running this code I get the following error when calling the model after compiling:
File "/home/swan/pytorch/ao/min_repro.py", line 199, in dequantize
int_data, scale, zero_point = self.layout_tensor.get_plain()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/swan/pytorch/ao/min_repro.py", line 854, in get_plain
return self.int_data.get_plain()[0], self.scale, self.zero_point
^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.TorchRuntimeError: Failed running call_method forward(*(ParametrizedLinear(
in_features=32, out_features=10, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): UnwrapTensorSubclass()
)
)
), FakeTensor(..., size=(32,))), **{}):
'FakeTensor' object has no attribute 'get_plain'
There might be an issue with how unwrap_tensor_subclass
handles cases where there are nested tensor_subclasses, but Im not sure why this doesn't work, but AffineQuantizedTensor is able to hold an AQTLayout tensor and work just fine.
You can check out https://github.com/vayuda/ao/blob/intx/torchao/dtypes/affine_quantized_tensor.py#L593 for what Im trying to do.
Metadata
Metadata
Assignees
Labels
No labels