diff --git a/README.md b/README.md index 051363ce6..b587bdc52 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,10 @@ paddle2onnx --model_dir saved_inference_model \ --model_filename model.pdmodel \ --params_filename model.pdiparams \ --save_file model.onnx +paddle2onnx --model_dir ch_ppstructure_mobile_v2.0_SLANet_infer \ + --model_filename inference.pdmodel \ + --params_filename inference.pdiparams \ + --save_file inference.onnx ``` 可调整的转换参数如下表: diff --git a/paddle2onnx/mapper/exporter.cc b/paddle2onnx/mapper/exporter.cc index 474b693f9..932dc7124 100644 --- a/paddle2onnx/mapper/exporter.cc +++ b/paddle2onnx/mapper/exporter.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -36,6 +36,38 @@ namespace paddle2onnx { MapperHelper *MapperHelper::helper = nullptr; int32_t OnnxHelper::opset_version = 7; +bool ModelExporter::IsWhileSupported(const PaddleParser &parser, + const int64_t &block_id, + const int64_t &op_id) { + auto x_info = parser.GetOpInput(block_id, op_id, "X"); + auto out_info = parser.GetOpOutput(block_id, op_id, "Out"); + auto cond_info = parser.GetOpInput(block_id, op_id, "Condition"); + std::set input_names; + for (size_t i = 0; i < x_info.size(); ++i) { + input_names.insert(x_info[i].name); + } + input_names.insert(cond_info[0].name); + + for (size_t i = 0; i < out_info.size(); ++i) { + auto iter = input_names.find(out_info[i].name); + if (iter == input_names.end()) { + P2OLogger() << "Cannot find output:" << out_info[i].name + << " in input tensors while converting operator 'while', " + "Paddle2ONNX doesn't support this situation now." + << std::endl; + return false; + } + } + + for (size_t i = 0; i < x_info.size(); ++i) { + if (x_info[i].is_tensor_array) { + P2OLogger() << "LodTensorArray is not supported." << std::endl; + return false; + } + } + return true; +} + bool ModelExporter::IsOpsRegistered(const PaddleParser &parser, bool enable_experimental_op) { OnnxHelper temp_helper; @@ -45,21 +77,16 @@ bool ModelExporter::IsOpsRegistered(const PaddleParser &parser, auto op = parser.GetOpDesc(i, j); if (op.type() == "feed" || op.type() == "fetch") { continue; - } - - if (op.type() == "conditional_block" || op.type() == "select_input") { + } else if (op.type() == "conditional_block" || + op.type() == "select_input") { continue; - } -#if 0 - if (op.type() == "while" && enable_experimental_op) - { - if (!IsLoopSupported(parser, i, j)) - { - unsupported_ops.insert("while"); - } - continue; + } else if (op.type() == "while" && enable_experimental_op) { + if (!IsWhileSupported(parser, i, j)) { + unsupported_ops.insert("while"); } -#endif + continue; + } + if (custom_ops.find(op.type()) != custom_ops.end()) { continue; } @@ -107,28 +134,23 @@ int32_t ModelExporter::GetMinOpsetVersion(const PaddleParser &parser) { } int current_opset = 7; - if (op.type() == "select_input") { P2OLogger() << "Detected there's control flow " "op('conditional_block/select_input') in your model, " << "this requires the minimal opset version of 11." << std::endl; current_opset = 11; + } else if (op.type() == "while") { + P2OLogger() + << "Detected there's control flow 'while' op in your model, " + << "this requires the minimal opset version of 13." << std::endl; + current_opset = 13; } else { auto mapper = MapperHelper::Get()->CreateMapper(op.type(), parser, &helper, i, j); current_opset = mapper->GetMinOpsetVersion(verbose_); delete mapper; } -#if 0 - if (op.type() == "while") - { - P2OLogger() << "Detected there's control flow 'while' op in your model, " - << "this requires the minimal opset version of 13." - << std::endl; - current_opset = 13; - } -#endif if (current_opset > max_opset) { max_opset = current_opset; @@ -262,72 +284,26 @@ void ModelExporter::ExportParameters( } } -ONNX_NAMESPACE::GraphProto ModelExporter::ExportConditionalBlock( - const PaddleParser &parser, int32_t block_id, int32_t op_id, - const std::string &output_names) { - auto op = parser.GetOpDesc(block_id, op_id); - - // Get sub_block_idx - int32_t sub_block_idx = -1; - for (size_t i = 0; i < op.attrs_size(); ++i) { - if (op.attrs(i).name() == "sub_block") { - sub_block_idx = op.attrs(i).block_idx(); - break; - } - } - Assert(sub_block_idx != -1, - "Due to the unsupported sub_block_idx, the conversion is aborted."); - - std::vector> temp_parameters; - - std::vector> temp_inputs; - // auto input_info = parser.GetOpInput(block_id, op_id, "Input"); - // for (int index = 0; index < input_info.size(); index++) - // { - // temp_inputs.push_back(std::move(MakeValueInfo(input_info[index]))); - // } - - std::vector> temp_outputs; - auto out_info = parser.GetOpOutput(block_id, op_id, "Out"); - for (int index = 0; index < out_info.size(); index++) { - if (out_info[index].name != output_names) { - continue; - } - temp_outputs.push_back(std::move(MakeValueInfo(out_info[index]))); - } - return std::move(ExportBlock(parser, sub_block_idx, temp_parameters, - temp_inputs, temp_outputs)); -} - -ONNX_NAMESPACE::GraphProto ModelExporter::ExportFillConstant( - const PaddleParser &parser, OnnxHelper *temp_helper, int32_t block_id, - int32_t op_id, const std::string &output_names) { - ONNX_NAMESPACE::GraphProto graph; - graph.set_name("PaddlePaddle fill_constant Graph " + std::to_string(op_id)); - auto op = parser.GetOpDesc(block_id, op_id); // fill_constant - auto out_info = parser.GetOpOutput(block_id, op_id, "Out"); - - *(graph.add_output()) = (*MakeValueInfo(out_info[0])); - for (auto &item : temp_helper->nodes) { - if (item->output(0) == output_names) { - *(graph.add_node()) = (*item.get()); - break; - } - } - - return std::move(graph); -} ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock( const PaddleParser &parser, int32_t block_id, std::vector> ¶meters, std::vector> &inputs, - std::vector> &outputs) { + std::vector> &outputs, + OnnxHelper *helper, bool is_while_block) { ONNX_NAMESPACE::GraphProto graph; graph.set_name("PaddlePaddle Graph " + std::to_string(block_id)); - OnnxHelper temp_helper; auto num_ops = parser.NumOfOps(block_id); - temp_helper.nodes.reserve(num_ops * 3); - temp_helper.Clear(); + + // Init ONNXHelp + OnnxHelper *temp_helper = nullptr; + if (helper == nullptr) { + temp_helper = new OnnxHelper(); + temp_helper->nodes.reserve(num_ops * 3); + temp_helper->Clear(); + } else { + temp_helper = helper; + } + for (auto op_id = 0; op_id < num_ops; ++op_id) { auto op = parser.GetOpDesc(block_id, op_id); if (op.type() == "feed") { @@ -341,69 +317,20 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock( } continue; } else if (op.type() == "select_input") { - auto input_info = parser.GetOpInput(block_id, op_id, "X"); - - Assert(input_info.size() == 2, - "Only support when number of select_input's input_node is 2."); - - // Build else sub graph - auto else_node_name = input_info[0].name; - auto conditional_block_cood_it = sub_block_map_.find(else_node_name); - Assert(conditional_block_cood_it != sub_block_map_.end(), - "Can't find select_input else_input node."); - auto conditional_block_cood = conditional_block_cood_it->second; - ONNX_NAMESPACE::GraphProto else_graph, then_graph; - auto else_node = parser.GetOpDesc(conditional_block_cood.first, - conditional_block_cood.second); - - if (else_node.type().find("conditional_block") != std::string::npos) { - else_graph = ExportConditionalBlock( - parser, conditional_block_cood.first, conditional_block_cood.second, - else_node_name); - } else { - else_graph = ExportFillConstant( - parser, &temp_helper, conditional_block_cood.first, - conditional_block_cood.second, else_node_name); - } - - // Build then sub graph - auto then_node_name = input_info[1].name; - conditional_block_cood_it = sub_block_map_.find(then_node_name); - Assert(conditional_block_cood_it != sub_block_map_.end(), - "Can't find select_input then_input node."); - conditional_block_cood = conditional_block_cood_it->second; - auto then_node = parser.GetOpDesc(conditional_block_cood.first, - conditional_block_cood.second); - - // use node.type() to make sure correctness - if (then_node.type().find("conditional_block") != std::string::npos) { - then_graph = ExportConditionalBlock( - parser, conditional_block_cood.first, conditional_block_cood.second, - then_node_name); - } else { - then_graph = ExportFillConstant( - parser, &temp_helper, conditional_block_cood.first, - conditional_block_cood.second, then_node_name); - } - - auto cond_info = parser.GetOpInput(block_id, op_id, "Mask"); - auto output_info = parser.GetOpOutput(block_id, op_id, "Out"); - auto cond_name = temp_helper.AutoCast( - cond_info[0].name, cond_info[0].dtype, P2ODataType::BOOL); - auto node = - temp_helper.MakeNode("If", {cond_name}, {output_info[0].name}); - AddAttribute(node, "then_branch", then_graph); - AddAttribute(node, "else_branch", else_graph); + ExportSelectInput(parser, temp_helper, block_id, op_id); continue; } else if (op.type() == "fill_constant") { auto out_info = parser.GetOpOutput(block_id, op_id, "Out"); sub_block_map_[out_info[0].name] = {block_id, op_id}; + } else if (op.type() == "while") { + ExportWhile(parser, temp_helper, block_id, op_id); + continue; } - ExportOp(parser, &temp_helper, opset_version_, block_id, op_id, verbose_); + ExportOp(parser, temp_helper, opset_version_, block_id, op_id, verbose_); } - ProcessGraphDumplicateNames(parameters, inputs, outputs, temp_helper.nodes, - temp_helper.quantize_info); + ProcessGraphDumplicateNames(parameters, inputs, outputs, temp_helper->nodes, + temp_helper->quantize_info, is_while_block); // Process the model according to deploy_mackend_ if (parser.is_quantized_model) { @@ -422,14 +349,14 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock( deploy_backend_ + "."); } P2OLogger() << "Deploy backend is: " << deploy_backend_ << std::endl; - quantize_processer_->ProcessQuantizeModel(¶meters, &inputs, &outputs, - &temp_helper.nodes, &temp_helper, - parser, calibration_cache_); + quantize_processer_->ProcessQuantizeModel( + ¶meters, &inputs, &outputs, &(temp_helper->nodes), temp_helper, + parser, calibration_cache_); delete quantize_processer_; quantize_processer_ = nullptr; // Update int8 weights in quantized OP to float32 - UpdateParameters(temp_helper.updated_params, parameters); + UpdateParameters(temp_helper->updated_params, parameters); } for (auto &item : parameters) { @@ -444,14 +371,17 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock( *(graph.add_output()) = (*item.get()); } - for (auto &item : temp_helper.nodes) { + for (auto &item : temp_helper->nodes) { *(graph.add_node()) = (*item.get()); } - for (auto &item : temp_helper.value_infos) { + for (auto &item : temp_helper->value_infos) { *(graph.add_value_info()) = (*item.get()); } + if (helper == nullptr) { + delete temp_helper; + } return std::move(graph); } @@ -553,12 +483,6 @@ void ModelExporter::ExportOp(const PaddleParser &parser, OnnxHelper *helper, int32_t opset_version, int64_t block_id, int64_t op_id, bool verbose) { auto op = parser.GetOpDesc(block_id, op_id); -#if 0 - if (op.type() == "while") - { - return ExportLoop(parser, helper, opset_version, block_id, op_id, verbose); - } -#endif if (MapperHelper::Get()->IsRegistered(op.type())) { auto mapper = MapperHelper::Get()->CreateMapper(op.type(), parser, helper, block_id, op_id); @@ -581,26 +505,43 @@ void ModelExporter::ProcessGraphDumplicateNames( std::vector> &inputs, std::vector> &outputs, std::vector> &nodes, - std::map &quantize_info) { - std::map renamer; + std::map &quantize_info, bool is_while_block) { + /********************* Create Tensor Names *********************/ + for (auto &item : nodes) { + for (size_t i = 0; i < item->input_size(); ++i) { + if (item->name().find("Loop") != std::string::npos) { + // P2OLogger() << "nodes item input:" << item->input(i) << std::endl; + while_tensor_names_.insert(item->input(i)); + } + } + for (size_t i = 0; i < item->output_size(); ++i) { + if (item->name().find("Loop") != std::string::npos) { + // P2OLogger() << "nodes item output:" << item->output(i) << std::endl; + while_tensor_names_.insert(item->output(i)); + } + } + } + // for (const auto& tensor_name : while_tensor_names_) { + // tensor_names_.erase(tensor_name); + // } + /********************* Create Tensor Names *********************/ + + /********************* Rename *********************/ for (auto &item : parameters) { for (size_t i = 0; i < item->output_size(); ++i) { if (tensor_names_.find(item->output(i)) != tensor_names_.end()) { - Assert(false, "There's dumplicate names in exported parameters."); + P2OLogger() + << "[WARNING] There's dumplicate names in exported parameters."; + continue; } tensor_names_.insert(item->output(i)); } } for (auto &item : inputs) { - if (tensor_names_.find(item->name()) != tensor_names_.end()) { - continue; - // Assert(false, "There's dumplicate names:" + item->name() + " in - // exported parameters and inputs."); - } tensor_names_.insert(item->name()); } - + std::map renamer; for (auto &item : nodes) { // update node inputs for (size_t i = 0; i < item->input_size(); ++i) { @@ -617,15 +558,17 @@ void ModelExporter::ProcessGraphDumplicateNames( // dumplicate name for (size_t i = 0; i < item->output_size(); ++i) { if (tensor_names_.find(item->output(i)) != tensor_names_.end()) { + if (is_while_block) { + if (while_tensor_names_.find(item->output(i)) != while_tensor_names_.end()) { + // P2OLogger() << "Skip: " << item->output(i) << std::endl; + continue; + } + } std::string renamed_tensor_name = item->output(i); while (renamer.find(renamed_tensor_name) != renamer.end()) { renamed_tensor_name = renamer[renamed_tensor_name]; } - auto new_tensor_name = - MapperHelper::Get()->GenName(renamed_tensor_name); - // P2OLogger() << "Find dumplicate output name '" << renamed_tensor_name - // << "', it will rename to '" << new_tensor_name << "'." - // << std::endl; + auto new_tensor_name = MapperHelper::Get()->GenName(renamed_tensor_name); if (quantize_info.find(renamed_tensor_name) != quantize_info.end()) { quantize_info[new_tensor_name] = quantize_info[renamed_tensor_name]; } @@ -645,6 +588,7 @@ void ModelExporter::ProcessGraphDumplicateNames( item->set_name(updated_name); } } + /********************* Rename *********************/ } void ModelExporter::SaveExternalData(::ONNX_NAMESPACE::GraphProto *graph, @@ -749,9 +693,8 @@ std::string ModelExporter::Run( std::vector> inputs; std::vector> outputs; ExportInputOutputs(parser, inputs, outputs); - // Export Blocks - tensor_names_.clear(); + tensor_names_.clear(); auto share_graph = ExportBlock(parser, 0, parameters, inputs, outputs); *onnx_model_.mutable_graph() = share_graph; @@ -827,4 +770,4 @@ ONNX_NAMESPACE::ModelProto ModelExporter::Optimize( return ONNX_NAMESPACE::optimization::Optimize(model, passes); } -} // namespace paddle2onnx +} // namespace paddle2onnx diff --git a/paddle2onnx/mapper/exporter.h b/paddle2onnx/mapper/exporter.h index f3d8577d9..87e72b445 100644 --- a/paddle2onnx/mapper/exporter.h +++ b/paddle2onnx/mapper/exporter.h @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -87,14 +87,15 @@ class ModelExporter { void ExportParameters( const PaddleParser &parser, std::vector> ¶meters); - // Process dumplicate tensor names in paddle model std::set tensor_names_; + std::set while_tensor_names_; void ProcessGraphDumplicateNames( std::vector> ¶meters, std::vector> &inputs, std::vector> &outputs, std::vector> &nodes, - std::map &quantize_info); + std::map &quantize_info, + bool is_while_block = false); // Update constant node in parameters. When process quantize model, the weight // dtype may be int8, it should be convet to float32 and use this function to // update converted params. @@ -103,28 +104,35 @@ class ModelExporter { std::vector> ¶meters); // std::map> sub_block_map_; + ONNX_NAMESPACE::GraphProto ExportFillConstant(const PaddleParser &parser, + OnnxHelper *temp_helper, + int32_t block_id, int32_t op_id, + const std::string &output_name); ONNX_NAMESPACE::GraphProto ExportConditionalBlock( - const PaddleParser &parser, int32_t block_id, int32_t op_id, - const std::string &output_names); - ONNX_NAMESPACE::GraphProto ExportFillConstant( - const PaddleParser &parser, OnnxHelper *temp_helper, - int32_t block_id, int32_t op_id, - const std::string &output_names); + const PaddleParser &parser, OnnxHelper *temp_helper, int32_t block_id, + int32_t op_id, const std::string &output_name); + void ExportSelectInput(const PaddleParser &parser, OnnxHelper *temp_helper, + int32_t block_id, int32_t op_id); + void ExportWhile(const PaddleParser &parser, OnnxHelper *temp_helper, + int32_t block_id, int32_t op_id); ONNX_NAMESPACE::GraphProto ExportBlock( const PaddleParser &parser, int32_t block_id, std::vector> ¶meters, std::vector> &inputs, - std::vector> &outputs); + std::vector> &outputs, + OnnxHelper *helper = nullptr, bool is_while_block = false); void ExportOp(const PaddleParser &parser, OnnxHelper *helper, int32_t opset_version, int64_t block_id, int64_t op_id, bool verbose); + + bool IsWhileSupported(const PaddleParser &parser, const int64_t &block_id, + const int64_t &op_id); + #if 0 - bool IsLoopSupported(const PaddleParser &parser, const int64_t &block_id, - const int64_t &op_id); - void ExportLoop(const PaddleParser &parser, OnnxHelper *helper, - int32_t opset_version, int64_t block_id, int64_t op_id, - bool verbose); + void ExportLoop(const PaddleParser &parser, OnnxHelper *helper, + int32_t opset_version, int64_t block_id, int64_t op_id, + bool verbose); #endif ONNX_NAMESPACE::ModelProto Optimize(const ONNX_NAMESPACE::ModelProto &model); void CovertCustomOps(const PaddleParser &parser, OnnxHelper *helper, diff --git a/paddle2onnx/mapper/loop.cc b/paddle2onnx/mapper/loop.cc index c94abd392..c60b99ad9 100644 --- a/paddle2onnx/mapper/loop.cc +++ b/paddle2onnx/mapper/loop.cc @@ -16,35 +16,6 @@ #include "paddle2onnx/mapper/exporter.h" namespace paddle2onnx { - -bool ModelExporter::IsLoopSupported(const PaddleParser& parser, - const int64_t& block_id, - const int64_t& op_id) { - auto x_info = parser.GetOpInput(block_id, op_id, "X"); - auto out_info = parser.GetOpOutput(block_id, op_id, "Out"); - auto cond_info = parser.GetOpInput(block_id, op_id, "Condition"); - std::set input_names; - for (size_t i = 0; i < x_info.size(); ++i) { - input_names.insert(x_info[i].name); - } - input_names.insert(cond_info[0].name); - - for (size_t i = 0; i < out_info.size(); ++i) { - auto iter = input_names.find(out_info[i].name); - if (iter == input_names.end()) { - P2OLogger() << "Cannot find output:" << out_info[i].name << " in input tensors while converting operator 'while', Paddle2ONNX doesn't support this situation now." << std::endl; - return false; - } - } - for (size_t i = 0; i < x_info.size(); ++i) { - if (x_info[i].is_tensor_array) { - P2OLogger() << "LodTensorArray is not supported." << std::endl; - return false; - } - } - return true; -} - void ModelExporter::ExportLoop(const PaddleParser& parser, OnnxHelper* helper, int32_t opset_version, int64_t block_id, int64_t op_id, bool verbose) { @@ -192,6 +163,5 @@ void ModelExporter::ExportLoop(const PaddleParser& parser, OnnxHelper* helper, attr->set_type(ONNX_NAMESPACE::AttributeProto::GRAPH); *(attr->mutable_g()) = *(graph.get()); } - } // namespace paddle2onnx #endif \ No newline at end of file diff --git a/paddle2onnx/mapper/nn/interpolate.cc b/paddle2onnx/mapper/nn/interpolate.cc index 366436cc0..a3679cd05 100755 --- a/paddle2onnx/mapper/nn/interpolate.cc +++ b/paddle2onnx/mapper/nn/interpolate.cc @@ -34,7 +34,6 @@ int32_t InterpolateMapper::GetMinOpsetVersion(bool verbose) { << x_info[0].Rank() << std::endl; return -1; } - Logger(verbose, 11) << RequireOpset(11) << std::endl; return 11; } diff --git a/paddle2onnx/mapper/register_mapper.h b/paddle2onnx/mapper/register_mapper.h index 6dc43b2ff..b2758cda8 100644 --- a/paddle2onnx/mapper/register_mapper.h +++ b/paddle2onnx/mapper/register_mapper.h @@ -80,7 +80,8 @@ class MapperHelper { } std::string GenName(const std::string& op_name) { - std::string key = "p2o." + op_name + "."; + // std::string key = "p2o." + op_name + "."; + std::string key = op_name + "."; if (name_counter.find(key) == name_counter.end()) { name_counter[key] = 0; } else { diff --git a/paddle2onnx/mapper/select_input.cc b/paddle2onnx/mapper/select_input.cc new file mode 100644 index 000000000..192b3bc5c --- /dev/null +++ b/paddle2onnx/mapper/select_input.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle2onnx/mapper/exporter.h" +namespace paddle2onnx { +ONNX_NAMESPACE::GraphProto ModelExporter::ExportFillConstant( + const PaddleParser &parser, OnnxHelper *temp_helper, int32_t block_id, + int32_t op_id, const std::string &output_name) { + ONNX_NAMESPACE::GraphProto graph; + graph.set_name("PaddlePaddle fill_constant Graph " + std::to_string(op_id)); + + // Add input + + // Add node + auto &nodes = temp_helper->nodes; + for (int i = 0; i < nodes.size(); i++) { + auto &item = nodes[i]; + if (item->output(0) == output_name) { + *(graph.add_node()) = (*item.get()); + nodes.erase(nodes.begin() + i); + break; + } + } + + // Add output + auto out_info = parser.GetOpOutput(block_id, op_id, "Out"); + *(graph.add_output()) = (*MakeValueInfo(out_info[0])); + return std::move(graph); +} + +ONNX_NAMESPACE::GraphProto ModelExporter::ExportConditionalBlock( + const PaddleParser &parser, OnnxHelper *temp_helper, int32_t block_id, + int32_t op_id, const std::string &output_name) { + auto op = parser.GetOpDesc(block_id, op_id); + + // Get sub_block_idx + int32_t sub_block_idx = -1; + for (size_t i = 0; i < op.attrs_size(); ++i) { + if (op.attrs(i).name() == "sub_block") { + sub_block_idx = op.attrs(i).block_idx(); + break; + } + } + Assert(sub_block_idx != -1, + "Due to the unsupported sub_block_idx, the conversion is aborted."); + + // Export sub_block + std::vector> temp_parameters; + std::vector> temp_inputs; + std::vector> temp_outputs; + auto out_info = parser.GetOpOutput(block_id, op_id, "Out"); + for (int index = 0; index < out_info.size(); index++) { + if (out_info[index].name != output_name) { + continue; + } + temp_outputs.push_back(std::move(MakeValueInfo(out_info[index]))); + } + + return ExportBlock(parser, sub_block_idx, temp_parameters, temp_inputs, + temp_outputs); +} + +void ModelExporter::ExportSelectInput(const PaddleParser &parser, + OnnxHelper *temp_helper, int32_t block_id, + int32_t op_id) { + auto input_info = parser.GetOpInput(block_id, op_id, "X"); + + Assert(input_info.size() == 2, + "Only support when number of select_input's input_node is 2."); + + ONNX_NAMESPACE::GraphProto graphs[2]; + for (int i = 0; i < input_info.size(); i++) { + auto node_name = input_info[i].name; + auto conditional_block_cood_it = sub_block_map_.find(node_name); + Assert(conditional_block_cood_it != sub_block_map_.end(), + "Can't find select_input else_input node."); + auto conditional_block_cood = conditional_block_cood_it->second; + auto node = parser.GetOpDesc(conditional_block_cood.first, + conditional_block_cood.second); + + if (node.type().find("conditional_block") != std::string::npos) { + graphs[i] = ExportConditionalBlock( + parser, temp_helper, conditional_block_cood.first, + conditional_block_cood.second, node_name); + } else { + graphs[i] = + ExportFillConstant(parser, temp_helper, conditional_block_cood.first, + conditional_block_cood.second, node_name); + } + } + + auto cond_info = parser.GetOpInput(block_id, op_id, "Mask"); + auto output_info = parser.GetOpOutput(block_id, op_id, "Out"); + auto cond_name = temp_helper->AutoCast(cond_info[0].name, cond_info[0].dtype, + P2ODataType::BOOL); + auto node = temp_helper->MakeNode("If", {cond_name}, {output_info[0].name}); + AddAttribute(node, "else_branch", graphs[0]); + AddAttribute(node, "then_branch", graphs[1]); +} +} // namespace paddle2onnx \ No newline at end of file diff --git a/paddle2onnx/mapper/tensor/fill_constant.cc b/paddle2onnx/mapper/tensor/fill_constant.cc index 1fdce2c80..95cb4559e 100644 --- a/paddle2onnx/mapper/tensor/fill_constant.cc +++ b/paddle2onnx/mapper/tensor/fill_constant.cc @@ -36,11 +36,9 @@ int32_t FillConstantMapper::GetMinOpsetVersion(bool verbose) { return -1; } if (HasInput("ShapeTensorList")) { - Logger(verbose, 9) << "While ShapeTensorList as input, " << RequireOpset(9) << std::endl; return 9; } if (HasInput("ShapeTensor") && !IsConstantInput("ShapeTensor")) { - Logger(verbose, 9) << "While ShapeTensor as input and it's not a constant tensor, " << RequireOpset(9) << std::endl; return 9; } return 7; @@ -108,12 +106,17 @@ void FillConstantMapper::Opset9() { shape_name = helper_->ConcatIndices(shape_info); } - auto node = helper_->MakeNode("ConstantOfShape", {shape_name}); + std::shared_ptr node; + if (value_is_tensor) { + node = helper_->MakeNode("ConstantOfShape", {shape_name}); + } else { + node = helper_->MakeNode("ConstantOfShape", {shape_name}, {out_info[0].name}); + } auto attr = node->add_attribute(); attr->set_name("value"); attr->set_type(ONNX_NAMESPACE::AttributeProto::TENSOR); auto tensor = attr->mutable_t(); - tensor->set_name(out_info[0].name); + tensor->set_name(MapperHelper::Get()->GenName("ConstantOfShape.tensor")); tensor->set_data_type(onnx_dtype); tensor->add_dims(1); if (onnx_dtype == ONNX_NAMESPACE::TensorProto::INT32) { @@ -145,13 +148,16 @@ void FillConstantMapper::Opset9() { } else { std::vector shape; GetAttr("shape", &shape); - out = helper_->Constant(shape, onnx_dtype, value); + if (value_is_tensor) { + out = helper_->Constant(shape, onnx_dtype, value); + } else { + out = helper_->Constant(out_info[0].name, shape, onnx_dtype, value); + } } + if (value_is_tensor) { auto value_info = GetInput("ValueTensor"); helper_->MakeNode("Add", {out, value_info[0].name}, {out_info[0].name}); - } else { - helper_->MakeNode("Identity", {out}, {out_info[0].name}); } } diff --git a/paddle2onnx/mapper/tensor/one_hot_v2.cc b/paddle2onnx/mapper/tensor/one_hot_v2.cc index 29df5665d..232eb03ab 100644 --- a/paddle2onnx/mapper/tensor/one_hot_v2.cc +++ b/paddle2onnx/mapper/tensor/one_hot_v2.cc @@ -28,7 +28,6 @@ int32_t OneHotV2Mapper::GetMinOpsetVersion(bool verbose) { Error() << "dtype attribute and output dtype do not match." << std::endl; return -1; } - Logger(verbose, 9) << RequireOpset(9) << std::endl; return 9; } diff --git a/paddle2onnx/mapper/tensor/repeat_interleave.cc b/paddle2onnx/mapper/tensor/repeat_interleave.cc index 6304935d4..ebf8b0232 100644 --- a/paddle2onnx/mapper/tensor/repeat_interleave.cc +++ b/paddle2onnx/mapper/tensor/repeat_interleave.cc @@ -14,61 +14,62 @@ // limitations under the License. #include "paddle2onnx/mapper/tensor/repeat_interleave.h" -namespace paddle2onnx -{ - REGISTER_MAPPER(repeat_interleave, RepeatInterleaveMapper) +namespace paddle2onnx { +REGISTER_MAPPER(repeat_interleave, RepeatInterleaveMapper) - int32_t RepeatInterleaveMapper::GetMinOpsetVersion(bool verbose) - { - constexpr int op_version = 9; - Logger(verbose, op_version) << RequireOpset(op_version) << std::endl; - return op_version; - } +int32_t RepeatInterleaveMapper::GetMinOpsetVersion(bool verbose) { + constexpr int op_version = 9; + Logger(verbose, op_version) << RequireOpset(op_version) << std::endl; + return op_version; +} - void RepeatInterleaveMapper::Opset9() - { - auto x_info = GetInput("X"); // shape = [1, 2, 3] - auto out_info = GetOutput("Out"); - int n = x_info[0].shape[dim_]; - int x_shape_size = x_info[0].shape.size(); +void RepeatInterleaveMapper::Opset9() { + auto x_info = GetInput("X"); // shape = [1, 2, 3] + auto out_info = GetOutput("Out"); + int n = x_info[0].shape[dim_]; + int x_shape_size = x_info[0].shape.size(); - std::vector repeats; - int64_t repeat; - GetAttr("Repeats", &repeat); - if (repeat != 0) - { - std::vector rp_tmp(n, repeat); - repeats.assign(rp_tmp.begin(), rp_tmp.end()); - } + std::vector repeats; + int64_t repeat; + GetAttr("Repeats", &repeat); + if (repeat != 0) { + std::vector rp_tmp(n, repeat); + repeats.assign(rp_tmp.begin(), rp_tmp.end()); + } - std::string repeat_info_name = ""; - if (HasInput("RepeatsTensor")) - { - auto tmp_info = GetInput("RepeatsTensor"); - repeat_info_name = helper_->AutoCast(tmp_info[0].name, tmp_info[0].dtype, P2ODataType::INT64); - } - else - { - repeat_info_name = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64, repeats); - } + std::string repeat_info_name = ""; + if (HasInput("RepeatsTensor")) { + auto tmp_info = GetInput("RepeatsTensor"); + repeat_info_name = helper_->AutoCast(tmp_info[0].name, tmp_info[0].dtype, + P2ODataType::INT64); + } else { + repeat_info_name = + helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64, repeats); + } - std::vector splits(n, 1); + std::vector splits(n, 1); - std::vector split_repeat_info_names = helper_->Split(repeat_info_name, splits, 0); - std::vector split_input_names = helper_->Split(x_info[0].name, splits, dim_); + std::vector split_repeat_info_names = + helper_->Split(repeat_info_name, splits, 0); + std::vector split_input_names = + helper_->Split(x_info[0].name, splits, dim_); - int n_suffix_tile = x_shape_size - dim_ - 1; - int n_prefix_tile = dim_; - std::string suffix_name = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64, std::vector(n_suffix_tile, 1)); - std::string prefix_name = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64, std::vector(n_prefix_tile, 1)); + int n_suffix_tile = x_shape_size - dim_ - 1; + int n_prefix_tile = dim_; + std::string suffix_name = + helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64, + std::vector(n_suffix_tile, 1)); + std::string prefix_name = + helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64, + std::vector(n_prefix_tile, 1)); - std::vector output_names; - for (int i = 0; i < n; i++) - { - std::string tile_name = helper_->Concat({prefix_name, split_repeat_info_names[i], suffix_name}, 0); - auto node = helper_->MakeNode("Tile", {split_input_names[i], tile_name}, 1); - output_names.emplace_back(node->output(0)); - } - helper_->Concat(output_names, out_info[0].name, dim_); - } -} // namespace paddle2onnx \ No newline at end of file + std::vector output_names; + for (int i = 0; i < n; i++) { + std::string tile_name = helper_->Concat( + {prefix_name, split_repeat_info_names[i], suffix_name}, 0); + auto node = helper_->MakeNode("Tile", {split_input_names[i], tile_name}, 1); + output_names.emplace_back(node->output(0)); + } + helper_->Concat(output_names, out_info[0].name, dim_); +} +} // namespace paddle2onnx \ No newline at end of file diff --git a/paddle2onnx/mapper/tensor/set_value.cc b/paddle2onnx/mapper/tensor/set_value.cc index 88a1b6837..bc57e9a0f 100644 --- a/paddle2onnx/mapper/tensor/set_value.cc +++ b/paddle2onnx/mapper/tensor/set_value.cc @@ -37,7 +37,6 @@ int32_t SetValueMapper::GetMinOpsetVersion(bool verbose) { << std::endl; return -1; } - Logger(verbose, 12) << RequireOpset(12) << std::endl; return 12; } diff --git a/paddle2onnx/mapper/while.cc b/paddle2onnx/mapper/while.cc new file mode 100644 index 000000000..dccd9386f --- /dev/null +++ b/paddle2onnx/mapper/while.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle2onnx/mapper/exporter.h" +namespace paddle2onnx { +void ModelExporter::ExportWhile(const PaddleParser& parser, + OnnxHelper* temp_helper, int32_t block_id, + int32_t op_id) { + auto op = parser.GetOpDesc(block_id, op_id); + auto x_info = parser.GetOpInput(block_id, op_id, "X"); + auto cond_info = parser.GetOpInput(block_id, op_id, "Condition"); + auto out_info = parser.GetOpOutput(block_id, op_id, "Out"); + + ONNX_NAMESPACE::GraphProto graph; + /********************* Creat Body Gragh *********************/ + int32_t sub_block_idx = -1; + for (size_t i = 0; i < op.attrs_size(); ++i) { + if (op.attrs(i).name() == "sub_block") { + sub_block_idx = op.attrs(i).block_idx(); + break; + } + } + Assert(sub_block_idx > 0, "Cannot find sub_block in while operator."); + + std::vector> parameters; + std::vector input_names; + std::vector> inputs; + std::vector output_names; + std::vector> outputs; + + auto iter_name = MapperHelper::Get()->GenName("loop.iter"); + TensorInfo iter_info(iter_name, std::vector(1, 1), + P2ODataType::INT64); + inputs.push_back(std::move(MakeValueInfo(iter_info))); + + // Make cond + input_names.push_back(cond_info[0].name); + inputs.push_back(std::move(MakeValueInfo(cond_info[0]))); + outputs.push_back(std::move(std::move(MakeValueInfo(cond_info[0])))); + + // Make other inputs + for (size_t i = 0; i < x_info.size(); ++i) { + if (std::find(input_names.begin(), input_names.end(), x_info[i].name) != + input_names.end()) { + continue; + } + + if (!(x_info[i].is_tensor_array)) { + // P2OLogger() << x_info[i].name << "is tensor array" << std::endl; + inputs.push_back(std::move(MakeValueInfo(x_info[i]))); + } + input_names.push_back(x_info[i].name); + outputs.push_back(std::move(MakeValueInfo(x_info[i]))); + } + + graph = ExportBlock(parser, sub_block_idx, parameters, inputs, outputs, nullptr, true); + + /********************* Creat Body Gragh *********************/ + // Make Fake iter + auto fake_iter = temp_helper->Constant(ONNX_NAMESPACE::TensorProto::INT64, + std::vector(1, 1024)); + input_names.insert(input_names.begin(), fake_iter); + for(int i=2;iMakeNode("Loop", input_names, output_names); + AddAttribute(loop_node, "body", graph); +} +} // namespace paddle2onnx \ No newline at end of file diff --git a/paddle2onnx/parser/parser.cc b/paddle2onnx/parser/parser.cc index e8c6c0d8a..0debb87cd 100755 --- a/paddle2onnx/parser/parser.cc +++ b/paddle2onnx/parser/parser.cc @@ -453,19 +453,21 @@ void PaddleParser::GetBlocksOps() { TensorInfo PaddleParser::GetTensorInfo( const std::string& name, const paddle2onnx::framework::proto::BlockDesc& block) const { - auto block_idx = block.idx(); - auto iter = _blocks_var_name2id[block_idx].find(name); - if (iter == _blocks_var_name2id[block_idx].end()) { - if (block_idx == 0) { - Assert(false, - "Cannot find " + name + " in _blocks_var_name2id(global block)."); - } else { - block_idx = block.parent_idx(); - iter = _blocks_var_name2id[block_idx].find(name); - Assert(iter != _blocks_var_name2id[block_idx].end(), - "Cannot find " + name + " in _blocks_var_name2id(parent block)."); + int32_t block_idx = block.idx(); + bool is_find = false; + auto iter = _blocks_var_name2id[block_idx].begin(); + do { + iter = _blocks_var_name2id[block_idx].find(name); + if (iter != _blocks_var_name2id[block_idx].end()) { + is_find = true; + break; } - } + block_idx--; + } while (block_idx >= 0); + + Assert(is_find, + "Cannot find " + name + " in _blocks_var_name2id(global block)."); + auto var_idx = iter->second; // Dangerous conversion, lod tensor array is under limited supporting @@ -582,9 +584,9 @@ std::vector PaddleParser::GetOpAttrVar( std::vector inputs; for (auto i = 0; i < op.attrs_size(); ++i) { if (op.attrs(i).name() == name) { - Assert(IsAttrVar(op, i), "Required AttrVar: " + name + - " type is Variable in operator: " + - op.type()); + Assert(IsAttrVar(op, i), + "Required AttrVar: " + name + + " type is Variable in operator: " + op.type()); // Case 1: Attribute is a single Var if (op.attrs(i).has_var_name()) { inputs.push_back(GetTensorInfo(op.attrs(i).var_name(), block)); @@ -630,8 +632,8 @@ void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op, found = true; if (IsAttrVar(op, i)) break; Assert(op.attrs(i).has_i() || op.attrs(i).has_l(), - "Cannot find int32/int64 data from attr: " + name + " in op:" + - op.type()); + "Cannot find int32/int64 data from attr: " + name + + " in op:" + op.type()); if (op.attrs(i).has_i()) { *res = (int64_t)(op.attrs(i).i()); } else { @@ -728,8 +730,8 @@ void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op, found = true; if (IsAttrVar(op, i)) break; Assert(op.attrs(i).floats_size() >= 0, - "Cannot find list of float data from attr: " + name + " in op: " + - op.type()); + "Cannot find list of float data from attr: " + name + + " in op: " + op.type()); for (auto j = 0; j < op.attrs(i).floats_size(); ++j) { res->push_back(static_cast(op.attrs(i).floats(j))); } @@ -739,7 +741,6 @@ void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op, Assert(found, "Cannot found attribute " + name + " in op: " + op.type()); } - void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op, const std::string& name, std::vector* res) const { @@ -750,8 +751,8 @@ void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op, found = true; if (IsAttrVar(op, i)) break; Assert(op.attrs(i).float64s_size() >= 0, - "Cannot find list of double data from attr: " + name + " in op: " + - op.type()); + "Cannot find list of double data from attr: " + name + + " in op: " + op.type()); for (auto j = 0; j < op.attrs(i).float64s_size(); ++j) { res->push_back(static_cast(op.attrs(i).float64s(j))); } @@ -770,8 +771,8 @@ void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op, found = true; if (IsAttrVar(op, i)) break; Assert(op.attrs(i).bools_size() >= 0, - "Cannot find list of double data from attr: " + name + " in op: " + - op.type()); + "Cannot find list of double data from attr: " + name + + " in op: " + op.type()); for (auto j = 0; j < op.attrs(i).bools_size(); ++j) { res->push_back(static_cast(op.attrs(i).bools(j))); } @@ -881,30 +882,31 @@ bool PaddleParser::ExistsDumplicateTensorName() const { return false; } -#define DECLARE_GET_OP_SCALARS(scalar_type, target_type) \ -template <> \ -void PaddleParser::GetOpScalarsAttr(const paddle2onnx::framework::proto::OpDesc& op, \ - const std::string& name, \ - std::vector* res) const { \ - bool found = false; \ - res->clear(); \ - for (auto i = 0; i < op.attrs_size(); ++i) { \ - if (op.attrs(i).name() == name) { \ - found = true; \ - if (IsAttrVar(op, i)) break; \ - Assert(op.attrs(i).scalars_size() >= 0, \ - "Cannot find list of scalars data from attr: " + name + \ - " in op: " + op.type()); \ - for (auto j = 0; j < op.attrs(i).scalars_size(); ++j) { \ - Assert(op.attrs(i).scalars(j).has_##scalar_type(), \ - "Scalar type does not match with " #scalar_type); \ - res->push_back(static_cast(op.attrs(i).scalars(j).scalar_type())); \ - } \ - break; \ - } \ - } \ - Assert(found, "Cannot found attribute " + name + " in op: " + op.type()); \ -} +#define DECLARE_GET_OP_SCALARS(scalar_type, target_type) \ + template <> \ + void PaddleParser::GetOpScalarsAttr( \ + const paddle2onnx::framework::proto::OpDesc& op, \ + const std::string& name, std::vector* res) const { \ + bool found = false; \ + res->clear(); \ + for (auto i = 0; i < op.attrs_size(); ++i) { \ + if (op.attrs(i).name() == name) { \ + found = true; \ + if (IsAttrVar(op, i)) break; \ + Assert(op.attrs(i).scalars_size() >= 0, \ + "Cannot find list of scalars data from attr: " + name + \ + " in op: " + op.type()); \ + for (auto j = 0; j < op.attrs(i).scalars_size(); ++j) { \ + Assert(op.attrs(i).scalars(j).has_##scalar_type(), \ + "Scalar type does not match with " #scalar_type); \ + res->push_back( \ + static_cast(op.attrs(i).scalars(j).scalar_type())); \ + } \ + break; \ + } \ + } \ + Assert(found, "Cannot found attribute " + name + " in op: " + op.type()); \ + } DECLARE_GET_OP_SCALARS(i, int64_t) DECLARE_GET_OP_SCALARS(i, int32_t) diff --git a/tests/test_while.py b/tests/test_while.py new file mode 100644 index 000000000..d94e35116 --- /dev/null +++ b/tests/test_while.py @@ -0,0 +1,127 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from onnxbase import APIOnnx +from onnxbase import randtool + + +class BaseNet1(paddle.nn.Layer): + def __init__(self): + super(BaseNet1, self).__init__() + + def forward(self, inputs): + i = 0 + while i <= 3: + i += 1 + inputs += 1 + return inputs + + +def test_while_1(): + op = BaseNet1() + op.eval() + obj = APIOnnx(op, "while", [13]) + obj.set_input_data("input_data", paddle.to_tensor(0)) + obj.run() + + +class BaseNet2(paddle.nn.Layer): + def __init__(self): + super(BaseNet2, self).__init__() + + def forward(self, i, inputs): + while i <= 3: + i += 1 + inputs += 1 + return inputs + + +def test_while_2(): + op = BaseNet2() + op.eval() + obj = APIOnnx(op, "while", [13]) + obj.set_input_data("input_data", paddle.to_tensor(0), paddle.to_tensor(0)) + obj.run() + + +class BaseNet3(paddle.nn.Layer): + def __init__(self): + super(BaseNet3, self).__init__() + + def forward(self, i, j, k): + while i <= 3: + j += 1 + k += 1 + i += 1 + return j + k + + +def test_while_3(): + op = BaseNet3() + op.eval() + obj = APIOnnx(op, "while", [13]) + obj.set_input_data("input_data", paddle.to_tensor(0), paddle.to_tensor(0), paddle.to_tensor(0)) + obj.run() + + +class BaseNet4(paddle.nn.Layer): + def __init__(self): + super(BaseNet4, self).__init__() + + def forward(self, i, j, k): + while i <= 3: + if i < 1: + j += 1 + else: + j += 2 + i += 1 + return j + k + + +def test_while_4(): + op = BaseNet4() + op.eval() + obj = APIOnnx(op, "while", [13]) + obj.set_input_data("input_data", paddle.to_tensor(0), paddle.to_tensor(0), paddle.to_tensor(0)) + obj.run() + + +class BaseNet5(paddle.nn.Layer): + def __init__(self): + super(BaseNet5, self).__init__() + + def forward(self, i, j, k): + while i <= 3: + if i < 1: + j += 1 + else: + j += 2 + i += 1 + return j + k + + +def test_while_4(): + op = BaseNet4() + op.eval() + obj = APIOnnx(op, "while", [13]) + obj.set_input_data("input_data", paddle.to_tensor(0), paddle.to_tensor(0), paddle.to_tensor(0)) + obj.run() + + +if __name__ == "__main__": + test_while_1() + test_while_2() + test_while_3() + test_while_4()