Skip to content

Commit b38b2ec

Browse files
committed
[mlir][OpenMP] Allow composite SIMD REDUCTION and IF
Reduction support: llvm#146671 If Support is fixed in this PR The problem for the IF clause in composite constructs was that wsloop and simd both operate on the same CanonicalLoopInfo structure: with the SIMD processed first, followed by the wsloop. Previously the IF clause generated code like if (cond) { while (...) { simd_loop_body; } } else { while (...) { nonsimd_loop_body; } } The problem with this is that this invalidates the CanonicalLoopInfo structure to be processed by the wsloop later. To avoid this, in this patch I preserve the original loop, moving the IF clause inside of the loop: while (...) { if (cond) { simd_loop_body; } else { non_simd_loop_body; } } On simple examples I tried LLVM was able to hoist the if condition outside of the loop at -O3. The disadvantage of this is that we cannot add the llvm.loop.vectorize.enable attribute on either the SIMD or non-SIMD loops because they both share a loop back edge. There's no way of solving this without keeping the old design of having two different loops: which cannot be represented using only one CanonicalLoopInfo structure. I don't think the presence or absence of this attribute makes much difference. In my testing it is the llvm.loop.parallel_access metadata which makes the difference to vectorization. LLVM will vectorize if legal whether or not this attribute is there in the TRUE branch. In the FALSE branch this means the loop might be vectorized even when the condition is false: but I think this is still standards compliant: OpenMP 6.0 says that when the if clause is false that should be treated like the SIMDLEN clause is one. The SIMDLEN clause is defined as a "hint". For the same reason, SIMDLEN and SAFELEN clauses are silently ignored when SIMD IF is used. I think it is better to implement SIMD IF and ignore SIMDLEN and SAFELEN and some vectorization encouragement metadata when combined with IF than to ignore IF because IF could have correctness consequences whereas the rest are optimiztion hints. For example, the user might use the IF clause to disable SIMD programatically when it is known not safe to vectorize the loop. In this case it is not at all safe to add the parallel access or SAFELEN metadata.
1 parent f271c6d commit b38b2ec

File tree

6 files changed

+150
-103
lines changed

6 files changed

+150
-103
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5373,8 +5373,27 @@ void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
53735373
const Twine &NamePrefix) {
53745374
Function *F = CanonicalLoop->getFunction();
53755375

5376+
// We can't do
5377+
// if (cond) {
5378+
// simd_loop;
5379+
// } else {
5380+
// non_simd_loop;
5381+
// }
5382+
// because then the CanonicalLoopInfo would only point to one of the loops:
5383+
// leading to other constructs operating on the same loop to malfunction.
5384+
// Instead generate
5385+
// while (...) {
5386+
// if (cond) {
5387+
// simd_body;
5388+
// } else {
5389+
// not_simd_body;
5390+
// }
5391+
// }
5392+
// At least for simple loops, LLVM seems able to hoist the if out of the loop
5393+
// body at -O3
5394+
53765395
// Define where if branch should be inserted
5377-
Instruction *SplitBefore = CanonicalLoop->getPreheader()->getTerminator();
5396+
auto SplitBeforeIt = CanonicalLoop->getBody()->getFirstNonPHIIt();
53785397

53795398
// TODO: We should not rely on pass manager. Currently we use pass manager
53805399
// only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
@@ -5391,37 +5410,51 @@ void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
53915410
Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
53925411

53935412
// Create additional blocks for the if statement
5394-
BasicBlock *Head = SplitBefore->getParent();
5395-
Instruction *HeadOldTerm = Head->getTerminator();
5396-
llvm::LLVMContext &C = Head->getContext();
5413+
BasicBlock *Cond = SplitBeforeIt->getParent();
5414+
Instruction *CondOldTerm = Cond->getTerminator();
5415+
llvm::LLVMContext &C = Cond->getContext();
53975416
llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
5398-
C, NamePrefix + ".if.then", Head->getParent(), Head->getNextNode());
5417+
C, NamePrefix + ".if.then", Cond->getParent(), Cond->getNextNode());
53995418
llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
5400-
C, NamePrefix + ".if.else", Head->getParent(), CanonicalLoop->getExit());
5419+
C, NamePrefix + ".if.else", Cond->getParent(), CanonicalLoop->getExit());
54015420

