Skip to content

Commit b7cfa4a

Browse files
[LLVM][CostModel][AArch64] Remove magic numbers from f16 vector compares.
The PR also extends the code to cover bfloat vector compares that are also promoted to float. NOTE: There is a bail out for the compares that are scalarised that will be removed by llvm#135398.
1 parent 637f352 commit b7cfa4a

File tree

2 files changed

+36
-12
lines changed

2 files changed

+36
-12
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

+28-4
Original file line numberDiff line numberDiff line change
@@ -4236,10 +4236,34 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
42364236
}
42374237

42384238
if (isa<FixedVectorType>(ValTy) && ISD == ISD::SETCC) {
4239-
auto LT = getTypeLegalizationCost(ValTy);
4240-
// Cost v4f16 FCmp without FP16 support via converting to v4f32 and back.
4241-
if (LT.second == MVT::v4f16 && !ST->hasFullFP16())
4242-
return LT.first * 4; // fcvtl + fcvtl + fcmp + xtn
4239+
Type *ValScalarTy = ValTy->getScalarType();
4240+
if ((ValScalarTy->isHalfTy() && !ST->hasFullFP16()) ||
4241+
ValScalarTy->isBFloatTy()) {
4242+
auto *ValVTy = cast<FixedVectorType>(ValTy);
4243+
4244+
// FIXME: We currently scalarise these.
4245+
if (ValVTy->getNumElements() > 4)
4246+
return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred,
4247+
CostKind, Op1Info, Op2Info, I);
4248+
4249+
// Without dedicated instructions we promote [b]f16 compares to f32.
4250+
auto *PromotedTy =
4251+
VectorType::get(Type::getFloatTy(ValTy->getContext()), ValVTy);
4252+
4253+
InstructionCost Cost = 0;
4254+
// Promte operands to float vectors.
4255+
Cost += 2 * getCastInstrCost(Instruction::FPExt, PromotedTy, ValTy,
4256+
TTI::CastContextHint::None, CostKind);
4257+
// Compare float vectors.
4258+
Cost += getCmpSelInstrCost(Opcode, PromotedTy, CondTy, VecPred, CostKind,
4259+
Op1Info, Op2Info);
4260+
// During codegen we'll truncate the vector result from i32 to i16.
4261+
Cost +=
4262+
getCastInstrCost(Instruction::Trunc, VectorType::getInteger(ValVTy),
4263+
VectorType::getInteger(PromotedTy),
4264+
TTI::CastContextHint::None, CostKind);
4265+
return Cost;
4266+
}
42434267
}
42444268

42454269
// Treat the icmp in icmp(and, 0) as free, as we can make use of ands.

llvm/test/Analysis/CostModel/AArch64/vector-select.ll

