Skip to content

Commit eaba8d9

Browse files
Mel-Chentru
authored andcommitted
[VP] Refactor VectorBuilder to avoid layering violation. NFC (llvm#99276)
This patch refactors the handling of reduction to eliminate layering violations. * Introduced `getReductionIntrinsicID` in LoopUtils.h for mapping recurrence kinds to llvm.vector.reduce.* intrinsic IDs. * Updated `VectorBuilder::createSimpleTargetReduction` to accept llvm.vector.reduce.* intrinsic directly. * New function `VPIntrinsic::getForIntrinsic` for mapping intrinsic ID to the same functional VP intrinsic ID. (cherry picked from commit 6d12b3f)
1 parent 4bf04b2 commit eaba8d9

File tree

7 files changed

+129
-57
lines changed

7 files changed

+129
-57
lines changed

llvm/include/llvm/IR/IntrinsicInst.h

+4
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,10 @@ class VPIntrinsic : public IntrinsicInst {
569569
/// The llvm.vp.* intrinsics for this instruction Opcode
570570
static Intrinsic::ID getForOpcode(unsigned OC);
571571

572+
/// The llvm.vp.* intrinsics for this intrinsic ID \p Id. Return \p Id if it
573+
/// is already a VP intrinsic.
574+
static Intrinsic::ID getForIntrinsic(Intrinsic::ID Id);
575+
572576
// Whether \p ID is a VP intrinsic ID.
573577
static bool isVPIntrinsic(Intrinsic::ID);
574578

llvm/include/llvm/IR/VectorBuilder.h

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#ifndef LLVM_IR_VECTORBUILDER_H
1616
#define LLVM_IR_VECTORBUILDER_H
1717

18-
#include <llvm/Analysis/IVDescriptors.h>
1918
#include <llvm/IR/IRBuilder.h>
2019
#include <llvm/IR/InstrTypes.h>
2120
#include <llvm/IR/Instruction.h>
@@ -100,11 +99,11 @@ class VectorBuilder {
10099
const Twine &Name = Twine());
101100

102101
/// Emit a VP reduction intrinsic call for recurrence kind.
103-
/// \param Kind The kind of recurrence
102+
/// \param RdxID The intrinsic ID of llvm.vector.reduce.*
104103
/// \param ValTy The type of operand which the reduction operation is
105104
/// performed.
106105
/// \param VecOpArray The operand list.
107-
Value *createSimpleTargetReduction(RecurKind Kind, Type *ValTy,
106+
Value *createSimpleTargetReduction(Intrinsic::ID RdxID, Type *ValTy,
108107
ArrayRef<Value *> VecOpArray,
109108
const Twine &Name = Twine());
110109
};

llvm/include/llvm/Transforms/Utils/LoopUtils.h

+4
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,10 @@ bool canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
359359
SinkAndHoistLICMFlags &LICMFlags,
360360
OptimizationRemarkEmitter *ORE = nullptr);
361361

362+
/// Returns the llvm.vector.reduce intrinsic that corresponds to the recurrence
363+
/// kind.
364+
constexpr Intrinsic::ID getReductionIntrinsicID(RecurKind RK);
365+
362366
/// Returns the arithmetic instruction opcode used when expanding a reduction.
363367
unsigned getArithmeticReductionInstruction(Intrinsic::ID RdxID);
364368

llvm/lib/IR/IntrinsicInst.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,25 @@ Intrinsic::ID VPIntrinsic::getForOpcode(unsigned IROPC) {
599599
return Intrinsic::not_intrinsic;
600600
}
601601

602+
constexpr static Intrinsic::ID getForIntrinsic(Intrinsic::ID Id) {
603+
if (::isVPIntrinsic(Id))
604+
return Id;
605+
606+
switch (Id) {
607+
default:
608+
break;
609+
#define BEGIN_REGISTER_VP_INTRINSIC(VPID, ...) break;
610+
#define VP_PROPERTY_FUNCTIONAL_INTRINSIC(INTRIN) case Intrinsic::INTRIN:
611+
#define END_REGISTER_VP_INTRINSIC(VPID) return Intrinsic::VPID;
612+
#include "llvm/IR/VPIntrinsics.def"
613+
}
614+
return Intrinsic::not_intrinsic;
615+
}
616+
617+
Intrinsic::ID VPIntrinsic::getForIntrinsic(Intrinsic::ID Id) {
618+
return ::getForIntrinsic(Id);
619+
}
620+
602621
bool VPIntrinsic::canIgnoreVectorLengthParam() const {
603622
using namespace PatternMatch;
604623

llvm/lib/IR/VectorBuilder.cpp

+5-52
Original file line numberDiff line numberDiff line change
@@ -60,60 +60,13 @@ Value *VectorBuilder::createVectorInstruction(unsigned Opcode, Type *ReturnTy,
6060
return createVectorInstructionImpl(VPID, ReturnTy, InstOpArray, Name);
6161
}
6262

63-
Value *VectorBuilder::createSimpleTargetReduction(RecurKind Kind, Type *ValTy,
63+
Value *VectorBuilder::createSimpleTargetReduction(Intrinsic::ID RdxID,
64+
Type *ValTy,
6465
ArrayRef<Value *> InstOpArray,
6566
const Twine &Name) {
66-
Intrinsic::ID VPID;
67-
switch (Kind) {
68-
case RecurKind::Add:
69-
VPID = Intrinsic::vp_reduce_add;
70-
break;
71-
case RecurKind::Mul:
72-
VPID = Intrinsic::vp_reduce_mul;
73-
break;
74-
case RecurKind::And:
75-
VPID = Intrinsic::vp_reduce_and;
76-
break;
77-
case RecurKind::Or:
78-
VPID = Intrinsic::vp_reduce_or;
79-
break;
80-
case RecurKind::Xor:
81-
VPID = Intrinsic::vp_reduce_xor;
82-
break;
83-
case RecurKind::FMulAdd:
84-
case RecurKind::FAdd:
85-
VPID = Intrinsic::vp_reduce_fadd;
86-
break;
87-
case RecurKind::FMul:
88-
VPID = Intrinsic::vp_reduce_fmul;
89-
break;
90-
case RecurKind::SMax:
91-
VPID = Intrinsic::vp_reduce_smax;
92-
break;
93-
case RecurKind::SMin:
94-
VPID = Intrinsic::vp_reduce_smin;
95-
break;
96-
case RecurKind::UMax:
97-
VPID = Intrinsic::vp_reduce_umax;
98-
break;
99-
case RecurKind::UMin:
100-
VPID = Intrinsic::vp_reduce_umin;
101-
break;
102-
case RecurKind::FMax:
103-
VPID = Intrinsic::vp_reduce_fmax;
104-
break;
105-
case RecurKind::FMin:
106-
VPID = Intrinsic::vp_reduce_fmin;
107-
break;
108-
case RecurKind::FMaximum:
109-
VPID = Intrinsic::vp_reduce_fmaximum;
110-
break;
111-
case RecurKind::FMinimum:
112-
VPID = Intrinsic::vp_reduce_fminimum;
113-
break;
114-
default:
115-
llvm_unreachable("No VPIntrinsic for this reduction");
116-
}
67+
auto VPID = VPIntrinsic::getForIntrinsic(RdxID);
68+
assert(VPReductionIntrinsic::isVPReduction(VPID) &&
69+
"No VPIntrinsic for this reduction");
11770
return createVectorInstructionImpl(VPID, ValTy, InstOpArray, Name);
11871
}
11972

llvm/lib/Transforms/Utils/LoopUtils.cpp

+42-2
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,44 @@ bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop,
918918
return true;
919919
}
920920

921+
constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) {
922+
switch (RK) {
923+
default:
924+
llvm_unreachable("Unexpected recurrence kind");
925+
case RecurKind::Add:
926+
return Intrinsic::vector_reduce_add;
927+
case RecurKind::Mul:
928+
return Intrinsic::vector_reduce_mul;
929+
case RecurKind::And:
930+
return Intrinsic::vector_reduce_and;
931+
case RecurKind::Or:
932+
return Intrinsic::vector_reduce_or;
933+
case RecurKind::Xor:
934+
return Intrinsic::vector_reduce_xor;
935+
case RecurKind::FMulAdd:
936+
case RecurKind::FAdd:
937+
return Intrinsic::vector_reduce_fadd;
938+
case RecurKind::FMul:
939+
return Intrinsic::vector_reduce_fmul;
940+
case RecurKind::SMax:
941+
return Intrinsic::vector_reduce_smax;
942+
case RecurKind::SMin:
943+
return Intrinsic::vector_reduce_smin;
944+
case RecurKind::UMax:
945+
return Intrinsic::vector_reduce_umax;
946+
case RecurKind::UMin:
947+
return Intrinsic::vector_reduce_umin;
948+
case RecurKind::FMax:
949+
return Intrinsic::vector_reduce_fmax;
950+
case RecurKind::FMin:
951+
return Intrinsic::vector_reduce_fmin;
952+
case RecurKind::FMaximum:
953+
return Intrinsic::vector_reduce_fmaximum;
954+
case RecurKind::FMinimum:
955+
return Intrinsic::vector_reduce_fminimum;
956+
}
957+
}
958+
921959
unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
922960
switch (RdxID) {
923961
case Intrinsic::vector_reduce_fadd:
@@ -1215,12 +1253,13 @@ Value *llvm::createSimpleTargetReduction(VectorBuilder &VBuilder, Value *Src,
12151253
RecurKind Kind = Desc.getRecurrenceKind();
12161254
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
12171255
"AnyOf reduction is not supported.");
1256+
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
12181257
auto *SrcTy = cast<VectorType>(Src->getType());
12191258
Type *SrcEltTy = SrcTy->getElementType();
12201259
Value *Iden =
12211260
Desc.getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags());
12221261
Value *Ops[] = {Iden, Src};
1223-
return VBuilder.createSimpleTargetReduction(Kind, SrcTy, Ops);
1262+
return VBuilder.createSimpleTargetReduction(Id, SrcTy, Ops);
12241263
}
12251264

