Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor TensorRT EP code to better handle dynamic shape subgraphs #4504

Merged
merged 9 commits into from
Jul 15, 2020
Prev Previous commit
Next Next commit
add precision to trt node name
  • Loading branch information
stevenlix committed Jul 15, 2020
commit 9a52a5dc2e7d0ef30f715715143a3e81b52c32d7
20 changes: 12 additions & 8 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -738,10 +738,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
auto trt_parser = tensorrt_ptr::unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
trt_parser->parse(string_buf.data(), string_buf.size());
trt_config->setMaxWorkspaceSize(max_workspace_size_);
if (fp16_enable_ && trt_builder->platformHasFastFp16()) {
trt_config->setFlag(nvinfer1::BuilderFlag::kFP16);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled.";
}

int num_inputs = trt_network->getNbInputs();
int num_outputs = trt_network->getNbOutputs();
Expand Down Expand Up @@ -773,12 +769,19 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
}
}

std::string trt_node_name_with_precision = fused_node->Name();
if (fp16_enable_ && trt_builder->platformHasFastFp16()) {
trt_config->setFlag(nvinfer1::BuilderFlag::kFP16);
trt_node_name_with_precision += "_fp16";
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled.";
}

// Build TRT engine here if the graph doesn't have dynamic shape input. Otherwise engine will
// be built at runtime
tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine> trt_engine;
tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext> trt_context;
if (!has_dynamic_shape) {
std::ifstream planFile(GetEnginePath(engine_cache_path_, fused_node->Name()), std::ios::binary | std::ios::in);
std::ifstream planFile(GetEnginePath(engine_cache_path_, trt_node_name_with_precision), std::ios::binary | std::ios::in);
if (planFile && engine_cache_enable_) {
planFile.seekg(0, std::ios::end);
int engine_size = planFile.tellg();
Expand All @@ -787,7 +790,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
planFile.read((char*)engine_buf.get(), engine_size);
planFile.close();
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + GetEnginePath(engine_cache_path_, fused_node->Name());
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + GetEnginePath(engine_cache_path_, trt_node_name_with_precision);
} else {
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(trt_builder->buildEngineWithConfig(*trt_network, *trt_config));
if (trt_engine == nullptr) {
Expand All @@ -797,10 +800,10 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:

if (engine_cache_enable_) {
nvinfer1::IHostMemory* serializedModel = trt_engine->serialize();
std::ofstream file(GetEnginePath(engine_cache_path_, fused_node->Name()), std::ios::binary | std::ios::out);
std::ofstream file(GetEnginePath(engine_cache_path_, trt_node_name_with_precision), std::ios::binary | std::ios::out);
file.write(reinterpret_cast<char*>(serializedModel->data()), serializedModel->size());
serializedModel->destroy();
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + GetEnginePath(engine_cache_path_, fused_node->Name());
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + GetEnginePath(engine_cache_path_, trt_node_name_with_precision);
}
}
trt_context = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(trt_engine->createExecutionContext());
Expand Down Expand Up @@ -1310,6 +1313,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
}
}

cudaDeviceSynchronize();
for (const auto& binding_index : binding_buffers_to_freeup) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we leaking memory if the enqueueV2() fails above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For FAIL status, would the session run quit or continue to try other EPs? For the latter, we may need to free buffer when FAIL occurs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the session doesn't get destroyed automatically.
can enqueue failure be intermittent? (can there be a success after failure?)
in any case, a user could issue Run() again, or do the fallback manually (create a new session), so
it does seem like it could cause a leak.

cudaFree(buffers[binding_index]);
}
Expand Down