Skip to content

Commit bab295e

Browse files
spectrometerHBHzhangxiao-stackzhangxiao-stack
authored
[ROCm] Fix some ROCm codegen bugs (#15454)
* rocm bug fix:Module hip should be either dso exportable or binary serializable rocm bug fix: llvm.amdgcn.ds.bpermute Intrinsic has incorrect return type rocm bug fix:ptr addrspace(3) @shmem Global is external, but doesn't have external or weak linkage Co-authored-by: zhangxiao-stack <1244360827@qq.com> * lint --------- Co-authored-by: zhangxiao-stack <zhangqha@sugon.com> Co-authored-by: zhangxiao-stack <1244360827@qq.com>
1 parent b77d659 commit bab295e

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

src/runtime/rocm/rocm_module.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ class ROCMModuleNode : public runtime::ModuleNode {
6363
}
6464

6565
const char* type_key() const final { return "hip"; }
66-
66+
int GetPropertyMask() const final {
67+
return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable;
68+
}
6769
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final;
6870

6971
void SaveToFile(const String& file_name, const String& format) final {

src/target/llvm/codegen_llvm.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -702,8 +702,8 @@ llvm::GlobalVariable* CodeGenLLVM::AllocateSharedMemory(DataType dtype, size_t s
702702
llvm::GlobalValue::LinkageTypes linkage) {
703703
llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(dtype), size);
704704
llvm::GlobalVariable* global =
705-
new llvm::GlobalVariable(*module_, type, false, linkage, nullptr, "shmem", nullptr,
706-
llvm::GlobalValue::NotThreadLocal, shared_address_space);
705+
new llvm::GlobalVariable(*module_, type, false, linkage, llvm::UndefValue::get(type), "shmem",
706+
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
707707
#if TVM_LLVM_VERSION >= 100
708708
global->setAlignment(llvm::Align(alignment));
709709
#else

src/tir/transforms/lower_thread_allreduce.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
729729
// rocm only supports 32 bit operands for shuffling at the moment
730730
if ((target_->kind->name == "rocm") &&
731731
(std::any_of(types.begin(), types.end(), [](DataType ty) {
732-
if (ty.is_vector()) return true;
732+
if ((ty.is_vector()) || !ty.is_int()) return true;
733733
return ty.bits() != 32;
734734
}))) {
735735
return false;

0 commit comments

Comments
 (0)