12261265
Value *llvm::createTargetReduction(IRBuilderBase &B,
@@ -1260,9 +1299,10 @@ Value *llvm::createOrderedReduction(VectorBuilder &VBuilder,
12601299
assert(Src->getType()->isVectorTy() && "Expected a vector type");
12611300
assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
12621301

1302+
Intrinsic::ID Id = getReductionIntrinsicID(RecurKind::FAdd);
12631303
auto *SrcTy = cast<VectorType>(Src->getType());
12641304
Value *Ops[] = {Start, Src};
1265-
return VBuilder.createSimpleTargetReduction(RecurKind::FAdd, SrcTy, Ops);
1305+
return VBuilder.createSimpleTargetReduction(Id, SrcTy, Ops);
12661306
}
12671307

12681308
void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue,

llvm/unittests/IR/VPIntrinsicTest.cpp

+53
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,59 @@ TEST_F(VPIntrinsicTest, IntrinsicIDRoundTrip) {
367367
ASSERT_NE(FullTripCounts, 0u);
368368
}
369369

370+
/// Check that going from intrinsic to VP intrinsic and back results in the same
371+
/// intrinsic.
372+
TEST_F(VPIntrinsicTest, IntrinsicToVPRoundTrip) {
373+
bool IsFullTrip = false;
374+
Intrinsic::ID IntrinsicID = Intrinsic::not_intrinsic + 1;
375+
for (; IntrinsicID < Intrinsic::num_intrinsics; IntrinsicID++) {
376+
Intrinsic::ID VPID = VPIntrinsic::getForIntrinsic(IntrinsicID);
377+
// No equivalent VP intrinsic available.
378+
if (VPID == Intrinsic::not_intrinsic)
379+
continue;
380+
381+
// Return itself if passed intrinsic ID is VP intrinsic.
382+
if (VPIntrinsic::isVPIntrinsic(IntrinsicID)) {
383+
ASSERT_EQ(IntrinsicID, VPID);
384+
continue;
385+
}
386+
387+
std::optional<Intrinsic::ID> RoundTripIntrinsicID =
388+
VPIntrinsic::getFunctionalIntrinsicIDForVP(VPID);
389+
// No equivalent non-predicated intrinsic available.
390+
if (!RoundTripIntrinsicID)
391+
continue;
392+
393+
ASSERT_EQ(*RoundTripIntrinsicID, IntrinsicID);
394+
IsFullTrip = true;
395+
}
396+
ASSERT_TRUE(IsFullTrip);
397+
}
398+
399+
/// Check that going from VP intrinsic to equivalent non-predicated intrinsic
400+
/// and back results in the same intrinsic.
401+
TEST_F(VPIntrinsicTest, VPToNonPredIntrinsicRoundTrip) {
402+
std::unique_ptr<Module> M = createVPDeclarationModule();
403+
assert(M);
404+
405+
bool IsFullTrip = false;
406+
for (const auto &VPDecl : *M) {
407+
auto VPID = VPDecl.getIntrinsicID();
408+
std::optional<Intrinsic::ID> NonPredID =
409+
VPIntrinsic::getFunctionalIntrinsicIDForVP(VPID);
410+
411+
// No equivalent non-predicated intrinsic available
412+
if (!NonPredID)
413+
continue;
414+
415+
Intrinsic::ID RoundTripVPID = VPIntrinsic::getForIntrinsic(*NonPredID);
416+
417+
ASSERT_EQ(RoundTripVPID, VPID);
418+
IsFullTrip = true;
419+
}
420+
ASSERT_TRUE(IsFullTrip);
421+
}
422+
370423
/// Check that VPIntrinsic::getDeclarationForParams works.
371424
TEST_F(VPIntrinsicTest, VPIntrinsicDeclarationForParams) {
372425
std::unique_ptr<Module> M = createVPDeclarationModule();

0 commit comments

Comments
 (0)