Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
5a4ee2c
added initial llvm codegen for amdgpu
Aug 31, 2017
0a5270d
fixed whitespace
Aug 31, 2017
01a5c93
fixed hsaco gen from ir
Aug 31, 2017
c562873
fixed targetmachine for rocm and added GetSource for rocm
Aug 31, 2017
8ed9f9a
fixed whitespace issues
Aug 31, 2017
4700efc
changed statement to use less than 100 lines
Aug 31, 2017
230194e
added intrinsics for workgroup - rocm
Aug 31, 2017
b1d54f6
whitespace - newline error fix
Aug 31, 2017
84091ff
fixed error msg for workitem-workgroup intrinsics
Aug 31, 2017
bcf423d
added llvm ir dump for rocm codegen
Aug 31, 2017
cf947c6
[ROCM] changed codegen to emit proper amdgpu kernel header
Sep 1, 2017
8c61580
fixed whitespace error
Sep 1, 2017
14384a1
fixed whitespace error- 2
Sep 1, 2017
7bd23f4
fixed AddFunction to not to use extra arg
Sep 5, 2017
9860610
fixed whitespaces
Sep 5, 2017
4ba40cb
fixed whitespaces 2
Sep 5, 2017
0ecb1e2
fixed codegen for AMDGPU - now generating valid IR
Sep 6, 2017
0bce779
fixed codegen depending on code review
Sep 7, 2017
ae276a3
reviewed alignment for amd devices
Sep 7, 2017
54d02d6
added code to dump code object to file
Sep 7, 2017
e6d532d
fixed cpplint errors
Sep 7, 2017
52d8e2d
print out IR after pass manager
Sep 7, 2017
11d7585
added code to dump asm, obj to file and std string
Sep 11, 2017
1ca7418
fixed whitespaces
Sep 11, 2017
bb520d9
Update codegen_amdgpu.cc
Sep 12, 2017
fb29bed
used registry for amdgpu llvm
Sep 12, 2017
38805f5
Fixed whitespaces
Sep 12, 2017
8876cde
added code for calling linker
Sep 12, 2017
b0c38f7
fixed formatting errors
Sep 12, 2017
fcd7cc0
added rocm link python interface
Sep 12, 2017
6e9a0e9
fixed pylint issues and added more body to the function
Sep 12, 2017
e57aa24
added doc string
Sep 12, 2017
84044e3
added doc string for module
Sep 12, 2017
a6c053b
fixed python code after review, fixed llvm object codegen
Sep 12, 2017
c218cd3
fixed linker to generate code object
Sep 12, 2017
1afa473
removed dumping to output file and debugging log out
Sep 12, 2017
8fd4efc
fixed lint for python code
Sep 12, 2017
80dceee
added fault check after running linker
Sep 13, 2017
c3b39ca
removed print statement in rocm.py
Sep 13, 2017
678cb41
changed rocm lld linker to raise runtimeerror than emitting error log…
Sep 13, 2017
29b60e9
changed the way linker command line is pass to subprocess.popen
Sep 13, 2017
e48b48e
removed redundant code and reuse tvm utils
Sep 13, 2017
939e3ef
removed commented out code
Sep 13, 2017
4501e8e
removed cloning of unused modules, and put IR into string
Sep 13, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@
from .schedule import create_schedule
from .build_module import build, lower, build_config
from .tag import tag_scope
from .contrib import rocm as _rocm
2 changes: 2 additions & 0 deletions python/tvm/_ffi/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def context(dev_type, dev_id=0):
if dev_type not in TVMContext.STR2MASK:
if dev_type.find("nvptx") != -1:
dev_type = "cuda"
if dev_type.find("rocm") != -1:
dev_type = "rocm"
if dev_type not in TVMContext.STR2MASK:
raise ValueError("Unknown device type %s" % dev_type)
dev_type = TVMContext.STR2MASK[dev_type]
Expand Down
50 changes: 50 additions & 0 deletions python/tvm/contrib/rocm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Utility for ROCm backend"""
import subprocess
from . import util
from ..api import register_func

def rocm_link(in_file, out_file):
"""Link relocatable ELF object to shared ELF object using lld

