Skip to content

Commit 35be64a

Browse files
authored
[VPlan] Factor out logic to common compute costs to helper (NFCI). (llvm#153361)
A number of recipes compute costs for the same opcodes for scalars or vectors, depending on the recipe. Move the common logic out to a helper in VPRecipeWithIRFlags, that is then used by VPReplicateRecipe, VPWidenRecipe and VPInstruction. This makes it easier to cover all relevant opcodes, without duplication. PR: llvm#153361
1 parent f1458ec commit 35be64a

File tree

2 files changed

+90
-67
lines changed

2 files changed

+90
-67
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,11 @@ struct VPRecipeWithIRFlags : public VPSingleDefRecipe, public VPIRFlags {
898898
}
899899

900900
void execute(VPTransformState &State) override = 0;
901+
902+
/// Compute the cost for this recipe for \p VF, using \p Opcode and \p Ctx.
903+
std::optional<InstructionCost>
904+
getCostForRecipeWithOpcode(unsigned Opcode, ElementCount VF,
905+
VPCostContext &Ctx) const;
901906
};
902907

903908
/// Helper to access the operand that contains the unroll part for this recipe

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 85 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -942,28 +942,90 @@ Value *VPInstruction::generate(VPTransformState &State) {
942942
}
943943
}
944944

