Skip to content

Commit

Permalink
[BACKEND] Vulkan Runtime and SPIRV Codegen (apache#861)
Browse files Browse the repository at this point in the history
* [BACKEND] Vulkan Runtime and SPIRV Codegen

* fix doc
  • Loading branch information
tqchen authored Feb 2, 2018
1 parent 108e9f3 commit 79d503f
Show file tree
Hide file tree
Showing 50 changed files with 3,869 additions and 143 deletions.
25 changes: 22 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.5)
cmake_minimum_required(VERSION 3.7)
project(tvm C CXX)

if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/build/private/local_config.cmake)
Expand All @@ -22,6 +22,7 @@ endif()

tvm_option(USE_CUDA "Build with CUDA" OFF)
tvm_option(USE_OPENCL "Build with OpenCL" OFF)
tvm_option(USE_VULKAN "Build with Vulkan" OFF)
tvm_option(USE_OPENGL "Build with OpenGL" OFF)
tvm_option(USE_METAL "Build with Metal" OFF)
tvm_option(USE_RPC "Build with RPC" ON)
Expand Down Expand Up @@ -88,9 +89,11 @@ file(GLOB_RECURSE HALIDEIR_SRCS HalideIR/src/*.cpp)
list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS})
file(GLOB RUNTIME_SRCS src/runtime/*.cc)
file(GLOB COMPILER_LLVM_SRCS src/codegen/llvm/*.cc)
file(GLOB COMPILER_VULKAN_SRCS src/codegen/spirv/*.cc)
file(GLOB RUNTIME_CUDA_SRCS src/runtime/cuda/*.cc)
file(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc)
file(GLOB RUNTIME_OPENGL_SRCS src/runtime/opengl/*.cc)
file(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/*.cc)
file(GLOB RUNTIME_METAL_SRCS src/runtime/metal/*.mm)
file(GLOB RUNTIME_RPC_SRCS src/runtime/rpc/*.cc)
file(GLOB RUNTIME_GRAPH_SRCS src/runtime/graph/*.cc)
Expand Down Expand Up @@ -151,6 +154,22 @@ else(USE_OPENGL)
add_definitions(-DTVM_OPENGL_RUNTIME=0)
endif(USE_OPENGL)

if(USE_VULKAN)
find_package(Vulkan REQUIRED)
message(STATUS "Build with VULKAN support")
include_directories(${Vulkan_INCLUDE_DIRS})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${Vulkan_LIBRARIES})
list(APPEND RUNTIME_SRCS ${RUNTIME_VULKAN_SRCS})
list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS})
get_filename_component(VULKAN_LIB_PATH ${Vulkan_LIBRARY} DIRECTORY)
find_library(SPIRV_TOOLS_LIB SPIRV-Tools
${VULKAN_LIB_PATH}/spirv-tools)
list(APPEND TVM_LINKER_LIBS ${SPIRV_TOOLS_LIB})
add_definitions(-DTVM_VULKAN_RUNTIME=1)
else(USE_VULKAN)
add_definitions(-DTVM_VULKAN_RUNTIME=0)
endif(USE_VULKAN)

if(USE_METAL)
find_package(OpenCL QUIET REQUIRED)
message(STATUS "Build with Metal support")
Expand All @@ -174,7 +193,7 @@ if(USE_GRAPH_RUNTIME)
endif(USE_GRAPH_RUNTIME)

if(USE_LLVM)
find_package(LLVM CONFIG REQUIRED)
find_spackage(LLVM CONFIG REQUIRED)
include_directories(${LLVM_INCLUDE_DIRS})
add_definitions(${LLVM_DEFINITIONS})
set(TVM_LLVM_VERSION ${LLVM_VERSION_MAJOR}${LLVM_VERSION_MINOR})
Expand Down Expand Up @@ -252,4 +271,4 @@ if(MSVC)
target_compile_definitions(tvm_runtime PRIVATE -DHalide_EXPORTS)
target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS)
target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS)
endif()
endif()
2 changes: 1 addition & 1 deletion HalideIR
16 changes: 16 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ CUDA_SRC = $(wildcard src/runtime/cuda/*.cc)
ROCM_SRC = $(wildcard src/runtime/rocm/*.cc)
OPENCL_SRC = $(wildcard src/runtime/opencl/*.cc)
OPENGL_SRC = $(wildcard src/runtime/opengl/*.cc)
VULKAN_SRC = $(wildcard src/runtime/vulkan/*.cc)
RPC_SRC = $(wildcard src/runtime/rpc/*.cc)
GRAPH_SRC = $(wildcard src/runtime/graph/*.cc)
RUNTIME_SRC = $(wildcard src/runtime/*.cc)
Expand All @@ -69,6 +70,7 @@ CUDA_OBJ = $(patsubst src/%.cc, build/%.o, $(CUDA_SRC))
ROCM_OBJ = $(patsubst src/%.cc, build/%.o, $(ROCM_SRC))
OPENCL_OBJ = $(patsubst src/%.cc, build/%.o, $(OPENCL_SRC))
OPENGL_OBJ = $(patsubst src/%.cc, build/%.o, $(OPENGL_SRC))
VULKAN_OBJ = $(patsubst src/%.cc, build/%.o, $(VULKAN_SRC))
RPC_OBJ = $(patsubst src/%.cc, build/%.o, $(RPC_SRC))
GRAPH_OBJ = $(patsubst src/%.cc, build/%.o, $(GRAPH_SRC))
CC_OBJ = $(patsubst src/%.cc, build/%.o, $(CC_SRC)) $(LLVM_OBJ)
Expand Down Expand Up @@ -129,6 +131,20 @@ else
CFLAGS += -DTVM_OPENCL_RUNTIME=0
endif

ifdef VULKAN_SDK
CFLAGS += -I$(VULKAN_SDK)/include
LDFLAGS += -L$(VULKAN_SDK)/lib
LDFLAGS += -L$(VULKAN_SDK)/lib/spirv-tools
endif

ifeq ($(USE_VULKAN), 1)
CFLAGS += -DTVM_VULKAN_RUNTIME=1
LDFLAGS += -lvulkan -lSPIRV-Tools
RUNTIME_DEP += $(VULKAN_OBJ)
else
CFLAGS += -DTVM_VULKAN_RUNTIME=0
endif

ifeq ($(USE_OPENGL), 1)
CFLAGS += -DTVM_OPENGL_RUNTIME=1
EMCC_FLAGS += -DTVM_OPENGL_RUNTIME=1
Expand Down
12 changes: 12 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,18 @@ LoweredFunc LowerTVMBuiltin(LoweredFunc f);
*/
LoweredFunc CombineContextCall(LoweredFunc f);

/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
* to avoid pointer casting in backend when possible.
*
* \note implemeneted in storage_rewrite.cc
* \param f The function to be trasnformed
* \return Transformed function.
*/
LoweredFunc PointerValueTypeRewrite(LoweredFunc f);

/*!
* \brief Lower intrinsic function calls.
* \param f The device function to be lowered.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ typedef int64_t tvm_index_t;

/*! \brief Extension device types in TVM */
typedef enum {
kDLVulkan = 7,
kOpenGL = 11,

// Extension DRAM type, used for quickly test extension device
// The device api can differ depending on the xpu driver registered.
kExtDev = 12,
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from . import target

from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm, opengl, ext_dev
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev

from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import Function
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class TVMContext(ctypes.Structure):
1 : 'cpu',
2 : 'gpu',
4 : 'opencl',
7 : 'vulkan',
8 : 'metal',
9 : 'vpi',
10: 'rocm',
Expand All @@ -109,6 +110,7 @@ class TVMContext(ctypes.Structure):
'nvptx': 2,
'cl': 4,
'opencl': 4,
'vulkan': 7,
'metal': 8,
'vpi': 9,
'rocm': 10,
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/contrib/spirv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Utility for Interacting with SPIRV Tools"""
import subprocess
import os
from . import util


def optimize(spv_bin):
"""Optimize SPIRV using spirv-opt via CLI
Note that the spirv-opt is still experimental.
Parameters
----------
spv_bin : bytearray
The spirv file
Return
------
cobj_bin : bytearray
The HSA Code Object
"""

tmp_dir = util.tempdir()
tmp_in = tmp_dir.relpath("input.spv")
tmp_out = tmp_dir.relpath("output.spv")
with open(tmp_in, "wb") as out_file:
out_file.write(bytes(spv_bin))

sdk = os.environ.get("VULKAN_SDK", None)
cmd = os.path.join(sdk, "bin/spirv-opt") if sdk else "spirv-opt"
args = [cmd, "-O", tmp_in, "-o", tmp_out]
proc = subprocess.Popen(
args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()

if proc.returncode != 0:
msg = "Opitmizationerror using spirv-opt:\n"
msg += str(out)
raise RuntimeError(msg)

return bytearray(open(tmp_out, "rb").read())
18 changes: 18 additions & 0 deletions python/tvm/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,23 @@ def vpi(dev_id=0):
"""
return TVMContext(9, dev_id)


def vulkan(dev_id=0):
"""Construct a Vulkan device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(7, dev_id)


def opengl(dev_id=0):
"""Construct a OpenGL device
Expand All @@ -135,6 +152,7 @@ def opengl(dev_id=0):
"""
return TVMContext(11, dev_id)


def ext_dev(dev_id=0):
"""Construct a extension device
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(self,
# For now assume rocm schedule for opencl
self.keys += ("rocm", "gpu")
self.max_num_threads = 256
elif target_name in ("metal",):
elif target_name in ("metal", "vulkan"):
self.keys += ("gpu",)
self.max_num_threads = 256
elif target_name in ("opengl",):
Expand Down
2 changes: 2 additions & 0 deletions src/codegen/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,8 @@ void CodeGenC::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*)
}

void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*)
// constraint of current logic
CHECK_EQ(op->base.type(), Int(32));
os << "((int" << op->lanes << ")(";
for (int i = 0; i < op->lanes; i++) {
os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")";
Expand Down
59 changes: 59 additions & 0 deletions src/codegen/codegen_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*!
* Copyright (c) 2018 by Contributors
* \file codegen_common.h
* \brief Common utility for codegen.
*/
#ifndef TVM_CODEGEN_CODEGEN_COMMON_H_
#define TVM_CODEGEN_CODEGEN_COMMON_H_

#include <tvm/arithmetic.h>
#include "../arithmetic/compute_expr.h"

namespace tvm {
namespace codegen {

/*!
* \brief Visit AssertStmt recursively, update align_map from condition.
* \param op The AssertStmt
* \param align_map The alignmap
* \param fvisit The recursive visitor
* \tparam FVisit the recursive visitor
*/
template<typename FVisit>
inline void VisitAssert(
const ir::AssertStmt* op,
std::unordered_map<const Variable*, arith::ModularEntry>* align_map,
FVisit fvisit) {
using namespace ir;
auto& align_map_ = *align_map;
// Detect useful invariant pattern and use them to visit child.
// Pattern: Var % const == 0
// TODO(tqchen) merge these pattern to a generic scope info visitor.
if (const EQ* eq = op->condition.as<EQ>()) {
const Mod* mod = eq->a.as<Mod>();
int64_t factor = 0, offset = 0;
if (mod && arith::GetConst(eq->b, &offset)) {
const Variable *var = mod->a.as<Variable>();
if (var && arith::GetConst(mod->b, &factor)) {
arith::ModularEntry old = align_map_[var];
if (factor > old.coeff) {
arith::ModularEntry e;
e.coeff = static_cast<int>(factor);
e.base = static_cast<int>(offset);
// new alignment info,
align_map_[var] = e;
fvisit(op->body);
// restore old info
align_map_[var] = old;
return;
}
}
}
}
fvisit(op->body);
}

} // namespace codegen
} // namespace tvm

#endif // TVM_CODEGEN_CODEGEN_COMMON_H_
31 changes: 5 additions & 26 deletions src/codegen/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include "./codegen_llvm.h"
#include "./codegen_cpu.h"
#include "../codegen_common.h"
#include "../../pass/ir_util.h"
#include "../../arithmetic/compute_expr.h"

Expand Down Expand Up @@ -341,7 +342,7 @@ void CodeGenLLVM::GetAlignment(Type t,
int align_bits = t.bits();
while (align_bits < max_align_bits &&
me.base % 2 == 0 &&
me.coeff %2 == 0) {
me.coeff % 2 == 0) {
me.base = me.base / 2;
me.coeff = me.coeff / 2;
align_bits *= 2;
Expand Down Expand Up @@ -1026,31 +1027,9 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
}

void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
// Detect useful invariant pattern and use them to visit child.
// Pattern: Var % const == 0
// TODO(tqchen) move these pattern to a generic scope info visitor.
if (const EQ* eq = op->condition.as<EQ>()) {
const Mod* mod = eq->a.as<Mod>();
int64_t factor = 0, offset = 0;
if (mod && arith::GetConst(eq->b, &offset)) {
const Variable *var = mod->a.as<Variable>();
if (var && arith::GetConst(mod->b, &factor)) {
arith::ModularEntry old = align_map_[var];
if (factor > old.coeff) {
arith::ModularEntry e;
e.coeff = static_cast<int>(factor);
e.base = static_cast<int>(offset);
// new alignment info,
align_map_[var] = e;
this->VisitStmt(op->body);
// restore old info
align_map_[var] = old;
return;
}
}
}
}
this->VisitStmt(op->body);
VisitAssert(op, &align_map_, [this](const Stmt& body) {
this->VisitStmt(body);
});
}

void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
Expand Down
Loading

0 comments on commit 79d503f

Please sign in to comment.