Skip to content

Commit b72f1ec

Browse files
committed
[openmp][mlir] Lower parallel if to new fork_call_if function.
This patch adds a new runtime function `fork_call_if` and uses that to lower parallel if statements when going through OpenMPIRBuilder. This fixes an issue where the OpenMPIRBuilder passes all arguments to fork_call as a struct but this struct is not filled corretly in the non-if branch by handling the fork inside the runtime. Differential Revision: https://reviews.llvm.org/D138495
1 parent e7328a9 commit b72f1ec

File tree

8 files changed

+85
-148
lines changed

8 files changed

+85
-148
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPKinds.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ __OMP_RTL(__kmpc_flush, false, Void, IdentPtr)
203203
__OMP_RTL(__kmpc_global_thread_num, false, Int32, IdentPtr)
204204
__OMP_RTL(__kmpc_get_hardware_thread_id_in_block, false, Int32, )
205205
__OMP_RTL(__kmpc_fork_call, true, Void, IdentPtr, Int32, ParallelTaskPtr)
206+
__OMP_RTL(__kmpc_fork_call_if, false, Void, IdentPtr, Int32, ParallelTaskPtr,
207+
Int32, VoidPtr)
206208
__OMP_RTL(__kmpc_omp_taskwait, false, Int32, IdentPtr, Int32)
207209
__OMP_RTL(__kmpc_omp_taskwait_51, false, Int32, IdentPtr, Int32, Int32)
208210
__OMP_RTL(__kmpc_omp_taskyield, false, Int32, IdentPtr, Int32, /* Int */ Int32)

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 32 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -914,34 +914,21 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
914914
AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
915915
AllocaInst *ZeroAddr = Builder.CreateAlloca(Int32, nullptr, "zero.addr");
916916

917-
// If there is an if condition we actually use the TIDAddr and ZeroAddr in the
918-
// program, otherwise we only need them for modeling purposes to get the
919-
// associated arguments in the outlined function. In the former case,
920-
// initialize the allocas properly, in the latter case, delete them later.
921-
if (IfCondition) {
922-
Builder.CreateStore(Constant::getNullValue(Int32), TIDAddr);
923-
Builder.CreateStore(Constant::getNullValue(Int32), ZeroAddr);
924-
} else {
925-
ToBeDeleted.push_back(TIDAddr);
926-
ToBeDeleted.push_back(ZeroAddr);
927-
}
917+
// We only need TIDAddr and ZeroAddr for modeling purposes to get the
918+
// associated arguments in the outlined function, so we delete them later.
919+
ToBeDeleted.push_back(TIDAddr);
920+
ToBeDeleted.push_back(ZeroAddr);
928921

929922
// Create an artificial insertion point that will also ensure the blocks we
930923
// are about to split are not degenerated.
931924
auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
932925

933-
Instruction *ThenTI = UI, *ElseTI = nullptr;
934-
if (IfCondition)
935-
SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
936-
937-
BasicBlock *ThenBB = ThenTI->getParent();
938-
BasicBlock *PRegEntryBB = ThenBB->splitBasicBlock(ThenTI, "omp.par.entry");
939-
BasicBlock *PRegBodyBB =
940-
PRegEntryBB->splitBasicBlock(ThenTI, "omp.par.region");
926+
BasicBlock *EntryBB = UI->getParent();
927+
BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(UI, "omp.par.entry");
928+
BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(UI, "omp.par.region");
941929
BasicBlock *PRegPreFiniBB =
942-
PRegBodyBB->splitBasicBlock(ThenTI, "omp.par.pre_finalize");
943-
BasicBlock *PRegExitBB =
944-
PRegPreFiniBB->splitBasicBlock(ThenTI, "omp.par.exit");
930+
PRegBodyBB->splitBasicBlock(UI, "omp.par.pre_finalize");
931+
BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(UI, "omp.par.exit");
945932

946933
auto FiniCBWrapper = [&](InsertPointTy IP) {
947934
// Hide "open-ended" blocks from the given FiniCB by setting the right jump
@@ -975,7 +962,7 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
975962
Builder.CreateLoad(Int32, ZeroAddr, "zero.addr.use");
976963
ToBeDeleted.push_back(ZeroAddrUse);
977964

978-
// ThenBB
965+
// EntryBB
979966
// |
980967
// V
981968
// PRegionEntryBB <- Privatization allocas are placed here.
@@ -998,8 +985,12 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
998985
BodyGenCB(InnerAllocaIP, CodeGenIP);
999986

1000987
LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n");
988+
FunctionCallee RTLFn;
989+
if (IfCondition)
990+
RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if);
991+
else
992+
RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
1001993

