Skip to content

Commit

Permalink
Upgrade to CUDA10.2 for TensorRT (microsoft#3084)
Browse files Browse the repository at this point in the history
* Switch to CUDA10.2

* Update win-gpu-tensorrt-ci-pipeline.yml

* Update win-gpu-tensorrt-ci-pipeline.yml

* remove dynamic_shape

* update onnx-tensorrt submodule

* check if input shape is specified for TensorRT subgraph input and enable some TensorRT unit tests

* fix format issue

* add shape inference instruction for TensorRT

* update according to the reviews

* Update win-gpu-tensorrt-ci-pipeline.yml
  • Loading branch information
stevenlix authored Feb 25, 2020
1 parent d7f2cdc commit f4a5d17
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 65 deletions.
6 changes: 3 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@
[submodule "cmake/external/FeaturizersLibrary"]
path = cmake/external/FeaturizersLibrary
url = https://github.com/microsoft/FeaturizersLibrary.git
[submodule "cmake/external/onnx-tensorrt"]
path = cmake/external/onnx-tensorrt
url = https://github.com/stevenlix/onnx-tensorrt.git
[submodule "cmake/external/SafeInt/safeint"]
path = cmake/external/SafeInt/safeint
url = https://github.com/dcleblanc/SafeInt.git
[submodule "cmake/external/onnx-tensorrt"]
path = cmake/external/onnx-tensorrt
url = https://github.com/stevenlix/onnx-tensorrt.git
2 changes: 1 addition & 1 deletion cmake/external/onnx-tensorrt
5 changes: 4 additions & 1 deletion docs/execution_providers/TensorRT-ExecutionProvider.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ status = session_object.Load(model_file_name);
```
The C API details are [here](../C_API.md#c-api).

#### Shape Inference for TensorRT Subgraphs
If some operators in the model are not supported by TensorRT, ONNX Runtime will partition the graph and only send supported subgraphs to TensorRT execution provider. Because TensorRT requires that all inputs of the subgraphs have shape specified, ONNX Runtime will throw error if there is no input shape info. In this case please run shape inference for the entire model first by running script [here](https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py).

#### Sample
To run Faster R-CNN model on TensorRT execution provider,
This example shows how to run Faster R-CNN model on TensorRT execution provider,

First, download Faster R-CNN onnx model from onnx model zoo [here](https://github.com/onnx/models/tree/master/vision/object_detection_segmentation/faster-rcnn).

Expand Down
92 changes: 48 additions & 44 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
const std::string fp16_enable_env = env_instance.GetEnvironmentVar(tensorrt_env_vars::kFP16Enable);
if (!fp16_enable_env.empty()) {
fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true);
}
}
}

TensorrtExecutionProvider::~TensorrtExecutionProvider() {}
Expand Down Expand Up @@ -184,8 +184,8 @@ bool FindCycleHelper(int i, const std::list<int>* adjacency_map,
// Remove nodes with empty shape (for example [1, 0]) because TensorRT 7 doens't support empty shape
SubGraphCollection_t RemoveEmptyShapeNodes(const onnxruntime::GraphViewer& graph) {
// Here only NonZero and NonMaxSuppression related empty shape nodes are removed, particularly for Faster-rcnn and Mask-rcnn models.
// TODO: Remove the code if TensorRT fixed the issue in the future release, or find a better generic way here to work around
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder();
// TODO: Remove the code if TensorRT fixed the issue in the future release, or find a better generic way here to work around
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder();
const std::string exclude_dim_name1 = "NonZero";
const std::string exclude_dim_name2 = "NonMaxSuppression";
SubGraphCollection_t parser_nodes_vector = {{{}, false}};
Expand All @@ -202,8 +202,8 @@ SubGraphCollection_t RemoveEmptyShapeNodes(const onnxruntime::GraphViewer& graph
std::string dim_name = dim.dim_param();
if (!dim_name.empty()) {
if ((dim_name.find(exclude_dim_name1) != std::string::npos) || (dim_name.find(exclude_dim_name2) != std::string::npos)) {
exclude_node = true;
break;
exclude_node = true;
break;
}
}
}
Expand All @@ -216,9 +216,9 @@ SubGraphCollection_t RemoveEmptyShapeNodes(const onnxruntime::GraphViewer& graph

// Remove the node with empty input shape
if (!exclude_node) {
parser_nodes_vector.back().first.push_back(index);
parser_nodes_vector.back().first.push_back(index);
} else if (!parser_nodes_vector.back().first.empty()) {
parser_nodes_vector.push_back({{},false});
parser_nodes_vector.push_back({{}, false});
}
}

Expand Down Expand Up @@ -407,6 +407,18 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect

ORT_ENFORCE(graph_build.Resolve().IsOK());

// Check if input tensors have shapes
if (iterations > 1) {
for (const auto* input_arg : graph_build.GetInputs()) {
if (input_arg->Shape() == nullptr) {
ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"TensorRT input: " + input_arg->Name() + " has no shape specified. " +
"Please run shape inference on the onnx model first. Details can be found in " +
"https://github.com/microsoft/onnxruntime/blob/master/docs/execution_providers/TensorRT-ExecutionProvider.md#shape-inference-for-tensorrt-subgraphs"));
}
}
}

// Serialize modelproto to string
const onnxruntime::GraphViewer graph_viewer(graph_build);

Expand Down Expand Up @@ -453,10 +465,10 @@ void TensorrtExecutionProvider::RemoveTensorRTGraphCycles(SubGraphCollection_t&
std::unordered_map<int, std::string> index_to_node_map;
std::unordered_map<std::string, std::unordered_set<std::string>> input_to_nodes_map, node_to_outputs_map;
std::unordered_set<int> non_trt_node_index(node_index.begin(), node_index.end());
int counter = 0, id = 0;
int counter = 0, id = 0;
for (const auto& group : supported_nodes_vector) {
if (!group.first.empty()) {
// Construct subgraph from node list
// Construct subgraph from node list
std::unique_ptr<IndexedSubGraph> sub_graph = GetSubGraph(group, counter, graph);

// Create node to inputs/outputs/index maps
Expand All @@ -468,23 +480,23 @@ void TensorrtExecutionProvider::RemoveTensorRTGraphCycles(SubGraphCollection_t&
}

if (meta_def != nullptr) {
for (const auto& input: meta_def->inputs) {
for (const auto& input : meta_def->inputs) {
input_to_nodes_map[input].insert(node_name);
}
for (const auto& output: meta_def->outputs) {
for (const auto& output : meta_def->outputs) {
node_to_outputs_map[node_name].insert(output);
}
}

// Remove TensorRT nodes from node index list
for (const auto& index: group.first) {
for (const auto& index : group.first) {
non_trt_node_index.erase(node_index[index]);
}
}
}

// Add non TensorRT nodes to the maps
for (const auto& index: non_trt_node_index) {
for (const auto& index : non_trt_node_index) {
const auto& node = graph.GetNode(index);
std::string node_name = node->Name();
if (node_to_index_map.find(node_name) == node_to_index_map.end()) {
Expand All @@ -503,13 +515,13 @@ void TensorrtExecutionProvider::RemoveTensorRTGraphCycles(SubGraphCollection_t&

// Create adjacency list
int graph_size = node_to_index_map.size();
std::list<int> *adjacency_map = new std::list<int>[graph_size];
for (const auto& node: node_to_outputs_map) {
std::list<int>* adjacency_map = new std::list<int>[graph_size];
for (const auto& node : node_to_outputs_map) {
for (auto iter = node.second.begin(); iter != node.second.end(); ++iter) {
const auto& loc = input_to_nodes_map.find(*iter);
if (loc != input_to_nodes_map.end()) {
int parent_node_index = node_to_index_map.find(node.first)->second;
for (auto child_node: loc->second) {
for (auto child_node : loc->second) {
int child_node_index = node_to_index_map.find(child_node)->second;
adjacency_map[parent_node_index].push_back(child_node_index);
}
Expand All @@ -518,8 +530,8 @@ void TensorrtExecutionProvider::RemoveTensorRTGraphCycles(SubGraphCollection_t&
}

// Check cycle in the graph
bool *visited = new bool[graph_size];
bool *st = new bool[graph_size];
bool* visited = new bool[graph_size];
bool* st = new bool[graph_size];
for (int i = 0; i < graph_size; ++i) {
visited[i] = false;
st[i] = false;
Expand All @@ -529,19 +541,19 @@ void TensorrtExecutionProvider::RemoveTensorRTGraphCycles(SubGraphCollection_t&
bool has_cycle = false;
for (int i = 0; i < graph_size; ++i) {
if (FindCycleHelper(i, adjacency_map, visited, st, cycles)) {
has_cycle = true;
break;
has_cycle = true;
break;
}
}

// Remove TensorRT subgraph from the supported node list if it's part of the cycle
// Remove TensorRT subgraph from the supported node list if it's part of the cycle
if (has_cycle) {
for (int i = 0; i < static_cast<int>(cycles.size()); ++i) {
auto loc = index_to_node_map.find(cycles[i]);
if (loc != index_to_node_map.end() && loc->second.find("TRTKernel") != std::string::npos) {
int trt_node_index = std::stoi(loc->second.substr(10));
supported_nodes_vector.erase(supported_nodes_vector.begin() + trt_node_index);
trt_cycle = true;
trt_cycle = true;
break;
}
}
Expand Down Expand Up @@ -587,7 +599,7 @@ TensorrtExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
result.push_back(onnxruntime::make_unique<ComputeCapability>(std::move(sub_graph)));
}
}

return result;
}

Expand Down Expand Up @@ -660,19 +672,23 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
trt_profile->setShapeValues(input->getName(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], nb_dims);
trt_profile->setShapeValues(input->getName(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], nb_dims);
} else { // Execution tensor
bool is_dynamic_shape = false;
for (int j = 0, end = nb_dims; j < end; ++j) {
// For dynamic shape subgraph, a dummy engine is created at compile phase.
// Real engine will be created at compute phase based on input data
if (dims.d[j] == -1) { // Dynamic shape
dims_min.d[j] = 1;
dims_opt.d[j] = 1;
dims_max.d[j] = 1;
is_dynamic_shape = true;
}
}
// TRT6: Optimization profile need to be provided for all inputs if any of them has dynamic shape
trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMIN, dims_min);
trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kOPT, dims_opt);
trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMAX, dims_max);

if (is_dynamic_shape) {
trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMIN, dims_min);
trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kOPT, dims_opt);
trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMAX, dims_max);
}
}
}

Expand Down Expand Up @@ -764,8 +780,8 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
*p = {context->allocate_func, context->release_func, context->allocator_handle, parsers_[context->node_name].get(),
engines_[context->node_name].get(), contexts_[context->node_name].get(), builders_[context->node_name].get(),
networks_[context->node_name].get(), input_info_[context->node_name], output_info_[context->node_name],
input_shape_ranges_[context->node_name], output_shapes_[context->node_name], &tensorrt_mu_, &fp16_enable_,
&max_workspace_size_};
input_shape_ranges_[context->node_name], output_shapes_[context->node_name], &tensorrt_mu_, &fp16_enable_,
&max_workspace_size_};
*state = p.release();
return 0;
};
Expand All @@ -790,14 +806,9 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
int total_bindings = num_binding_inputs + num_binding_outputs;
std::vector<void*> buffers(total_bindings);

bool dynamic_shape = false;
auto trt_context = trt_state->context;
if (!trt_context->allInputDimensionsSpecified() || !trt_context->allInputShapesSpecified()) {
dynamic_shape = true;
}

// Update shape ranges
bool dimension_update = false;
auto trt_context = trt_state->context;
auto trt_builder = trt_state->builder;
nvinfer1::IOptimizationProfile* trt_profile = nullptr;
for (int i = 0, end = num_binding_inputs; i < end; ++i) {
Expand Down Expand Up @@ -857,20 +868,13 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
}
}
}

// TensorRT6 requires optimization profile to be defined for all inputs if any input dimension is symbolic
if (dimension_update && dynamic_shape) {
trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMIN, dims_min);
trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kOPT, dims_opt);
trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMAX, dims_max);
}
}

// Regenerate engine and context
// Only one profile is generated, so no need to explicitly set optimization profile
if (dimension_update) {
auto trt_config = unique_pointer<nvinfer1::IBuilderConfig>(trt_builder->createBuilderConfig());
trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr));
trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr));
trt_config->addOptimizationProfile(trt_profile);
if (*(trt_state->fp16_enable_ptr) && trt_builder->platformHasFastFp16()) {
trt_config->setFlag(nvinfer1::BuilderFlag::kFP16);
Expand Down Expand Up @@ -985,4 +989,4 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:

return Status::OK();
}
} // namespace onnxruntime
} // namespace onnxruntime
21 changes: 7 additions & 14 deletions onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ TEST(TensorOpTest, SpaceToDepthTest_1) {
1.1f, 1.3f,
3.1f, 3.3f};
test.AddOutput<float>("output", {N, C * blocksize * blocksize, H / blocksize, W / blocksize}, result);
// TensorRT has error: Expected output shape [{1,8,1,2}] did not match run output shape [{8,1,1,2}] for output
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
test.Run();
}

TEST(TensorOpTest, SpaceToDepthTest_2) {
Expand Down Expand Up @@ -70,8 +69,7 @@ TEST(TensorOpTest, SpaceToDepthTest_2) {
98., 101., 66., 69., 84., 87., 102., 105., 67., 70., 85.,
88., 103., 106., 68., 71., 86., 89., 104., 107.};
test.AddOutput<float>("output", {2, 27, 1, 2}, result);
// TensorRT has error: Expected output shape [{2,27,1,2}] did not match run output shape [{54,1,1,2}] for output
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
test.Run();
}

TEST(TensorOpTest, DepthToSpaceTest_1) {
Expand Down Expand Up @@ -102,8 +100,7 @@ TEST(TensorOpTest, DepthToSpaceTest_1) {
2.0f, 2.1f, 2.2f, 2.3f,
3.0f, 3.1f, 3.2f, 3.3f};
test.AddOutput<float>("output", {N, C / (blocksize * blocksize), H * blocksize, W * blocksize}, result);
// TensorRT output shape mismatches
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
test.Run();
}

TEST(TensorOpTest, DepthToSpaceTest_2) {
Expand Down Expand Up @@ -146,8 +143,7 @@ TEST(TensorOpTest, DepthToSpaceTest_2) {
122., 140., 123., 141., 88., 106., 89., 107., 124., 142., 125.,
143.};
test.AddOutput<float>("output", {2, 3, 6, 4}, result);
// TensorRT output shape mismatches
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
test.Run();
}

TEST(TensorOpTest, DepthToSpaceTest_3) {
Expand Down Expand Up @@ -190,8 +186,7 @@ TEST(TensorOpTest, DepthToSpaceTest_3) {
122., 140., 123., 141., 88., 106., 89., 107., 124., 142., 125.,
143.};
test.AddOutput<float>("output", {2, 3, 6, 4}, result);
// TensorRT output shape mismatches
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
test.Run();
}

TEST(TensorOpTest, DepthToSpaceTest_4) {
Expand Down Expand Up @@ -235,8 +230,7 @@ TEST(TensorOpTest, DepthToSpaceTest_4) {
122., 140., 123., 141., 88., 106., 89., 107., 124., 142., 125.,
143.};
test.AddOutput<float>("output", {2, 3, 6, 4}, result);
// TensorRT output shape mismatches
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
test.Run();
}

TEST(TensorOpTest, DepthToSpaceTest_5) {
Expand All @@ -263,8 +257,7 @@ TEST(TensorOpTest, DepthToSpaceTest_5) {
21., 30., 22., 31., 23., 32.};

test.AddOutput<float>("output", {1, 1, 4, 6}, result);
// TensorRT output shape mismatches
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
test.Run();
}

} // namespace test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
displayName: 'Generate cmake config'
inputs:
scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py'
arguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 16 2019" --msvc_toolset 14.16 --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="C:\local\TensorRT-7.0.0.11.cuda-10.0.cudnn7.6\TensorRT-7.0.0.11" --cuda_version=10.0 --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0" --cudnn_home="C:\local\cudnn-10.0-windows10-x64-v7.6.5.32\cuda" --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0'
arguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 16 2019" --msvc_toolset 14.16 --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="C:\local\TensorRT-7.0.0.11.cuda-10.2.cudnn7.6\TensorRT-7.0.0.11" --cuda_version=10.2 --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.2" --cudnn_home="C:\local\cudnn-10.2-windows10-x64-v7.6.5.32\cuda" --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0'
workingDirectory: '$(Build.BinariesDirectory)'

- task: VSBuild@1
Expand Down Expand Up @@ -81,7 +81,7 @@ jobs:
del wheel_filename_file
python.exe -m pip install -q --upgrade %WHEEL_FILENAME%
set PATH=$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig);%PATH%
python $(Build.SourcesDirectory)\tools\ci_build\build.py --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 16 2019" --msvc_toolset 14.16 --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="C:\local\TensorRT-7.0.0.11.cuda-10.0.cudnn7.6\TensorRT-7.0.0.11" --cuda_version=10.0 --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0" --cudnn_home="C:\local\cudnn-10.0-windows10-x64-v7.6.5.32\cuda" --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0
python $(Build.SourcesDirectory)\tools\ci_build\build.py --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 16 2019" --msvc_toolset 14.16 --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="C:\local\TensorRT-7.0.0.11.cuda-10.2.cudnn7.6\TensorRT-7.0.0.11" --cuda_version=10.2 --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.2" --cudnn_home="C:\local\cudnn-10.2-windows10-x64-v7.6.5.32\cuda" --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0
workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)'
displayName: 'Run tests'
Expand Down

0 comments on commit f4a5d17

Please sign in to comment.