Skip to content

Commit 766cbcf

Browse files
FEAT-#128: oneDNN: Implement compilation and execution
1 parent eb43c64 commit 766cbcf

File tree

11 files changed

+283
-152
lines changed

11 files changed

+283
-152
lines changed

CMakeLists.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,13 @@ if(GC_ENABLE_BINDINGS_PYTHON)
9595
endif()
9696

9797
set(GC_LIB_LINKED_LIBS
98-
GCPasses
99-
MLIROneDNNGraph
98+
GCJitWrapper
99+
GCCpuRuntime
100100
)
101-
add_library(graph_compiler SHARED ${GC_LIB_SOURCES})
101+
add_mlir_library(graph_compiler SHARED ${GC_LIB_SOURCES})
102102
target_include_directories(graph_compiler PUBLIC ${GC_LIB_INCLUDES})
103-
target_compile_options(graph_compiler PRIVATE -fvisibility=hidden)
103+
target_compile_options(graph_compiler PRIVATE -fvisibility=hidden -fexceptions)
104104
target_link_options(graph_compiler PRIVATE -Wl,--gc-sections)
105-
target_link_libraries(graph_compiler PRIVATE ${GC_LIB_LINKED_LIBS})
105+
target_link_libraries(graph_compiler PUBLIC ${GC_LIB_LINKED_LIBS})
106106

107107
add_subdirectory(test)

src/dnnl/JsonParser.cpp

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,18 @@
2323

2424
#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h"
2525

26+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2627
#include "mlir/ExecutionEngine/OptUtils.h"
2728
#include "mlir/IR/Builders.h"
2829
#include "mlir/IR/MLIRContext.h"
2930
#include "mlir/InitAllPasses.h"
3031

3132
#include "JsonParser.h"
3233