54025421
// Create if condition branch.
5403-
Builder.SetInsertPoint(HeadOldTerm);
5422+
Builder.SetInsertPoint(CondOldTerm);
54045423
Instruction *BrInstr =
54055424
Builder.CreateCondBr(IfCond, ThenBlock, /*ifFalse*/ ElseBlock);
54065425
InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
5407-
// Then block contains branch to omp loop which needs to be vectorized
5426+
// Then block contains branch to omp loop body which needs to be vectorized
54085427
spliceBB(IP, ThenBlock, false, Builder.getCurrentDebugLocation());
5409-
ThenBlock->replaceSuccessorsPhiUsesWith(Head, ThenBlock);
5428+
ThenBlock->replaceSuccessorsPhiUsesWith(Cond, ThenBlock);
54105429

54115430
Builder.SetInsertPoint(ElseBlock);
54125431

54135432
// Clone loop for the else branch
54145433
SmallVector<BasicBlock *, 8> NewBlocks;
54155434

5416-
VMap[CanonicalLoop->getPreheader()] = ElseBlock;
5435+
// Cond is the block that has the if clause condition
5436+
// LoopCond is omp_loop.cond
5437+
// LoopHeader is omp_loop.header
5438+
BasicBlock *LoopCond = Cond->getUniquePredecessor();
5439+
BasicBlock *LoopHeader = LoopCond->getUniquePredecessor();
5440+
assert(LoopCond && LoopHeader && "Invalid loop structure");
54175441
for (BasicBlock *Block : L->getBlocks()) {
5442+
if (Block == L->getLoopPreheader() || Block == L->getLoopLatch() ||
5443+
Block == LoopHeader || Block == LoopCond || Block == Cond) {
5444+
continue;
5445+
}
54185446
BasicBlock *NewBB = CloneBasicBlock(Block, VMap, "", F);
54195447
NewBB->moveBefore(CanonicalLoop->getExit());
54205448
VMap[Block] = NewBB;
54215449
NewBlocks.push_back(NewBB);
54225450
}
54235451
remapInstructionsInBlocks(NewBlocks, VMap);
54245452
Builder.CreateBr(NewBlocks.front());
5453+
5454+
// The loop latch must have only one predecessor. Currently it is branched to
5455+
// from both the 'then' and 'else' branches.
5456+
L->getLoopLatch()->splitBasicBlock(
5457+
L->getLoopLatch()->begin(), NamePrefix + ".pre_latch", /*Before=*/true);
54255458
}
54265459

54275460
unsigned
@@ -5478,19 +5511,6 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
54785511
if (IfCond) {
54795512
ValueToValueMapTy VMap;
54805513
createIfVersion(CanonicalLoop, IfCond, VMap, "simd");
5481-
// Add metadata to the cloned loop which disables vectorization
5482-
Value *MappedLatch = VMap.lookup(CanonicalLoop->getLatch());
5483-
assert(MappedLatch &&
5484-
"Cannot find value which corresponds to original loop latch");
5485-
assert(isa<BasicBlock>(MappedLatch) &&
5486-
"Cannot cast mapped latch block value to BasicBlock");
5487-
BasicBlock *NewLatchBlock = dyn_cast<BasicBlock>(MappedLatch);
5488-
ConstantAsMetadata *BoolConst =
5489-
ConstantAsMetadata::get(ConstantInt::getFalse(Type::getInt1Ty(Ctx)));
5490-
addBasicBlockMetadata(
5491-
NewLatchBlock,
5492-
{MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"),
5493-
BoolConst})});
54945514
}
54955515

54965516
SmallSet<BasicBlock *, 8> Reachable;
@@ -5524,6 +5544,14 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
55245544
Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"), AccessGroup}));
55255545
}
55265546

5547+
// FIXME: the IF clause shares a loop backedge for the SIMD and non-SIMD
5548+
// versions so we can't add the loop attributes in that case.
5549+
if (IfCond) {
5550+
// we can still add llvm.loop.parallel_access
5551+
addLoopMetadata(CanonicalLoop, LoopMDList);
5552+
return;
5553+
}
5554+
55275555
// Use the above access group metadata to create loop level
55285556
// metadata, which should be distinct for each loop.
55295557
ConstantAsMetadata *BoolConst =

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -702,30 +702,6 @@ static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation,
702702
moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
703703
}
704704

