Skip to content

Shape mismatch error for int8_weight_only quantization when input is of shape (N, 1) #761

Open
@jbschlosser

Description

@jbschlosser

I encountered this in practice for the EGNN model.

Repro:

import torch
from torch import nn
from torchao.quantization.quant_api import quantize_, int8_weight_only


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(1, 4)

    def forward(self, x):
        return self.lin(x)


m = MyModule()
quantize_(m, int8_weight_only())
inp = torch.randn(32, 1)
output = m(inp)
print(output.shape)
Stack trace
Traceback (most recent call last):
  File ".../scripts/quant_repro.py", line 16, in <module>
    quantize_(m, int8_weight_only())
  File ".../ao/torchao/quantization/quant_api.py", line 348, in quantize_
    _replace_with_custom_fn_if_matches_filter(
  File ".../ao/torchao/quantization/quant_api.py", line 187, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../ao/torchao/quantization/quant_api.py", line 183, in _replace_with_custom_fn_if_matches_filter
    model = replacement_fn(model)
            ^^^^^^^^^^^^^^^^^^^^^
  File ".../ao/torchao/quantization/quant_api.py", line 279, in insert_subclass
    lin.weight = torch.nn.Parameter(constructor(lin.weight), requires_grad=False)
                                    ^^^^^^^^^^^^^^^^^^^^^^^
  File ".../ao/torchao/quantization/quant_api.py", line 441, in apply_int8wo_quant
    return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../ao/torchao/dtypes/affine_quantized_tensor.py", line 220, in from_float
    int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../ao/torchao/quantization/quant_primitives.py", line 197, in quantize_affine
    return _quantize_affine(
           ^^^^^^^^^^^^^^^^^
  File ".../torch/_ops.py", line 1117, in __call__
    return self._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../ao/torchao/quantization/quant_primitives.py", line 227, in _quantize_affine
    return _quantize_affine_no_dtype_cast(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../ao/torchao/quantization/quant_primitives.py", line 257, in _quantize_affine_no_dtype_cast
    scale = scale.view(shape_after_reduction)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[4, 1]' is invalid for input of size 1

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions