Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add python example of TensorRT INT8 inference on ResNet model #6255

Merged
merged 18 commits into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
generalize calibrate
  • Loading branch information
stevenlix committed Jan 12, 2021
commit d4141c25116f8bcb97334222938dfa27a6ee07df
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
Please download Resnet50 model from ONNX model zoo https://github.com/onnx/models/blob/master/vision/classification/resnet/model/resnet50-v2-7.tar.gz
Untar the model into the workspace
'''

# Dataset settings
model_path = "./resnet50-v2-7.onnx"
ilsvrc2012_dataset_path = "./ILSVRC2012"
Expand All @@ -319,7 +319,8 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
# Generate INT8 calibration table
if calibration_table_generation_enable:
data_reader = ImageNetDataReader(ilsvrc2012_dataset_path,start_index=0, end_index=calibration_dataset_size, stride=calibration_dataset_size, batch_size=batch_size, model_path=augmented_model_path, input_name=input_name)
calibration_cache = calibrate(new_model_path, data_reader, providers=["CUDAExecutionProvider"], tensorrt_calibration=True)
# For TensorRT calibration, augment all FP32 tensors, disable ORT graph optimization and skip quantization parameter calculation
calibration_cache = calibrate(new_model_path, data_reader, augment_all=True, providers=["CUDAExecutionProvider"], ort_graph_optimization_enable=False, quantization_params_calculation_enable=False)
write_calibration_table(calibration_cache)

# Run prediction in Tensorrt EP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_calibration_table(model_path, augmented_model_path, calibration_dataset)
for i in range(0, total_data_size, stride):
data_reader = YoloV3DataReader(calibration_dataset,start_index=start_index, end_index=start_index+stride, stride=stride, batch_size=1, model_path=augmented_model_path)
calibrator.set_data_reader(data_reader)
generate_calibration_table(calibrator, model_path, augmented_model_path, False, data_reader)
generate_calibration_table(calibrator, model_path, augmented_model_path, False, data_reader, True)
start_index += stride


Expand All @@ -56,7 +56,7 @@ def get_calibration_table(model_path, augmented_model_path, calibration_dataset)
# data_reader = YoloV3VisionDataReader(calibration_dataset, width=512, height=288, stride=1000, batch_size=20, model_path=augmented_model_path)
# data_reader = YoloV3VisionDataReader(calibration_dataset, width=608, height=384, stride=1000, batch_size=20, model_path=augmented_model_path)
# calibrator.set_data_reader(data_reader)
# generate_calibration_table(calibrator, model_path, augmented_model_path, True, data_reader)
# generate_calibration_table(calibrator, model_path, augmented_model_path, True, data_reader, True)

write_calibration_table(calibrator.get_calibration_cache())
print('calibration table generated and saved.')
Expand Down
65 changes: 37 additions & 28 deletions onnxruntime/python/tools/quantization/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_next(self) -> dict:
raise NotImplementedError

class ONNXCalibrater:
def __init__(self, model_path, data_reader: CalibrationDataReader, calibrate_op_types, black_nodes, white_nodes,
def __init__(self, model_path, data_reader: CalibrationDataReader, calibrate_op_types, black_nodes, white_nodes, augment_all,
augmented_model_path):
'''
:param model_path: ONNX model to calibrate
Expand All @@ -40,14 +40,15 @@ def __init__(self, model_path, data_reader: CalibrationDataReader, calibrate_op_
:param op_types: operator types to be calibrated and quantized, default = 'Conv,MatMul'
:param black_nodes: operator names that should not be quantized, default = ''
:param white_nodes: operator names that force to be quantized, default = ''
:param augment_all: augment all FP32 activation tensors, default = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yufenglee , what's your take on this api? i think we are trying to provide flexibility to user to specify which nodes to quantize.
so far, the api supports by op type, by node name, now we want to give option by data type /all option.
op_types, black_nodes, white_nodes arguments all work and are consistent with each other.
But i'm not sure about augment_all.
meaning, in some sense if you see augment_all, op_types, black_nodes etc. all carry no meaning and should be ignored?
is there a better way to make it all consistent with each other?

Copy link
Member

@jywu-msft jywu-msft Jan 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if explicit empty list or None means augment all?
i.e. empty op_types, empty black_nodes, and empty white_nodes , or None for each.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything empty means no augmentation. We could add all nodes to the white_nodes list for TRT, but it requires user to get graph's node list first before calibration.

Copy link
Member

@jywu-msft jywu-msft Jan 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't have to mean no augmentation. the default could mean to augment everything, unless you explicitly specify which ops or node names to include/exclude.
we can define the semantics and document it clearly.
i'm not sure if it makes sense for empty to mean no augmentation. why would the user call the api if there is nothing to do.

:param augmented_model_path: save augmented_model to this path

'''
self.model_path = model_path
self.data_reader = data_reader
self.calibrate_op_types = calibrate_op_types
self.black_nodes = black_nodes
self.white_nodes = white_nodes
self.augment_all = augment_all
self.augmented_model_path = augmented_model_path
self.input_name_to_nodes = {}
self.calibration_cache = {} # save temporary calibration table
Expand All @@ -59,7 +60,7 @@ def set_data_reader(self, data_reader):
def get_calibration_cache(self):
return self.calibration_cache

def augment_graph(self, augment_all_ops=False):
def augment_graph(self):
'''
Adds ReduceMin and ReduceMax nodes to all quantization_candidates op type nodes in
model and ensures their outputs are stored as part of the graph output
Expand All @@ -77,7 +78,7 @@ def augment_graph(self, augment_all_ops=False):
tensors_to_calibrate = set()

for node in model.graph.node:
if augment_all_ops:
if self.augment_all:
should_be_calibrate = True
else:
should_be_calibrate = ((node.op_type in self.calibrate_op_types) and
Expand All @@ -92,7 +93,7 @@ def augment_graph(self, augment_all_ops=False):

# If augmenting all ops, it's possible that some nodes' input value are 0.
# Can't reduce on dim with value of 0 if 'keepdims' is false, therefore set keepdims to 1.
if augment_all_ops:
if self.augment_all:
keepdims_value = 1
else:
keepdims_value = 0
Expand Down Expand Up @@ -122,24 +123,26 @@ def augment_graph(self, augment_all_ops=False):
return model

#Using augmented outputs to generate inputs for quantization
def get_intermediate_outputs(self, calib_mode='naive', providers=None, tensorrt_calibration=False):
def get_intermediate_outputs(self, calib_mode='naive', providers=None, ort_graph_optimization_enable=True):
'''
Gather intermediate model outputs after running inference
parameter calib_mode: type 'naive' gives (ReduceMin, ReduceMax) pairs
for each augmented node across test data sets, where
the first element is a minimum of all ReduceMin values
and the second element is a maximum of all ReduceMax
values;
:return: dictionary mapping: {added node names: (ReduceMin, ReduceMax) pairs }
Gather intermediate model outputs after running inference
parameter calib_mode: type 'naive' gives (ReduceMin, ReduceMax) pairs
for each augmented node across test data sets, where
the first element is a minimum of all ReduceMin values
and the second element is a maximum of all ReduceMax
values;
parameter providers: Onnxruntime execution providers
parameter ort_graph_optimization_enable: Enable all OnnxRuntime graph optimizations, default = True
:return: dictionary mapping: {added node names: (ReduceMin, ReduceMax) pairs }
'''

#conduct inference session and get intermediate outputs
if tensorrt_calibration:
if ort_graph_optimization_enable:
session = onnxruntime.InferenceSession(self.augmented_model_path, None)
else:
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL #ORT_ENABLE_BASIC
session = onnxruntime.InferenceSession(self.augmented_model_path, sess_options=sess_options, providers=providers)
else:
session = onnxruntime.InferenceSession(self.augmented_model_path, None)

#number of outputs in original model
model = onnx.load(self.model_path)
Expand Down Expand Up @@ -295,9 +298,10 @@ def get_calibrator(model_path,
op_types=['Conv', 'MatMul'],
black_nodes=[],
white_nodes=[],
augment_all=False,
augmented_model_path='augmented_model.onnx'):

calibrator = ONNXCalibrater(model_path, data_reader, op_types, black_nodes, white_nodes, augmented_model_path)
calibrator = ONNXCalibrater(model_path, data_reader, op_types, black_nodes, white_nodes, augment_all, augmented_model_path)

return calibrator

Expand All @@ -308,6 +312,7 @@ def calculate_calibration_data(model_path,
activation_type=QuantType.QUInt8,
nodes_to_quantize=[],
nodes_to_exclude=[],
augment_all=False,
augmented_model_path='augmented_model.onnx'):

if activation_type != QuantType.QUInt8:
Expand All @@ -320,23 +325,23 @@ def calculate_calibration_data(model_path,
print("augmented model path: %s" % augmented_model_path)

if not calibrator:
calibrator = get_calibrator(model_path, calibration_data_reader, op_types_to_quantize, nodes_to_quantize, nodes_to_exclude, augmented_model_path=augmented_model_path)
calibrator = get_calibrator(model_path, calibration_data_reader, op_types_to_quantize, nodes_to_exclude, nodes_to_quantize, augment_all, augmented_model_path=augmented_model_path)

if not os.path.exists(augmented_model_path):
augmented_model = calibrator.augment_graph(augment_all_ops=True)
onnx.save(augmented_model, augmented_model_path)

calibrator.get_intermediate_outputs(providers=["CUDAExecutionProvider"])

def generate_calibration_table(calibrator, model_path, augmented_model_path, remove_previous_flag, data_reader, calibration_dataset=None, stride=5000, batch_size=20):
def generate_calibration_table(calibrator, model_path, augmented_model_path, remove_previous_flag, data_reader, augment_all, calibration_dataset=None, stride=5000, batch_size=20):

if remove_previous_flag and os.path.exists(augmented_model_path):
os.remove(augmented_model_path)
print("remove previously generated %s and start to generate a new one." % (augmented_model_path))

if not calibrator:
calibrator = get_calibrator(model_path, data_reader, augmented_model_path=augmented_model_path)
calculate_calibration_data(model_path, calibrator, augmented_model_path=augmented_model_path)
calibrator = get_calibrator(model_path, data_reader, augment_all=augment_all, augmented_model_path=augmented_model_path)
calculate_calibration_data(model_path, calibrator, augment_all=augment_all, augmented_model_path=augmented_model_path)

return calibrator.get_calibration_cache()

Expand All @@ -345,31 +350,35 @@ def calibrate(model_path,
op_types=['Conv', 'MatMul'],
black_nodes=[],
white_nodes=[],
augment_all=False,
augmented_model_path='augmented_model.onnx',
providers=["CPUExecutionProvider"],
tensorrt_calibration=False):
ort_graph_optimization_enable=True,
quantization_params_calculation_enable=True):
'''
Given an onnx model, augment and run the augmented model on calibration data set, aggregate and calculate the quantization parameters.
:param model_path: ONNX model to calibrate
:param data_reader: user implemented object to read in and preprocess calibration dataset based on CalibrationDataReader interface
:param op_types: operator types to be calibrated and quantized, default = 'Conv,MatMul'
:param black_nodes: operator names that should not be quantized, default = ''
:param white_nodes: operator names that force to be quantized, default = ''
param augment_all: augment all FP32 activation tensors, default = False
:param augmented_model_path: save augmented_model to this path
:param providers: execution providers to run calibration
:tensorrt_calibration: TensorRT calibration. Quantization parameter calculation will be skipped
:param ort_graph_optimization_enable: enable all OnnxRuntime graph optimizations, default = True
:param quantization_params_calculation_enable: enable quantization parameter calculation, default = True
'''
#1. initialize a calibrater
calibrater = ONNXCalibrater(model_path, data_reader, op_types, black_nodes, white_nodes, augmented_model_path)
calibrater = ONNXCalibrater(model_path, data_reader, op_types, black_nodes, white_nodes, augment_all, augmented_model_path)
#2. augment
augmented_model = calibrater.augment_graph()
onnx.save(augmented_model, augmented_model_path)
#3. generate quantization thresholds
dict_for_quantization = calibrater.get_intermediate_outputs(providers=providers, tensorrt_calibration=tensorrt_calibration)
dict_for_quantization = calibrater.get_intermediate_outputs(providers=providers, ort_graph_optimization_enable=ort_graph_optimization_enable)
#4. generate quantization parameters dict
quantization_params_dict = {}
if not tensorrt_calibration:
quantization_params_dict = {}
if quantization_params_calculation_enable:
quantization_params_dict = calibrater.calculate_quantization_params(dict_for_quantization)
print("Calibrated,quantized parameters calculated and returned.")

return dict_for_quantization if tensorrt_calibration else quantization_params_dict
return dict_for_quantization if quantization_params_calculation_enable else quantization_params_dict