Skip to content

Commit 58cb76e

Browse files
committed
[LLVM] Initial support for codegen LLVM.
1 parent c8ec411 commit 58cb76e

29 files changed

+1382
-68
lines changed

HalideIR

Submodule HalideIR updated from e68ae61 to 1a11a6c

Makefile

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
ifndef config
22
ifneq ("$(wildcard ./config.mk)","")
3-
config = config.mk
3+
config ?= config.mk
44
else
5-
config = make/config.mk
5+
config ?= make/config.mk
66
endif
77
endif
88

@@ -19,31 +19,24 @@ SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc)
1919
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
2020
ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
2121

22-
ifneq ($(USE_CUDA_PATH), NONE)
23-
NVCC=$(USE_CUDA_PATH)/bin/nvcc
24-
endif
25-
2622
export LDFLAGS = -pthread -lm
27-
export CFLAGS = -std=c++11 -Wall -O2\
28-
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC
29-
export FRAMEWORKS=
30-
31-
ifneq ($(ADD_CFLAGS), NONE)
32-
CFLAGS += $(ADD_CFLAGS)
33-
endif
23+
export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\
24+
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0
3425

35-
ifneq ($(ADD_LDFLAGS), NONE)
36-
LDFLAGS += $(ADD_LDFLAGS)
26+
ifdef CUDA_PATH
27+
NVCC=$(CUDA_PATH)/bin/nvcc
28+
CFLAGS += -I$(CUDA_PATH)/include
29+
LDFLAGS += -L$(CUDA_PATH)/lib64
3730
endif
3831

39-
4032
ifeq ($(USE_CUDA), 1)
4133
CFLAGS += -DTVM_CUDA_RUNTIME=1
4234
LDFLAGS += -lcuda -lcudart -lnvrtc
4335
else
4436
CFLAGS += -DTVM_CUDA_RUNTIME=0
4537
endif
4638

39+
FRAMEWORKS=
4740

4841
ifeq ($(USE_OPENCL), 1)
4942
CFLAGS += -DTVM_OPENCL_RUNTIME=1
@@ -57,6 +50,23 @@ else
5750
CFLAGS += -DTVM_OPENCL_RUNTIME=0
5851
endif
5952

53+
# llvm configuration
54+
LLVM_CONFIG=llvm-config
55+
56+
ifeq ($(USE_LLVM), 1)
57+
LLVM_VERSION=$(shell $(LLVM_CONFIG) --version| cut -b 1,3)
58+
LLVM_INCLUDE=$(filter -I%, $(shell $(LLVM_CONFIG) --cxxflags))
59+
LDFLAGS += $(shell $(LLVM_CONFIG) --ldflags --libs --system-libs)
60+
CFLAGS += $(LLVM_INCLUDE) -DTVM_LLVM_VERSION=$(LLVM_VERSION)
61+
endif
62+
63+
ifdef $(ADD_CFLAGS)
64+
CFLAGS += $(ADD_CFLAGS)
65+
endif
66+
67+
ifdef $(ADD_LDFLAGS)
68+
LDFLAGS += $(ADD_LDFLAGS)
69+
endif
6070

6171
include tests/cpp/unittest.mk
6272

include/tvm/buffer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class BufferNode : public Node {
9090
Type dtype);
9191

9292
static constexpr const char* _type_key = "Buffer";
93-
TVM_DECLARE_NODE_TYPE_INFO(BufferNode);
93+
TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node);
9494
};
9595

9696
inline const BufferNode* Buffer::operator->() const {

include/tvm/codegen.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ PackedFunc BuildStackVM(
3131
LoweredFunc func,
3232
const std::unordered_map<LoweredFunc, PackedFunc>& device_funcs);
3333

34+
/*!
35+
* \brief Build a LLVM VM function, this is still beta
36+
* \param func The LoweredFunc to be build
37+
* \return A packed function representing the func.
38+
*/
39+
PackedFunc BuildLLVM(LoweredFunc func);
40+
3441
/*!
3542
* \brief Build a CUDA function with NVRTC
3643
*

include/tvm/expr.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ using Halide::Internal::make_zero;
3636
using Halide::Internal::as_const_int;
3737
using Halide::Internal::as_const_uint;
3838

39-
4039
inline Type TVMType2Type(TVMType t) {
4140
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
4241
}
@@ -182,7 +181,7 @@ class IterVarNode : public Node {
182181
static IterVar make(Range dom, Var var, std::string thread_tag);
183182

184183
static constexpr const char* _type_key = "IterVar";
185-
TVM_DECLARE_NODE_TYPE_INFO(IterVarNode);
184+
TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node);
186185
};
187186

188187
// inline implementations

include/tvm/ir.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ using Halide::Internal::Realize;
200200
using Halide::Internal::Block;
201201
using Halide::Internal::IfThenElse;
202202
using Halide::Internal::Evaluate;
203+
// ir functions
204+
using Halide::Internal::is_const_power_of_two_integer;
203205

204206
} // namespace ir
205207
} // namespace tvm

include/tvm/lowered_func.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class LoweredFuncNode : public FunctionBaseNode {
9292
}
9393

9494
static constexpr const char* _type_key = "LoweredFunc";
95-
TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode);
95+
TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode, Node);
9696
};
9797

9898
// Implementations of inline functions

include/tvm/operation.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class PlaceholderOpNode : public OperationNode {
3939
Type dtype);
4040

4141
static constexpr const char* _type_key = "PlaceholderOp";
42-
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode);
42+
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode);
4343
};
4444

4545
/*!
@@ -74,7 +74,7 @@ class ComputeOpNode : public OperationNode {
7474
Expr body);
7575

7676
static constexpr const char* _type_key = "ComputeOp";
77-
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode);
77+
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
7878
};
7979

8080
/*!
@@ -123,7 +123,7 @@ class ScanOpNode : public OperationNode {
123123
Array<Tensor> state_placeholder);
124124

125125
static constexpr const char* _type_key = "ScanOp";
126-
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode);
126+
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode);
127127
};
128128

129129

include/tvm/packed_func_ext.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct NodeTypeChecker {
3333
// It can be turned off, but will make non strict checking.
3434
// TODO(tqchen) possibly find alternative to turn of RTTI
3535
using ContainerType = typename T::ContainerType;
36-
return (dynamic_cast<ContainerType*>(sptr) != nullptr);
36+
return sptr->derived_from<ContainerType>();
3737
}
3838
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
3939
using ContainerType = typename T::ContainerType;

0 commit comments

Comments
 (0)