705-
/// Helper function to map block arguments defined by ignored loop wrappers to
706-
/// LLVM values and prevent any uses of those from triggering null pointer
707-
/// dereferences.
708-
///
709-
/// This must be called after block arguments of parent wrappers have already
710-
/// been mapped to LLVM IR values.
711-
static LogicalResult
712-
convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
713-
LLVM::ModuleTranslation &moduleTranslation) {
714-
// Map block arguments directly to the LLVM value associated to the
715-
// corresponding operand. This is semantically equivalent to this wrapper not
716-
// being present.
717-
return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
718-
.Case([&](omp::SimdOp op) {
719-
forwardArgs(moduleTranslation,
720-
cast<omp::BlockArgOpenMPOpInterface>(*op));
721-
op.emitWarning() << "simd information on composite construct discarded";
722-
return success();
723-
})
724-
.Default([&](Operation *op) {
725-
return op->emitError() << "cannot ignore wrapper";
726-
});
727-
}
728-
729705
/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
730706
static LogicalResult
731707
convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -2852,17 +2828,6 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
28522828
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
28532829
auto simdOp = cast<omp::SimdOp>(opInst);
28542830

2855-
// Ignore simd in composite constructs with unsupported clauses
2856-
// TODO: Replace this once simd + clause combinations are properly supported
2857-
if (simdOp.isComposite() &&
2858-
(simdOp.getReductionByref().has_value() || simdOp.getIfExpr())) {
2859-
if (failed(convertIgnoredWrapper(simdOp, moduleTranslation)))
2860-
return failure();
2861-
2862-
return inlineConvertOmpRegions(simdOp.getRegion(), "omp.simd.region",
2863-
builder, moduleTranslation);
2864-
}
2865-
28662831
if (failed(checkImplementationStatus(opInst)))
28672832
return failure();
28682833

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
llvm.func @_QPfoo(%arg0: !llvm.ptr {fir.bindc_name = "array", llvm.nocapture}, %arg1: !llvm.ptr {fir.bindc_name = "t", llvm.nocapture}) {
4+
%0 = llvm.mlir.constant(0 : i64) : i32
5+
%1 = llvm.mlir.constant(1 : i32) : i32
6+
%2 = llvm.mlir.constant(10 : i64) : i64
7+
%3 = llvm.mlir.constant(1 : i64) : i64
8+
%4 = llvm.alloca %3 x i32 {bindc_name = "i", pinned} : (i64) -> !llvm.ptr
9+
%5 = llvm.load %arg1 : !llvm.ptr -> i32
10+
%6 = llvm.icmp "ne" %5, %0 : i32
11+
%7 = llvm.trunc %2 : i64 to i32
12+
omp.wsloop {
13+
omp.simd if(%6) {
14+
omp.loop_nest (%arg2) : i32 = (%1) to (%7) inclusive step (%1) {
15+
llvm.store %arg2, %4 : i32, !llvm.ptr
16+
%8 = llvm.load %4 : !llvm.ptr -> i32
17+
%9 = llvm.sext %8 : i32 to i64
18+
%10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr, i64) -> !llvm.ptr, i32
19+
llvm.store %8, %10 : i32, !llvm.ptr
20+
omp.yield
21+
}
22+
} {omp.composite}
23+
} {omp.composite}
24+
llvm.return
25+
}
26+
27+
// CHECK-LABEL: @_QPfoo
28+
// ...
29+
// CHECK: omp_loop.preheader: ; preds =
30+
// CHECK: store i32 0, ptr %[[LB_ADDR:.*]], align 4
31+
// CHECK: store i32 9, ptr %[[UB_ADDR:.*]], align 4
32+
// CHECK: store i32 1, ptr %[[STEP_ADDR:.*]], align 4
33+
// CHECK: %[[VAL_15:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
34+
// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %[[VAL_15]], i32 34, ptr %{{.*}}, ptr %[[LB_ADDR]], ptr %[[UB_ADDR]], ptr %[[STEP_ADDR]], i32 1, i32 0)
35+
// CHECK: %[[LB:.*]] = load i32, ptr %[[LB_ADDR]], align 4
36+
// CHECK: %[[UB:.*]] = load i32, ptr %[[UB_ADDR]], align 4
37+
// CHECK: %[[VAL_18:.*]] = sub i32 %[[UB]], %[[LB]]
38+
// CHECK: %[[COUNT:.*]] = add i32 %[[VAL_18]], 1
39+
// CHECK: br label %[[OMP_LOOP_HEADER:.*]]
40+
// CHECK: omp_loop.header: ; preds = %[[OMP_LOOP_INC:.*]], %[[OMP_LOOP_PREHEADER:.*]]
41+
// CHECK: %[[IV:.*]] = phi i32 [ 0, %[[OMP_LOOP_PREHEADER]] ], [ %[[NEW_IV:.*]], %[[OMP_LOOP_INC]] ]
42+
// CHECK: br label %[[OMP_LOOP_COND:.*]]
43+
// CHECK: omp_loop.cond: ; preds = %[[OMP_LOOP_HEADER]]
44+
// CHECK: %[[VAL_25:.*]] = icmp ult i32 %[[IV]], %[[COUNT]]
45+
// CHECK: br i1 %[[VAL_25]], label %[[OMP_LOOP_BODY:.*]], label %[[OMP_LOOP_EXIT:.*]]
46+
// CHECK: omp_loop.body: ; preds = %[[OMP_LOOP_COND]]
47+
// CHECK: %[[VAL_28:.*]] = add i32 %[[IV]], %[[LB]]
48+
// CHECK: %[[VAL_29:.*]] = mul i32 %[[VAL_28]], 1
49+
// CHECK: %[[VAL_30:.*]] = add i32 %[[VAL_29]], 1
50+
// This is the IF clause:
51+
// CHECK: br i1 %{{.*}}, label %[[SIMD_IF_THEN:.*]], label %[[SIMD_IF_ELSE:.*]]
52+
53+
// CHECK: simd.if.then: ; preds = %[[OMP_LOOP_BODY]]
54+
// CHECK: br label %[[VAL_33:.*]]
55+
// CHECK: omp.loop_nest.region: ; preds = %[[SIMD_IF_THEN]]
56+
// This version contains !llvm.access.group metadata for SIMD
57+
// CHECK: store i32 %[[VAL_30]], ptr %{{.*}}, align 4, !llvm.access.group !1
58+
// CHECK: %[[VAL_34:.*]] = load i32, ptr %{{.*}}, align 4, !llvm.access.group !1
59+
// CHECK: %[[VAL_35:.*]] = sext i32 %[[VAL_34]] to i64
60+
// CHECK: %[[VAL_36:.*]] = getelementptr i32, ptr %[[VAL_37:.*]], i64 %[[VAL_35]]
61+
// CHECK: store i32 %[[VAL_34]], ptr %[[VAL_36]], align 4, !llvm.access.group !1
62+
// CHECK: br label %[[OMP_REGION_CONT3:.*]]
63+
// CHECK: omp.region.cont3: ; preds = %[[VAL_33]]
64+
// CHECK: br label %[[SIMD_PRE_LATCH:.*]]
65+
66+
// CHECK: simd.pre_latch: ; preds = %[[OMP_REGION_CONT3]], %[[OMP_REGION_CONT35:.*]]
67+
// CHECK: br label %[[OMP_LOOP_INC]]
68+
// CHECK: omp_loop.inc: ; preds = %[[SIMD_PRE_LATCH]]
69+
// CHECK: %[[NEW_IV]] = add nuw i32 %[[IV]], 1
70+
// CHECK: br label %[[OMP_LOOP_HEADER]], !llvm.loop !2
71+
72+
// CHECK: simd.if.else: ; preds = %[[OMP_LOOP_BODY]]
73+
// CHECK: br label %[[VAL_41:.*]]
74+
// CHECK: omp.loop_nest.region4: ; preds = %[[SIMD_IF_ELSE]]
75+
// No llvm.access.group metadata for else clause
76+
// CHECK: store i32 %[[VAL_30]], ptr %{{.*}}, align 4
77+
// CHECK: %[[VAL_42:.*]] = load i32, ptr %{{.*}}, align 4
78+
// CHECK: %[[VAL_43:.*]] = sext i32 %[[VAL_42]] to i64
79+
// CHECK: %[[VAL_44:.*]] = getelementptr i32, ptr %[[VAL_37]], i64 %[[VAL_43]]
80+
// CHECK: store i32 %[[VAL_42]], ptr %[[VAL_44]], align 4
81+
// CHECK: br label %[[OMP_REGION_CONT35]]
82+
// CHECK: omp.region.cont35: ; preds = %[[VAL_41]]
83+
// CHECK: br label %[[SIMD_PRE_LATCH]]
84+
85+
// CHECK: omp_loop.exit: ; preds = %[[OMP_LOOP_COND]]
86+
// CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_15]])
87+
// CHECK: %[[VAL_45:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
88+
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_45]])
89+
90+
// CHECK: !1 = distinct !{}
91+
// CHECK: !2 = distinct !{!2, !3}
92+
// CHECK: !3 = !{!"llvm.loop.parallel_accesses", !1}
93+
// CHECK-NOT: llvm.loop.vectorize

