Skip to content

Commit d288574

Browse files
authored
[TTI][RISCV] Model cost of loading constants arms of selects and compares (llvm#109824)
This follows in the spirit of 7d82c99, and extends the costing API for compares and selects to provide information about the operands passed in an analogous manner. This allows us to model the cost of materializing the vector constant, as some select-of-constants are significantly more expensive than others when you account for the cost of materializing the constants involved. This is a stepping stone towards fixing llvm#109466. A separate SLP patch will be required to utilize the new API.
1 parent 3469db8 commit d288574

23 files changed

+208
-138
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

+14-7
Original file line numberDiff line numberDiff line change
@@ -1371,11 +1371,15 @@ class TargetTransformInfo {
13711371
/// is an existing instruction that holds Opcode, it may be passed in the
13721372
/// 'I' parameter. The \p VecPred parameter can be used to indicate the select
13731373
/// is using a compare with the specified predicate as condition. When vector
1374-
/// types are passed, \p VecPred must be used for all lanes.
1374+
/// types are passed, \p VecPred must be used for all lanes. For a
1375+
/// comparison, the two operands are the natural values. For a select, the
1376+
/// two operands are the *value* operands, not the condition operand.
13751377
InstructionCost
13761378
getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
13771379
CmpInst::Predicate VecPred,
13781380
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
1381+
OperandValueInfo Op1Info = {OK_AnyValue, OP_None},
1382+
OperandValueInfo Op2Info = {OK_AnyValue, OP_None},
13791383
const Instruction *I = nullptr) const;
13801384

13811385
/// \return The expected cost of vector Insert and Extract.
@@ -2049,11 +2053,11 @@ class TargetTransformInfo::Concept {
20492053
virtual InstructionCost getCFInstrCost(unsigned Opcode,
20502054
TTI::TargetCostKind CostKind,
20512055
const Instruction *I = nullptr) = 0;
2052-
virtual InstructionCost getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
2053-
Type *CondTy,
2054-
CmpInst::Predicate VecPred,
2055-
TTI::TargetCostKind CostKind,
2056-
const Instruction *I) = 0;
2056+
virtual InstructionCost
2057+
getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
2058+
CmpInst::Predicate VecPred, TTI::TargetCostKind CostKind,
2059+
OperandValueInfo Op1Info, OperandValueInfo Op2Info,
2060+
const Instruction *I) = 0;
20572061
virtual InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
20582062
TTI::TargetCostKind CostKind,
20592063
unsigned Index, Value *Op0,
@@ -2710,8 +2714,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
27102714
InstructionCost getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
27112715
CmpInst::Predicate VecPred,
27122716
TTI::TargetCostKind CostKind,
2717+
OperandValueInfo Op1Info,
2718+
OperandValueInfo Op2Info,
27132719
const Instruction *I) override {
2714-
return Impl.getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
2720+
return Impl.getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
2721+
Op1Info, Op2Info, I);
27152722
}
27162723
InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
27172724
TTI::TargetCostKind CostKind,

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

+8-2
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,8 @@ class TargetTransformInfoImplBase {
666666
InstructionCost getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
667667
CmpInst::Predicate VecPred,
668668
TTI::TargetCostKind CostKind,
669+
TTI::OperandValueInfo Op1Info,
670+
TTI::OperandValueInfo Op2Info,
669671
const Instruction *I) const {
670672
return 1;
671673
}
@@ -1332,19 +1334,23 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
13321334
match(U, m_LogicalOr()) ? Instruction::Or : Instruction::And, Ty,
13331335
CostKind, Op1Info, Op2Info, Operands, I);
13341336
}
1337+
const auto Op1Info = TTI::getOperandInfo(Operands[1]);
1338+
const auto Op2Info = TTI::getOperandInfo(Operands[2]);
13351339
Type *CondTy = Operands[0]->getType();
13361340
return TargetTTI->getCmpSelInstrCost(Opcode, U->getType(), CondTy,
13371341
CmpInst::BAD_ICMP_PREDICATE,
1338-
CostKind, I);
1342+
CostKind, Op1Info, Op2Info, I);
13391343
}
13401344
case Instruction::ICmp:
13411345
case Instruction::FCmp: {
1346+
const auto Op1Info = TTI::getOperandInfo(Operands[0]);
1347+
const auto Op2Info = TTI::getOperandInfo(Operands[1]);
13421348
Type *ValTy = Operands[0]->getType();
13431349
// TODO: Also handle ICmp/FCmp constant expressions.
13441350
return TargetTTI->getCmpSelInstrCost(Opcode, ValTy, U->getType(),
13451351
I ? cast<CmpInst>(I)->getPredicate()
13461352
: CmpInst::BAD_ICMP_PREDICATE,
1347-
CostKind, I);
1353+
CostKind, Op1Info, Op2Info, I);
13481354
}
13491355
case Instruction::InsertElement: {
13501356
auto *IE = dyn_cast<InsertElementInst>(U);

llvm/include/llvm/CodeGen/BasicTTIImpl.h

+10-7
Original file line numberDiff line numberDiff line change
@@ -1222,18 +1222,20 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
12221222
return BaseT::getCFInstrCost(Opcode, CostKind, I);
12231223
}
12241224

