-
Notifications
You must be signed in to change notification settings - Fork 795
[SYCL] Add group local memory call lowering pass #3329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
162cc19
2c78c26
43782d8
10a3c54
f0d3556
2277e52
4f34903
8350829
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| // RUN: %clang_cc1 -fsycl-is-device -triple spir64-unknown-unknown-sycldevice -disable-llvm-passes -S -emit-llvm %s -o - | FileCheck %s | ||
| // RUN: %clang_cc1 -fsycl-is-device -triple spir64-unknown-unknown-sycldevice -S -emit-llvm %s -o - | FileCheck %s | ||
|
||
|
|
||
| // CHECK: [[WGLOCALMEM_1:@WGLocalMem.*]] = internal addrspace(3) global [8 x i8] undef, align 8 | ||
| // CHECK: [[WGLOCALMEM_2:@WGLocalMem.*]] = internal addrspace(3) global [4 x i8] undef, align 4 | ||
| // CHECK: [[WGLOCALMEM_3:@WGLocalMem.*]] = internal addrspace(3) global [128 x i8] undef, align 4 | ||
|
|
||
| #include "Inputs/sycl.hpp" | ||
|
|
||
| constexpr size_t WgSize = 32; | ||
| constexpr size_t WgCount = 4; | ||
| constexpr size_t Size = WgSize * WgCount; | ||
|
|
||
| class KernelA; | ||
| class KernelB; | ||
|
|
||
| using namespace cl::sycl; | ||
|
|
||
| int main() { | ||
| queue Q; | ||
| { | ||
| Q.submit([&](handler &cgh) { | ||
| cgh.parallel_for<KernelA>( | ||
| range<1>(Size), [=](item<1> Item) { | ||
| auto *Ptr1 = group_local_memory<long>(); | ||
| auto *Ptr2 = group_local_memory<float>(); | ||
| }); | ||
| }); | ||
| } | ||
|
|
||
| { | ||
| Q.submit([&](handler &cgh) { | ||
| cgh.parallel_for<KernelB>( | ||
| range<1>(Size), [=](item<1> Item) { | ||
| auto *Ptr3 = group_local_memory<int[WgSize]>(); | ||
| }); | ||
| }); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| //===-- LowerWGLocalMemory.h - SYCL kernel local memory allocation pass ---===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // Replaces calls to __sycl_allocateLocalMemory(Size, Alignment) function with | ||
mlychkov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // allocation of memory in local address space at the kernel scope. | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef LLVM_SYCLLOWERIR_LOWERWGLOCALMEMORY_H | ||
| #define LLVM_SYCLLOWERIR_LOWERWGLOCALMEMORY_H | ||
|
|
||
| #include "llvm/IR/Module.h" | ||
| #include "llvm/IR/PassManager.h" | ||
|
|
||
| namespace llvm { | ||
|
|
||
| class SYCLLowerWGLocalMemoryPass | ||
| : public PassInfoMixin<SYCLLowerWGLocalMemoryPass> { | ||
| public: | ||
| PreservedAnalyses run(Module &M, ModuleAnalysisManager &); | ||
| }; | ||
|
|
||
| ModulePass *createSYCLLowerWGLocalMemoryPass(); | ||
mlychkov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| void initializeSYCLLowerWGLocalMemoryLegacyPass(PassRegistry &); | ||
|
|
||
| } // namespace llvm | ||
|
|
||
| #endif // LLVM_SYCLLOWERIR_LOWERWGLOCALMEMORY_H | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| //===-- LowerWGLocalMemory.cpp - SYCL kernel local memory allocation pass -===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // This pass replaces calls to __sycl_allocateLocalMemory(Size, Alignment) | ||
| // function with allocation of memory in local address space at the kernel | ||
| // scope. | ||
mlychkov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "llvm/SYCLLowerIR/LowerWGLocalMemory.h" | ||
| #include "llvm/IR/Function.h" | ||
| #include "llvm/IR/IRBuilder.h" | ||
| #include "llvm/IR/InstIterator.h" | ||
| #include "llvm/InitializePasses.h" | ||
| #include "llvm/Pass.h" | ||
|
|
||
| using namespace llvm; | ||
|
|
||
| #define DEBUG_TYPE "LowerWGLocalMemory" | ||
|
|
||
| static constexpr char SYCL_ALLOCLOCALMEM_CALL[] = "__sycl_allocateLocalMemory"; | ||
mlychkov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| static constexpr char LOCALMEMORY_GV_PREF[] = "WGLocalMem"; | ||
|
|
||
| namespace { | ||
| class SYCLLowerWGLocalMemoryLegacy : public ModulePass { | ||
kbobrovs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| public: | ||
| static char ID; | ||
|
|
||
| SYCLLowerWGLocalMemoryLegacy() : ModulePass(ID) { | ||
| initializeSYCLLowerWGLocalMemoryLegacyPass( | ||
| *PassRegistry::getPassRegistry()); | ||
| } | ||
|
|
||
| bool runOnModule(Module &M) override { | ||
| ModuleAnalysisManager DummyMAM; | ||
| auto PA = Impl.run(M, DummyMAM); | ||
| return !PA.areAllPreserved(); | ||
| } | ||
|
|
||
| private: | ||
| SYCLLowerWGLocalMemoryPass Impl; | ||
| }; | ||
| } // namespace | ||
|
|
||
| char SYCLLowerWGLocalMemoryLegacy::ID = 0; | ||
| INITIALIZE_PASS(SYCLLowerWGLocalMemoryLegacy, "sycllowerwglocalmemory", | ||
| "Replace __sycl_allocateLocalMemory with allocation of memory " | ||
| "in local address space", | ||
| false, false) | ||
|
|
||
| ModulePass *llvm::createSYCLLowerWGLocalMemoryPass() { | ||
| return new SYCLLowerWGLocalMemoryLegacy(); | ||
| } | ||
|
|
||
| static bool lowerAllocaLocalMem(Module &M) { | ||
| SmallVector<CallInst *, 8> ToReplace; | ||
| for (Function &F : M) { | ||
| CallingConv::ID CC = F.getCallingConv(); | ||
mlychkov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| for (auto &I : instructions(F)) { | ||
| auto *CI = dyn_cast<CallInst>(&I); | ||
| Function *Callee = nullptr; | ||
| if (!CI || !(Callee = CI->getCalledFunction())) | ||
| continue; | ||
| StringRef Name = Callee->getName(); | ||
| if (Name != SYCL_ALLOCLOCALMEM_CALL) | ||
| continue; | ||
|
|
||
| // TODO: Static local memory allocation should be requested only in | ||
| // spir kernel scope. | ||
mlychkov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert((CC == llvm::CallingConv::SPIR_FUNC || | ||
mlychkov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| CC == llvm::CallingConv::SPIR_KERNEL) && | ||
| "WG static local memery can be allocated only in kernel scope"); | ||
mlychkov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ToReplace.push_back(CI); | ||
kbobrovs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
|
|
||
| if (ToReplace.empty()) | ||
| return false; | ||
|
|
||
| for (auto *CI : ToReplace) { | ||
| Value *ArgSize = CI->getArgOperand(0); | ||
| uint64_t Size = cast<llvm::ConstantInt>(ArgSize)->getZExtValue(); | ||
| Value *ArgAlign = CI->getArgOperand(1); | ||
| uint64_t Alignment = cast<llvm::ConstantInt>(ArgAlign)->getZExtValue(); | ||
|
|
||
| IRBuilder<> Builder(CI); | ||
| Type *LocalMemArrayTy = ArrayType::get(Builder.getInt8Ty(), Size); | ||
| unsigned LocalAS = | ||
| CI->getFunctionType()->getReturnType()->getPointerAddressSpace(); | ||
| auto *LocalMemArrayGV = | ||
| new GlobalVariable(M, // module | ||
| LocalMemArrayTy, // type | ||
| false, // isConstant | ||
| GlobalValue::InternalLinkage, // Linkage | ||
| UndefValue::get(LocalMemArrayTy), // Initializer | ||
| LOCALMEMORY_GV_PREF, // Name prefix | ||
| nullptr, // InsertBefore | ||
| GlobalVariable::NotThreadLocal, // ThreadLocalMode | ||
| LocalAS // AddressSpace | ||
| ); | ||
| LocalMemArrayGV->setAlignment(Align(Alignment)); | ||
|
|
||
| Value *LocalMemArrayGVPtr = Builder.CreatePointerCast( | ||
| LocalMemArrayGV, | ||
| Builder.getInt8PtrTy(LocalMemArrayGV->getAddressSpace())); | ||
| CI->replaceAllUsesWith(LocalMemArrayGVPtr); | ||
| CI->eraseFromParent(); | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| PreservedAnalyses SYCLLowerWGLocalMemoryPass::run(Module &M, | ||
| ModuleAnalysisManager &) { | ||
| if (lowerAllocaLocalMem(M)) | ||
| return PreservedAnalyses::none(); | ||
| return PreservedAnalyses::all(); | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.