+8-8
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ define <2 x double> @v2f64_select_ogt(<2 x double> %a, <2 x double> %b, <2 x dou
168168

169169
define <4 x bfloat> @v4bf16_select_ogt(<4 x bfloat> %a, <4 x bfloat> %b, <4 x bfloat> %c) {
170170
; COST-LABEL: 'v4bf16_select_ogt'
171-
; COST-NEXT: Cost Model: Found costs of 1 for: %cmp.1 = fcmp ogt <4 x bfloat> %a, %b
171+
; COST-NEXT: Cost Model: Found costs of RThru:4 CodeSize:1 Lat:1 SizeLat:1 for: %cmp.1 = fcmp ogt <4 x bfloat> %a, %b
172172
; COST-NEXT: Cost Model: Found costs of RThru:10 CodeSize:1 Lat:1 SizeLat:1 for: %s.1 = select <4 x i1> %cmp.1, <4 x bfloat> %a, <4 x bfloat> %c
173173
; COST-NEXT: Cost Model: Found costs of RThru:0 CodeSize:1 Lat:1 SizeLat:1 for: ret <4 x bfloat> %s.1
174174
;
@@ -255,7 +255,7 @@ define <2 x double> @v2f64_select_oge(<2 x double> %a, <2 x double> %b, <2 x dou
255255

256256
define <4 x bfloat> @v4bf16_select_oge(<4 x bfloat> %a, <4 x bfloat> %b, <4 x bfloat> %c) {
257257
; COST-LABEL: 'v4bf16_select_oge'
258-
; COST-NEXT: Cost Model: Found costs of 1 for: %cmp.1 = fcmp oge <4 x bfloat> %a, %b
258+
; COST-NEXT: Cost Model: Found costs of RThru:4 CodeSize:1 Lat:1 SizeLat:1 for: %cmp.1 = fcmp oge <4 x bfloat> %a, %b
259259
; COST-NEXT: Cost Model: Found costs of RThru:10 CodeSize:1 Lat:1 SizeLat:1 for: %s.1 = select <4 x i1> %cmp.1, <4 x bfloat> %a, <4 x bfloat> %c
260260
; COST-NEXT: Cost Model: Found costs of RThru:0 CodeSize:1 Lat:1 SizeLat:1 for: ret <4 x bfloat> %s.1
261261
;
@@ -342,7 +342,7 @@ define <2 x double> @v2f64_select_olt(<2 x double> %a, <2 x double> %b, <2 x dou
342342

343343
define <4 x bfloat> @v4bf16_select_olt(<4 x bfloat> %a, <4 x bfloat> %b, <4 x bfloat> %c) {
344344
; COST-LABEL: 'v4bf16_select_olt'
345-
; COST-NEXT: Cost Model: Found costs of 1 for: %cmp.1 = fcmp olt <4 x bfloat> %a, %b
345+
; COST-NEXT: Cost Model: Found costs of RThru:4 CodeSize:1 Lat:1 SizeLat:1 for: %cmp.1 = fcmp olt <4 x bfloat> %a, %b
346346
; COST-NEXT: Cost Model: Found costs of RThru:10 CodeSize:1 Lat:1 SizeLat:1 for: %s.1 = select <4 x i1> %cmp.1, <4 x bfloat> %a, <4 x bfloat> %c
347347
; COST-NEXT: Cost Model: Found costs of RThru:0 CodeSize:1 Lat:1 SizeLat:1 for: ret <4 x bfloat> %s.1
348348
;
@@ -429,7 +429,7 @@ define <2 x double> @v2f64_select_ole(<2 x double> %a, <2 x double> %b, <2 x dou
429429

430430
define <4 x bfloat> @v4bf16_select_ole(<4 x bfloat> %a, <4 x bfloat> %b, <4 x bfloat> %c) {
431431
; COST-LABEL: 'v4bf16_select_ole'
432-
; COST-NEXT: Cost Model: Found costs of 1 for: %cmp.1 = fcmp ole <4 x bfloat> %a, %b
432+
; COST-NEXT: Cost Model: Found costs of RThru:4 CodeSize:1 Lat:1 SizeLat:1 for: %cmp.1 = fcmp ole <4 x bfloat> %a, %b
433433
; COST-NEXT: Cost Model: Found costs of RThru:10 CodeSize:1 Lat:1 SizeLat:1 for: %s.1 = select <4 x i1> %cmp.1, <4 x bfloat> %a, <4 x bfloat> %c
434434
; COST-NEXT: Cost Model: Found costs of RThru:0 CodeSize:1 Lat:1 SizeLat:1 for: ret <4 x bfloat> %s.1
435435
;
@@ -516,7 +516,7 @@ define <2 x double> @v2f64_select_oeq(<2 x double> %a, <2 x double> %b, <2 x dou
516516

517517
define <4 x bfloat> @v4bf16_select_oeq(<4 x bfloat> %a, <4 x bfloat> %b, <4 x bfloat> %c) {
518518
; COST-LABEL: 'v4bf16_select_oeq'
519-
; COST-NEXT: Cost Model: Found costs of 1 for: %cmp.1 = fcmp oeq <4 x bfloat> %a, %b
519+
; COST-NEXT: Cost Model: Found costs of RThru:4 CodeSize:1 Lat:1 SizeLat:1 for: %cmp.1 = fcmp oeq <4 x bfloat> %a, %b
520520
; COST-NEXT: Cost Model: Found costs of RThru:10 CodeSize:1 Lat:1 SizeLat:1 for: %s.1 = select <4 x i1> %cmp.1, <4 x bfloat> %a, <4 x bfloat> %c
521521
; COST-NEXT: Cost Model: Found costs of RThru:0 CodeSize:1 Lat:1 SizeLat:1 for: ret <4 x bfloat> %s.1
522522
;
@@ -603,7 +603,7 @@ define <2 x double> @v2f64_select_one(<2 x double> %a, <2 x double> %b, <2 x dou
603603

604604
define <4 x bfloat> @v4bf16_select_one(<4 x bfloat> %a, <4 x bfloat> %b, <4 x bfloat> %c) {
605605
; COST-LABEL: 'v4bf16_select_one'
606-
; COST-NEXT: Cost Model: Found costs of 1 for: %cmp.1 = fcmp one <4 x bfloat> %a, %b
606+
; COST-NEXT: Cost Model: Found costs of RThru:4 CodeSize:1 Lat:1 SizeLat:1 for: %cmp.1 = fcmp one <4 x bfloat> %a, %b
607607
; COST-NEXT: Cost Model: Found costs of RThru:10 CodeSize:1 Lat:1 SizeLat:1 for: %s.1 = select <4 x i1> %cmp.1, <4 x bfloat> %a, <4 x bfloat> %c
608608
; COST-NEXT: Cost Model: Found costs of RThru:0 CodeSize:1 Lat:1 SizeLat:1 for: ret <4 x bfloat> %s.1
609609
;
@@ -690,7 +690,7 @@ define <2 x double> @v2f64_select_une(<2 x double> %a, <2 x double> %b, <2 x dou
690690

691691
define <4 x bfloat> @v4bf16_select_une(<4 x bfloat> %a, <4 x bfloat> %b, <4 x bfloat> %c) {
692692
; COST-LABEL: 'v4bf16_select_une'
693-
; COST-NEXT: Cost Model: Found costs of 1 for: %cmp.1 = fcmp une <4 x bfloat> %a, %b
693+
; COST-NEXT: Cost Model: Found costs of RThru:4 CodeSize:1 Lat:1 SizeLat:1 for: %cmp.1 = fcmp une <4 x bfloat> %a, %b
694694
; COST-NEXT: Cost Model: Found costs of RThru:10 CodeSize:1 Lat:1 SizeLat:1 for: %s.1 = select <4 x i1> %cmp.1, <4 x bfloat> %a, <4 x bfloat> %c
695695
; COST-NEXT: Cost Model: Found costs of RThru:0 CodeSize:1 Lat:1 SizeLat:1 for: ret <4 x bfloat> %s.1
696696
;
@@ -777,7 +777,7 @@ define <2 x double> @v2f64_select_ord(<2 x double> %a, <2 x double> %b, <2 x dou
777777

778778
define <4 x bfloat> @v4bf16_select_ord(<4 x bfloat> %a, <4 x bfloat> %b, <4 x bfloat> %c) {
779779
; COST-LABEL: 'v4bf16_select_ord'
780-
; COST-NEXT: Cost Model: Found costs of 1 for: %cmp.1 = fcmp ord <4 x bfloat> %a, %b
780+
; COST-NEXT: Cost Model: Found costs of RThru:4 CodeSize:1 Lat:1 SizeLat:1 for: %cmp.1 = fcmp ord <4 x bfloat> %a, %b
781781
; COST-NEXT: Cost Model: Found costs of RThru:10 CodeSize:1 Lat:1 SizeLat:1 for: %s.1 = select <4 x i1> %cmp.1, <4 x bfloat> %a, <4 x bfloat> %c
782782
; COST-NEXT: Cost Model: Found costs of RThru:0 CodeSize:1 Lat:1 SizeLat:1 for: ret <4 x bfloat> %s.1
783783
;

0 commit comments

Comments
 (0)