Open
Description
I encountered this in practice for the EGNN model.
Repro:
import torch
from torch import nn
from torchao.quantization.quant_api import quantize_, int8_weight_only
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(1, 4)
def forward(self, x):
return self.lin(x)
m = MyModule()
quantize_(m, int8_weight_only())
inp = torch.randn(32, 1)
output = m(inp)
print(output.shape)
Stack trace
Traceback (most recent call last):
File ".../scripts/quant_repro.py", line 16, in <module>
quantize_(m, int8_weight_only())
File ".../ao/torchao/quantization/quant_api.py", line 348, in quantize_
_replace_with_custom_fn_if_matches_filter(
File ".../ao/torchao/quantization/quant_api.py", line 187, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../ao/torchao/quantization/quant_api.py", line 183, in _replace_with_custom_fn_if_matches_filter
model = replacement_fn(model)
^^^^^^^^^^^^^^^^^^^^^
File ".../ao/torchao/quantization/quant_api.py", line 279, in insert_subclass
lin.weight = torch.nn.Parameter(constructor(lin.weight), requires_grad=False)
^^^^^^^^^^^^^^^^^^^^^^^
File ".../ao/torchao/quantization/quant_api.py", line 441, in apply_int8wo_quant
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../ao/torchao/dtypes/affine_quantized_tensor.py", line 220, in from_float
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../ao/torchao/quantization/quant_primitives.py", line 197, in quantize_affine
return _quantize_affine(
^^^^^^^^^^^^^^^^^
File ".../torch/_ops.py", line 1117, in __call__
return self._op(*args, **(kwargs or {}))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../ao/torchao/quantization/quant_primitives.py", line 227, in _quantize_affine
return _quantize_affine_no_dtype_cast(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../ao/torchao/quantization/quant_primitives.py", line 257, in _quantize_affine_no_dtype_cast
scale = scale.view(shape_after_reduction)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[4, 1]' is invalid for input of size 1