Open
Description
🐛 Describe the bug
Similar to #146621, when quantizing a model containing a torch.empty_like()
, torch.ones_like
, and torch.randn_like
operations using PT2E (prepare_pt2e)
, the process fails with an assertion error inside _maybe_insert_input_observers_for_node
. The root cause is that torch.empty_like()
, torch.ones_like
, and torch.randn_like
have kwargs
, but currently the code assumes that most aten
ops (except a few listed ones) should not have kwargs
.
torch.empty_like
import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer, QuantizationSpec
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer.utils import _annotate_input_qspec_map, _annotate_output_qspec
class EmptyLikeModule(torch.nn.Module):
def forward(self, t: torch.Tensor):
return torch.empty_like(t)
class TestQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) ->torch.fx.GraphModule:
qspec = QuantizationSpec(torch.int8, HistogramObserver, qscheme=
torch.per_tensor_symmetric)
for node in model.graph.nodes:
for input_node in node.all_input_nodes:
_annotate_input_qspec_map(node, input_node, qspec)
_annotate_output_qspec(node, qspec)
return model
def validate(self, model: torch.fx.GraphModule) ->None:
pass
def test_bug():
model = EmptyLikeModule()
exported_model = torch.export.export(model, (torch.randn(10),))
prepared_model = prepare_pt2e(exported_model.graph_module,
TestQuantizer())
if __name__ == '__main__':
test_bug()
torch.ones_like
import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer, QuantizationSpec
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer.utils import _annotate_input_qspec_map, _annotate_output_qspec
class OnesLikeModule(torch.nn.Module):
def forward(self, t: torch.Tensor):
return torch.ones_like(t)
class TestQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) ->torch.fx.GraphModule:
qspec = QuantizationSpec(torch.int8, HistogramObserver, qscheme=
torch.per_tensor_symmetric)
for node in model.graph.nodes:
for input_node in node.all_input_nodes:
_annotate_input_qspec_map(node, input_node, qspec)
_annotate_output_qspec(node, qspec)
return model
def validate(self, model: torch.fx.GraphModule) ->None:
pass
def test_bug():
model = OnesLikeModule()
exported_model = torch.export.export(model, (torch.randn(10),))
prepared_model = prepare_pt2e(exported_model.graph_module,
TestQuantizer())
if __name__ == '__main__':
test_bug()
torch.randn_like
import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer, QuantizationSpec
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer.utils import _annotate_input_qspec_map, _annotate_output_qspec
class RandnLikeModule(torch.nn.Module):
def forward(self, t: torch.Tensor):
return torch.randn_like(t)
class TestQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) ->torch.fx.GraphModule:
qspec = QuantizationSpec(torch.int8, HistogramObserver, qscheme=
torch.per_tensor_symmetric)
for node in model.graph.nodes:
for input_node in node.all_input_nodes:
_annotate_input_qspec_map(node, input_node, qspec)
_annotate_output_qspec(node, qspec)
return model
def validate(self, model: torch.fx.GraphModule) ->None:
pass
def test_bug():
model = RandnLikeModule()
exported_model = torch.export.export(model, (torch.randn(10),))
prepared_model = prepare_pt2e(exported_model.graph_module,
TestQuantizer())
if __name__ == '__main__':
test_bug()
Versions
PyTorch version: 2.7.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel @msaroufim