1002-
FunctionCallee RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
1003994
if (auto *F = dyn_cast<llvm::Function>(RTLFn.getCallee())) {
1004995
if (!F->hasMetadata(llvm::LLVMContext::MD_callback)) {
1005996
llvm::LLVMContext &Ctx = F->getContext();
@@ -1034,15 +1025,30 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
10341025
CI->getParent()->setName("omp_parallel");
10351026
Builder.SetInsertPoint(CI);
10361027

1037-
// Build call __kmpc_fork_call(Ident, n, microtask, var1, .., varn);
1028+
// Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
10381029
Value *ForkCallArgs[] = {
10391030
Ident, Builder.getInt32(NumCapturedVars),
10401031
Builder.CreateBitCast(&OutlinedFn, ParallelTaskPtr)};
10411032

10421033
SmallVector<Value *, 16> RealArgs;
10431034
RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
1035+
if (IfCondition) {
1036+
Value *Cond = Builder.CreateSExtOrTrunc(IfCondition,
1037+
Type::getInt32Ty(M.getContext()));
1038+
RealArgs.push_back(Cond);
1039+
}
10441040
RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());
10451041

1042+
// __kmpc_fork_call_if always expects a void ptr as the last argument
1043+
// If there are no arguments, pass a null pointer.
1044+
auto PtrTy = Type::getInt8PtrTy(M.getContext());
1045+
if (IfCondition && NumCapturedVars == 0) {
1046+
llvm::Value *Void = ConstantPointerNull::get(PtrTy);
1047+
RealArgs.push_back(Void);
1048+
}
1049+
if (IfCondition && RealArgs.back()->getType() != PtrTy)
1050+
RealArgs.back() = Builder.CreateBitCast(RealArgs.back(), PtrTy);
1051+
10461052
Builder.CreateCall(RTLFn, RealArgs);
10471053

10481054
LLVM_DEBUG(dbgs() << "With fork_call placed: "
@@ -1055,35 +1061,7 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
10551061
Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
10561062
Builder.CreateStore(Builder.CreateLoad(Int32, OutlinedAI), PrivTIDAddr);
10571063

1058-
// If no "if" clause was present we do not need the call created during
1059-
// outlining, otherwise we reuse it in the serialized parallel region.
1060-
if (!ElseTI) {
1061-
CI->eraseFromParent();
1062-
} else {
1063-
1064-
// If an "if" clause was present we are now generating the serialized
1065-
// version into the "else" branch.
1066-
Builder.SetInsertPoint(ElseTI);
1067-
1068-
// Build calls __kmpc_serialized_parallel(&Ident, GTid);
1069-
Value *SerializedParallelCallArgs[] = {Ident, ThreadID};
1070-
Builder.CreateCall(
1071-
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_serialized_parallel),
1072-
SerializedParallelCallArgs);
1073-
1074-
// OutlinedFn(&GTid, &zero, CapturedStruct);
1075-
CI->removeFromParent();
1076-
Builder.Insert(CI);
1077-
1078-
// __kmpc_end_serialized_parallel(&Ident, GTid);
1079-
Value *EndArgs[] = {Ident, ThreadID};
1080-
Builder.CreateCall(
1081-
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_serialized_parallel),
1082-
EndArgs);
1083-
1084-
LLVM_DEBUG(dbgs() << "With serialized parallel region: "
1085-
<< *Builder.GetInsertBlock()->getParent() << "\n");
1086-
}
1064+
CI->eraseFromParent();
10871065

