Skip to content

Commit 837696b

Browse files
author
Menooker
authored
Add all-in-one driver for compile-and-execute (#97)
1 parent f072ff0 commit 837696b

File tree

10 files changed

+278
-2
lines changed

10 files changed

+278
-2
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//===-- Driver.h - The top-level MLIR compiler driver -----------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef GC_EXECUTIONENGINE_DRIVER_DRIVER_H
10+
#define GC_EXECUTIONENGINE_DRIVER_DRIVER_H
11+
12+
#include "mlir/ExecutionEngine/CRunnerUtils.h"
13+
#include "mlir/ExecutionEngine/ExecutionEngine.h"
14+
#include <memory>
15+
#include <string_view>
16+
17+
namespace mlir {
18+
class DialectRegistry;
19+
namespace gc {
20+
21+
const DialectRegistry &initCompilerAndGetDialects();
22+
23+
// the pointers to XXXMemRefType
24+
using GeneralMemrefPtr = void *;
25+
using JitModuleFuncT = void (*)(void **);
26+
27+
struct DriverOptions {
28+
/// the optimization level for the LLVM-JIT
29+
llvm::CodeGenOptLevel jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive;
30+
/// whether to run the MLIR transformation passes
31+
bool runTransforms = true;
32+
/// todo: target machine, etc.
33+
};
34+
35+
class JitModule {
36+
public:
37+
static llvm::Expected<std::shared_ptr<JitModule>>
38+
create(Operation *op, const DriverOptions &options = {});
39+
40+
/// args should be an array of XXXMemrefType*
41+
void call(GeneralMemrefPtr *args, std::size_t numArgs) {
42+
// Silly code, MLIR execution engine requires pointers of real args as
43+
// inputs
44+
llvm::SmallVector<void *, 32> realargs;
45+
realargs.reserve(numArgs);
46+
for (size_t i = 0; i < numArgs; i++) {
47+
realargs.push_back(&args[i]);
48+
}
49+
compute(realargs.data());
50+
}
51+
52+
/// directly call compute(). args should be an array of void*. args[i] should
53+
/// be a pointer to the real data. For passing memref, users need to 1) create
54+
/// a pointer to XXXMemrefType 2) store the pointer to pointer to
55+
/// XXXMemrefType in args[i]
56+
void callRaw(void **args) { compute(args); }
57+
58+
JitModule(std::unique_ptr<ExecutionEngine> engine, JitModuleFuncT compute);
59+
~JitModule();
60+
61+
private:
62+
std::unique_ptr<ExecutionEngine> engine;
63+
JitModuleFuncT compute;
64+
};
65+
66+
} // namespace gc
67+
} // namespace mlir
68+
69+
#endif

lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRFuncDialect)
2+
13
add_mlir_dialect_library(MLIRCPURuntimeDialect
24
CPURuntimeDialect.cpp
35
CPURuntimeOps.cpp
@@ -10,5 +12,5 @@ add_mlir_dialect_library(MLIRCPURuntimeDialect
1012
MLIRCPURuntimePassesIncGen
1113

1214
LINK_LIBS PUBLIC
13-
MLIRFuncDialect
15+
${MLIR_LINK_COMPONENTS}
1416
)

lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRFuncDialect)
2+
13
add_mlir_dialect_library(MLIRCPURuntimeTransforms
24
CPURuntimeToLLVM.cpp
35

@@ -8,7 +10,7 @@ add_mlir_dialect_library(MLIRCPURuntimeTransforms
810
MLIRCPURuntimePassesIncGen
911

1012
LINK_LIBS PUBLIC
11-
MLIRFuncDialect
13+
${MLIR_LINK_COMPONENTS}
1214
MLIRCPURuntimeDialect
1315
)
1416

lib/gc/ExecutionEngine/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
add_subdirectory(CPURuntime)
2+
add_subdirectory(Driver)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
if(GC_DEV_LINK_LLVM_DYLIB)
2+
set(LLVM_LINK_COMPONENTS
3+
LLVM
4+
)
5+
get_property(dialect_libs GLOBAL PROPERTY GC_DIALECT_LIBS)
6+
get_property(conversion_libs GLOBAL PROPERTY GC_PASS_LIBS)
7+
set(MLIR_LINK_COMPONENTS
8+
MLIR
9+
MLIRExecutionEngineShared
10+
)
11+
else()
12+
set(LLVM_LINK_COMPONENTS
13+
Core
14+
Support
15+
nativecodegen
16+
native
17+
)
18+
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
19+
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
20+
set(MLIR_LINK_COMPONENTS
21+
MLIRBuiltinToLLVMIRTranslation
22+
MLIRExecutionEngine
23+
MLIRLLVMDialect
24+
MLIRLLVMToLLVMIRTranslation
25+
MLIRToLLVMIRTranslationRegistration
26+
)
27+
endif()
28+
29+
add_mlir_library(GCJitWrapper
30+
Driver.cpp
31+
32+
ADDITIONAL_HEADER_DIRS
33+
${PROJECT_SOURCE_DIR}/include
34+
35+
LINK_LIBS PUBLIC
36+
${MLIR_LINK_COMPONENTS}
37+
${dialect_libs}
38+
${conversion_libs}
39+
GCPasses
40+
)
41+
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//===-- Driver.cpp - Top-level MLIR compiler driver -------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "gc/ExecutionEngine/Driver/Driver.h"
10+
#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h"
11+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h"
12+
#include "gc/Transforms/Passes.h"
13+
#include "mlir/InitAllDialects.h"
14+
#include "mlir/InitAllPasses.h"
15+
#include "mlir/Pass/PassManager.h"
16+
#include "mlir/Target/LLVMIR/Dialect/All.h"
17+
#include "string.h"
18+
#include "llvm/Support/InitLLVM.h"
19+
#include "llvm/Support/TargetSelect.h"
20+
21+
namespace mlir {
22+
namespace gc {
23+
24+
static DialectRegistry initDialects() {
25+
mlir::registerAllPasses();
26+
mlir::gc::registerGraphCompilerPasses();
27+
mlir::cpuruntime::registerCPURuntimePasses();
28+
mlir::DialectRegistry registry;
29+
registry.insert<mlir::cpuruntime::CPURuntimeDialect>();
30+
mlir::registerAllDialects(registry);
31+
mlir::cpuruntime::registerConvertCPURuntimeToLLVMInterface(registry);
32+
registry.insert<mlir::onednn_graph::OneDNNGraphDialect>();
33+
llvm::InitializeNativeTarget();
34+
llvm::InitializeNativeTargetAsmPrinter();
35+
llvm::InitializeNativeTargetAsmParser();
36+
mlir::registerAllToLLVMIRTranslations(registry);
37+
return registry;
38+
}
39+
40+
const DialectRegistry &initCompilerAndGetDialects() {
41+
static DialectRegistry reg = initDialects();
42+
return reg;
43+
}
44+
45+
static const char defaultComputeName[] = "_mlir_ciface_compute";
46+
47+
llvm::Expected<std::shared_ptr<JitModule>>
48+
JitModule::create(Operation *op, const DriverOptions &options) {
49+
if (options.runTransforms) {
50+
mlir::PassManager pm{op->getContext()};
51+
populateCPUPipeline(pm);
52+
if (auto result = pm.run(op); failed(result)) {
53+
return llvm::make_error<llvm::StringError>(
54+
"MLIR pass error", llvm::inconvertibleErrorCode());
55+
}
56+
}
57+
ExecutionEngineOptions exeOptions;
58+
exeOptions.jitCodeGenOptLevel = options.jitCodeGenOptLevel;
59+
std::unique_ptr<llvm::TargetMachine> tm = nullptr;
60+
auto exec = ExecutionEngine::create(op, exeOptions, std::move(tm));
61+
if (!exec) {
62+
return exec.takeError();
63+
}
64+
auto &engine = *exec;
65+
JitModuleFuncT compute;
66+
{
67+
auto expectCompute = engine->lookupPacked(defaultComputeName);
68+
if (!expectCompute) {
69+
return expectCompute.takeError();
70+
}
71+
compute = *expectCompute;
72+
}
73+
return std::make_shared<JitModule>(std::move(engine), compute);
74+
}
75+
76+
JitModule::JitModule(std::unique_ptr<ExecutionEngine> engine,
77+
JitModuleFuncT compute)
78+
: engine{std::move(engine)}, compute{compute} {}
79+
JitModule::~JitModule() = default;
80+
81+
} // namespace gc
82+
} // namespace mlir

lib/gc/Transforms/Pipeline.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
1111
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
1212
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13+
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
1314
#include "mlir/Dialect/Linalg/Passes.h"
1415
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1516
#include "mlir/Dialect/MemRef/Transforms/Passes.h"

test/mlir/unittests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ function(add_mlir_unittest test_dirname)
1313
endfunction()
1414

1515
add_subdirectory(Example)
16+
add_subdirectory(ExecutionEngine)
1617

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
add_mlir_unittest(GCExecutionEngineTests
2+
JitWrapper.cpp
3+
)
4+
target_link_libraries(GCExecutionEngineTests
5+
PRIVATE
6+
GCJitWrapper
7+
GCCpuRuntime)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
//===-- JitWrapper.cpp - Wrapper for JIT ------------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "gc/ExecutionEngine/Driver/Driver.h"
10+
#include "mlir/AsmParser/AsmParser.h"
11+
#include "mlir/ExecutionEngine/MemRefUtils.h"
12+
#include "mlir/IR/AsmState.h"
13+
#include "mlir/IR/BuiltinOps.h"
14+
#include "mlir/IR/MLIRContext.h"
15+
#include "mlir/Parser/Parser.h"
16+
#include "mlir/Pass/PassManager.h"
17+
#include "llvm/Support/ErrorOr.h"
18+
#include "llvm/Support/MemoryBuffer.h"
19+
#include "llvm/Support/SourceMgr.h"
20+
#include "llvm/Support/raw_ostream.h"
21+
#include "gtest/gtest.h"
22+
#include <memory>
23+
24+
using namespace mlir;
25+
26+
static const char code1[] = R"mlir(
27+
module {
28+
llvm.mlir.global constant @__num_orig_num_args(3 : i32) : i32
29+
func.func @compute(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } {
30+
%out = tensor.empty() : tensor<128xf32>
31+
%2 = linalg.add ins(%a, %b : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32>
32+
return %2 : tensor<128xf32>
33+
}
34+
}
35+
)mlir";
36+
37+
extern "C" {
38+
extern int gc_runtime_keep_alive;
39+
}
40+
41+
TEST(ExecutionEngine, JitWrapper) {
42+
gc_runtime_keep_alive = 0;
43+
MLIRContext ctx{gc::initCompilerAndGetDialects()};
44+
std::unique_ptr<llvm::MemoryBuffer> ir_buffer =
45+
llvm::MemoryBuffer::getMemBuffer(code1);
46+
// Parse the input mlir.
47+
llvm::SourceMgr sourceMgr;
48+
sourceMgr.AddNewSourceBuffer(std::move(ir_buffer), llvm::SMLoc());
49+
mlir::OwningOpRef<mlir::ModuleOp> module =
50+
mlir::parseSourceFile<mlir::ModuleOp>(sourceMgr, &ctx);
51+
ASSERT_TRUE(module);
52+
auto jited = gc::JitModule::create(module.get());
53+
bool jit_success = static_cast<bool>(jited);
54+
if (!jit_success) {
55+
auto err = jited.takeError();
56+
llvm::errs() << err;
57+
llvm::consumeError(std::move(err));
58+
}
59+
ASSERT_TRUE(jit_success);
60+
OwningMemRef<float, 1> bufA{
61+
{128}, {128}, [](float &ptr, ArrayRef<int64_t>) { ptr = 1.0f; }};
62+
OwningMemRef<float, 1> bufB{
63+
{128}, {128}, [](float &ptr, ArrayRef<int64_t> idx) { ptr = idx[0]; }};
64+
OwningMemRef<float, 1> bufC{{128}, {128}};
65+
void *args[] = {&*bufA, &*bufB, &*bufC};
66+
jited.get()->call(args, 3);
67+
for (int i = 0; i < 128; i++) {
68+
ASSERT_EQ(bufC[{i}], 1.0f + i);
69+
}
70+
}

0 commit comments

Comments
 (0)