Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions onnxruntime/contrib_ops/cpu/quadric/quadric_custom_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,13 @@ QuadricCustomOp::Info::Info(const onnxruntime::Node& node, const GraphViewer& su
subgraph_input_names.insert(input->Name());
}

// This is commented out because we include initializers as inputs to the custom op, but
// This is only an inequality because we include initializers as inputs to the custom op, but
// *NOT* the sub-graph. As a result, the number of inputs differs. Unfortunately, ORT doesn't do
// a great job of telling us whether something is truly an initializer or not, so we can't
// effectively check whether an input is an initializer or not.
/*ORT_ENFORCE(num_subgraph_inputs == static_cast<size_t>(num_inputs),
"'QuadricCustomOp' node has ", num_inputs, " inputs which doesn't match the subgraph's ",
ORT_ENFORCE(num_subgraph_inputs <= static_cast<size_t>(num_inputs),
"'QuadricCustomOp' node (", node.Name(), ") has ", num_inputs, " inputs which is fewer than the subgraph's ",
num_subgraph_inputs, " inputs.");
*/

auto& subgraph_outputs = subgraph.GetOutputs();
auto num_subgraph_outputs = subgraph_outputs.size();
Expand Down Expand Up @@ -226,23 +225,29 @@ Status QuadricCustomOp::SetupSubgraphExecutionInfo(const SessionState& session_s
info_ = std::make_unique<QuadricCustomOp::Info>(node, subgraph_session_state.GetGraphViewer());

const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap();
auto num_subgraph_inputs = subgraph_session_state.GetGraphViewer().GetInputs().size();

std::vector<std::string> feed_names;

const auto& input_defs = node.InputDefs();
for (size_t i = 0, end = info_->num_inputs; i < end; ++i) {
for (size_t i = 0, end = num_subgraph_inputs; i < end; ++i) {
const auto* input = input_defs[i];
// Not all subgraph inputs will have names that correspond to the node's inputs. The inputs
// that diverge like this are limited *only* to initializers and we don't need to create
// feeds for them. Furthermore, since they are not actually used by the custom op (and
// not even by the sub-graph since the subgraph contains its own version of initializers)
// they end up getting removed from the graph during an optimization step and so we can't
// prove that it's an initializer using Graph::IsInitializedTensor

if (info_->subgraph_input_names.find(input->Name()) != info_->subgraph_input_names.end()) {
feed_names.push_back(input->Name());
info_->used_inputs[i] = true;
std::string input_name = input->Name();
// Strip-off any '/duplicated'
auto pos = input_name.find("/duplicated");
if (pos != std::string::npos) {
input_name = input_name.erase(pos);
}
ORT_ENFORCE(info_->subgraph_input_names.find(input_name) != info_->subgraph_input_names.end(),
"Could not match input ", input_name, " with any subgraph input.");
feed_names.push_back(input_name);
info_->used_inputs[i] = true;
}

std::unique_ptr<FeedsFetchesManager> ffm;
Expand Down