10881066
for (Instruction *I : ToBeDeleted)
10891067
I->eraseFromParent();

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -986,38 +986,22 @@ TEST_F(OpenMPIRBuilderTest, ParallelIfCond) {
986986
EXPECT_EQ(OutlinedFn->arg_size(), 3U);
987987

988988
EXPECT_EQ(&OutlinedFn->getEntryBlock(), PrivAI->getParent());
989-
ASSERT_EQ(OutlinedFn->getNumUses(), 2U);
989+
ASSERT_EQ(OutlinedFn->getNumUses(), 1U);
990990

991-
CallInst *DirectCI = nullptr;
992991
CallInst *ForkCI = nullptr;
993992
for (User *Usr : OutlinedFn->users()) {
994-
if (isa<CallInst>(Usr)) {
995-
ASSERT_EQ(DirectCI, nullptr);
996-
DirectCI = cast<CallInst>(Usr);
997-
} else {
998-
ASSERT_TRUE(isa<ConstantExpr>(Usr));
999-
ASSERT_EQ(Usr->getNumUses(), 1U);
1000-
ASSERT_TRUE(isa<CallInst>(Usr->user_back()));
1001-
ForkCI = cast<CallInst>(Usr->user_back());
1002-
}
993+
ASSERT_TRUE(isa<ConstantExpr>(Usr));
994+
ASSERT_EQ(Usr->getNumUses(), 1U);
995+
ASSERT_TRUE(isa<CallInst>(Usr->user_back()));
996+
ForkCI = cast<CallInst>(Usr->user_back());
1003997
}
1004998

1005-
EXPECT_EQ(ForkCI->getCalledFunction()->getName(), "__kmpc_fork_call");
1006-
EXPECT_EQ(ForkCI->arg_size(), 4U);
999+
EXPECT_EQ(ForkCI->getCalledFunction()->getName(), "__kmpc_fork_call_if");
1000+
EXPECT_EQ(ForkCI->arg_size(), 5U);
10071001
EXPECT_TRUE(isa<GlobalVariable>(ForkCI->getArgOperand(0)));
10081002
EXPECT_EQ(ForkCI->getArgOperand(1),
10091003
ConstantInt::get(Type::getInt32Ty(Ctx), 1));
1010-
Value *StoredForkArg =
1011-
findStoredValueInAggregateAt(Ctx, ForkCI->getArgOperand(3), 0);
1012-
EXPECT_EQ(StoredForkArg, F->arg_begin());
1013-
1014-
EXPECT_EQ(DirectCI->getCalledFunction(), OutlinedFn);
1015-
EXPECT_EQ(DirectCI->arg_size(), 3U);
1016-
EXPECT_TRUE(isa<AllocaInst>(DirectCI->getArgOperand(0)));
1017-
EXPECT_TRUE(isa<AllocaInst>(DirectCI->getArgOperand(1)));
1018-
Value *StoredDirectArg =
1019-
findStoredValueInAggregateAt(Ctx, DirectCI->getArgOperand(2), 0);
1020-
EXPECT_EQ(StoredDirectArg, F->arg_begin());
1004+
EXPECT_EQ(ForkCI->getArgOperand(3)->getType(), Type::getInt32Ty(Ctx));
10211005
}
10221006