1225-
InstructionCost getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
1226-
CmpInst::Predicate VecPred,
1227-
TTI::TargetCostKind CostKind,
1228-
const Instruction *I = nullptr) {
1225+
InstructionCost getCmpSelInstrCost(
1226+
unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
1227+
TTI::TargetCostKind CostKind,
1228+
TTI::OperandValueInfo Op1Info = {TTI::OK_AnyValue, TTI::OP_None},
1229+
TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
1230+
const Instruction *I = nullptr) {
12291231
const TargetLoweringBase *TLI = getTLI();
12301232
int ISD = TLI->InstructionOpcodeToISD(Opcode);
12311233
assert(ISD && "Invalid opcode");
12321234

12331235
// TODO: Handle other cost kinds.
12341236
if (CostKind != TTI::TCK_RecipThroughput)
12351237
return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
1236-
I);
1238+
Op1Info, Op2Info, I);
12371239

12381240
// Selects on vectors are actually vector selects.
12391241
if (ISD == ISD::SELECT) {
@@ -1260,8 +1262,9 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
12601262
unsigned Num = cast<FixedVectorType>(ValVTy)->getNumElements();
12611263
if (CondTy)
12621264
CondTy = CondTy->getScalarType();
1263-
InstructionCost Cost = thisT()->getCmpSelInstrCost(
1264-
Opcode, ValVTy->getScalarType(), CondTy, VecPred, CostKind, I);
1265+
InstructionCost Cost =
1266+
thisT()->getCmpSelInstrCost(Opcode, ValVTy->getScalarType(), CondTy,
1267+
VecPred, CostKind, Op1Info, Op2Info, I);
12651268

12661269
// Return the cost of multiple scalar invocation plus the cost of
12671270
// inserting and extracting the values.

llvm/lib/Analysis/TargetTransformInfo.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -1015,11 +1015,12 @@ InstructionCost TargetTransformInfo::getCFInstrCost(
10151015

10161016
InstructionCost TargetTransformInfo::getCmpSelInstrCost(
10171017
unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
1018-
TTI::TargetCostKind CostKind, const Instruction *I) const {
1018+
TTI::TargetCostKind CostKind, OperandValueInfo Op1Info,
1019+
OperandValueInfo Op2Info, const Instruction *I) const {
10191020
assert((I == nullptr || I->getOpcode() == Opcode) &&
10201021
"Opcode should reflect passed instruction.");
1021-
InstructionCost Cost =
1022-
TTIImpl->getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
1022+
InstructionCost Cost = TTIImpl->getCmpSelInstrCost(
1023+
Opcode, ValTy, CondTy, VecPred, CostKind, Op1Info, Op2Info, I);
10231024
assert(Cost >= 0 && "TTI should not produce negative costs!");
10241025
return Cost;
10251026
}

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -3440,15 +3440,14 @@ InstructionCost AArch64TTIImpl::getAddressComputationCost(Type *Ty,
34403440
return 1;
34413441
}
34423442

3443-
InstructionCost AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
3444-
Type *CondTy,
3445-
CmpInst::Predicate VecPred,
3446-
TTI::TargetCostKind CostKind,
3447-
const Instruction *I) {
3443+
InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
3444+
unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
3445+
TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info,
3446+
TTI::OperandValueInfo Op2Info, const Instruction *I) {
34483447
// TODO: Handle other cost kinds.
34493448
if (CostKind != TTI::TCK_RecipThroughput)
34503449
return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
3451-
I);
3450+
Op1Info, Op2Info, I);
34523451

