Skip to content

Commit fe3279e

Browse files
authored
Initialize operation arguments with ONNX model constants (llvm#8)
* Save current state. * Include constant arguments in source. * Emit constants for Reshape second argument. * Clean-up code. * Add changes to gen_doc.py file. * Propagate constant tensor to Reshape second arg to infer shape. * Update documentation. * Eliminate constant tensor operations when lowering to KRNL dialect. * Replace ConstantTensorOp with ConstantOp. * Add comment to remove temporary Constant lowering code. * Remove unused shape inference for Constant. * Remove comment. * Remove explicit constant elimination. * Refactor code.
1 parent ba02b90 commit fe3279e

File tree

11 files changed

+383
-126
lines changed

11 files changed

+383
-126
lines changed

doc/gen_doc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
("MaxPool", "ImportNodeMaxPool"),
3737
("BatchNormalization", "ImportNodeBatchNormalization"),
3838
("Pad", "ImportNodePad"),
39+
("Reshape", "ImportNodeReshape"),
3940
#("Transpose", "ImportNodeTranspose")
4041
])
4142

src/builder/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
add_library(builder
2+
frontend_dialect_helper.cpp
3+
frontend_dialect_helper.hpp
24
frontend_dialect_transformer.cpp
35
frontend_dialect_transformer.hpp
46
op_build_table.inc
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
//===------------------- frontend_dialect_helper.cpp ----------------------===//
2+
//
3+
// Copyright 2019 The IBM Research Authors.
4+
//
5+
// =============================================================================
6+
//
7+
// Helper methods for handling input ONNX models.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#include "src/builder/frontend_dialect_helper.hpp"
12+
13+
namespace onnf {
14+
15+
void replaceAll(std::string &str, const std::string &from,
16+
const std::string &to) {
17+
if (from.empty())
18+
return;
19+
size_t start_pos = 0;
20+
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
21+
str.replace(start_pos, from.length(), to);
22+
start_pos += to.length(); // In case 'to' contains 'from', like replacing
23+
// 'x' with 'yx'
24+
}
25+
}
26+
27+
std::string legalize_name(std::string name) {
28+
std::replace(name.begin(), name.end(), '/', '_');
29+
std::replace(name.begin(), name.end(), '-', '_');
30+
replaceAll(name, ":", "_colon_");
31+
// If tensor name starts with a number, prepend n to make it a legal c++
32+
// identifier.
33+
if (name.size() > 0 && isdigit(name.at(0)))
34+
name.insert(0, 1, 'n');
35+
return name;
36+
}
37+
38+
mlir::Value OnnxOnnfSymbolMapping::GetTensorByOnnxName(
39+
const std::string &name) {
40+
assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
41+
onnx_name2onnf_tensor.end() &&
42+
"Tensor not found");
43+
return onnx_name2onnf_tensor.at(legalize_name(name));
44+
}
45+
46+
void OnnxOnnfSymbolMapping::AddMapping(
47+
const std::string &name, mlir::Value tensor) {
48+
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
49+
"Tensor already exists.");
50+
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
51+
}
52+
53+
bool OnnxOnnfSymbolMapping::ContainKey(std::string name) {
54+
return onnx_name2onnf_tensor.count(name) != 0;
55+
}
56+
57+
template <typename T>
58+
struct TransformValueToONNXData {
59+
static const google::protobuf::RepeatedField<T> data(
60+
onnx::TensorProto initializer) {
61+
return google::protobuf::RepeatedField<T>();
62+
}
63+
};
64+
65+
template <>
66+
struct TransformValueToONNXData<double> {
67+
static const google::protobuf::RepeatedField<double> data(
68+
onnx::TensorProto initializer) {
69+
return initializer.double_data();
70+
}
71+
};
72+
73+
template <>
74+
struct TransformValueToONNXData<float> {
75+
static const google::protobuf::RepeatedField<float> data(
76+
onnx::TensorProto initializer) {
77+
return initializer.float_data();
78+
}
79+
};
80+
81+
template <>
82+
struct TransformValueToONNXData<int32_t> {
83+
static const google::protobuf::RepeatedField<int32_t> data(
84+
onnx::TensorProto initializer) {
85+
return initializer.int32_data();
86+
}
87+
};
88+
89+
template <>
90+
struct TransformValueToONNXData<int64_t> {
91+
static const google::protobuf::RepeatedField<int64_t> data(
92+
onnx::TensorProto initializer) {
93+
return initializer.int64_data();
94+
}
95+
};
96+
97+
// Helper method for constructing an array attribute from a model input.
98+
template <typename T>
99+
static T* CreateArrayAttribute(onnx::TensorProto initializer, int *size) {
100+
if (initializer.raw_data().size()) {
101+
// copy & take care of endianness
102+
std::vector<char> byteInitializer;
103+
std::copy(initializer.raw_data().begin(), initializer.raw_data().end(),
104+
back_inserter(byteInitializer));
105+
*size = initializer.raw_data().size() / sizeof(T);
106+
return reinterpret_cast<T*>(&byteInitializer[0]);
107+
}
108+
109+
// copy, no need to take care of endianness
110+
auto data = TransformValueToONNXData<T>::data(initializer);
111+
*size = data.size();
112+
return &data[0];
113+
}
114+
115+
void InitializedTensorMapping::AddMapping(
116+
std::string name, onnx::TensorProto tensor) {
117+
assert(nameToInitializedTensor.count(name) == 0 &&
118+
"Tensor initializer already mapped.");
119+
nameToInitializedTensor.emplace(name, tensor);
120+
}
121+
122+
123+
bool InitializedTensorMapping::ContainKey(std::string name) {
124+
return nameToInitializedTensor.count(name) != 0;
125+
}
126+
127+
mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor(
128+
mlir::Location loc, mlir::OpBuilder &builder, std::string name) {
129+
// Initializer for input.
130+
onnx::TensorProto initializer = GetInitializedTensor(name);
131+
132+
// Emit ConstantOp and record the mapping between the input and
133+
// the constant value.
134+
mlir::ArrayAttr constantArrayAttribute;
135+
mlir::Type elementType;
136+
int length;
137+
switch (initializer.data_type()) {
138+
case (onnx::TensorProto::FLOAT): {
139+
float *typeArray =
140+
CreateArrayAttribute<float>(initializer, &length);
141+
std::vector<float> arrayAttrInitializer(
142+
typeArray, typeArray + length);
143+
llvm::ArrayRef<float> array(typeArray, length);
144+
constantArrayAttribute = builder.getF32ArrayAttr(array);
145+
elementType = builder.getF32Type();
146+
break;
147+
}
148+
case (onnx::TensorProto::INT32): {
149+
int32_t *typeArray =
150+
CreateArrayAttribute<int32_t>(initializer, &length);
151+
std::vector<int32_t> arrayAttrInitializer(
152+
typeArray, typeArray + length);
153+
llvm::ArrayRef<int32_t> array(typeArray, length);
154+
constantArrayAttribute = builder.getI32ArrayAttr(array);
155+
elementType = builder.getIntegerType(32);
156+
break;
157+
}
158+
case (onnx::TensorProto::INT64): {
159+
int64_t *typeArray =
160+
CreateArrayAttribute<int64_t>(initializer, &length);
161+
std::vector<int64_t> arrayAttrInitializer(
162+
typeArray, typeArray + length);
163+
llvm::ArrayRef<int64_t> array(typeArray, length);
164+
constantArrayAttribute = builder.getI64ArrayAttr(array);
165+
elementType = builder.getIntegerType(64);
166+
break;
167+
}
168+
}
169+
170+
// Create empty sparse_value attribute.
171+
llvm::ArrayRef<int64_t> array;
172+
auto sparseValueAttribute = builder.getI64ArrayAttr(array);
173+
174+
// Create value attribute.
175+
llvm::ArrayRef<int64_t> tensorDims(initializer.dims().data(),
176+
initializer.dims().size());
177+
mlir::Type tensorType =
178+
mlir::RankedTensorType::get(tensorDims, elementType);
179+
180+
return builder.create<mlir::ONNXConstantOp>(
181+
loc, tensorType, sparseValueAttribute,
182+
constantArrayAttribute);
183+
}
184+
185+
} // namespace onnf
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
//===------------------- frontend_dialect_helper.hpp ----------------------===//
2+
//
3+
// Copyright 2019 The IBM Research Authors.
4+
//
5+
// =============================================================================
6+
//
7+
// Helper methods for handling input ONNX models.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#pragma once
12+
13+
#include <numeric>
14+
#include <regex>
15+
#include <tuple>
16+
17+
#include "mlir/Analysis/Verifier.h"
18+
#include "mlir/Dialect/StandardOps/Ops.h"
19+
#include "mlir/IR/Attributes.h"
20+
#include "mlir/IR/Builders.h"
21+
#include "mlir/IR/Function.h"
22+
#include "mlir/IR/Location.h"
23+
#include "mlir/IR/Matchers.h"
24+
#include "mlir/IR/MLIRContext.h"
25+
#include "mlir/IR/Module.h"
26+
#include "mlir/IR/PatternMatch.h"
27+
#include "mlir/IR/StandardTypes.h"
28+
#include "mlir/IR/Types.h"
29+
30+
#include "llvm/ADT/STLExtras.h"
31+
#include "llvm/ADT/ScopedHashTable.h"
32+
#include "llvm/Support/raw_ostream.h"
33+
34+
#include "src/dialect/onnx/onnx_ops.hpp"
35+
#include "onnx/onnx_pb.h"
36+
37+
namespace onnf {
38+
39+
void replaceAll(std::string &str, const std::string &from,
40+
const std::string &to);
41+
42+
std::string legalize_name(std::string name);
43+
44+
struct OnnxOnnfSymbolMapping {
45+
/*!
46+
* Get MLIR tensor by onnx tensor name.
47+
* @param name onnx tensor name.
48+
* @return onnf tensor corresponding to `name`.
49+
*/
50+
mlir::Value GetTensorByOnnxName(const std::string &name);
51+
52+
/*!
53+
* Add a new mapping from onnx tensor name to MLIR symbol.
54+
* @param name onnx tensor name.
55+
* @param tensor MLIR Value pointer.
56+
*/
57+
void AddMapping(const std::string &name, mlir::Value tensor);
58+
59+
bool ContainKey(std::string name);
60+
61+
private:
62+
/*!
63+
* mapping from onnx tensor names to MLIR tensor.
64+
*/
65+
std::map<std::string, mlir::Value> onnx_name2onnf_tensor;
66+
};
67+
68+
struct InitializedTensorMapping {
69+
// Add new entry.
70+
void AddMapping(std::string name, onnx::TensorProto tensor);
71+
72+
// Check if input is initialized. Not all inputs are, some of the inputs
73+
// require input from the user and are not stored inside the ONNX model
74+
// itself.
75+
bool ContainKey(std::string name);
76+
77+
// Emit constant argument (initialized arguments) as a ConstantOp.
78+
// This method will allow operations to use the constant data contained
79+
// in an ONNX model as they are being compiled.
80+
// This method enables the emission of such constant operation on demand.
81+
//
82+
// This will allow the propagation of shape information passed in as an
83+
// argument to operations such as Reshape and will enable other
84+
// optimizations such as constant folding.
85+
mlir::Value EmitInitializerForInputTensor(mlir::Location loc,
86+
mlir::OpBuilder &builder, std::string name);
87+
88+
// Get initialized tensor.
89+
onnx::TensorProto& GetInitializedTensor(std::string name) {
90+
assert(nameToInitializedTensor.find(name) !=
91+
nameToInitializedTensor.end() &&
92+
"Tensor initializer not found");
93+
return nameToInitializedTensor.at(name);
94+
}
95+
96+
private:
97+
// Mapping from ONNX tensor name to InitializedTensor.
98+
std::map<std::string, onnx::TensorProto> nameToInitializedTensor;
99+
};
100+
101+
} // namespace onnf

0 commit comments

Comments
 (0)