Parameters
----------
in_file : str
Input file name (relocatable ELF object file)

out_file : str
Output file name (shared ELF object file)
"""
args = ["ld.lld", "-shared", in_file, "-o", out_file]
proc = subprocess.Popen(
args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()

if proc.returncode != 0:
msg = "Linking error using ld.lld:\n"
msg += str(out)
raise RuntimeError(msg)

@register_func("tvm_callback_rocm_link")
def callback_rocm_link(obj_bin):
"""Links object file generated from LLVM to HSA Code Object

Parameters
----------
obj_bin : bytearray
The object file

Return
------
cobj_bin : bytearray
The HSA Code Object
"""
tmp_dir = util.tempdir()
tmp_obj = tmp_dir.relpath("rocm_kernel.o")
tmp_cobj = tmp_dir.relpath("rocm_kernel.co")
with open(tmp_obj, "wb") as out_file:
out_file.write(bytes(obj_bin))
rocm_link(tmp_obj, tmp_cobj)
cobj_bin = bytearray(open(tmp_cobj, "rb").read())
return cobj_bin
188 changes: 188 additions & 0 deletions src/codegen/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_amdgpu.cc
* \brief AMDGPU code generator.
*/
#ifdef TVM_LLVM_VERSION
#if TVM_ROCM_RUNTIME

#include <tvm/runtime/device_api.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
#include "./codegen_llvm.h"
#include "../build_common.h"
#include "../../pass/ir_util.h"
#include "../../runtime/rocm/rocm_module.h"

namespace tvm {
namespace codegen {

// AMDGPU code generator.
class CodeGenAMDGPU : public CodeGenLLVM {
public:
void AddFunction(const LoweredFunc& f) final {
// add function as void return value
CodeGenLLVM::AddFunctionInternal(f, true);
function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
}

void VisitStmt_(const Allocate* op) final {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
CHECK_EQ(op->free_function, "nop");
buf = MakeValue(op->new_expr);
} else {
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation in GPU";
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = GetTempAllocaAlignment(op->type, constant_size);
}
// maximum necessary alignment in the AMD devices
if (info.alignment > 16) {
info.alignment = 16;
}
if (info.scope.rank == 2) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = builder_->CreateAlloca(
LLVMType(op->type), ConstInt32(constant_size));
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
alloca->setAlignment(info.alignment);
}
buf = alloca;
} else {
CHECK_EQ(info.scope.rank, 1)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if address space is consistent with Amd gpu backend

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to make another pass on the codegen part as there are obvious differences between nvptx and amdgcn codegen. Is there a way I can see IR directly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to check the LLVM code, do module_->dump(); you have to insert it manually in the code though. Otherwise, implement GetSource in hip module which should give you the assembly

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll use LOG(WARNING) << module_->dump(); to see it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tqchen I am getting following error:

$ python tests/python/unittest/test_codegen_device.py                          
Traceback (most recent call last):
  File "tests/python/unittest/test_codegen_device.py", line 1, in <module>
    import tvm
  File "/home/aditya/tvm/python/tvm/__init__.py", line 5, in <module>
    from . import tensor
  File "/home/aditya/tvm/python/tvm/tensor.py", line 4, in <module>
    from ._ffi.node import NodeBase, NodeGeneric, register_node, convert_to_node
  File "/home/aditya/tvm/python/tvm/_ffi/node.py", line 8, in <module>
    from .node_generic import NodeGeneric, convert_to_node, const
  File "/home/aditya/tvm/python/tvm/_ffi/node_generic.py", line 7, in <module>
    from .base import string_types
  File "/home/aditya/tvm/python/tvm/_ffi/base.py", line 43, in <module>
    _LIB, _LIB_NAME = _load_lib()
  File "/home/aditya/tvm/python/tvm/_ffi/base.py", line 35, in _load_lib
    lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)
  File "/usr/lib/python2.7/ctypes/__init__.py", line 362, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: /home/aditya/tvm/lib/libtvm.so: undefined symbol: _ZNK4llvm6Module4dumpEv

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually I use module_->dump() without piping it to stream and it should work

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I am aware of it. Once I get LLVM IR dump, I can get better understanding of what to change or even add more functionality.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, seems this is something we will need frequently. Let us simply also print out llvm ir and save it to the code field(optional) in the ROCMModule, so we can access it with module.get_source()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated PR with IR dump code. I see some places that need to be changed (which is why I am getting Bus Error).

llvm::Type* type = llvm::ArrayType::get(LLVMType(op->type), constant_size);
// Allocate shared memory in global, address_space = 3
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
global->setAlignment(info.alignment);
buf = global;
}
}
buf = builder_->CreatePointerCast(
buf, LLVMType(op->type)->getPointerTo(
buf->getType()->getPointerAddressSpace()));
CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
this->VisitStmt(op->body);
}

// Return the thread index via intrinsics.
llvm::Value* GetThreadIndex(const IterVar& iv) final {
runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x;
if (ts.rank == 1) {
switch (ts.dim_index) {
case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; break;
case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_y; break;
case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_z; break;
default: LOG(FATAL) << "unknown workitem idx";
}
} else {
CHECK_EQ(ts.rank, 0);
switch (ts.dim_index) {
case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_x; break;
case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_y; break;
case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_z; break;
default: LOG(FATAL) << "unknown workgroup idx";
}
}
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id);
return builder_->CreateCall(f, {});
}

llvm::Value* CreateStorageSync(const Call* op) final {
const std::string& sync = op->args[0].as<StringImm>()->value;
if (sync == "warp") {
// TODO(tqchen) warp sync in CUDA9
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the comment here, is there any need of warp(wavefront) synchronizer in AMD GPU?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are sync commands for AMD GPU, but if it CUDA9 specific, current generation AMD GPUs don't support it.

return nullptr;
} else if (sync == "shared") {
llvm::Function* f = llvm::Intrinsic::getDeclaration(
module_.get(),
::llvm::Intrinsic::amdgcn_s_barrier);
return builder_->CreateCall(f, {});
} else {
LOG(FATAL) << "Do not support sync " << sync;
return nullptr;
}
}

void InitPassManagerBuilder(llvm::PassManagerBuilder* builder) final {
// Additional optimization hook to tweak the builder.
}

unsigned GetGlobalAddressSpace() {
return 1;
}

protected:
void InitTarget(llvm::TargetMachine* tm) final {
// Maximum vector lane = float4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double check this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. float4 or dwordx4 is the perfect alignment.

native_vector_bits_ = 4 * 32;
CodeGenLLVM::InitTarget(tm);
}
};

runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
CHECK(target.length(
) >= 4 &&
target.substr(0, 4) == "rocm");
llvm::TargetMachine* tm = \
GetLLVMTargetMachine("-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx900" + \
target.substr(4, target.length() - 4));

std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU());
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
cg->Init(funcs[0]->name, tm, ctx.get(), false, false);
for (LoweredFunc f : funcs) {
cg->AddFunction(f);
}

std::unique_ptr<llvm::Module> module = cg->Finish();

llvm::SmallString<8> dataObj, data_ll, dataAsm;
llvm::raw_svector_ostream destObj(dataObj), dest_ll(data_ll), destAsm(dataAsm);
destObj.SetUnbuffered();
dest_ll.SetUnbuffered();
destAsm.SetUnbuffered();
module->print(dest_ll, nullptr);
std::unique_ptr<llvm::Module> mAsm = llvm::CloneModule(module.get());
std::unique_ptr<llvm::Module> mObj = llvm::CloneModule(module.get());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove mObjFile and mAsmFile. We can consider hold two optional source code in RocmModule, both ll and asm, and return them when different source suffix is requested, that might help you in debugging.

llvm::legacy::PassManager pass;

CHECK(tm->addPassesToEmitFile(
pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0)
<< "Cannot emit target CGFT_ObjectFile";
pass.run(*mObj);
std::string obj(dataObj.begin(), dataObj.end());

const auto* f = tvm::runtime::Registry::Get("tvm_callback_rocm_link");
CHECK(f != nullptr) << "Require tvm_callback_rocm_link to exist, do import tvm.contrib.rocm";

TVMByteArray arr;
arr.data = &obj[0];
arr.size = obj.length();

std::string hsaco = (*f)(arr);
std::string ll(data_ll.begin(), data_ll.end());

return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(funcs), ll);
}

TVM_REGISTER_API("codegen.build_rocm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildAMDGPU(args[0], args[1]);
});

} // namespace codegen
} // namespace tvm
#endif // TVM_ROCM_RUNTIME
#endif // TVM_LLVM_VERSION
6 changes: 5 additions & 1 deletion src/codegen/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
Type t = arg.type();
if (t.is_handle() && f->handle_data_type.count(arg)) {
arg_type.push_back(
LLVMType(f->handle_data_type[arg].type())->getPointerTo());
LLVMType(f->handle_data_type[arg].type())->getPointerTo(GetGlobalAddressSpace()));
if (!is_restricted_) {
alias_var_set_.insert(arg.get());
}
Expand Down Expand Up @@ -555,6 +555,10 @@ int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) co
return native_vector_bits_;
}

unsigned CodeGenLLVM::GetGlobalAddressSpace() {
return 0;
}

void CodeGenLLVM::GetAlignment(
Type t, const Variable* buf_var, const Expr& index,
int* p_alignment, int* p_native_bits) {
Expand Down
4 changes: 4 additions & 0 deletions src/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace codegen {

using namespace ir;


/*!
* \brief A base class to generate a LLVM.
*/
Expand Down Expand Up @@ -148,6 +149,9 @@ class CodeGenLLVM :
virtual void Optimize();
// Get the maximim storage align bits of buffer pointer given storage scope.
virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const;
// Get correct address space depending on the backend
virtual unsigned GetGlobalAddressSpace();

void AddFunctionInternal(const LoweredFunc& f, bool ret_void);
// Create extern call
llvm::CallInst* CreateCallExtern(llvm::Type* ret,
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ bool RuntimeEnabled(const std::string& target) {
f_name = "device_api.vpi";
} else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
f_name = "codegen.build_nvptx";
} else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
f_name = "codegen.build_rocm";
} else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled");
if (pf == nullptr) return false;
Expand Down
10 changes: 9 additions & 1 deletion src/runtime/rocm/rocm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,17 @@ class ROCMModuleNode : public runtime::ModuleNode {
stream->Write(data_);
}

std::string GetSource(const std::string& format) final {
if (format == fmt_) { return data_; }
if (fmt_ == "hsaco") { return data_; }
return "";
}

// get a CUfunction from primary context in device_id
hipFunction_t GetFunc(int device_id, const std::string& func_name) {
std::lock_guard<std::mutex> lock(mutex_);
// must recheck under the lock scope

if (module_[device_id] == nullptr) {
ROCM_DRIVER_CALL(hipModuleLoadData(&(module_[device_id]), data_.c_str()));
}
Expand Down Expand Up @@ -140,7 +147,9 @@ class ROCMWrappedFunc {
if (fcache_[device_id] == nullptr) {
fcache_[device_id] = m_->GetFunc(device_id, func_name_);
}

hipStream_t strm = static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);

ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
void* config[] = {
HIP_LAUNCH_PARAM_BUFFER_POINTER, &packed_args,
Expand Down Expand Up @@ -181,7 +190,6 @@ PackedFunc ROCMModuleNode::GetFunction(
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main";

auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
Expand Down
2 changes: 2 additions & 0 deletions tests/python/integration/test_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def check_device(device):
np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)

check_device("nvptx -mcpu=sm_20")
check_device("rocm")
check_device("metal")
check_device("opencl")
check_device("cuda")
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_codegen_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def check_module_save(device, host="stackvm"):
check_target("cuda", host="llvm")
check_module_save("cuda", host="stackvm")
check_target("nvptx", host="llvm")
check_target("rocm", host="llvm")

if __name__ == "__main__":
test_add_pipeline()