34533452
int ISD = TLI->InstructionOpcodeToISD(Opcode);
34543453
// We don't lower some vector selects well that are wider than the register
@@ -3527,7 +3526,8 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
35273526

35283527
// The base case handles scalable vectors fine for now, since it treats the
35293528
// cost as 1 * legalization cost.
3530-
return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
3529+
return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
3530+
Op1Info, Op2Info, I);
35313531
}
35323532

35333533
AArch64TTIImpl::TTI::MemCmpExpansionOptions

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,12 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
208208
InstructionCost getAddressComputationCost(Type *Ty, ScalarEvolution *SE,
209209
const SCEV *Ptr);
210210

211-
InstructionCost getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
212-
CmpInst::Predicate VecPred,
213-
TTI::TargetCostKind CostKind,
214-
const Instruction *I = nullptr);
211+
InstructionCost getCmpSelInstrCost(
212+
unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
213+
TTI::TargetCostKind CostKind,
214+
TTI::OperandValueInfo Op1Info = {TTI::OK_AnyValue, TTI::OP_None},
215+
TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
216+
const Instruction *I = nullptr);
215217

216218
TTI::MemCmpExpansionOptions enableMemCmpExpansion(bool OptSize,
217219
bool IsZeroCmp) const;

llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp

+7-8
Original file line numberDiff line numberDiff line change
@@ -934,11 +934,10 @@ InstructionCost ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy,
934934
return BaseT::getVectorInstrCost(Opcode, ValTy, CostKind, Index, Op0, Op1);
935935
}
936936

937-
InstructionCost ARMTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
938-
Type *CondTy,
939-
CmpInst::Predicate VecPred,
940-
TTI::TargetCostKind CostKind,
941-
const Instruction *I) {
937+
InstructionCost ARMTTIImpl::getCmpSelInstrCost(
938+
unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
939+
TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info,
940+
TTI::OperandValueInfo Op2Info, const Instruction *I) {
942941
int ISD = TLI->InstructionOpcodeToISD(Opcode);
943942

944943
// Thumb scalar code size cost for select.
@@ -1052,7 +1051,7 @@ InstructionCost ARMTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
10521051
VecValTy->getNumElements() *
10531052
getCmpSelInstrCost(Opcode, ValTy->getScalarType(),
10541053
VecCondTy->getScalarType(), VecPred,
1055-
CostKind, I);
1054+
CostKind, Op1Info, Op2Info, I);
10561055
}
10571056

10581057
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
@@ -1077,8 +1076,8 @@ InstructionCost ARMTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
10771076
if (ST->hasMVEIntegerOps() && ValTy->isVectorTy())
10781077
BaseCost = ST->getMVEVectorCostFactor(CostKind);
10791078

1080-
return BaseCost *
1081-
BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
1079+
return BaseCost * BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred,
1080+
CostKind, Op1Info, Op2Info, I);
10821081
}
10831082

10841083
InstructionCost ARMTTIImpl::getAddressComputationCost(Type *Ty,

llvm/lib/Target/ARM/ARMTargetTransformInfo.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,12 @@ class ARMTTIImpl : public BasicTTIImplBase<ARMTTIImpl> {
239239
TTI::TargetCostKind CostKind,
240240
const Instruction *I = nullptr);
241241

242-
InstructionCost getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
243-
CmpInst::Predicate VecPred,
244-
TTI::TargetCostKind CostKind,
245-
const Instruction *I = nullptr);
242+
InstructionCost getCmpSelInstrCost(
243+
unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
244+
TTI::TargetCostKind CostKind,
245+
TTI::OperandValueInfo Op1Info = {TTI::OK_AnyValue, TTI::OP_None},
246+
TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
247+
const Instruction *I = nullptr);
246248

247249
using BaseT::getVectorInstrCost;
248250
InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,

llvm/lib/Target/BPF/BPFTargetTransformInfo.h

+7-5
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,17 @@ class BPFTTIImpl : public BasicTTIImplBase<BPFTTIImpl> {
4444
return TTI::TCC_Basic;
4545
}
4646

47-
InstructionCost getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
48-
CmpInst::Predicate VecPred,
49-
TTI::TargetCostKind CostKind,
50-
const llvm::Instruction *I = nullptr) {
47+
InstructionCost getCmpSelInstrCost(
48+
unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
49+
TTI::TargetCostKind CostKind,
50+
TTI::OperandValueInfo Op1Info = {TTI::OK_AnyValue, TTI::OP_None},
51+
TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
52+
const llvm::Instruction *I = nullptr) {
5153
if (Opcode == Instruction::Select)
5254
return SCEVCheapExpansionBudget.getValue();
5355

5456
return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
55-
I);
57+
Op1Info, Op2Info, I);
5658
}
5759

