@@ -71,55 +71,53 @@ class CodeGenAMDGPU : public CodeGenLLVM {
7171 void VisitStmt_ (const AllocateNode* op) final {
7272 CHECK (!is_zero (op->condition ));
7373 llvm::Value* buf = nullptr ;
74- if (op->new_expr .defined ()) {
75- CHECK_EQ (op->free_function , " nop" );
76- buf = MakeValue (op->new_expr );
77- } else {
78- int32_t constant_size = op->constant_allocation_size ();
79- CHECK_GT (constant_size, 0 )
80- << " Can only handle constant size stack allocation in GPU" ;
81- StorageInfo& info = alloc_storage_info_[op->buffer_var .get ()];
82- if (constant_size % 4 == 0 && info.alignment == 0 ) {
83- info.alignment = GetTempAllocaAlignment (op->dtype , constant_size);
84- }
85- // maximum necessary alignment in the AMD devices
86- if (info.alignment > 16 ) {
87- info.alignment = 16 ;
88- }
89- if (info.scope .rank == runtime::StorageRank::kLocal ) {
90- // const int local_address_space = 5;
91- // TODO(tqchen): for higher version of LLVM, local address space can be set.
92- llvm::AllocaInst* alloca = WithFunctionEntry ([&]() {
93- return builder_->CreateAlloca (
94- DTypeToLLVMType (op->dtype ), ConstInt32 (constant_size));
95- });
96- if (alloca->getAlignment () < static_cast <uint32_t >(info.alignment )) {
74+
75+ int32_t constant_size = op->constant_allocation_size ();
76+ CHECK_GT (constant_size, 0 )
77+ << " Can only handle constant size stack allocation in GPU" ;
78+
79+ StorageInfo& info = alloc_storage_info_[op->buffer_var .get ()];
80+ if (constant_size % 4 == 0 && info.alignment == 0 ) {
81+ info.alignment = GetTempAllocaAlignment (op->dtype , constant_size);
82+ }
83+ // maximum necessary alignment in the AMD devices
84+ if (info.alignment > 16 ) {
85+ info.alignment = 16 ;
86+ }
87+ if (info.scope .rank == runtime::StorageRank::kLocal ) {
88+ // const int local_address_space = 5;
89+ // TODO(tqchen): for higher version of LLVM, local address space can be set.
90+ llvm::AllocaInst* alloca = WithFunctionEntry ([&]() {
91+ return builder_->CreateAlloca (
92+ DTypeToLLVMType (op->dtype ), ConstInt32 (constant_size));
93+ });
94+ if (alloca->getAlignment () < static_cast <uint32_t >(info.alignment )) {
9795#if TVM_LLVM_VERSION >= 100
98- alloca->setAlignment (llvm::Align (info.alignment ));
96+ alloca->setAlignment (llvm::Align (info.alignment ));
9997#else
100- alloca->setAlignment (info.alignment );
98+ alloca->setAlignment (info.alignment );
10199#endif
102- }
103- buf = alloca;
104- } else {
105- CHECK (info.scope .rank == runtime::StorageRank::kShared )
106- << " Can only allocate shared or local memory inside kernel" ;
107- // Shared memory: address space == 3
108- const unsigned shared_address_space = 3 ;
109- llvm::Type* type = llvm::ArrayType::get (
110- DTypeToLLVMType (op->dtype ), constant_size);
111- // Allocate shared memory in global, address_space = 3
112- llvm::GlobalVariable *global = new llvm::GlobalVariable (
113- *module_, type, false , llvm::GlobalValue::PrivateLinkage, 0 , " .shared" ,
114- nullptr , llvm::GlobalValue::NotThreadLocal, shared_address_space);
100+ }
101+ buf = alloca;
102+ } else {
103+ CHECK (info.scope .rank == runtime::StorageRank::kShared )
104+ << " Can only allocate shared or local memory inside kernel" ;
105+ // Shared memory: address space == 3
106+ const unsigned shared_address_space = 3 ;
107+ llvm::Type* type = llvm::ArrayType::get (
108+ DTypeToLLVMType (op->dtype ), constant_size);
109+ // Allocate shared memory in global, address_space = 3
110+ llvm::GlobalVariable *global = new llvm::GlobalVariable (
111+ *module_, type, false , llvm::GlobalValue::PrivateLinkage, 0 , " .shared" ,
112+ nullptr , llvm::GlobalValue::NotThreadLocal, shared_address_space);
115113#if TVM_LLVM_VERSION >= 100
116- global->setAlignment (llvm::Align (info.alignment ));
114+ global->setAlignment (llvm::Align (info.alignment ));
117115#else
118- global->setAlignment (info.alignment );
116+ global->setAlignment (info.alignment );
119117#endif
120- buf = global;
121- }
118+ buf = global;
122119 }
120+
123121 buf = builder_->CreatePointerCast (
124122 buf, DTypeToLLVMType (op->dtype )->getPointerTo (
125123 buf->getType ()->getPointerAddressSpace ()));
0 commit comments