Skip to content

Commit 053451c

Browse files
committed
[RISCV] Handle scalarized reductions in getArithmeticReductionCost
This fixes a crash reported at llvm#114250 (comment) If the vector type isn't legal at all, e.g. bfloat with +zvfbfmin, then the legalized type will be scalarized. So use getScalarType() instead of getVectorElement() when checking for f16/bf16.
1 parent 4853bf0 commit 053451c

File tree

2 files changed

+137
-35
lines changed

2 files changed

+137
-35
lines changed

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -1874,9 +1874,8 @@ RISCVTTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
18741874
break;
18751875
case ISD::FADD:
18761876
// We can't promote f16/bf16 fadd reductions.
1877-
if ((LT.second.getVectorElementType() == MVT::f16 &&
1878-
!ST->hasVInstructionsF16()) ||
1879-
LT.second.getVectorElementType() == MVT::bf16)
1877+
if ((LT.second.getScalarType() == MVT::f16 && !ST->hasVInstructionsF16()) ||
1878+
LT.second.getScalarType() == MVT::bf16)
18801879
return BaseT::getArithmeticReductionCost(Opcode, Ty, FMF, CostKind);
18811880
if (TTI::requiresOrderedReduction(FMF)) {
18821881
Opcodes.push_back(RISCV::VFMV_S_F);

0 commit comments

Comments
 (0)