-
Notifications
You must be signed in to change notification settings - Fork 80
Description
Issue Type
Others
Source
pip (model-compression-toolkit)
MCT Version
2.4.2
OS Platform and Distribution
Linux Ubuntu 22.04 x86_64
Python version
3.12
Describe the issue
I'm trying to quantize the RF-DETR model
To adapt the model to be compatible with the supported MCT ops I forked the repo and pushed the needed change to reproduce the error on a specific branch ( https://github.com/StijnWoestenborghs/rf-detr/commits/feature/torch_tracing_mcp/ )
At this point I made sure the RFDETRNano model is torch fx tracable. And it looks like the MCT graph can be build. given some dummy representative dataset (Note that I've currenlty hardcoded the shape of the representative dataset in the model as well during torch fx tracing in order for the tracing to work. Fixed it to (5, 3, 384, 384) with 384x384 being the input tensor resolution of the RFDETRNano model)
I'm using the following pyproject.toml setup
[project]
name = "rf-detr-imx500"
version = "0.1.0"
description = "RF-DETR-IMX500"
readme = "README.md"
requires-python = "==3.12.*"
dependencies = [
"pillow",
"mct-quantizers",
"onnxruntime~=1.17.0",
"sony-custom-layers==0.2.0",
"onnxruntime-extensions",
"model-compression-toolkit>=2.2.2",
"rfdetr",
]
[tool.uv.sources]
rfdetr = { path = "../rf-detr", editable = true }
where "../rf-detr" points to my forked rf-detr repo.
And made sure that inferencing the an inference call on the model still working as expected. (Both for gpu and cpu)
from PIL import Image
from rfdetr import RFDETRBase, RFDETRLarge, RFDETRNano, RFDETRSmall, RFDETRMedium
from rfdetr.util.coco_classes import COCO_CLASSES
model = RFDETRNano()
# model = RFDETRNano(device='cpu')
model.optimize_for_inference()
image = Image.open("./assets/coco_samples/785.jpg")
detections = model.predict(image, threshold=0.5)
labels = [
f"{COCO_CLASSES[class_id]} {confidence:.2f}"
for class_id, confidence
in zip(detections.class_id, detections.confidence)
]
# annotation
import supervision as sv
annotated_image = image.copy()
annotated_image = sv.BoxAnnotator().annotate(annotated_image, detections)
annotated_image = sv.LabelAnnotator().annotate(annotated_image, detections, labels)
annotated_image.show()
The problem occurs when trying to quantize the model (with IMX500 as a target platform in mind)
The quantization script (see below) seems to be running when increasing the python recursion limit, up to
RuntimeError: Failed to calculate activation memory cuts for graph.
I also noticed that by default the quantization script is starting some GPU process, even though I'm specifying model = RFDETRNano(device="cpu"). I verified both just inferencing and just fx tracing the model is listening to the device argument. In order to make sure this didn't have an effect on quantization, I also made sure to run export CUDA_VISIBLE_DEVICES= before running the quantization, which is resulting in exactly the same error.
Expected behaviour
Quantization starts
Code to reproduce the issue
import numpy as np
import torch
import model_compression_toolkit as mct
from rfdetr import RFDETRBase, RFDETRLarge, RFDETRNano, RFDETRSmall, RFDETRMedium
from rfdetr.util.misc import NestedTensor, nested_tensor_from_tensor_list
print(mct.__version__)
import sys
print(sys.getrecursionlimit())
sys.setrecursionlimit(5000)
print(sys.getrecursionlimit())
def load_model():
model = RFDETRNano(device='cpu')
model.optimize_for_inference()
return model.model.model
# Here we just return RANDOM data
def get_representative_dataset(n_iter: int, shape):
def representative_dataset():
for _ in range(n_iter):
yield [torch.rand(shape)]
return representative_dataset
def quantization(model, n_iter, shape):
print(f"Starting quantization with shape: {shape}")
tpc = mct.get_target_platform_capabilities(fw_name="pytorch",
target_platform_name='imx500',
target_platform_version='v1')
# Preform post training quantization
"""
quant_model, _ = mct.ptq.keras_post_training_quantization(
model,
representative_data_gen=get_representative_dataset(n_iter),
target_platform_capabilities=tpc)
"""
mp_config = mct.core.MixedPrecisionQuantizationConfig(num_of_images=5)
config = mct.core.CoreConfig(mixed_precision_config=mp_config,
quantization_config=mct.core.QuantizationConfig(shift_negative_activation_correction=True,
concat_threshold_update=True))
# Define target Resource Utilization for mixed precision weights quantization (76% of 'standard' 8bits quantization)
# see also https://github.com/sony/model_optimization/blob/main/tutorials/notebooks/mct_features_notebooks/pytorch/example_pytorch_mixed_precision_ptq.ipynb
# and https://github.com/sony/model_optimization/blob/main/tutorials/notebooks/mct_features_notebooks/keras/example_keras_pruning_mnist.ipynb
resource_utilization_data = mct.core.pytorch_resource_utilization_data(model,
get_representative_dataset(n_iter, shape),
config,
target_platform_capabilities=tpc)
resource_utilization = mct.core.ResourceUtilization(resource_utilization_data.weights_memory * 0.8)
# Perform post training quantization
quant_model, _ = mct.ptq.pytorch_post_training_quantization(model,
get_representative_dataset(n_iter, shape),
target_resource_utilization=resource_utilization,
core_config=config,
target_platform_capabilities=tpc)
print('Quantized model is ready')
return quant_model
def main():
save_model_path = 'models/qmodel_rfdetr_nano.onnx'
model = load_model()
N_ITER = 1
BATCH_SIZE = 5 # Try with batch size 1 first
SHAPE = (BATCH_SIZE, 3, 384, 384) # Here you must update for your model
quant_model = quantization(model, N_ITER, SHAPE)
mct.exporter.pytorch_export_model(
model=quant_model,
save_model_path=save_model_path,
repr_dataset=get_representative_dataset(N_ITER, SHAPE)
)
if __name__ == '__main__':
main()Log output
Traceback (most recent call last):
File "/home/SONY/s1000328194/git/rf-detr-imx500/quant.py", line 112, in <module>
main()
File "/home/SONY/s1000328194/git/rf-detr-imx500/quant.py", line 98, in main
quant_model = quantization(model, N_ITER, SHAPE)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/SONY/s1000328194/git/rf-detr-imx500/quant.py", line 69, in quantization
resource_utilization_data = mct.core.pytorch_resource_utilization_data(model,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/SONY/s1000328194/git/rf-detr-imx500/.venv/lib/python3.12/site-packages/model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py", line 92, in pytorch_resource_utilization_data
return compute_resource_utilization_data(in_model,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/SONY/s1000328194/git/rf-detr-imx500/.venv/lib/python3.12/site-packages/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py", line 66, in compute_resource_utilization_data
return ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantizedNonFused, BitwidthMode.QDefaultSP)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/SONY/s1000328194/git/rf-detr-imx500/.venv/lib/python3.12/site-packages/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py", line 214, in compute_resource_utilization
a_total, a_per_cut, _ = self.compute_activations_utilization(target_criterion, bitwidth_mode, act_qcs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/SONY/s1000328194/git/rf-detr-imx500/.venv/lib/python3.12/site-packages/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py", line 345, in compute_activations_utilization
return self.compute_activation_utilization_by_cut(target_criterion, bitwidth_mode, act_qcs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/SONY/s1000328194/git/rf-detr-imx500/.venv/lib/python3.12/site-packages/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py", line 398, in compute_activation_utilization_by_cut
for cut in self.cuts:
^^^^^^^^^
File "/home/SONY/s1000328194/git/rf-detr-imx500/.venv/lib/python3.12/site-packages/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py", line 149, in cuts
raise RuntimeError("Failed to calculate activation memory cuts for graph.")
RuntimeError: Failed to calculate activation memory cuts for graph.