mlir/test/Target/LLVMIR/openmp-llvm.mlir

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -820,8 +820,6 @@ llvm.func @simd_if(%arg0: !llvm.ptr {fir.bindc_name = "n"}, %arg1: !llvm.ptr {fi
820820
}
821821
// Be sure that llvm.loop.vectorize.enable metadata appears twice
822822
// CHECK: llvm.loop.parallel_accesses
823-
// CHECK-NEXT: llvm.loop.vectorize.enable
824-
// CHECK: llvm.loop.vectorize.enable
825823

826824
// -----
827825

mlir/test/Target/LLVMIR/openmp-reduction.mlir

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -637,9 +637,12 @@ llvm.func @wsloop_simd_reduction(%lb : i64, %ub : i64, %step : i64) {
637637
// Outlined function.
638638
// CHECK: define internal void @[[OUTLINED]]
639639

640-
// Private reduction variable and its initialization.
640+
// reduction variable in wsloop
641641
// CHECK: %[[PRIVATE:.+]] = alloca float
642+
// reduction variable in simd
643+
// CHECK: %[[PRIVATE2:.+]] = alloca float
642644
// CHECK: store float 0.000000e+00, ptr %[[PRIVATE]]
645+
// CHECK: store float 0.000000e+00, ptr %[[PRIVATE2]]
643646

644647
// Call to the reduction function.
645648
// CHECK: call i32 @__kmpc_reduce
@@ -659,9 +662,9 @@ llvm.func @wsloop_simd_reduction(%lb : i64, %ub : i64, %step : i64) {
659662

660663
// Update of the private variable using the reduction region
661664
// (the body block currently comes after all the other blocks).
662-
// CHECK: %[[PARTIAL:.+]] = load float, ptr %[[PRIVATE]]
665+
// CHECK: %[[PARTIAL:.+]] = load float, ptr %[[PRIVATE2]]
663666
// CHECK: %[[UPDATED:.+]] = fadd float 2.000000e+00, %[[PARTIAL]]
664-
// CHECK: store float %[[UPDATED]], ptr %[[PRIVATE]]
667+
// CHECK: store float %[[UPDATED]], ptr %[[PRIVATE2]]
665668

666669
// Reduction function.
667670
// CHECK: define internal void @[[REDFUNC]]

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -489,43 +489,3 @@ llvm.func @wsloop_order(%lb : i32, %ub : i32, %step : i32) {
489489
}
490490
llvm.return
491491
}
492-
493-
// -----
494-
495-
llvm.func @do_simd_if(%1 : !llvm.ptr, %5 : i32, %4 : i32, %6 : i1) {
496-
omp.wsloop {
497-
// expected-warning@below {{simd information on composite construct discarded}}
498-
omp.simd if(%6) {
499-
omp.loop_nest (%arg0) : i32 = (%5) to (%4) inclusive step (%5) {
500-
llvm.store %arg0, %1 : i32, !llvm.ptr
501-
omp.yield
502-
}
503-
} {omp.composite}
504-
} {omp.composite}
505-
llvm.return
506-
}
507-
508-
// -----
509-
510-
omp.declare_reduction @add_reduction_i32 : i32 init {
511-
^bb0(%arg0: i32):
512-
%0 = llvm.mlir.constant(0 : i32) : i32
513-
omp.yield(%0 : i32)
514-
} combiner {
515-
^bb0(%arg0: i32, %arg1: i32):
516-
%0 = llvm.add %arg0, %arg1 : i32
517-
omp.yield(%0 : i32)
518-
}
519-
llvm.func @do_simd_reduction(%1 : !llvm.ptr, %3 : !llvm.ptr, %6 : i32, %7 : i32) {
520-
omp.wsloop reduction(@add_reduction_i32 %3 -> %arg0 : !llvm.ptr) {
521-
// expected-warning@below {{simd information on composite construct discarded}}
522-
omp.simd reduction(@add_reduction_i32 %arg0 -> %arg1 : !llvm.ptr) {
523-
omp.loop_nest (%arg2) : i32 = (%7) to (%6) inclusive step (%7) {
524-
llvm.store %arg2, %1 : i32, !llvm.ptr
525-
%12 = llvm.load %arg1 : !llvm.ptr -> i32
526-
omp.yield
527-
}
528-
} {omp.composite}
529-
} {omp.composite}
530-
llvm.return
531-
}

0 commit comments

Comments
 (0)