Skip to content

Commit

Permalink
Update TensorRT dynamic shape profile when input shape changed during…
Browse files Browse the repository at this point in the history
… runtime (microsoft#3904)

* Update dynamic shape range when input shape changed during runtime

* Update tensorrt_execution_provider.cc

* Update tensorrt_execution_provider.cc
  • Loading branch information
stevenlix authored May 11, 2020
1 parent 6d2d927 commit 28f693a
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -695,9 +695,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
for (unsigned int i = 0, end = trt_network->getNbInputs(); i < end; ++i) {
auto input = trt_network->getInput(i);
nvinfer1::Dims dims = input->getDimensions();
nvinfer1::Dims dims_min = dims;
nvinfer1::Dims dims_opt = dims;
nvinfer1::Dims dims_max = dims;
nvinfer1::Dims dims_min(dims), dims_opt(dims), dims_max(dims);

int nb_dims = dims.nbDims;
if (input->isShapeTensor()) { // Shape tensor
Expand Down Expand Up @@ -851,16 +849,14 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
auto trt_builder = trt_state->builder;
nvinfer1::IOptimizationProfile* trt_profile = nullptr;
for (int i = 0, end = num_binding_inputs; i < end; ++i) {
// TODO: check if getInput indexing is same with binding index
auto input = trt_state->network->getInput(i);
nvinfer1::Dims dims = input->getDimensions();
nvinfer1::Dims dims_min = dims;
nvinfer1::Dims dims_opt = dims;
nvinfer1::Dims dims_max = dims;

// Check and update shape ranges for dynamic shape inputs
auto& shape_ranges = trt_state->input_shape_ranges;
if (shape_ranges.find(i) != shape_ranges.end()) {
// TODO: check if getInput indexing is same with binding index
auto input = trt_state->network->getInput(i);
nvinfer1::Dims dims = input->getDimensions();
nvinfer1::Dims dims_min(dims), dims_opt(dims), dims_max(dims);

const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_indexes[i]);
auto tensor_info = ort.GetTensorTypeAndShape(input_tensor);
const auto& tensor_shape = ort.GetTensorShape(tensor_info);
Expand All @@ -870,10 +866,15 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
for (int j = 0, end = nb_dims; j < end; ++j) {
auto& shape_range = shape_ranges[i];
if (shape_range.find(j) != shape_range.end()) {
dims_min.d[j] = shape_range[j].first;
dims_opt.d[j] = shape_range[j].second;
dims_max.d[j] = shape_range[j].second;

// Update minimum dimension
if (tensor_shape[j] < shape_range[j].first) {
shape_range[j].first = tensor_shape[j];
dims_min.d[j] = tensor_shape[j];
dims_opt.d[j] = tensor_shape[j];
dimension_update = true;
}
// Update maximum dimension
Expand Down

0 comments on commit 28f693a

Please sign in to comment.