Skip to content

AssertionError During Quantization of torch.empty_like(), torch.ones_like, and torch.randn_like #2146

Open
@defaultd661

Description

@defaultd661

🐛 Describe the bug

Similar to #146621, when quantizing a model containing a torch.empty_like(), torch.ones_like, and torch.randn_like operations using PT2E (prepare_pt2e), the process fails with an assertion error inside _maybe_insert_input_observers_for_node. The root cause is that torch.empty_like(), torch.ones_like, and torch.randn_like have kwargs, but currently the code assumes that most aten ops (except a few listed ones) should not have kwargs.

torch.empty_like

import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer, QuantizationSpec
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer.utils import _annotate_input_qspec_map, _annotate_output_qspec

class EmptyLikeModule(torch.nn.Module):

    def forward(self, t: torch.Tensor):
        return torch.empty_like(t)

class TestQuantizer(Quantizer):

    def annotate(self, model: torch.fx.GraphModule) ->torch.fx.GraphModule:
        qspec = QuantizationSpec(torch.int8, HistogramObserver, qscheme=
            torch.per_tensor_symmetric)
        for node in model.graph.nodes:
            for input_node in node.all_input_nodes:
                _annotate_input_qspec_map(node, input_node, qspec)
            _annotate_output_qspec(node, qspec)
        return model

    def validate(self, model: torch.fx.GraphModule) ->None:
        pass
    
def test_bug():
    model = EmptyLikeModule()
    exported_model = torch.export.export(model, (torch.randn(10),))
    prepared_model = prepare_pt2e(exported_model.graph_module,
        TestQuantizer())

if __name__ == '__main__':
    test_bug()

torch.ones_like

import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer, QuantizationSpec
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer.utils import _annotate_input_qspec_map, _annotate_output_qspec

class OnesLikeModule(torch.nn.Module):

    def forward(self, t: torch.Tensor):
        return torch.ones_like(t)

class TestQuantizer(Quantizer):

    def annotate(self, model: torch.fx.GraphModule) ->torch.fx.GraphModule:
        qspec = QuantizationSpec(torch.int8, HistogramObserver, qscheme=
            torch.per_tensor_symmetric)
        for node in model.graph.nodes:
            for input_node in node.all_input_nodes:
                _annotate_input_qspec_map(node, input_node, qspec)
            _annotate_output_qspec(node, qspec)
        return model

    def validate(self, model: torch.fx.GraphModule) ->None:
        pass

def test_bug():
    model = OnesLikeModule()
    exported_model = torch.export.export(model, (torch.randn(10),))
    prepared_model = prepare_pt2e(exported_model.graph_module,
        TestQuantizer())

if __name__ == '__main__':
    test_bug()

torch.randn_like

import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer, QuantizationSpec
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer.utils import _annotate_input_qspec_map, _annotate_output_qspec

class RandnLikeModule(torch.nn.Module):

    def forward(self, t: torch.Tensor):
        return torch.randn_like(t)

class TestQuantizer(Quantizer):

    def annotate(self, model: torch.fx.GraphModule) ->torch.fx.GraphModule:
        qspec = QuantizationSpec(torch.int8, HistogramObserver, qscheme=
            torch.per_tensor_symmetric)
        for node in model.graph.nodes:
            for input_node in node.all_input_nodes:
                _annotate_input_qspec_map(node, input_node, qspec)
            _annotate_output_qspec(node, qspec)
        return model

    def validate(self, model: torch.fx.GraphModule) ->None:
        pass

def test_bug():
    model = RandnLikeModule()
    exported_model = torch.export.export(model, (torch.randn(10),))
    prepared_model = prepare_pt2e(exported_model.graph_module,
        TestQuantizer())

if __name__ == '__main__':
    test_bug()

Versions

PyTorch version: 2.7.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel @msaroufim

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions