Skip to content

Commit

Permalink
Update TensorRT parser (microsoft#3650)
Browse files Browse the repository at this point in the history
* update onnx-tensorrt submodule

* add more model dumping point

* update trt kernel name and docker readme file

* fix minior issues

* fix format issue

* update onnx-tensorrt submodule

Co-authored-by: stevenlix <stevenlix>
  • Loading branch information
stevenlix authored Apr 24, 2020
1 parent 939d036 commit 2ab78c5
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion dockerfiles/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ Use `docker pull` with any of the images and tags below to pull an image and try
```

## TensorRT
**Ubuntu 18.04, CUDA 10.1.243, TensorRT 6.0.1**
**Ubuntu 18.04, CUDA 10.2, TensorRT 7.0.0**

1. Build the docker image from the Dockerfile in this repository.
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph

// Assign inputs and outputs to subgraph's meta_def
auto meta_def = onnxruntime::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>();
meta_def->name = "TRTKernel_" + std::to_string(kernels_index++);
meta_def->name = "TRTKernel_" + graph.Name() + "_" + std::to_string(kernels_index++);
meta_def->domain = kMSDomain;

for (const auto& input : inputs) {
Expand Down Expand Up @@ -454,6 +454,12 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
std::string string_buf;
model_proto.SerializeToString(&string_buf);

if (dump_subgraphs_) {
// Dump TensorRT subgraph for debugging if enabled via ORT_TENSORRT_DUMP_SUBGRAPHS env variable.
std::fstream dump("TensorrtExecutionProvider_TRT_Subgraph.onnx", std::ios::out | std::ios::trunc | std::ios::binary);
model_proto.SerializeToOstream(&dump);
}

// Get supported node list recursively
SubGraphCollection_t parser_nodes_list;
TensorrtLogger& trt_logger = GetTensorrtLogger();
Expand Down

0 comments on commit 2ab78c5

Please sign in to comment.