Open
Description
🐞Describing the bug
- The issue arises when using
coremltools.optimize.torch.quantization
to quantize a model withw8a8
forConv3d
layers. The quantization process does not work properly, causing errors or incorrect behavior when converting the model to Core ML format.
Stack Trace
Traceback (most recent call last):
File "/Users/silveryu/Developer/lightsvd/diffuserskit/torch_quant.py", line 72, in <module>
ct.convert(traced_model,
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/_converters_entry.py", line 635, in convert
mlmodel = mil_convert(
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 188, in mil_convert
return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 212, in _mil_convert
proto, mil_program = mil_convert_to_proto(
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 288, in mil_convert_to_proto
prog = frontend_converter(model, **kwargs)
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 108, in __call__
return load(*args, **kwargs)
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 88, in load
return _perform_torch_convert(converter, debug)
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 151, in _perform_torch_convert
prog = converter.convert()
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 1383, in convert
self.convert_const()
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 1251, in convert_const
self._add_const(name, val)
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 1206, in _add_const
compression_op = self._construct_compression_op(val.detach().numpy(), name)
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 1189, in _construct_compression_op
result = self._construct_quantization_op(val, compression_info, param_name, result)
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 953, in _construct_quantization_op
raise ValueError(
ValueError: In conv1.weight, the `weight` should have same rank as `scale`, but got (1, 1, 3, 3, 3) vs (1, 1)
To Reproduce
import tqdm
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from coremltools.optimize.torch.quantization import (
LinearQuantizer,
LinearQuantizerConfig,
ModuleLinearQuantizerConfig,
)
import coremltools as ct
def torch_quantization(
model: torch.nn.Module,
sample_dataloader: DataLoader,
) -> torch.nn.Module:
config = LinearQuantizerConfig(
global_config=ModuleLinearQuantizerConfig(
weight_dtype="qint8",
activation_dtype="quint8",
quantization_scheme="symmetric",
milestones=[0, 1000, 1000, 0],
)
)
quantizer = LinearQuantizer(model, config)
example_inputs = next(iter(sample_dataloader))
quantizer.prepare(example_inputs=example_inputs, inplace=True)
quantizer.step()
# Do a forward pass through the model with calibration data
for data in tqdm.tqdm(sample_dataloader, desc="calibrating"):
with torch.no_grad():
model(data)
quantized_model = quantizer.finalize()
return quantized_model
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv3d(1, 1, 3)
def forward(self, x):
x = self.conv1(x)
return x
class sample_dataset(Dataset):
def __init__(self):
super().__init__()
self.len = 10
def __len__(self):
return self.len
def __getitem__(self, item):
return torch.randn(1, 12, 224, 224, device="mps", dtype=torch.float32)
test_model = SimpleNet().to("mps")
quantized_model = torch_quantization(test_model, DataLoader(sample_dataset()))
with torch.no_grad():
traced_model = torch.jit.trace(quantized_model.to("cpu"), torch.randn(1, 12, 244, 244, dtype=torch.float32))
_ = traced_model(torch.randn(1, 12, 224, 224, dtype=torch.float32))
ct.convert(traced_model,
inputs=[ct.TensorType(name="input", shape=(3, 244, 244))],
convert_to="mlprogram",
minimum_deployment_target=ct.target.iOS18)
- This code triggers the issue when performing quantization on a model with a
Conv3d
layer usingw8a8
quantization.
System environment (please complete the following information):
- Coremltools: 8.1
- OS: macOS 15.3
- PyTorch: 2.4.0
- Python: 3.10.15
Additional context
- The issue appears specifically when attempting to quantize a model with
Conv3d
. Other layers, likeConv2d
, seem to work fine.