|
| 1 | +/* * Licensed to the Apache Software Foundation (ASF) under one |
| 2 | + * or more contributor license agreements. See the NOTICE file |
| 3 | + * distributed with this work for additional information |
| 4 | + * regarding copyright ownership. The ASF licenses this file |
| 5 | + * to you under the Apache License, Version 2.0 (the |
| 6 | + * "License"); you may not use this file except in compliance |
| 7 | + * with the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, |
| 12 | + * software distributed under the License is distributed on an |
| 13 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | + * KIND, either express or implied. See the License for the |
| 15 | + * specific language governing permissions and limitations |
| 16 | + * under the License. |
| 17 | +
|
| 18 | + * file runtime/contrib/tensorrt/tensorrt_builder.h |
| 19 | + * brief Contains TensorRTBuilder class which can be used to convert a relay |
| 20 | + * program into a TRT engine which can be used for inference. |
| 21 | +*/ |
| 22 | + |
| 23 | +#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_CALIBRATOR_H_ |
| 24 | +#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_CALIBRATOR_H_ |
| 25 | + |
| 26 | +#include <string> |
| 27 | +#include <vector> |
| 28 | + |
| 29 | +#include "../../cuda/cuda_common.h" |
| 30 | +#include "NvInfer.h" |
| 31 | + |
| 32 | +namespace tvm { |
| 33 | +namespace runtime { |
| 34 | + |
| 35 | +class TensorRTCalibrator : public nvinfer1::IInt8EntropyCalibrator2 { |
| 36 | + public: |
| 37 | + TensorRTCalibrator(int batch_size, const std::vector<std::string>& input_names) |
| 38 | + : batch_size_(batch_size), num_batches_calibrated_(0), input_names_(input_names) {} |
| 39 | + |
| 40 | + ~TensorRTCalibrator() { |
| 41 | + // Free calibration data |
| 42 | + for (auto& inputs : data_) { |
| 43 | + for (size_t i = 0; i < inputs.size(); ++i) { |
| 44 | + delete[] inputs[i]; |
| 45 | + } |
| 46 | + } |
| 47 | + // Free buffers |
| 48 | + for (size_t i = 0; i < buffers_.size(); ++i) { |
| 49 | + CUDA_CALL(cudaFree(buffers_[i])); |
| 50 | + } |
| 51 | + } |
| 52 | + |
| 53 | + void AddBatchData(const std::vector<void*>& bindings, const std::vector<size_t>& binding_sizes) { |
| 54 | + // Copy data from GPU |
| 55 | + std::vector<float*> data_host(bindings.size(), nullptr); |
| 56 | + for (size_t i = 0; i < bindings.size(); ++i) { |
| 57 | + data_host[i] = new float[batch_size_ * binding_sizes[i]]; |
| 58 | + CUDA_CALL(cudaMemcpy(static_cast<void*>(data_host[i]), bindings[i], |
| 59 | + batch_size_ * binding_sizes[i] * sizeof(float), cudaMemcpyDeviceToHost)); |
| 60 | + } |
| 61 | + data_.push_back(data_host); |
| 62 | + data_sizes_.push_back(binding_sizes); |
| 63 | + } |
| 64 | + |
| 65 | + int getBatchSize() const override { return batch_size_; } |
| 66 | + |
| 67 | + /*! |
| 68 | + * \brief TensorRT will call this method to get next batch of data to |
| 69 | + * calibrate with. |
| 70 | + */ |
| 71 | + bool getBatch(void* bindings[], const char* names[], int nbBindings) override { |
| 72 | + AllocateBuffersIfNotAllocated(); |
| 73 | + CHECK_EQ(input_names_.size(), nbBindings); |
| 74 | + for (size_t i = 0; i < input_names_.size(); ++i) { |
| 75 | + CHECK_EQ(input_names_[i], names[i]); |
| 76 | + CUDA_CALL(cudaMemcpy(buffers_[i], data_[num_batches_calibrated_][i], |
| 77 | + batch_size_ * data_sizes_[num_batches_calibrated_][i] * sizeof(float), |
| 78 | + cudaMemcpyHostToDevice)); |
| 79 | + bindings[i] = buffers_[i]; |
| 80 | + } |
| 81 | + num_batches_calibrated_++; |
| 82 | + // TODO(trevmorr): Free data from previous batch? |
| 83 | + return (num_batches_calibrated_ < data_.size()); |
| 84 | + } |
| 85 | + |
| 86 | + const void* readCalibrationCache(size_t& length) override { |
| 87 | + if (calibration_cache_.empty()) return nullptr; |
| 88 | + length = calibration_cache_.size(); |
| 89 | + return calibration_cache_.data(); |
| 90 | + } |
| 91 | + |
| 92 | + void writeCalibrationCache(const void* cache, size_t length) override { |
| 93 | + calibration_cache_.assign(static_cast<const char*>(cache), length); |
| 94 | + } |
| 95 | + |
| 96 | + private: |
| 97 | + /*! \brief Batch size. */ |
| 98 | + int batch_size_; |
| 99 | + /*! \brief Number of batches already fed to calibrator. */ |
| 100 | + int num_batches_calibrated_; |
| 101 | + /*! \brief Storage for calibration cache. */ |
| 102 | + std::string calibration_cache_; |
| 103 | + |
| 104 | + /*! \brief Data to be used for calibration. */ |
| 105 | + std::vector<std::vector<float*>> data_; |
| 106 | + /*! \brief Number of elements for data to be used for calibration. */ |
| 107 | + std::vector<std::vector<size_t>> data_sizes_; |
| 108 | + |
| 109 | + /*! \brief Device buffers to be used for calibration. */ |
| 110 | + std::vector<void*> buffers_; |
| 111 | + |
| 112 | + /*! \brief Names of inputs */ |
| 113 | + const std::vector<std::string> input_names_; |
| 114 | + |
| 115 | + /*! \brief Allocate device memory buffers. data_sizes_ must already have one |
| 116 | + * entry. */ |
| 117 | + void AllocateBuffersIfNotAllocated() { |
| 118 | + if (!buffers_.empty()) return; |
| 119 | + CHECK_GE(data_sizes_.size(), 1); |
| 120 | + const int num_inputs = data_sizes_[0].size(); |
| 121 | + buffers_.assign(num_inputs, nullptr); |
| 122 | + for (int i = 0; i < num_inputs; ++i) { |
| 123 | + CUDA_CALL(cudaMalloc(&buffers_[i], data_sizes_[0][i] * sizeof(float))); |
| 124 | + } |
| 125 | + } |
| 126 | +}; |
| 127 | + |
| 128 | +} // namespace runtime |
| 129 | +} // namespace tvm |
| 130 | +#endif // TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_CALIBRATOR_H_ |
0 commit comments