33-
mlir::ModuleOp JsonParser::parse() {
34-
std::vector<size_t> inputPorts;
34+
mlir::ModuleOp
35+
JsonParser::parse(llvm::SmallVector<size_t> &outputIds,
36+
std::unordered_map<std::size_t, Strides> &strides) {
37+
llvm::SmallVector<size_t> inputPorts;
3538
bool hasInputPorts = false;
3639
bool hasOutputPorts = false;
3740
_reader.begin_object();
@@ -57,7 +60,7 @@ mlir::ModuleOp JsonParser::parse() {
5760
readNumArray(inputPorts);
5861
} else if (_str == "output_ports") {
5962
hasOutputPorts = true;
60-
readNumArray(_outputIds);
63+
readNumArray(outputIds);
6164
} else if (_str == "graph") {
6265
_reader.begin_array();
6366
while (_reader.next_array_item()) {
@@ -87,13 +90,13 @@ mlir::ModuleOp JsonParser::parse() {
8790

8891
if (!hasOutputPorts) {
8992
// If output_ports is not specified, using the last operation's outputs.
90-
_outputIds = _uaS;
93+
outputIds.append(_uaS.begin(), _uaS.end());
9194
}
9295

9396
// The function return values.
94-
std::vector<mlir::Value> outputs;
95-
outputs.reserve(_outputIds.size());
96-
for (auto id : _outputIds) {
97+
llvm::SmallVector<mlir::Value> outputs;
98+
outputs.reserve(outputIds.size());
99+
for (auto id : outputIds) {
97100
auto entry = _valueMap.find(id);
98101
if (entry == _valueMap.end()) {
99102
_str = std::to_string(id);
@@ -103,13 +106,25 @@ mlir::ModuleOp JsonParser::parse() {
103106
}
104107
auto ret = _builder.create<mlir::func::ReturnOp>(_loc, outputs);
105108

109+
// Copying the strides for the inputs and outputs.
110+
for (auto &ids : {&_inputIds, &outputIds}) {
111+
for (auto id : *ids) {
112+
auto entry = _strides.find(id);
113+
if (entry != _strides.end()) {
114+
strides[id] = entry->second;
115+
}
116+
}
117+
}
118+
106119
// Creating the final function and moving the entry block.
107120
mlir::OpBuilder builder(_builder.getContext());
108121
auto module = builder.create<mlir::ModuleOp>(_loc);
109122
auto func = builder.create<mlir::func::FuncOp>(
110-
_loc, "main",
123+
_loc, "compute",
111124
builder.getFunctionType(_entryBlock->getArgumentTypes(),
112125
ret->getOperandTypes()));
126+
func->setAttr(mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName(),
127+
mlir::UnitAttr::get(_builder.getContext()));
113128
auto entry = func.addEntryBlock();
114129
_entryBlock->moveBefore(entry);
115130
entry->erase();
@@ -251,7 +266,9 @@ inline mlir::Attribute JsonParser::readAttr() {
251266

252267
mlir::Type JsonParser::readTensorType() {
253268
GetTypeFn getTypeFn = nullptr;
269+
bool strided = false;
254270
_ia64.clear();
271+
_ia642.clear();
255272
_reader.begin_object();
256273

257274
while (_reader.next_object_item(&_str)) {
@@ -267,22 +284,17 @@ mlir::Type JsonParser::readTensorType() {
267284
} else if (_str == "shape") {
268285
readNumArray(_ia64);
269286
} else if (_str == "stride") {
270-
_ia642.clear();
271287
readNumArray(_ia642);
272-
if ((_ia642.size() > 1) ||
273-
((_ia642.size() == 1) &&
274-
(_ia642[0] != std::numeric_limits<int64_t>::min()))) {
275-
// TODO: Add support for strides
276-
throwErr<std::logic_error>("Unsupported stride value: ");
277-
}
278288
} else if (_str == "layout_type") {
279289
_reader.read_string(&_str);
280-
if ((_str != "undef") && (_str != "any")) {
290+
if (_str == "strided") {
291+
strided = true;
292+
} else if ((_str != "undef") && (_str != "any")) {
281293
throwErr<std::logic_error>("Unsupported layout_type: ");
282294
}
283295
} else if (_str == "property_type") {
284296
_reader.read_string(&_str);
285-
if ((_str != "undef") && (_str != "constant")) {
297+
if ((_str != "undef") && (_str != "variable") && (_str != "constant")) {
286298
throwErr<std::logic_error>("Unsupported property_type: ");
287299
}
288300
} else {
@@ -295,6 +307,10 @@ mlir::Type JsonParser::readTensorType() {
295307
throwErr<std::invalid_argument>("dtype is not specified");
296308
}
297309

310+
if (strided) {
311+
_strides[_uS].assign(_ia642.begin(), _ia642.end());
312+
}
313+
298314
if ((_ia64.size() == 1) &&
299315
(_ia64[0] == std::numeric_limits<int64_t>::min())) {
300316
return mlir::UnrankedTensorType::get(getTypeFn(_builder));

src/dnnl/JsonParser.h

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,25 @@ using float32_t = float;
4343
#include "mlir/Parser/Parser.h"
4444
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
4545

46+
#include "dnnl_types.h"
4647
#include "graph/utils/json.hpp"
4748

49+
using Strides = llvm::SmallVector<int64_t, DNNL_MAX_NDIMS>;
50+
4851
class JsonParser {
4952
dnnl::impl::graph::utils::json::json_reader_t _reader;
5053
mlir::OpBuilder _builder;
5154
mlir::Location _loc;
5255
mlir::Block *_entryBlock;
53-
std::vector<size_t> &_inputIds;
54-
std::vector<size_t> &_outputIds;
56+
llvm::SmallVector<size_t> &_inputIds;
57+
std::unordered_map<std::size_t, Strides> _strides;
5558
// Function input and operations output values. Used to connect the
5659
// operations inputs and outputs.
5760
std::unordered_map<std::size_t, mlir::Value> _valueMap;
5861
// Temporary value holders, used by the parser
59-
std::vector<mlir::Value> _operands;
60-
std::vector<mlir::Type> _resultTypes;
61-
std::vector<mlir::NamedAttribute> _attributes;
62+
llvm::SmallVector<mlir::Value> _operands;
63+
llvm::SmallVector<mlir::Type> _resultTypes;
64+
llvm::SmallVector<mlir::NamedAttribute> _attributes;
6265
std::string _str;
6366
std::string _str2;
6467
std::size_t _uS;
@@ -70,9 +73,9 @@ class JsonParser {
7073
std::vector<std::float32_t> _fa32;
7174

7275
JsonParser(mlir::MLIRContext &context, std::istream &stream,
73-
std::vector<size_t> &inputIds, std::vector<size_t> &outputIds)
76+
llvm::SmallVector<size_t> &inputIds)
7477
: _reader(&stream), _builder(&context), _loc(_builder.getUnknownLoc()),
75-
_inputIds(inputIds), _outputIds(outputIds), _valueMap(), _operands(),
78+
_inputIds(inputIds), _strides(), _valueMap(), _operands(),
7679
_resultTypes(), _attributes(), _str(), _str2(), _uS(), _i64(), _f32(),
7780
_uaS(), _ia64(), _ia642(), _fa32() {
7881
// Creating a dummy function since we don't know the actual type yet.
@@ -82,7 +85,8 @@ class JsonParser {
8285
_builder.setInsertionPointToStart(_entryBlock);
8386
}
8487

85-
mlir::ModuleOp parse();
88+
mlir::ModuleOp parse(llvm::SmallVector<size_t> &outputIds,
89+
std::unordered_map<std::size_t, Strides> &strides);
8690
void readOp();
8791
mlir::Attribute readAttr();
8892
mlir::Type readTensorType();
@@ -120,11 +124,12 @@ class JsonParser {
120124
}
121125
}
122126

123-
template <typename T> inline void readNumArray(std::vector<T> &vec) {
127+
template <typename T, template <typename...> class Container, typename... Any>
128+
inline void readNumArray(Container<T, Any...> &c) {
124129
_reader.begin_array();
125130
for (T value; _reader.next_array_item();) {
126131
_reader.read_number(&value);
127-
vec.push_back(value);
132+
c.push_back(value);
128133
}
129134
}
130135

@@ -175,14 +180,16 @@ class JsonParser {
175180
* @param json JSON string containing the oneDNN graph.
176181
* @param inputIds Input tensor IDs are added to this vector.
177182
* @param outputIds Output tensor IDs are added to this vector.
183+
* @param strides Strides for each tensor are added to this map.
178184
* @return The resulting MLIR module.
179185
*/
180-
static mlir::ModuleOp parse(mlir::MLIRContext &context,
181-
const std::string_view &json,
182-
std::vector<size_t> &inputIds,
183-
std::vector<size_t> &outputIds) {
186+
static mlir::ModuleOp
187+
parse(mlir::MLIRContext &context, const std::string_view &json,
188+
llvm::SmallVector<size_t> &inputIds,
189+
llvm::SmallVector<size_t> &outputIds,
190+
std::unordered_map<std::size_t, Strides> &strides) {
184191
std::istringstream stream(json.data());
185-
JsonParser parser(context, stream, inputIds, outputIds);
186-
return parser.parse();
192+
JsonParser parser(context, stream, inputIds);
193+
return parser.parse(outputIds, strides);
187194
}
188195
};

0 commit comments

Comments
 (0)