forked from PaddlePaddle/FastDeploy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
onnx_reader.cc
124 lines (114 loc) · 4.28 KB
/
onnx_reader.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include <fstream>
#include <iostream>
#include <set>
#include <string>
#include <cstring>
#include "paddle2onnx/converter.h"
#include "paddle2onnx/mapper/exporter.h"
#include "paddle2onnx/optimizer/paddle2onnx_optimizer.h"
namespace paddle2onnx {
int32_t GetDataTypeFromOnnx(int dtype) {
if (dtype == ONNX_NAMESPACE::TensorProto::FLOAT) {
return 0;
} else if (dtype == ONNX_NAMESPACE::TensorProto::DOUBLE) {
return 1;
} else if (dtype == ONNX_NAMESPACE::TensorProto::UINT8) {
return 2;
} else if (dtype == ONNX_NAMESPACE::TensorProto::INT8) {
return 3;
} else if (dtype == ONNX_NAMESPACE::TensorProto::INT32) {
return 4;
} else if (dtype == ONNX_NAMESPACE::TensorProto::INT64) {
return 5;
} else if (dtype == ONNX_NAMESPACE::TensorProto::FLOAT16) {
return 6;
}
Assert(false, "Only support float/double/uint8/int32/int64/float16 in OnnxReader.");
return -1;
}
OnnxReader::OnnxReader(const char* model_buffer, int buffer_size) {
ONNX_NAMESPACE::ModelProto model;
std::string content(model_buffer, model_buffer + buffer_size);
model.ParseFromString(content);
std::set<std::string> initializer_names;
for (auto i = 0; i < model.graph().initializer_size(); ++i) {
initializer_names.insert(model.graph().initializer(i).name());
}
num_outputs = model.graph().output_size();
Assert(num_outputs <= 100,
"The number of outputs is exceed 100, unexpected situation.");
num_inputs = 0;
for (int i = 0; i < model.graph().input_size(); ++i) {
if (initializer_names.find(model.graph().input(i).name()) !=
initializer_names.end()) {
continue;
}
num_inputs += 1;
Assert(num_inputs <= 100,
"The number of inputs is exceed 100, unexpected situation.");
inputs[i].dtype =
GetDataTypeFromOnnx(model.graph().input(i).type().tensor_type().elem_type());
std::strcpy(inputs[i].name, model.graph().input(i).name().c_str());
auto& shape = model.graph().input(i).type().tensor_type().shape();
int dim_size = shape.dim_size();
inputs[i].rank = dim_size;
inputs[i].shape = new int64_t[dim_size];
for (int j = 0; j < dim_size; ++j) {
inputs[i].shape[j] = static_cast<int64_t>(shape.dim(j).dim_value());
if (inputs[i].shape[j] <= 0) {
inputs[i].shape[j] = -1;
}
}
}
for (int i = 0; i < num_outputs; ++i) {
std::strcpy(outputs[i].name, model.graph().output(i).name().c_str());
outputs[i].dtype =
GetDataTypeFromOnnx(model.graph().output(i).type().tensor_type().elem_type());
auto& shape = model.graph().output(i).type().tensor_type().shape();
int dim_size = shape.dim_size();
outputs[i].rank = dim_size;
outputs[i].shape = new int64_t[dim_size];
for (int j = 0; j < dim_size; ++j) {
outputs[i].shape[j] = static_cast<int64_t>(shape.dim(j).dim_value());
if (outputs[i].shape[j] <= 0) {
outputs[i].shape[j] = -1;
}
}
}
}
bool RemoveMultiClassNMS(const char* model_buffer, int buffer_size,
char** out_model, int* out_model_size) {
ONNX_NAMESPACE::ModelProto model;
std::string content(model_buffer, model_buffer + buffer_size);
model.ParseFromString(content);
auto* graph = model.mutable_graph();
int nms_index = -1;
std::vector<std::string> inputs;
for (int i = 0; i < graph->node_size(); ++i) {
if (graph->node(i).op_type() == "MultiClassNMS") {
nms_index = -1;
for (int j = 0; j < graph->node(i).input_size(); ++j) {
inputs.push_back(graph->node(i).input(j));
}
break;
}
}
graph->clear_output();
for (size_t i = 0; i < inputs.size(); ++i) {
auto output = graph->add_output();
output->set_name(inputs[i]);
auto type_proto = output->mutable_type();
auto tensor_type_proto = type_proto->mutable_tensor_type();
tensor_type_proto->set_elem_type(ONNX_NAMESPACE::TensorProto::FLOAT);
auto shape = tensor_type_proto->mutable_shape();
shape->add_dim()->set_dim_value(-1);
shape->add_dim()->set_dim_value(-1);
shape->add_dim()->set_dim_value(-1);
}
auto optimized_model = ONNX_NAMESPACE::optimization::OptimizeOnnxModel(model);
*out_model_size = optimized_model.ByteSizeLong();
*out_model = new char[*out_model_size];
optimized_model.SerializeToArray(*out_model, *out_model_size);
return true;
}
} // namespace paddle2onnx