@@ -121,6 +121,26 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
121121 std::copy_if (orig_order.cbegin (), orig_order.cend (), std::back_inserter (nodes_in_topological_order_),
122122 [this ](NodeIndex idx) { return filtered_node_indices_.count (idx) != 0 ; });
123123
124+ // Filter the initializers also
125+ // Get the names of all the inputs and implicit inputs of all the nodes in this subgraph
126+ for (const auto node_idx : filtered_node_indices_) {
127+ const auto * node = GetNode (node_idx);
128+ ORT_ENFORCE (node, " Mismatch between Graph and IndexedSubGraph. Node not found: " , node_idx);
129+ const ONNX_NAMESPACE::TensorProto* tensor = nullptr ;
130+ for (const auto * node_input : node->InputDefs ()) {
131+ if (graph.GetInitializedTensor (node_input->Name (), tensor)) {
132+ filtered_initializers_.insert ({node_input->Name (), tensor});
133+ }
134+ }
135+
136+ // The implicit inputs for subgraphs (if any)
137+ for (const auto * node_input : node->ImplicitInputDefs ()) {
138+ if (graph.GetInitializedTensor (node_input->Name (), tensor)) {
139+ filtered_initializers_.insert ({node_input->Name (), tensor});
140+ }
141+ }
142+ }
143+
124144#if !defined(ORT_MINIMAL_BUILD)
125145 auto orig_priority_order = std::move (nodes_in_topological_order_with_priority_);
126146 nodes_in_topological_order_with_priority_.reserve (filter_info->nodes .size ());
@@ -146,6 +166,10 @@ const std::string& GraphViewer::Description() const noexcept {
146166
147167bool GraphViewer::GetInitializedTensor (const std::string& tensor_name,
148168 const ONNX_NAMESPACE::TensorProto*& value) const {
169+ // if we are using filtered subgraph, the initializer has to be part of the subgraph
170+ if (filter_info_ != nullptr && filtered_initializers_.find (tensor_name) == filtered_initializers_.cend ())
171+ return false ;
172+
149173 return graph_->GetInitializedTensor (tensor_name, value);
150174}
151175
@@ -220,7 +244,9 @@ const std::vector<NodeIndex>& GraphViewer::GetRootNodes() const {
220244}
221245
222246const InitializedTensorSet& GraphViewer::GetAllInitializedTensors () const noexcept {
223- return graph_->GetAllInitializedTensors ();
247+ return (filter_info_ == nullptr )
248+ ? graph_->GetAllInitializedTensors ()
249+ : filtered_initializers_;
224250}
225251
226252const NodeArg* GraphViewer::GetNodeArg (const std::string& name) const {
0 commit comments