5860
InstructionCost getArithmeticInstrCost(

llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -255,19 +255,19 @@ InstructionCost HexagonTTIImpl::getInterleavedMemoryOpCost(
255255
CostKind);
256256
}
257257

258-
InstructionCost HexagonTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
259-
Type *CondTy,
260-
CmpInst::Predicate VecPred,
261-
TTI::TargetCostKind CostKind,
262-
const Instruction *I) {
258+
InstructionCost HexagonTTIImpl::getCmpSelInstrCost(
259+
unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
260+
TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info,
261+
TTI::OperandValueInfo Op2Info, const Instruction *I) {
263262
if (ValTy->isVectorTy() && CostKind == TTI::TCK_RecipThroughput) {
264263
if (!isHVXVectorType(ValTy) && ValTy->isFPOrFPVectorTy())
265264
return InstructionCost::getMax();
266265
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
267266
if (Opcode == Instruction::FCmp)
268267
return LT.first + FloatFactor * getTypeNumElements(ValTy);
269268
}
270-
return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
269+
return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
270+
Op1Info, Op2Info, I);
271271
}
272272

273273
InstructionCost HexagonTTIImpl::getArithmeticInstrCost(

llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,12 @@ class HexagonTTIImpl : public BasicTTIImplBase<HexagonTTIImpl> {
132132
unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
133133
Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
134134
bool UseMaskForCond = false, bool UseMaskForGaps = false);
135-
InstructionCost getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
136-
CmpInst::Predicate VecPred,
137-
TTI::TargetCostKind CostKind,
138-
const Instruction *I = nullptr);
135+
InstructionCost getCmpSelInstrCost(
136+
unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
137+
TTI::TargetCostKind CostKind,
138+
TTI::OperandValueInfo Op1Info = {TTI::OK_AnyValue, TTI::OP_None},
139+
TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
140+
const Instruction *I = nullptr);
139141
InstructionCost getArithmeticInstrCost(
140142
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
141143
TTI::OperandValueInfo Op1Info = {TTI::OK_AnyValue, TTI::OP_None},

llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp

+6-7
Original file line numberDiff line numberDiff line change
@@ -655,18 +655,17 @@ InstructionCost PPCTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
655655
return Cost;
656656
}
657657

658-
InstructionCost PPCTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
659-
Type *CondTy,
660-
CmpInst::Predicate VecPred,
661-
TTI::TargetCostKind CostKind,
662-
const Instruction *I) {
658+
InstructionCost PPCTTIImpl::getCmpSelInstrCost(
659+
unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
660+
TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info,
661+
TTI::OperandValueInfo Op2Info, const Instruction *I) {
663662
InstructionCost CostFactor =
664663
vectorCostAdjustmentFactor(Opcode, ValTy, nullptr);
665664
if (!CostFactor.isValid())
666665
return InstructionCost::getMax();
667666

668-
InstructionCost Cost =
669-
BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
667+
InstructionCost Cost = BaseT::getCmpSelInstrCost(
668+
Opcode, ValTy, CondTy, VecPred, CostKind, Op1Info, Op2Info, I);
670669
// TODO: Handle other cost kinds.
671670
if (CostKind != TTI::TCK_RecipThroughput)
672671
return Cost;

llvm/lib/Target/PowerPC/PPCTargetTransformInfo.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,12 @@ class PPCTTIImpl : public BasicTTIImplBase<PPCTTIImpl> {
118118
const Instruction *I = nullptr);
119119
InstructionCost getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind,
120120
const Instruction *I = nullptr);
121-
InstructionCost getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
122-
CmpInst::Predicate VecPred,
123-
TTI::TargetCostKind CostKind,
124-
const Instruction *I = nullptr);
121+
InstructionCost getCmpSelInstrCost(
122+
unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
123+
TTI::TargetCostKind CostKind,
124+
TTI::OperandValueInfo Op1Info = {TTI::OK_AnyValue, TTI::OP_None},
125+
TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
126+
const Instruction *I = nullptr);
125127
using BaseT::getVectorInstrCost;
126128
InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
127129
TTI::TargetCostKind CostKind,

0 commit comments

Comments
 (0)