Skip to content

unwrap_tensor_subclass and nested tensor subclasses issue #515

Closed
@vayuda

Description

@vayuda

I'm noticing strange behavior when trying to create a tensor_subclass which holds another tensor_sub class.
Here is a minified repro: (add this to the bottom of torchao/dtypes/affine_quantized_tensor.py

@dataclass(frozen=True)
class TestLayout(LayoutType):
    scales: torch.Tensor
    zeros: torch.Tensor
    
    def post_process(self, input: torch.Tensor) -> torch.Tensor:
        return PlainAQTLayout.from_plain(input, self.scales, self.zeros, PlainLayoutType())
    
@register_layout_cls(TestLayout)
class TestAQTLayout(PlainAQTLayout):
    def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return self.int_data.get_plain()[0], self.scale, self.zero_point
    
    @classmethod
    def from_plain(
        cls,
        int_data: torch.Tensor,
        scale: torch.Tensor,
        zero_point: torch.Tensor,
        layout_type: LayoutType,
    ):
        assert isinstance(layout_type, TestLayout)
        return cls(int_data, scale, zero_point, layout_type)
    
if __name__ == "__main__":
    from torchao.quantization.quant_api import quantize_, _get_linear_subclass_inserter 
    from torchao.utils import unwrap_tensor_subclass
    def test_quant():
        def apply_test_quant(weight):
            layout_type = TestLayout(torch.tensor([1.0]), torch.tensor([0.0]) )
            mapping_type = MappingType.ASYMMETRIC
            block_size = (1, 2)
            quant_min = 0
            quant_max = 8
            eps = torch.finfo(torch.float32).eps
            zero_point_dtype = torch.int32
            zero_point_domain = ZeroPointDomain.INT
            
            return to_affine_quantized(
                weight, mapping_type, block_size, torch.uint8, quant_min = quant_min,
                quant_max = quant_max, eps = eps, 
                zero_point_dtype=zero_point_dtype,
                zero_point_domain=zero_point_domain,
                layout_type=layout_type,
            )
    
        return _get_linear_subclass_inserter (apply_test_quant)
    class LinearModel(torch.nn.Module):
        def __init__(self):
            super(LinearModel, self).__init__()
            self.linear = torch.nn.Linear(32, 10)        
        def forward(self, x):
            return self.linear(x)
    test_input = torch.randn(32)
    m =LinearModel()
    m.forward(test_input)
    quantize_(m, test_quant())
    m.forward(test_input)
    m = unwrap_tensor_subclass(m)
    m = torch.compile(m, fullgraph=True)
    m.forward(test_input)

When running this code I get the following error when calling the model after compiling:

File "/home/swan/pytorch/ao/min_repro.py", line 199, in dequantize
    int_data, scale, zero_point = self.layout_tensor.get_plain()
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/swan/pytorch/ao/min_repro.py", line 854, in get_plain
    return self.int_data.get_plain()[0], self.scale, self.zero_point
           ^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.TorchRuntimeError: Failed running call_method forward(*(ParametrizedLinear(
  in_features=32, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): UnwrapTensorSubclass()
    )
  )
), FakeTensor(..., size=(32,))), **{}):
'FakeTensor' object has no attribute 'get_plain'

There might be an issue with how unwrap_tensor_subclass handles cases where there are nested tensor_subclasses, but Im not sure why this doesn't work, but AffineQuantizedTensor is able to hold an AQTLayout tensor and work just fine.
You can check out https://github.com/vayuda/ao/blob/intx/torchao/dtypes/affine_quantized_tensor.py#L593 for what Im trying to do.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions