-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[CostModel] Make sure getCmpSelInstrCost is passed a CondTy #135535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
It is already required along certain code paths that the CondTy is valid. Fix some of the uses to make sure it is passed.
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-vectorizers Author: David Green (davemgreen) ChangesIt is already required along certain code paths that the CondTy is valid. Fix some of the uses to make sure it is passed. Full diff: https://github.com/llvm/llvm-project/pull/135535.diff 4 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index eacf75c24695f..983fb16f255ec 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1384,11 +1384,9 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
return InstructionCost::getInvalid();
unsigned Num = cast<FixedVectorType>(ValVTy)->getNumElements();
- if (CondTy)
- CondTy = CondTy->getScalarType();
- InstructionCost Cost =
- thisT()->getCmpSelInstrCost(Opcode, ValVTy->getScalarType(), CondTy,
- VecPred, CostKind, Op1Info, Op2Info, I);
+ InstructionCost Cost = thisT()->getCmpSelInstrCost(
+ Opcode, ValVTy->getScalarType(), CondTy->getScalarType(), VecPred,
+ CostKind, Op1Info, Op2Info, I);
// Return the cost of multiple scalar invocation plus the cost of
// inserting and extracting the values.
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index eac7e7c209c95..6d9fd98cb20a8 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -3094,7 +3094,8 @@ static bool validateAndCostRequiredSelects(BasicBlock *BB, BasicBlock *ThenBB,
if (ThenV == OrigV)
continue;
- Cost += TTI.getCmpSelInstrCost(Instruction::Select, PN.getType(), nullptr,
+ Cost += TTI.getCmpSelInstrCost(Instruction::Select, PN.getType(),
+ CmpInst::makeCmpResultType(PN.getType()),
CmpInst::BAD_ICMP_PREDICATE, CostKind);
// Don't convert to selects if we could remove undefined behavior instead.
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 0acca63503afa..2b61d0c5441ed 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -6974,10 +6974,10 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I,
}
VectorTy = toVectorTy(ValTy, VF);
- return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy, nullptr,
- cast<CmpInst>(I)->getPredicate(), CostKind,
- {TTI::OK_AnyValue, TTI::OP_None},
- {TTI::OK_AnyValue, TTI::OP_None}, I);
+ return TTI.getCmpSelInstrCost(
+ I->getOpcode(), VectorTy, CmpInst::makeCmpResultType(VectorTy),
+ cast<CmpInst>(I)->getPredicate(), CostKind,
+ {TTI::OK_AnyValue, TTI::OP_None}, {TTI::OK_AnyValue, TTI::OP_None}, I);
}
case Instruction::Store:
case Instruction::Load: {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 2cff343d915cf..ebedea1d65a9a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -1753,10 +1753,10 @@ InstructionCost VPWidenRecipe::computeCost(ElementCount VF,
case Instruction::FCmp: {
Instruction *CtxI = dyn_cast_or_null<Instruction>(getUnderlyingValue());
Type *VectorTy = toVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF);
- return Ctx.TTI.getCmpSelInstrCost(Opcode, VectorTy, nullptr, getPredicate(),
- Ctx.CostKind,
- {TTI::OK_AnyValue, TTI::OP_None},
- {TTI::OK_AnyValue, TTI::OP_None}, CtxI);
+ return Ctx.TTI.getCmpSelInstrCost(
+ Opcode, VectorTy, CmpInst::makeCmpResultType(VectorTy), getPredicate(),
+ Ctx.CostKind, {TTI::OK_AnyValue, TTI::OP_None},
+ {TTI::OK_AnyValue, TTI::OP_None}, CtxI);
}
default:
llvm_unreachable("Unsupported opcode for instruction");
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - cheers
It is already required along certain code paths that the CondTy is valid. Fix some of the uses to make sure it is passed.