From eab164e1a5cd7760ee653ffe455e7a10cded7cb2 Mon Sep 17 00:00:00 2001 From: stevenlix <38092805+stevenlix@users.noreply.github.com> Date: Fri, 15 Jan 2021 09:59:56 -0800 Subject: [PATCH] Add python example of TensorRT INT8 inference on ResNet model (#6255) * add trt int8 example on resnet model * Update e2e_tensorrt_resnet_example.py * remove keras dependency and update class names * move ImageNetDataReader and ImageClassificationEvaluator to tensorrt resnet example * simplify e2e_tensorrt_resnet_example.py * Update preprocessing.py * merge tensorrt_calibrate * Update calibrate.py * Update calibrate.py * generalize calibrate * Update calibrate.py * fix issues * fix formating * remove augment_all --- .../e2e_tensorrt_resnet_example.py | 332 ++++++++++++++++++ .../python/tools/quantization/calibrate.py | 64 ++-- .../python/tools/quantization/quant_utils.py | 2 +- 3 files changed, 368 insertions(+), 30 deletions(-) create mode 100644 onnxruntime/python/tools/quantization/E2E_example_model/e2e_tensorrt_resnet_example.py diff --git a/onnxruntime/python/tools/quantization/E2E_example_model/e2e_tensorrt_resnet_example.py b/onnxruntime/python/tools/quantization/E2E_example_model/e2e_tensorrt_resnet_example.py new file mode 100644 index 0000000000000..bfe8a0f53ec38 --- /dev/null +++ b/onnxruntime/python/tools/quantization/E2E_example_model/e2e_tensorrt_resnet_example.py @@ -0,0 +1,332 @@ +import os +import onnx +import glob +import scipy.io +import numpy as np +from PIL import Image +import onnx +import onnxruntime +from onnxruntime.quantization import CalibrationDataReader, calibrate, write_calibration_table + +class ImageNetDataReader(CalibrationDataReader): + def __init__(self, image_folder, + width=224, + height=224, + start_index=0, + end_index=0, + stride=1, + batch_size=1, + model_path='augmented_model.onnx', + input_name='data'): + ''' + :param image_folder: image dataset folder + :param width: image width + :param height: image height + :param start_index: start index of images + :param end_index: end index of images + :param stride: image size of each data get + :param batch_size: batch size of inference + :param model_path: model name and path + :param input_name: model input name + ''' + + self.image_folder = image_folder + "/val" + self.model_path = model_path + self.preprocess_flag = True + self.enum_data_dicts = iter([]) + self.datasize = 0 + self.width = width + self.height = height + self.start_index = start_index + self.end_index = len(os.listdir(self.image_folder)) if end_index == 0 else end_index + self.stride = stride if stride >= 1 else 1 + self.batch_size = batch_size + self.input_name = input_name + + def get_dataset_size(self): + return len(os.listdir(self.image_folder)) + + def get_input_name(self): + if self.input_name: + return + session = onnxruntime.InferenceSession(self.model_path, providers=['CPUExecutionProvider']) + self.input_name = session.get_inputs()[0].name + + def get_next(self): + iter_data = next(self.enum_data_dicts, None) + if iter_data: + return iter_data + + self.enum_data_dicts = None + if self.start_index < self.end_index: + if self.batch_size == 1: + data = self.load_serial() + else: + data = self.load_batches() + + self.start_index += self.stride + self.enum_data_dicts = iter(data) + + return next(self.enum_data_dicts, None) + else: + return None + + def load_serial(self): + width = self.width + height = self.width + nchw_data_list, filename_list, image_size_list = self.preprocess_imagenet(self.image_folder, height, width, self.start_index, self.stride) + input_name = self.input_name + + data = [] + for i in range(len(nchw_data_list)): + nhwc_data = nchw_data_list[i] + file_name = filename_list[i] + data.append({input_name: nhwc_data}) + return data + + def load_batches(self): + width = self.width + height = self.height + batch_size = self.batch_size + stride = self.stride + input_name = self.input_name + + batches = [] + for index in range(0, stride, batch_size): + start_index = self.start_index + index + nchw_data_list, filename_list, image_size_list = self.preprocess_imagenet(self.image_folder, height, width, start_index, batch_size) + + if nchw_data_list.size == 0: + break + + nchw_data_batch = [] + for i in range(len(nchw_data_list)): + nhwc_data = np.squeeze(nchw_data_list[i], 0) + nchw_data_batch.append(nhwc_data) + batch_data = np.concatenate(np.expand_dims(nchw_data_batch, axis=0), axis=0) + data = {input_name: batch_data} + + batches.append(data) + + return batches + + def preprocess_imagenet(self, images_folder, height, width, start_index=0, size_limit=0): + ''' + Loads a batch of images and preprocess them + parameter images_folder: path to folder storing images + parameter height: image height in pixels + parameter width: image width in pixels + parameter start_index: image index to start with + parameter size_limit: number of images to load. Default is 0 which means all images are picked. + return: list of matrices characterizing multiple images + ''' + + def preprocess_images(input, channels=3, height=224, width=224): + image = input.resize((width, height), Image.ANTIALIAS) + input_data = np.asarray(image).astype(np.float32) + if len(input_data.shape) != 2: + input_data = input_data.transpose([2, 0, 1]) + else: + input_data = np.stack([input_data] * 3) + mean = np.array([0.079, 0.05, 0]) + 0.406 + std = np.array([0.005, 0, 0.001]) + 0.224 + for channel in range(input_data.shape[0]): + input_data[channel, :, :] = (input_data[channel, :, :] / 255 - mean[channel]) / std[channel] + return input_data + + image_names = os.listdir(images_folder) + image_names.sort() + if size_limit > 0 and len(image_names) >= size_limit: + end_index = start_index + size_limit + if end_index > len(image_names): + end_index = len(image_names) + batch_filenames = [image_names[i] for i in range(start_index, end_index)] + else: + batch_filenames = image_names + + unconcatenated_batch_data = [] + image_size_list = [] + + for image_name in batch_filenames: + image_filepath = images_folder + '/' + image_name + img = Image.open(image_filepath) + image_data = preprocess_images(img) + image_data = np.expand_dims(image_data, 0) + unconcatenated_batch_data.append(image_data) + image_size_list.append(np.array([img.size[1], img.size[0]], dtype=np.float32).reshape(1, 2)) + + batch_data = np.concatenate(np.expand_dims(unconcatenated_batch_data, axis=0), axis=0) + return batch_data, batch_filenames, image_size_list + + def get_synset_id(self, image_folder, offset, dataset_size): + ilsvrc2012_meta = scipy.io.loadmat(image_folder + "/devkit/data/meta.mat") + id_to_synset = {} + for i in range(1000): + id = int(ilsvrc2012_meta["synsets"][i,0][0][0][0]) + id_to_synset[id] = ilsvrc2012_meta["synsets"][i,0][1][0] + + synset_to_id = {} + file = open(image_folder + "/synset_words.txt","r") + index = 0 + for line in file: + parts = line.split(" ") + synset_to_id[parts[0]] = index + index = index + 1 + file.close() + + file = open(image_folder + "/devkit/data/ILSVRC2012_validation_ground_truth.txt","r") + id = file.read().strip().split("\n") + id = list(map(int, id)) + file.close() + + image_names = os.listdir(image_folder + "/val") + image_names.sort() + image_names = image_names[offset : offset + dataset_size] + seq_num = [] + for file in image_names: + seq_num.append(int(file.split("_")[-1].split(".")[0])) + id = np.array([id[index - 1] for index in seq_num]) + synset_id = np.array([synset_to_id[id_to_synset[index]] for index in id]) + + # one-hot encoding + synset_id_onehot = np.zeros((len(synset_id), 1000), dtype=np.float32) + for i, id in enumerate(synset_id): + synset_id_onehot[i, id] = 1.0 + return synset_id_onehot + +class ImageClassificationEvaluator: + def __init__(self, model_path, synset_id, + data_reader: CalibrationDataReader, + providers=["TensorrtExecutionProvider"] + ): + ''' + :param model_path: ONNX model to validate + :param synset_id: ILSVRC2012 synset id + :param data_reader: user implemented object to read in and preprocess calibration dataset + based on CalibrationDataReader Interface + :param providers: ORT execution provider type + ''' + + self.model_path = model_path + self.data_reader = data_reader + self.providers = providers + self.prediction_result_list = [] + self.synset_id = synset_id + + def get_result(self): + return self.prediction_result_list + + def predict(self): + sess_options = onnxruntime.SessionOptions() + sess_options.log_severity_level = 0 + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL + session = onnxruntime.InferenceSession(self.model_path, sess_options=sess_options, providers=self.providers) + + inference_outputs_list = [] + while True: + inputs = self.data_reader.get_next() + if not inputs: + break + output = session.run(None, inputs) + inference_outputs_list.append(output) + self.prediction_result_list = inference_outputs_list + + def top_k_accuracy(self, truth, prediction, k=1): + '''From https://github.com/chainer/chainer/issues/606 + ''' + + y = np.argsort(prediction)[:,-k:] + return np.any(y.T == truth.argmax(axis=1), axis=0).mean() + + def evaluate(self, prediction_results): + batch_size = len(prediction_results[0][0]) + total_val_images = len(prediction_results) * batch_size + y_prediction = np.empty((total_val_images, 1000), dtype=np.float32) + i = 0 + for res in prediction_results: + y_prediction[i:i + batch_size,:] = res[0] + i = i + batch_size + print("top 1: ", self.top_k_accuracy(self.synset_id, y_prediction, k=1)) + print("top 5: ", self.top_k_accuracy(self.synset_id, y_prediction, k=5)) + +def convert_model_batch_to_dynamic(model_path): + model = onnx.load(model_path) + input = model.graph.input + input_name = input[0].name + shape = input[0].type.tensor_type.shape + dim = shape.dim + if not dim[0].dim_param: + dim[0].dim_param = 'N' + model = onnx.shape_inference.infer_shapes(model) + model_name = model_path.split(".") + model_path = model_name[0] + "_dynamic.onnx" + onnx.save(model, model_path) + return [model_path, input_name] + +def get_dataset_size(dataset_path, calibration_dataset_size): + total_dataset_size = len(os.listdir(dataset_path + "/val")) + if calibration_dataset_size > total_dataset_size: + print("Warning: calibration data size is bigger than available dataset. Will assign half of the dataset for calibration") + calibration_dataset_size = total_dataset_size // 2 + calibration_dataset_size = (calibration_dataset_size // batch_size) * batch_size + if calibration_dataset_size == 0: + print("Warning: No dataset is assigned for calibration. Please use bigger dataset") + + prediction_dataset_size = ((total_dataset_size - calibration_dataset_size) // batch_size) * batch_size + if prediction_dataset_size <= 0: + print("Warning: No dataset is assigned for evaluation. Please use bigger dataset") + return [calibration_dataset_size, prediction_dataset_size] + +if __name__ == '__main__': + ''' + TensorRT EP INT8 Inference on Resnet model + + The script is using ILSVRC2012 ImageNet dataset for calibration and prediction. + Please prepare the dataset as below, + 1. Create dataset folder 'ILSVRC2012' in workspace. + 2. Download ILSVRC2012 validation dataset and development kit from http://www.image-net.org/challenges/LSVRC/2012/downloads. + 3. Extract validation dataset JPEG files to 'ILSVRC2012/val'. + 4. Extract development kit to 'ILSVRC2012/devkit'. Two files in the development kit are used, 'ILSVRC2012_validation_ground_truth.txt' and 'meta.mat'. + 5. Download 'synset_words.txt' from https://github.com/HoldenCaulfieldRye/caffe/blob/master/data/ilsvrc12/synset_words.txt into 'ILSVRC2012/'. + + Please download Resnet50 model from ONNX model zoo https://github.com/onnx/models/blob/master/vision/classification/resnet/model/resnet50-v2-7.tar.gz + Untar the model into the workspace + ''' + + # Dataset settings + model_path = "./resnet50-v2-7.onnx" + ilsvrc2012_dataset_path = "./ILSVRC2012" + augmented_model_path = "./augmented_model.onnx" + batch_size = 20 + calibration_dataset_size = 1000 # Size of dataset for calibration + + # INT8 calibration setting + calibration_table_generation_enable = True # Enable/Disable INT8 calibration + + # TensorRT EP INT8 settings + os.environ["ORT_TENSORRT_FP16_ENABLE"] = "1" # Enable FP16 precision + os.environ["ORT_TENSORRT_INT8_ENABLE"] = "1" # Enable INT8 precision + os.environ["ORT_TENSORRT_INT8_CALIBRATION_TABLE_NAME"] = "calibration.flatbuffers" # Calibration table name + os.environ["ORT_TENSORRT_ENGINE_CACHE_ENABLE"] = "1" # Enable engine caching + execution_provider = ["TensorrtExecutionProvider"] + + # Convert static batch to dynamic batch + [new_model_path, input_name] = convert_model_batch_to_dynamic(model_path) + + # Get calibration and prediction dataset size + [calibration_dataset_size, prediction_dataset_size] = get_dataset_size(ilsvrc2012_dataset_path, calibration_dataset_size) + + # Generate INT8 calibration table + if calibration_table_generation_enable: + data_reader = ImageNetDataReader(ilsvrc2012_dataset_path,start_index=0, end_index=calibration_dataset_size, stride=calibration_dataset_size, batch_size=batch_size, model_path=augmented_model_path, input_name=input_name) + # For TensorRT calibration, augment all FP32 tensors (empty op_types), disable ORT graph optimization and skip quantization parameter calculation + calibration_cache = calibrate(new_model_path, data_reader, op_types=[], providers=["CUDAExecutionProvider"], ort_graph_optimization_enable=False, quantization_params_calculation_enable=False) + write_calibration_table(calibration_cache) + + # Run prediction in Tensorrt EP + data_reader = ImageNetDataReader(ilsvrc2012_dataset_path, start_index=calibration_dataset_size, end_index=calibration_dataset_size + prediction_dataset_size, stride=prediction_dataset_size, batch_size=batch_size, model_path=new_model_path, input_name=input_name) + synset_id = data_reader.get_synset_id(ilsvrc2012_dataset_path, calibration_dataset_size, prediction_dataset_size) # Generate synset id + evaluator = ImageClassificationEvaluator(new_model_path, synset_id, data_reader, providers=execution_provider) + evaluator.predict() + result = evaluator.get_result() + evaluator.evaluate(result) diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index 7f008a1f590ae..24610903520ab 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -43,7 +43,6 @@ def __init__(self, model, data_reader: CalibrationDataReader, calibrate_op_types :param black_nodes: operator names that should not be quantized, default = '' :param white_nodes: operator names that force to be quantized, default = '' :param augmented_model_path: save augmented_model to this path - ''' if isinstance(model, string_types): self.model = onnx.load(model) @@ -66,7 +65,7 @@ def set_data_reader(self, data_reader): def get_calibration_cache(self): return self.calibration_cache - def augment_graph(self, augment_all_ops=False): + def augment_graph(self): ''' Adds ReduceMin and ReduceMax nodes to all quantization_candidates op type nodes in model and ensures their outputs are stored as part of the graph output @@ -84,11 +83,8 @@ def augment_graph(self, augment_all_ops=False): tensors_to_calibrate = set() for node in model.graph.node: - if augment_all_ops: - should_be_calibrate = True - else: - should_be_calibrate = ((node.op_type in self.calibrate_op_types) and - (node.name not in self.black_nodes)) or (node.name in self.white_nodes) + should_be_calibrate = ((node.op_type in self.calibrate_op_types) and + (node.name not in self.black_nodes)) or (node.name in self.white_nodes) or ((not self.calibrate_op_types) and (node.name not in self.black_nodes)) if should_be_calibrate: for tensor_name in itertools.chain(node.input, node.output): if tensor_name in value_infos.keys(): @@ -99,10 +95,10 @@ def augment_graph(self, augment_all_ops=False): # If augmenting all ops, it's possible that some nodes' input value are 0. # Can't reduce on dim with value of 0 if 'keepdims' is false, therefore set keepdims to 1. - if augment_all_ops: - keepdims_value = 1 - else: + if self.calibrate_op_types: keepdims_value = 0 + else: + keepdims_value = 1 for tensor in tensors_to_calibrate: # Adding ReduceMin nodes @@ -129,26 +125,28 @@ def augment_graph(self, augment_all_ops=False): return model #Using augmented outputs to generate inputs for quantization - def get_intermediate_outputs(self, calib_mode='naive', providers=None): + def get_intermediate_outputs(self, calib_mode='naive', providers=None, ort_graph_optimization_enable=True): ''' - Gather intermediate model outputs after running inference - parameter calib_mode: type 'naive' gives (ReduceMin, ReduceMax) pairs - for each augmented node across test data sets, where - the first element is a minimum of all ReduceMin values - and the second element is a maximum of all ReduceMax - values; - :return: dictionary mapping: {added node names: (ReduceMin, ReduceMax) pairs } + Gather intermediate model outputs after running inference + parameter calib_mode: type 'naive' gives (ReduceMin, ReduceMax) pairs + for each augmented node across test data sets, where + the first element is a minimum of all ReduceMin values + and the second element is a maximum of all ReduceMax + values; + parameter providers: Onnxruntime execution providers + parameter ort_graph_optimization_enable: Enable all OnnxRuntime graph optimizations, default = True + :return: dictionary mapping: {added node names: (ReduceMin, ReduceMax) pairs } ''' #conduct inference session and get intermediate outputs - if providers: + if ort_graph_optimization_enable: + session = onnxruntime.InferenceSession(self.augmented_model_path, None) + else: sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL #ORT_ENABLE_BASIC session = onnxruntime.InferenceSession(self.augmented_model_path, sess_options=sess_options, providers=providers) - else: - session = onnxruntime.InferenceSession(self.augmented_model_path, None) #number of outputs in original model num_model_outputs = len(self.model.graph.output) @@ -329,8 +327,8 @@ def calculate_calibration_data(model, calibrator = get_calibrator(model, calibration_data_reader, op_types_to_quantize, + nodes_to_exclude, nodes_to_quantize, - nodes_to_exclude, augmented_model_path=augmented_model_path) if not os.path.exists(augmented_model_path): @@ -365,15 +363,21 @@ def calibrate(model, op_types=['Conv', 'MatMul'], black_nodes=[], white_nodes=[], - augmented_model_path='augmented_model.onnx'): + augmented_model_path='augmented_model.onnx', + providers=["CPUExecutionProvider"], + ort_graph_optimization_enable=True, + quantization_params_calculation_enable=True): ''' - Given an onnx model, augment and run the augmented model on calibration data set, aggregate and calculate the quantization parameters. + Given an onnx model, augment and run the augmented model on calibration data set, aggregate and calculate the quantization parameters. :param model: ONNX model to calibrate. It can be a ModelProto or a model path :param data_reader: user implemented object to read in and preprocess calibration dataset based on CalibrationDataReader interface - :param op_types: operator types to be calibrated and quantized, default = 'Conv,MatMul' + :param op_types: operator types to be calibrated and quantized, default = 'Conv,MatMul'. Empty means to quantize all FP32 tensors (except black_nodes) :param black_nodes: operator names that should not be quantized, default = '' :param white_nodes: operator names that force to be quantized, default = '' :param augmented_model_path: save augmented_model to this path + :param providers: execution providers to run calibration + :param ort_graph_optimization_enable: enable all OnnxRuntime graph optimizations, default = True + :param quantization_params_calculation_enable: enable quantization parameter calculation, default = True ''' #1. initialize a calibrater calibrater = ONNXCalibrater(model, data_reader, op_types, black_nodes, white_nodes, augmented_model_path) @@ -381,9 +385,11 @@ def calibrate(model, augmented_model = calibrater.augment_graph() onnx.save(augmented_model, augmented_model_path) #3. generate quantization thresholds - dict_for_quantization = calibrater.get_intermediate_outputs() + dict_for_quantization = calibrater.get_intermediate_outputs(providers=providers, ort_graph_optimization_enable=ort_graph_optimization_enable) #4. generate quantization parameters dict - quantization_params_dict = calibrater.calculate_quantization_params(dict_for_quantization) - + quantization_params_dict = {} + if quantization_params_calculation_enable: + quantization_params_dict = calibrater.calculate_quantization_params(dict_for_quantization) print("Calibrated,quantized parameters calculated and returned.") - return quantization_params_dict + + return quantization_params_dict if quantization_params_calculation_enable else dict_for_quantization diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 0c2d80ecbc4e8..2b3e430492a4f 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -197,7 +197,7 @@ def write_calibration_table(calibration_cache): import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue - print(calibration_cache) + print("calibration cache: ", calibration_cache) with open("calibration.json", 'w') as file: file.write(json.dumps(calibration_cache)) # use `json.loads` to do the reverse