10231007
TEST_F(OpenMPIRBuilderTest, ParallelCancelBarrier) {

mlir/test/Target/LLVMIR/openmp-llvm.mlir

Lines changed: 4 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -151,33 +151,19 @@ llvm.func @test_omp_parallel_num_threads_3() -> () {
151151
// CHECK: define void @test_omp_parallel_if_1(i32 %[[IF_VAR_1:.*]])
152152
llvm.func @test_omp_parallel_if_1(%arg0: i32) -> () {
153153

154-
// Check that the allocas are emitted by the OpenMPIRBuilder at the top of the
155-
// function, before the condition. Allocas are only emitted by the builder when
156-
// the `if` clause is present. We match specific SSA value names since LLVM
157-
// actually produces those names.
158-
// CHECK: %tid.addr{{.*}} = alloca i32
159-
// CHECK: %zero.addr{{.*}} = alloca i32
160-
161-
// CHECK: %[[IF_COND_VAR_1:.*]] = icmp slt i32 %[[IF_VAR_1]], 0
162154
%0 = llvm.mlir.constant(0 : index) : i32
163155
%1 = llvm.icmp "slt" %arg0, %0 : i32
156+
// CHECK: %[[IF_COND_VAR_1:.*]] = icmp slt i32 %[[IF_VAR_1]], 0
157+
164158

165159
// CHECK: %[[GTN_IF_1:.*]] = call i32 @__kmpc_global_thread_num(ptr @[[SI_VAR_IF_1:.*]])
166-
// CHECK: br i1 %[[IF_COND_VAR_1]], label %[[IF_COND_TRUE_BLOCK_1:.*]], label %[[IF_COND_FALSE_BLOCK_1:.*]]
167-
// CHECK: [[IF_COND_TRUE_BLOCK_1]]:
168160
// CHECK: br label %[[OUTLINED_CALL_IF_BLOCK_1:.*]]
169161
// CHECK: [[OUTLINED_CALL_IF_BLOCK_1]]:
170-
// CHECK: call void {{.*}} @__kmpc_fork_call(ptr @[[SI_VAR_IF_1]], {{.*}} @[[OMP_OUTLINED_FN_IF_1:.*]])
162+
// CHECK: %[[I32_IF_COND_VAR_1:.*]] = sext i1 %[[IF_COND_VAR_1]] to i32
163+
// CHECK: call void @__kmpc_fork_call_if(ptr @[[SI_VAR_IF_1]], i32 0, ptr @[[OMP_OUTLINED_FN_IF_1:.*]], i32 %[[I32_IF_COND_VAR_1]], ptr null)
171164
// CHECK: br label %[[OUTLINED_EXIT_IF_1:.*]]
172165
// CHECK: [[OUTLINED_EXIT_IF_1]]:
173-
// CHECK: br label %[[OUTLINED_EXIT_IF_2:.*]]
174-
// CHECK: [[OUTLINED_EXIT_IF_2]]:
175166
// CHECK: br label %[[RETURN_BLOCK_IF_1:.*]]
176-
// CHECK: [[IF_COND_FALSE_BLOCK_1]]:
177-
// CHECK: call void @__kmpc_serialized_parallel(ptr @[[SI_VAR_IF_1]], i32 %[[GTN_IF_1]])
178-
// CHECK: call void @[[OMP_OUTLINED_FN_IF_1]]
179-
// CHECK: call void @__kmpc_end_serialized_parallel(ptr @[[SI_VAR_IF_1]], i32 %[[GTN_IF_1]])
180-
// CHECK: br label %[[RETURN_BLOCK_IF_1]]
181167
omp.parallel if(%1 : i1) {
182168
omp.barrier
183169
omp.terminator
@@ -193,58 +179,6 @@ llvm.func @test_omp_parallel_if_1(%arg0: i32) -> () {
193179

194180
// -----
195181

196-
// CHECK-LABEL: @test_nested_alloca_ip
197-
llvm.func @test_nested_alloca_ip(%arg0: i32) -> () {
198-
199-
// Check that the allocas are emitted by the OpenMPIRBuilder at the top of
200-
// the function, before the condition. Allocas are only emitted by the
201-
// builder when the `if` clause is present. We match specific SSA value names
202-
// since LLVM actually produces those names and ensure they come before the
203-
// "icmp" that is the first operation we emit.
204-
// CHECK: %tid.addr{{.*}} = alloca i32
205-
// CHECK: %zero.addr{{.*}} = alloca i32
206-
// CHECK: icmp slt i32 %{{.*}}, 0
207-
%0 = llvm.mlir.constant(0 : index) : i32
208-
%1 = llvm.icmp "slt" %arg0, %0 : i32
209-
210-
omp.parallel if(%1 : i1) {
211-
// The "parallel" operation will be outlined, check the the function is
212-
// produced. Inside that function, further allocas should be placed before
213-
// another "icmp".
214-
// CHECK: define
215-
// CHECK: %tid.addr{{.*}} = alloca i32
216-
// CHECK: %zero.addr{{.*}} = alloca i32
217-
// CHECK: icmp slt i32 %{{.*}}, 1
218-
%2 = llvm.mlir.constant(1 : index) : i32
219-
%3 = llvm.icmp "slt" %arg0, %2 : i32
220-
221-
omp.parallel if(%3 : i1) {
222-
// One more nesting level.
223-
// CHECK: define
224-
// CHECK: %tid.addr{{.*}} = alloca i32
225-
// CHECK: %zero.addr{{.*}} = alloca i32
226-
// CHECK: icmp slt i32 %{{.*}}, 2
227-
228-
%4 = llvm.mlir.constant(2 : index) : i32
229-
%5 = llvm.icmp "slt" %arg0, %4 : i32
230-
231-
omp.parallel if(%5 : i1) {
232-
omp.barrier
233-
omp.terminator
234-
}
235-
236-
omp.barrier
237-
omp.terminator
238-
}
239-
omp.barrier
240-
omp.terminator
241-
}
242-
243-
llvm.return
244-
}
245-
246-
// -----
247-
248182
// CHECK-LABEL: define void @test_omp_parallel_3()
249183
llvm.func @test_omp_parallel_3() -> () {
250184
// CHECK: [[OMP_THREAD_3_1:%.*]] = call i32 @__kmpc_global_thread_num(ptr @{{[0-9]+}})

openmp/runtime/src/kmp.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3901,6 +3901,9 @@ KMP_EXPORT kmp_int32 __kmpc_bound_num_threads(ident_t *);
39013901
KMP_EXPORT kmp_int32 __kmpc_ok_to_fork(ident_t *);
39023902
KMP_EXPORT void __kmpc_fork_call(ident_t *, kmp_int32 nargs,
39033903
kmpc_micro microtask, ...);
3904+
KMP_EXPORT void __kmpc_fork_call_if(ident_t *loc, kmp_int32 nargs,
3905+
kmpc_micro microtask, kmp_int32 cond,
3906+
void *args);
39043907

39053908
KMP_EXPORT void __kmpc_serialized_parallel(ident_t *, kmp_int32 global_tid);
39063909
KMP_EXPORT void __kmpc_end_serialized_parallel(ident_t *, kmp_int32 global_tid);

openmp/runtime/src/kmp_csupport.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,37 @@ void __kmpc_fork_call(ident_t *loc, kmp_int32 argc, kmpc_micro microtask, ...) {
330330
#endif // KMP_STATS_ENABLED
331331
}
332332

333+
/*!
334+
@ingroup PARALLEL
335+
@param loc source location information
336+
@param microtask pointer to callback routine consisting of outlined parallel
337+
construct
338+
@param cond condition for running in parallel
339+
@param args struct of pointers to shared variables that aren't global
340+
341+
Perform a fork only if the condition is true.
342+
*/
343+
void __kmpc_fork_call_if(ident_t *loc, kmp_int32 argc, kmpc_micro microtask,
344+
kmp_int32 cond, void *args) {
345+
int gtid = __kmp_entry_gtid();
346+
int zero = 0;
347+
if (cond) {
348+
if (args)
349+
__kmpc_fork_call(loc, argc, microtask, args);
350+
else
351+
__kmpc_fork_call(loc, argc, microtask);
352+
} else {
353+
__kmpc_serialized_parallel(loc, gtid);
354+
355+
if (args)
356+
microtask(&gtid, &zero, args);
357+
else
358+
microtask(&gtid, &zero);
359+
360+
__kmpc_end_serialized_parallel(loc, gtid);
361+
}
362+
}
363+
333364
/*!
334365
@ingroup PARALLEL
335366
@param loc source location information

openmp/runtime/test/lit.cfg

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ if 'INTEL_LICENSE_FILE' in os.environ:
133133
# substitutions
134134
config.substitutions.append(("%libomp-compile-and-run", \
135135
"%libomp-compile && %libomp-run"))
136+
config.substitutions.append(("%libomp-irbuilder-compile-and-run", \
137+
"%libomp-irbuilder-compile && %libomp-run"))
136138
config.substitutions.append(("%libomp-c99-compile-and-run", \
137139
"%libomp-c99-compile && %libomp-run"))
138140
config.substitutions.append(("%libomp-cxx-compile-and-run", \
@@ -143,6 +145,8 @@ config.substitutions.append(("%libomp-cxx-compile", \
143145
"%clangXX %openmp_flags %flags -std=c++17 %s -o %t" + libs))
144146
config.substitutions.append(("%libomp-compile", \
145147
"%clang %openmp_flags %flags %s -o %t" + libs))
148+
config.substitutions.append(("%libomp-irbuilder-compile", \
149+
"%clang %openmp_flags %flags -fopenmp-enable-irbuilder %s -o %t" + libs))
146150
config.substitutions.append(("%libomp-c99-compile", \
147151
"%clang %openmp_flags %flags -std=c99 %s -o %t" + libs))
148152
config.substitutions.append(("%libomp-run", "%t"))

openmp/runtime/test/parallel/omp_parallel_if.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: %libomp-compile-and-run
2+
// RUN: %libomp-irbuilder-compile-and-run
23
#include <stdio.h>
34
#include "omp_testsuite.h"
45

0 commit comments

Comments
 (0)