Skip to content

Commit

Permalink
Fix TensorRT kernel conflict issue for subgraphs of control flow oper…
Browse files Browse the repository at this point in the history
…ators (microsoft#6115)

* add static subgraph kernel index

* change kernel naming to avoid conflicts
  • Loading branch information
stevenlix authored Dec 16, 2020
1 parent 0978d2b commit aa49e47
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_map<s
}
return true;
}

} // namespace

namespace google {
Expand Down Expand Up @@ -500,7 +501,7 @@ void ToGraphProtoInternal(const GraphViewer& graph, Provider_GraphProto& graph_p
}
}

std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, int& kernels_index, const GraphViewer& graph) const {
std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, const GraphViewer& graph) const {
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder();
std::unordered_set<size_t> node_set;
node_set.reserve(graph_nodes_index.first.size());
Expand Down Expand Up @@ -605,7 +606,7 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
// Assign inputs and outputs to subgraph's meta_def
auto meta_def = IndexedSubGraph_MetaDef::Create();
const std::string graph_type = graph.IsSubgraph() ? "subgraph" : "graph";
meta_def->name() = "TRTKernel_" + graph_type + "_" + graph.Name() + "_" + std::to_string(kernels_index++);
meta_def->name() = "TRTKernel_" + graph_type + "_" + graph.Name() + "_" + std::to_string(subgraph_id_++);
meta_def->domain() = kMSDomain;

for (const auto& input : inputs) {
Expand Down Expand Up @@ -771,11 +772,11 @@ void TensorrtExecutionProvider::RemoveTensorRTGraphCycles(SubGraphCollection_t&
std::unordered_map<int, std::string> index_to_node_map;
std::unordered_map<std::string, std::unordered_set<std::string>> input_to_nodes_map, node_to_outputs_map;
std::unordered_set<int> non_trt_node_index(node_index.begin(), node_index.end());
int counter = 0, id = 0;
int id = 0;
for (const auto& group : supported_nodes_vector) {
if (!group.first.empty()) {
// Construct subgraph from node list
std::unique_ptr<IndexedSubGraph> sub_graph = GetSubGraph(group, counter, graph);
std::unique_ptr<IndexedSubGraph> sub_graph = GetSubGraph(group, graph);

// Create node to inputs/outputs/index maps
const auto& meta_def = sub_graph->GetMetaDef();
Expand Down Expand Up @@ -901,10 +902,10 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,

// Construct subgraph capability from node list
std::vector<std::unique_ptr<ComputeCapability>> result;
int counter = 0, number_of_trt_nodes = 0;
int number_of_trt_nodes = 0;
for (const auto& group : supported_nodes_vector) {
if (!group.first.empty()) {
std::unique_ptr<IndexedSubGraph> sub_graph = GetSubGraph(group, counter, graph);
std::unique_ptr<IndexedSubGraph> sub_graph = GetSubGraph(group, graph);
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
number_of_trt_nodes += group.first.size();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool engine_cache_enable_ = false;
std::string cache_path_;
nvinfer1::IRuntime* runtime_ = nullptr;

OrtMutex tensorrt_mu_;
int device_id_;
AllocatorPtr allocator_;
mutable int subgraph_id_ = 0;

std::unordered_map<std::string, tensorrt_ptr::unique_pointer<nvonnxparser::IParser>> parsers_;
std::unordered_map<std::string, tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>> engines_;
std::unordered_map<std::string, tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>> contexts_;
Expand All @@ -139,7 +141,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
std::unordered_map<std::string, std::unordered_map<std::string, std::unordered_map<int, std::pair<int64_t, int64_t>>>> input_shape_ranges_;

/**Get IndexedSubGraph based on node list of the subgraph*/
std::unique_ptr<IndexedSubGraph> GetSubGraph(SubGraph_t graph_nodes_index, int& kernels_index,
std::unique_ptr<IndexedSubGraph> GetSubGraph(SubGraph_t graph_nodes_index,
const GraphViewer& graph) const;

/**
Expand All @@ -153,7 +155,5 @@ class TensorrtExecutionProvider : public IExecutionProvider {
const GraphViewer& graph, bool* early_termination) const;

void RemoveTensorRTGraphCycles(SubGraphCollection_t& supported_nodes_vector, const GraphViewer& graph) const;

AllocatorPtr allocator_;
};
} // namespace onnxruntime

0 comments on commit aa49e47

Please sign in to comment.