Skip to content

Commit cc6e8fb

Browse files
authored
Filter initializers for GraphViewer with IndexedSubGraph (#5884)
* fix filtered subgraph initializer issue * minor fix * Inlcude implicit input of nodes to see if they are initializers * Add test case * minor update * Address PR comments * Fix some code error
1 parent ba739a8 commit cc6e8fb

File tree

3 files changed

+39
-4
lines changed

3 files changed

+39
-4
lines changed

include/onnxruntime/core/graph/graph_viewer.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,15 @@ class GraphViewer {
9595
*/
9696
const ConstGraphNodes& Nodes() const noexcept;
9797

98-
/** Gets the number of valid nodes in the Graph.
98+
/** Gets the number of valid nodes in the Graph.
9999
@remarks Returns the number of nodes in filter_info_ if set.
100100
*/
101101
int NumberOfNodes() const noexcept;
102102

103103
/** Gets the maximum NodeIndex value used by Nodes in the Graph. */
104104
int MaxNodeIndex() const noexcept;
105105

106-
/** Gets the NodeIndex values for the Graph nodes, sorted into topological order.
106+
/** Gets the NodeIndex values for the Graph nodes, sorted into topological order.
107107
@remarks Filtered using filter_info_ if set.
108108
*/
109109
const std::vector<NodeIndex>& GetNodesInTopologicalOrder(ExecutionOrder order = ExecutionOrder::DEFAULT) const;
@@ -138,7 +138,7 @@ class GraphViewer {
138138

139139
/**
140140
returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime.
141-
@param check_outer_scope If true and the 'graph_' is a subgraph, check parent graph/s for 'name'
141+
@param check_outer_scope If true and the 'graph_' is a subgraph, check parent graph/s for 'name'
142142
if the name is not found in 'graph_'.
143143
*/
144144
bool IsConstantInitializer(const std::string& name, bool check_outer_scope) const;
@@ -188,5 +188,6 @@ class GraphViewer {
188188
std::vector<const NodeArg*> filtered_node_inputs_;
189189
std::vector<const NodeArg*> filtered_node_inputs_including_initializers_;
190190
std::vector<const NodeArg*> filtered_node_outputs_;
191+
InitializedTensorSet filtered_initializers_;
191192
};
192193
} // namespace onnxruntime

onnxruntime/core/graph/graph_viewer.cc

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,26 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
121121
std::copy_if(orig_order.cbegin(), orig_order.cend(), std::back_inserter(nodes_in_topological_order_),
122122
[this](NodeIndex idx) { return filtered_node_indices_.count(idx) != 0; });
123123

124+
// Filter the initializers also
125+
// Get the names of all the inputs and implicit inputs of all the nodes in this subgraph
126+
for (const auto node_idx : filtered_node_indices_) {
127+
const auto* node = GetNode(node_idx);
128+
ORT_ENFORCE(node, "Mismatch between Graph and IndexedSubGraph. Node not found: ", node_idx);
129+
const ONNX_NAMESPACE::TensorProto* tensor = nullptr;
130+
for (const auto* node_input : node->InputDefs()) {
131+
if (graph.GetInitializedTensor(node_input->Name(), tensor)) {
132+
filtered_initializers_.insert({node_input->Name(), tensor});
133+
}
134+
}
135+
136+
// The implicit inputs for subgraphs (if any)
137+
for (const auto* node_input : node->ImplicitInputDefs()) {
138+
if (graph.GetInitializedTensor(node_input->Name(), tensor)) {
139+
filtered_initializers_.insert({node_input->Name(), tensor});
140+
}
141+
}
142+
}
143+
124144
#if !defined(ORT_MINIMAL_BUILD)
125145
auto orig_priority_order = std::move(nodes_in_topological_order_with_priority_);
126146
nodes_in_topological_order_with_priority_.reserve(filter_info->nodes.size());
@@ -146,6 +166,10 @@ const std::string& GraphViewer::Description() const noexcept {
146166

147167
bool GraphViewer::GetInitializedTensor(const std::string& tensor_name,
148168
const ONNX_NAMESPACE::TensorProto*& value) const {
169+
// if we are using filtered subgraph, the initializer has to be part of the subgraph
170+
if (filter_info_ != nullptr && filtered_initializers_.find(tensor_name) == filtered_initializers_.cend())
171+
return false;
172+
149173
return graph_->GetInitializedTensor(tensor_name, value);
150174
}
151175

@@ -220,7 +244,9 @@ const std::vector<NodeIndex>& GraphViewer::GetRootNodes() const {
220244
}
221245

222246
const InitializedTensorSet& GraphViewer::GetAllInitializedTensors() const noexcept {
223-
return graph_->GetAllInitializedTensors();
247+
return (filter_info_ == nullptr)
248+
? graph_->GetAllInitializedTensors()
249+
: filtered_initializers_;
224250
}
225251

226252
const NodeArg* GraphViewer::GetNodeArg(const std::string& name) const {

onnxruntime/test/ir/graph_viewer_test.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,14 @@ TEST(GraphViewer, FilteredGraph) {
9292
EXPECT_EQ(viewer.GetOutputs().size(), final_metadef->outputs.size());
9393
EXPECT_EQ(viewer.IsSubgraph(), false)
9494
<< "GraphViewer is for a filtered set of nodes of a single graph and not a nested subgraph";
95+
96+
// Verify the viewer's initializers are filtered as well
97+
const auto& viewer_initializers = viewer.GetAllInitializedTensors();
98+
EXPECT_EQ(viewer_initializers.size(), initializers.size());
99+
// We should have less initializers in the viewer than the underlying graph
100+
EXPECT_LT(viewer_initializers.size(), graph.GetAllInitializedTensors().size());
101+
// Pick a initializers which is not in the viewer, and check it is not part of the viewer's initializers
102+
EXPECT_TRUE(viewer_initializers.count("Constant15770PastValue16469") == 0);
95103
}
96104
} // namespace test
97105
} // namespace onnxruntime

0 commit comments

Comments
 (0)