Skip to content

Commit 1f1eef9

Browse files
FEAT-#145: oneDNN: Support for existing onednn_graph dialect ops. (#133)
* FEAT-#145: oneDNN: Support for existing onednn_graph dialect ops. Closes #145.
1 parent 5c678d9 commit 1f1eef9

17 files changed

+799
-6
lines changed

src/dnnl/JsonParser.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,11 @@ inline mlir::Attribute JsonParser::readAttr() {
245245
} else if (_str == "s64[]") {
246246
_ia64.clear();
247247
readNumArray(_ia64);
248-
attr = _builder.getI64ArrayAttr(_ia64);
248+
attr = _builder.getDenseI64ArrayAttr(_ia64);
249249
} else if (_str == "f32[]") {
250250
_fa32.clear();
251251
readNumArray(_fa32);
252-
attr = _builder.getF32ArrayAttr(_fa32);
252+
attr = _builder.getDenseF32ArrayAttr(_fa32);
253253
} else if (_str == "string") {
254254
_reader.read_string(&_str);
255255
attr = _builder.getStringAttr(_str);

src/dnnl/JsonParser.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@
2727
#include <stdfloat>
2828
#else
2929
namespace std {
30+
#if defined(__SIZEOF_FLOAT__) && __SIZEOF_FLOAT__ == 4
3031
using float32_t = float;
32+
#elif defined(__SIZEOF_DOUBLE__) && __SIZEOF_DOUBLE__ == 4
33+
using float32_t = double;
34+
#else
35+
static_assert(false, "Unable to determine 32-bit floating point type");
36+
#endif
3137
} // namespace std
3238
#endif
3339

@@ -145,8 +151,16 @@ class JsonParser {
145151
}
146152
std::unordered_map<std::string, OpBuilderFn> _opBuilders{
147153
GC_OP("Add", mlir::onednn_graph::AddOp),
154+
GC_OP("Divide", mlir::onednn_graph::DivOp),
148155
GC_OP("MatMul", mlir::onednn_graph::MatMulOp),
156+
GC_OP("Multiply", mlir::onednn_graph::MulOp),
157+
GC_OP("Pow", mlir::onednn_graph::PowOp),
158+
GC_OP("ReduceMean", mlir::onednn_graph::ReduceMeanOp),
159+
GC_OP("ReduceSum", mlir::onednn_graph::ReduceSumOp),
149160
GC_OP("ReLU", mlir::onednn_graph::ReLUOp),
161+
GC_OP("Sigmoid", mlir::onednn_graph::SigmoidOp),
162+
GC_OP("Subtract", mlir::onednn_graph::SubOp),
163+
GC_OP("Typecast", mlir::onednn_graph::TypeCastOp),
150164
};
151165
#undef GC_OP
152166

test/dnnl/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ foreach (TEST_SOURCE ${TEST_SOURCES})
1212
target_include_directories(${TEST_NAME} PRIVATE ${GC_LIB_INCLUDES})
1313
if (${TEST_NAME} MATCHES "^TestApi.*")
1414
# The API tests are linked with the shared lib
15-
target_link_libraries(${TEST_NAME} PRIVATE graph_compiler)
15+
target_link_libraries(${TEST_NAME} PRIVATE LLVMSupport graph_compiler)
1616
else ()
1717
# The other tests are linked with the static lib and have non-public includes
1818
target_link_libraries(${TEST_NAME} PRIVATE graph_compiler_static)

test/dnnl/DnnlTestUtils.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,21 @@
2121
#include <sstream>
2222
#include <string>
2323

24-
static std::string read_str_resource(const std::string &name) {
24+
#if __cplusplus > 202002L
25+
#include <stdfloat>
26+
#else
27+
namespace std {
28+
#if defined(__SIZEOF_FLOAT__) && __SIZEOF_FLOAT__ == 4
29+
using float32_t = float;
30+
#elif defined(__SIZEOF_DOUBLE__) && __SIZEOF_DOUBLE__ == 4
31+
using float32_t = double;
32+
#else
33+
static_assert(false, "No 32-bit floating point type available");
34+
#endif
35+
} // namespace std
36+
#endif
37+
38+
static std::string readStrResource(const std::string &name) {
2539
std::filesystem::path res_dir{"resources"};
2640
auto path = std::filesystem::absolute(res_dir / name);
2741
std::ifstream file(path);

test/dnnl/TestApiBasic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include "graph/backend/elyzor/include/dnnl_graph_compiler.h"
2424

2525
TEST(TestApiBasic, basicWorkflow) {
26-
auto json = read_str_resource("add.json");
26+
auto json = readStrResource("add.json");
2727

2828
const struct dnnl_graph_compiler_context ctx = {.num_threads = 4};
2929
const struct dnnl_graph_compiler *gc;

0 commit comments

Comments
 (0)