945+
std::optional<InstructionCost> VPRecipeWithIRFlags::getCostForRecipeWithOpcode(
946+
unsigned Opcode, ElementCount VF, VPCostContext &Ctx) const {
947+
Type *ScalarTy = Ctx.Types.inferScalarType(this);
948+
Type *ResultTy = VF.isVector() ? toVectorTy(ScalarTy, VF) : ScalarTy;
949+
switch (Opcode) {
950+
case Instruction::FNeg:
951+
return Ctx.TTI.getArithmeticInstrCost(Opcode, ResultTy, Ctx.CostKind);
952+
case Instruction::UDiv:
953+
case Instruction::SDiv:
954+
case Instruction::SRem:
955+
case Instruction::URem:
956+
case Instruction::Add:
957+
case Instruction::FAdd:
958+
case Instruction::Sub:
959+
case Instruction::FSub:
960+
case Instruction::Mul:
961+
case Instruction::FMul:
962+
case Instruction::FDiv:
963+
case Instruction::FRem:
964+
case Instruction::Shl:
965+
case Instruction::LShr:
966+
case Instruction::AShr:
967+
case Instruction::And:
968+
case Instruction::Or:
969+
case Instruction::Xor: {
970+
TargetTransformInfo::OperandValueInfo RHSInfo = {
971+
TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None};
972+
973+
if (VF.isVector()) {
974+
// Certain instructions can be cheaper to vectorize if they have a
975+
// constant second vector operand. One example of this are shifts on x86.
976+
VPValue *RHS = getOperand(1);
977+
RHSInfo = Ctx.getOperandInfo(RHS);
978+
979+
if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
980+
getOperand(1)->isDefinedOutsideLoopRegions())
981+
RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
982+
}
983+
984+
Instruction *CtxI = dyn_cast_or_null<Instruction>(getUnderlyingValue());
985+
SmallVector<const Value *, 4> Operands;
986+
if (CtxI)
987+
Operands.append(CtxI->value_op_begin(), CtxI->value_op_end());
988+
return Ctx.TTI.getArithmeticInstrCost(
989+
Opcode, ResultTy, Ctx.CostKind,
990+
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
991+
RHSInfo, Operands, CtxI, &Ctx.TLI);
992+
}
993+
case Instruction::Freeze:
994+
// This opcode is unknown. Assume that it is the same as 'mul'.
995+
return Ctx.TTI.getArithmeticInstrCost(Instruction::Mul, ResultTy,
996+
Ctx.CostKind);
997+
case Instruction::ExtractValue:
998+
return Ctx.TTI.getInsertExtractValueCost(Instruction::ExtractValue,
999+
Ctx.CostKind);
1000+
case Instruction::ICmp:
1001+
case Instruction::FCmp: {
1002+
Type *ScalarOpTy = Ctx.Types.inferScalarType(getOperand(0));
1003+
Type *OpTy = VF.isVector() ? toVectorTy(ScalarOpTy, VF) : ScalarOpTy;
1004+
Instruction *CtxI = dyn_cast_or_null<Instruction>(getUnderlyingValue());
1005+
return Ctx.TTI.getCmpSelInstrCost(
1006+
Opcode, OpTy, CmpInst::makeCmpResultType(OpTy), getPredicate(),
1007+
Ctx.CostKind, {TTI::OK_AnyValue, TTI::OP_None},
1008+
{TTI::OK_AnyValue, TTI::OP_None}, CtxI);
1009+
}
1010+
}
1011+
return std::nullopt;
1012+
}
1013+
9451014
InstructionCost VPInstruction::computeCost(ElementCount VF,
9461015
VPCostContext &Ctx) const {
9471016
if (Instruction::isBinaryOp(getOpcode())) {
948-
Type *ResTy = Ctx.Types.inferScalarType(this);
949-
if (!vputils::onlyFirstLaneUsed(this))
950-
ResTy = toVectorTy(ResTy, VF);
951-
952-
if (!getUnderlyingValue()) {
953-
switch (getOpcode()) {
954-
case Instruction::FMul:
955-
return Ctx.TTI.getArithmeticInstrCost(getOpcode(), ResTy, Ctx.CostKind);
956-
default:
957-
// TODO: Compute cost for VPInstructions without underlying values once
958-
// the legacy cost model has been retired.
959-
return 0;
960-
}
1017+
if (!getUnderlyingValue() && getOpcode() != Instruction::FMul) {
1018+
// TODO: Compute cost for VPInstructions without underlying values once
1019+
// the legacy cost model has been retired.
1020+
return 0;
9611021
}
9621022

9631023
assert(!doesGeneratePerAllLanes() &&
9641024
"Should only generate a vector value or single scalar, not scalars "
9651025
"for all lanes.");
966-
return Ctx.TTI.getArithmeticInstrCost(getOpcode(), ResTy, Ctx.CostKind);
1026+
return *getCostForRecipeWithOpcode(
1027+
getOpcode(),
1028+
vputils::onlyFirstLaneUsed(this) ? ElementCount::getFixed(1) : VF, Ctx);
9671029
}
9681030

9691031
switch (getOpcode()) {
@@ -2033,20 +2095,13 @@ void VPWidenRecipe::execute(VPTransformState &State) {
20332095
InstructionCost VPWidenRecipe::computeCost(ElementCount VF,
20342096
VPCostContext &Ctx) const {
20352097
switch (Opcode) {
2036-
case Instruction::FNeg: {
2037-
Type *VectorTy = toVectorTy(Ctx.Types.inferScalarType(this), VF);
2038-
return Ctx.TTI.getArithmeticInstrCost(
2039-
Opcode, VectorTy, Ctx.CostKind,
2040-
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
2041-
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None});
2042-
}
2043-
20442098
case Instruction::UDiv:
20452099
case Instruction::SDiv:
20462100
case Instruction::SRem:
20472101
case Instruction::URem:
20482102
// More complex computation, let the legacy cost-model handle this for now.
20492103
return Ctx.getLegacyCost(cast<Instruction>(getUnderlyingValue()), VF);
2104+
case Instruction::FNeg:
20502105
case Instruction::Add:
20512106
case Instruction::FAdd:
20522107
case Instruction::Sub:
@@ -2060,45 +2115,12 @@ InstructionCost VPWidenRecipe::computeCost(ElementCount VF,
20602115
case Instruction::AShr:
20612116
case Instruction::And:
20622117
case Instruction::Or:
2063-
case Instruction::Xor: {
2064-
VPValue *RHS = getOperand(1);
2065-
// Certain instructions can be cheaper to vectorize if they have a constant
2066-
// second vector operand. One example of this are shifts on x86.
2067-
TargetTransformInfo::OperandValueInfo RHSInfo = Ctx.getOperandInfo(RHS);
2068-
2069-
if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
2070-
getOperand(1)->isDefinedOutsideLoopRegions())
2071-
RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
2072-
Type *VectorTy = toVectorTy(Ctx.Types.inferScalarType(this), VF);
2073-
Instruction *CtxI = dyn_cast_or_null<Instruction>(getUnderlyingValue());
2074-
2075-
SmallVector<const Value *, 4> Operands;
2076-
if (CtxI)
2077-
Operands.append(CtxI->value_op_begin(), CtxI->value_op_end());
2078-
return Ctx.TTI.getArithmeticInstrCost(
2079-
Opcode, VectorTy, Ctx.CostKind,
2080-
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
2081-
RHSInfo, Operands, CtxI, &Ctx.TLI);
2082-
}
2083-
case Instruction::Freeze: {
2084-
// This opcode is unknown. Assume that it is the same as 'mul'.
2085-
Type *VectorTy = toVectorTy(Ctx.Types.inferScalarType(this), VF);
2086-
return Ctx.TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy,
2087-
Ctx.CostKind);
2088-
}
2089-
case Instruction::ExtractValue: {
2090-
return Ctx.TTI.getInsertExtractValueCost(Instruction::ExtractValue,
2091-
Ctx.CostKind);
2092-
}
2118+
case Instruction::Xor:
2119+
case Instruction::Freeze:
2120+
case Instruction::ExtractValue:
20932121
case Instruction::ICmp:
2094-
case Instruction::FCmp: {
2095-
Instruction *CtxI = dyn_cast_or_null<Instruction>(getUnderlyingValue());
2096-
Type *VectorTy = toVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF);
2097-
return Ctx.TTI.getCmpSelInstrCost(
2098-
Opcode, VectorTy, CmpInst::makeCmpResultType(VectorTy), getPredicate(),
2099-
Ctx.CostKind, {TTI::OK_AnyValue, TTI::OP_None},
2100-
{TTI::OK_AnyValue, TTI::OP_None}, CtxI);
2101-
}
2122+
case Instruction::FCmp:
2123+
return *getCostForRecipeWithOpcode(getOpcode(), VF, Ctx);
21022124
default:
21032125
llvm_unreachable("Unsupported opcode for instruction");
21042126
}
@@ -2972,7 +2994,6 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
29722994
// transform, avoid computing their cost multiple times for now.
29732995
Ctx.SkipCostComputation.insert(UI);
29742996

2975-
Type *ResultTy = Ctx.Types.inferScalarType(this);
29762997
switch (UI->getOpcode()) {
29772998
case Instruction::GetElementPtr:
29782999
// We mark this instruction as zero-cost because the cost of GEPs in
@@ -2996,6 +3017,7 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
29963017
SmallVector<Type *, 4> Tys;
29973018
for (VPValue *ArgOp : drop_end(operands()))
29983019
Tys.push_back(Ctx.Types.inferScalarType(ArgOp));
3020+
Type *ResultTy = Ctx.Types.inferScalarType(this);
29993021
return Ctx.TTI.getCallInstrCost(CalledFn, ResultTy, Tys, Ctx.CostKind);
30003022
}
30013023
case Instruction::Add:
@@ -3012,12 +3034,8 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
30123034
case Instruction::And:
30133035
case Instruction::Or:
30143036
case Instruction::Xor: {
3015-
auto Op2Info = Ctx.getOperandInfo(getOperand(1));
3016-
SmallVector<const Value *, 4> Operands(UI->operand_values());
3017-
return Ctx.TTI.getArithmeticInstrCost(
3018-
UI->getOpcode(), ResultTy, Ctx.CostKind,
3019-
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
3020-
Op2Info, Operands, UI, &Ctx.TLI) *
3037+
return *getCostForRecipeWithOpcode(getOpcode(), ElementCount::getFixed(1),
3038+
Ctx) *
30213039
(isSingleScalar() ? 1 : VF.getFixedValue());
30223040
}
30233041
}

0 commit comments

Comments
 (0)