Skip to content

Commit

Permalink
[SLP] Remove LHS and RHS from OperationData.
Browse files Browse the repository at this point in the history
These were only really used for 2 things. One was to check if the operand matches the phi if it exists. The other was for the createOp method to build the reduction.

For the first case we still have the operation we just need to know how to index its operands. So I've modified getLHS/getRHS to just use the opcode/kind to know how to find the right operands on an instruction that is now passed in.

For the other case we had to create an OperationData object to set the LHS/RHS values and copy the opcode/kind from another object. We would then just call createOp on that temporary object. Instead I've made LHS/RHS arguments to createOp and removed all these temporary objects.

Differential Revision: https://reviews.llvm.org/D88193
  • Loading branch information
topperc committed Sep 24, 2020
1 parent d1419c9 commit 03f22b0
Showing 1 changed file with 50 additions and 67 deletions.
117 changes: 50 additions & 67 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6312,20 +6312,13 @@ class HorizontalReduction {
/// Opcode of the instruction.
unsigned Opcode = 0;

/// Left operand of the reduction operation.
Value *LHS = nullptr;

/// Right operand of the reduction operation.
Value *RHS = nullptr;

/// Kind of the reduction operation.
ReductionKind Kind = RK_None;

/// Checks if the reduction operation can be vectorized.
bool isVectorizable() const {
return LHS && RHS &&
// We currently only support add/mul/logical && min/max reductions.
((Kind == RK_Arithmetic &&
// We currently only support add/mul/logical && min/max reductions.
return ((Kind == RK_Arithmetic &&
(Opcode == Instruction::Add || Opcode == Instruction::FAdd ||
Opcode == Instruction::Mul || Opcode == Instruction::FMul ||
Opcode == Instruction::And || Opcode == Instruction::Or ||
Expand All @@ -6336,7 +6329,8 @@ class HorizontalReduction {
}

/// Creates reduction operation with the current opcode.
Value *createOp(IRBuilder<> &Builder, const Twine &Name) const {
Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS,
const Twine &Name) const {
assert(isVectorizable() &&
"Expected add|fadd or min/max reduction operation.");
Value *Cmp = nullptr;
Expand Down Expand Up @@ -6377,8 +6371,8 @@ class HorizontalReduction {

/// Constructor for reduction operations with opcode and its left and
/// right operands.
OperationData(unsigned Opcode, Value *LHS, Value *RHS, ReductionKind Kind)
: Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind) {
OperationData(unsigned Opcode, ReductionKind Kind)
: Opcode(Opcode), Kind(Kind) {
assert(Kind != RK_None && "One of the reduction operations is expected.");
}

Expand Down Expand Up @@ -6411,16 +6405,14 @@ class HorizontalReduction {

/// Total number of operands in the reduction operation.
unsigned getNumberOfOperands() const {
assert(Kind != RK_None && !!*this && LHS && RHS &&
"Expected reduction operation.");
assert(Kind != RK_None && !!*this && "Expected reduction operation.");
return isMinMax() ? 3 : 2;
}

/// Checks if the instruction is in basic block \p BB.
/// For a min/max reduction check that both compare and select are in \p BB.
bool hasSameParent(Instruction *I, BasicBlock *BB, bool IsRedOp) const {
assert(Kind != RK_None && !!*this && LHS && RHS &&
"Expected reduction operation.");
assert(Kind != RK_None && !!*this && "Expected reduction operation.");
if (IsRedOp && isMinMax()) {
auto *Cmp = cast<Instruction>(cast<SelectInst>(I)->getCondition());
return I->getParent() == BB && Cmp && Cmp->getParent() == BB;
Expand All @@ -6430,8 +6422,7 @@ class HorizontalReduction {

/// Expected number of uses for reduction operations/reduced values.
bool hasRequiredNumberOfUses(Instruction *I, bool IsReductionOp) const {
assert(Kind != RK_None && !!*this && LHS && RHS &&
"Expected reduction operation.");
assert(Kind != RK_None && !!*this && "Expected reduction operation.");
// SelectInst must be used twice while the condition op must have single
// use only.
if (isMinMax())
Expand All @@ -6445,8 +6436,7 @@ class HorizontalReduction {

/// Initializes the list of reduction operations.
void initReductionOps(ReductionOpsListType &ReductionOps) {
assert(Kind != RK_None && !!*this && LHS && RHS &&
"Expected reduction operation.");
assert(Kind != RK_None && !!*this && "Expected reduction operation.");
if (isMinMax())
ReductionOps.assign(2, ReductionOpsType());
else
Expand All @@ -6455,8 +6445,7 @@ class HorizontalReduction {

/// Add all reduction operations for the reduction instruction \p I.
void addReductionOps(Instruction *I, ReductionOpsListType &ReductionOps) {
assert(Kind != RK_None && !!*this && LHS && RHS &&
"Expected reduction operation.");
assert(Kind != RK_None && !!*this && "Expected reduction operation.");
if (isMinMax()) {
ReductionOps[0].emplace_back(cast<SelectInst>(I)->getCondition());
ReductionOps[1].emplace_back(I);
Expand All @@ -6467,8 +6456,7 @@ class HorizontalReduction {

/// Checks if instruction is associative and can be vectorized.
bool isAssociative(Instruction *I) const {
assert(Kind != RK_None && *this && LHS && RHS &&
"Expected reduction operation.");
assert(Kind != RK_None && *this && "Expected reduction operation.");
switch (Kind) {
case RK_Arithmetic:
return I->isAssociative();
Expand All @@ -6493,15 +6481,13 @@ class HorizontalReduction {
/// Checks if two operation data are both a reduction op or both a reduced
/// value.
bool operator==(const OperationData &OD) const {
assert(((Kind != OD.Kind) || ((!LHS == !OD.LHS) && (!RHS == !OD.RHS))) &&
assert(((Kind != OD.Kind) || (Opcode != 0 && OD.Opcode != 0)) &&
"One of the comparing operations is incorrect.");
return this == &OD || (Kind == OD.Kind && Opcode == OD.Opcode);
return Kind == OD.Kind && Opcode == OD.Opcode;
}
bool operator!=(const OperationData &OD) const { return !(*this == OD); }
void clear() {
Opcode = 0;
LHS = nullptr;
RHS = nullptr;
Kind = RK_None;
}

Expand All @@ -6513,19 +6499,25 @@ class HorizontalReduction {

/// Get kind of reduction data.
ReductionKind getKind() const { return Kind; }
Value *getLHS() const { return LHS; }
Value *getRHS() const { return RHS; }
Type *getConditionType() const {
return isMinMax() ? CmpInst::makeCmpResultType(LHS->getType()) : nullptr;
Value *getLHS(Instruction *I) const {
if (Kind == RK_None)
return nullptr;
return I->getOperand(getFirstOperandIndex());
}
Value *getRHS(Instruction *I) const {
if (Kind == RK_None)
return nullptr;
return I->getOperand(getFirstOperandIndex() + 1);
}

/// Creates reduction operation with the current opcode with the IR flags
/// from \p ReductionOps.
Value *createOp(IRBuilder<> &Builder, const Twine &Name,
Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS,
const Twine &Name,
const ReductionOpsListType &ReductionOps) const {
assert(isVectorizable() &&
"Expected add|fadd or min/max reduction operation.");
auto *Op = createOp(Builder, Name);
auto *Op = createOp(Builder, LHS, RHS, Name);
switch (Kind) {
case RK_Arithmetic:
propagateIRFlags(Op, ReductionOps[0]);
Expand All @@ -6545,11 +6537,11 @@ class HorizontalReduction {
}
/// Creates reduction operation with the current opcode with the IR flags
/// from \p I.
Value *createOp(IRBuilder<> &Builder, const Twine &Name,
Instruction *I) const {
Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS,
const Twine &Name, Instruction *I) const {
assert(isVectorizable() &&
"Expected add|fadd or min/max reduction operation.");
auto *Op = createOp(Builder, Name);
auto *Op = createOp(Builder, LHS, RHS, Name);
switch (Kind) {
case RK_Arithmetic:
propagateIRFlags(Op, I);
Expand Down Expand Up @@ -6637,19 +6629,18 @@ class HorizontalReduction {
Value *LHS;
Value *RHS;
if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(I)) {
return OperationData(cast<BinaryOperator>(I)->getOpcode(), LHS, RHS,
RK_Arithmetic);
return OperationData(cast<BinaryOperator>(I)->getOpcode(), RK_Arithmetic);
}
if (auto *Select = dyn_cast<SelectInst>(I)) {
// Look for a min/max pattern.
if (m_UMin(m_Value(LHS), m_Value(RHS)).match(Select)) {
return OperationData(Instruction::ICmp, LHS, RHS, RK_UMin);
return OperationData(Instruction::ICmp, RK_UMin);
} else if (m_SMin(m_Value(LHS), m_Value(RHS)).match(Select)) {
return OperationData(Instruction::ICmp, LHS, RHS, RK_SMin);
return OperationData(Instruction::ICmp, RK_SMin);
} else if (m_UMax(m_Value(LHS), m_Value(RHS)).match(Select)) {
return OperationData(Instruction::ICmp, LHS, RHS, RK_UMax);
return OperationData(Instruction::ICmp, RK_UMax);
} else if (m_SMax(m_Value(LHS), m_Value(RHS)).match(Select)) {
return OperationData(Instruction::ICmp, LHS, RHS, RK_SMax);
return OperationData(Instruction::ICmp, RK_SMax);
} else {
// Try harder: look for min/max pattern based on instructions producing
// same values such as: select ((cmp Inst1, Inst2), Inst1, Inst2).
Expand Down Expand Up @@ -6693,19 +6684,19 @@ class HorizontalReduction {

case CmpInst::ICMP_ULT:
case CmpInst::ICMP_ULE:
return OperationData(Instruction::ICmp, LHS, RHS, RK_UMin);
return OperationData(Instruction::ICmp, RK_UMin);

case CmpInst::ICMP_SLT:
case CmpInst::ICMP_SLE:
return OperationData(Instruction::ICmp, LHS, RHS, RK_SMin);
return OperationData(Instruction::ICmp, RK_SMin);

case CmpInst::ICMP_UGT:
case CmpInst::ICMP_UGE:
return OperationData(Instruction::ICmp, LHS, RHS, RK_UMax);
return OperationData(Instruction::ICmp, RK_UMax);

case CmpInst::ICMP_SGT:
case CmpInst::ICMP_SGE:
return OperationData(Instruction::ICmp, LHS, RHS, RK_SMax);
return OperationData(Instruction::ICmp, RK_SMax);
}
}
}
Expand All @@ -6726,13 +6717,13 @@ class HorizontalReduction {
// r *= v1 + v2 + v3 + v4
// In such a case start looking for a tree rooted in the first '+'.
if (Phi) {
if (ReductionData.getLHS() == Phi) {
if (ReductionData.getLHS(B) == Phi) {
Phi = nullptr;
B = dyn_cast<Instruction>(ReductionData.getRHS());
B = dyn_cast<Instruction>(ReductionData.getRHS(B));
ReductionData = getOperationData(B);
} else if (ReductionData.getRHS() == Phi) {
} else if (ReductionData.getRHS(B) == Phi) {
Phi = nullptr;
B = dyn_cast<Instruction>(ReductionData.getLHS());
B = dyn_cast<Instruction>(ReductionData.getLHS(B));
ReductionData = getOperationData(B);
}
}
Expand Down Expand Up @@ -6984,11 +6975,8 @@ class HorizontalReduction {
} else {
// Update the final value in the reduction.
Builder.SetCurrentDebugLocation(Loc);
OperationData VectReductionData(ReductionData.getOpcode(),
VectorizedTree, ReducedSubTree,
ReductionData.getKind());
VectorizedTree =
VectReductionData.createOp(Builder, "op.rdx", ReductionOps);
VectorizedTree = ReductionData.createOp(
Builder, VectorizedTree, ReducedSubTree, "op.rdx", ReductionOps);
}
i += ReduxWidth;
ReduxWidth = PowerOf2Floor(NumReducedVals - i);
Expand All @@ -6999,19 +6987,15 @@ class HorizontalReduction {
for (; i < NumReducedVals; ++i) {
auto *I = cast<Instruction>(ReducedVals[i]);
Builder.SetCurrentDebugLocation(I->getDebugLoc());
OperationData VectReductionData(ReductionData.getOpcode(),
VectorizedTree, I,
ReductionData.getKind());
VectorizedTree = VectReductionData.createOp(Builder, "", ReductionOps);
VectorizedTree = ReductionData.createOp(Builder, VectorizedTree, I, "",
ReductionOps);
}
for (auto &Pair : ExternallyUsedValues) {
// Add each externally used value to the final reduction.
for (auto *I : Pair.second) {
Builder.SetCurrentDebugLocation(I->getDebugLoc());
OperationData VectReductionData(ReductionData.getOpcode(),
VectorizedTree, Pair.first,
ReductionData.getKind());
VectorizedTree = VectReductionData.createOp(Builder, "op.extra", I);
VectorizedTree = ReductionData.createOp(Builder, VectorizedTree,
Pair.first, "op.extra", I);
}
}

Expand Down Expand Up @@ -7133,9 +7117,8 @@ class HorizontalReduction {
Builder.CreateShuffleVector(TmpVec, LeftMask, "rdx.shuf.l");
Value *RightShuf =
Builder.CreateShuffleVector(TmpVec, RightMask, "rdx.shuf.r");
OperationData VectReductionData(ReductionData.getOpcode(), LeftShuf,
RightShuf, ReductionData.getKind());
TmpVec = VectReductionData.createOp(Builder, "op.rdx", ReductionOps);
TmpVec = ReductionData.createOp(Builder, LeftShuf, RightShuf, "op.rdx",
ReductionOps);
}

// The result is in the first element of the vector.
Expand Down

0 comments on commit 03f22b0

Please sign in to comment.