Skip to content

Commit eb7addd

Browse files
committed
[AArch64] Spare N2I roundtrip when splatting float comparison
Transform `select_cc t1, t2, -1, 0` for floats into a vector comparison which generates a mask, which is later on combined with potential vectorized DUPs.
1 parent fd452da commit eb7addd

File tree

4 files changed

+543
-54
lines changed

4 files changed

+543
-54
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 160 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -11002,9 +11002,126 @@ SDValue AArch64TargetLowering::LowerSETCCCARRY(SDValue Op,
1100211002
Cmp.getValue(1));
1100311003
}
1100411004

11005+
/// Emit vector comparison for floating-point values, producing a mask.
11006+
static SDValue emitVectorComparison(SDValue LHS, SDValue RHS,
11007+
AArch64CC::CondCode CC, bool NoNans, EVT VT,
11008+
const SDLoc &DL, SelectionDAG &DAG) {
11009+
EVT SrcVT = LHS.getValueType();
11010+
assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
11011+
"function only supposed to emit natural comparisons");
11012+
11013+
switch (CC) {
11014+
default:
11015+
return SDValue();
11016+
case AArch64CC::NE: {
11017+
SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, DL, VT, LHS, RHS);
11018+
// Use vector semantics for the inversion to potentially save a copy between
11019+
// SIMD and regular registers.
11020+
if (!LHS.getValueType().isVector()) {
11021+
EVT VecVT =
11022+
EVT::getVectorVT(*DAG.getContext(), VT, 128 / VT.getSizeInBits());
11023+
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
11024+
SDValue MaskVec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT,
11025+
DAG.getUNDEF(VecVT), Fcmeq, Zero);
11026+
SDValue InvertedMask = DAG.getNOT(DL, MaskVec, VecVT);
11027+
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, InvertedMask, Zero);
11028+
}
11029+
return DAG.getNOT(DL, Fcmeq, VT);
11030+
}
11031+
case AArch64CC::EQ:
11032+
return DAG.getNode(AArch64ISD::FCMEQ, DL, VT, LHS, RHS);
11033+
case AArch64CC::GE:
11034+
return DAG.getNode(AArch64ISD::FCMGE, DL, VT, LHS, RHS);
11035+
case AArch64CC::GT:
11036+
return DAG.getNode(AArch64ISD::FCMGT, DL, VT, LHS, RHS);
11037+
case AArch64CC::LE:
11038+
if (!NoNans)
11039+
return SDValue();
11040+
// If we ignore NaNs then we can use to the LS implementation.
11041+
[[fallthrough]];
11042+
case AArch64CC::LS:
11043+
return DAG.getNode(AArch64ISD::FCMGE, DL, VT, RHS, LHS);
11044+
case AArch64CC::LT:
11045+
if (!NoNans)
11046+
return SDValue();
11047+
// If we ignore NaNs then we can use to the MI implementation.
11048+
[[fallthrough]];
11049+
case AArch64CC::MI:
11050+
return DAG.getNode(AArch64ISD::FCMGT, DL, VT, RHS, LHS);
11051+
}
11052+
}
11053+
11054+
/// For SELECT_CC, when the true/false values are (-1, 0) and the compared
11055+
/// values are scalars, try to emit a mask generating vector instruction.
11056+
static SDValue emitFloatCompareMask(SDValue LHS, SDValue RHS, SDValue TVal,
11057+
SDValue FVal, ISD::CondCode CC, bool NoNaNs,
11058+
const SDLoc &DL, SelectionDAG &DAG) {
11059+
assert(!LHS.getValueType().isVector());
11060+
assert(!RHS.getValueType().isVector());
11061+
11062+
auto *CTVal = dyn_cast<ConstantSDNode>(TVal);
11063+
auto *CFVal = dyn_cast<ConstantSDNode>(FVal);
11064+
if (!CTVal || !CFVal)
11065+
return {};
11066+
if (!(CTVal->isAllOnes() && CFVal->isZero()) &&
11067+
!(CTVal->isZero() && CFVal->isAllOnes()))
11068+
return {};
11069+
11070+
if (CTVal->isZero())
11071+
CC = ISD::getSetCCInverse(CC, LHS.getValueType());
11072+
11073+
EVT VT = TVal.getValueType();
11074+
if (VT.getSizeInBits() != LHS.getValueType().getSizeInBits())
11075+
return {};
11076+
11077+
if (!NoNaNs && (CC == ISD::SETUO || CC == ISD::SETO)) {
11078+
bool OneNaN = false;
11079+
if (LHS == RHS) {
11080+
OneNaN = true;
11081+
} else if (DAG.isKnownNeverNaN(RHS)) {
11082+
OneNaN = true;
11083+
RHS = LHS;
11084+
} else if (DAG.isKnownNeverNaN(LHS)) {
11085+
OneNaN = true;
11086+
LHS = RHS;
11087+
}
11088+
if (OneNaN)
11089+
CC = (CC == ISD::SETUO) ? ISD::SETUNE : ISD::SETOEQ;
11090+
}
11091+
11092+
AArch64CC::CondCode CC1;
11093+
AArch64CC::CondCode CC2;
11094+
bool ShouldInvert = false;
11095+
changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);
11096+
SDValue Cmp = emitVectorComparison(LHS, RHS, CC1, NoNaNs, VT, DL, DAG);
11097+
SDValue Cmp2;
11098+
if (CC2 != AArch64CC::AL) {
11099+
Cmp2 = emitVectorComparison(LHS, RHS, CC2, NoNaNs, VT, DL, DAG);
11100+
if (!Cmp2)
11101+
return {};
11102+
}
11103+
if (!Cmp2 && !ShouldInvert)
11104+
return Cmp;
11105+
11106+
EVT VecVT = EVT::getVectorVT(*DAG.getContext(), VT, 128 / VT.getSizeInBits());
11107+
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
11108+
Cmp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT, DAG.getUNDEF(VecVT), Cmp,
11109+
Zero);
11110+
if (Cmp2) {
11111+
Cmp2 = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT, DAG.getUNDEF(VecVT),
11112+
Cmp2, Zero);
11113+
Cmp = DAG.getNode(ISD::OR, DL, VecVT, Cmp, Cmp2);
11114+
}
11115+
if (ShouldInvert)
11116+
Cmp = DAG.getNOT(DL, Cmp, VecVT);
11117+
Cmp = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Cmp, Zero);
11118+
return Cmp;
11119+
}
11120+
1100511121
SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
1100611122
SDValue RHS, SDValue TVal,
11007-
SDValue FVal, const SDLoc &dl,
11123+
SDValue FVal, bool HasNoNaNs,
11124+
const SDLoc &dl,
1100811125
SelectionDAG &DAG) const {
1100911126
// Handle f128 first, because it will result in a comparison of some RTLIB
1101011127
// call result against zero.
@@ -11188,6 +11305,18 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
1118811305
LHS.getValueType() == MVT::f64);
1118911306
assert(LHS.getValueType() == RHS.getValueType());
1119011307
EVT VT = TVal.getValueType();
11308+
11309+
// If the purpose of the comparison is to select between all ones
11310+
// or all zeros, try to use a vector comparison because the operands are
11311+
// already stored in SIMD registers.
11312+
if (Subtarget->isNeonAvailable()) {
11313+
bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || HasNoNaNs;
11314+
SDValue VectorCmp =
11315+
emitFloatCompareMask(LHS, RHS, TVal, FVal, CC, NoNaNs, dl, DAG);
11316+
if (VectorCmp)
11317+
return VectorCmp;
11318+
}
11319+
1119111320
SDValue Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
1119211321

1119311322
// Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
@@ -11274,15 +11403,17 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
1127411403
SDValue RHS = Op.getOperand(1);
1127511404
SDValue TVal = Op.getOperand(2);
1127611405
SDValue FVal = Op.getOperand(3);
11406+
bool HasNoNans = Op->getFlags().hasNoNaNs();
1127711407
SDLoc DL(Op);
11278-
return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
11408+
return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, HasNoNans, DL, DAG);
1127911409
}
1128011410

1128111411
SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
1128211412
SelectionDAG &DAG) const {
1128311413
SDValue CCVal = Op->getOperand(0);
1128411414
SDValue TVal = Op->getOperand(1);
1128511415
SDValue FVal = Op->getOperand(2);
11416+
bool HasNoNans = Op->getFlags().hasNoNaNs();
1128611417
SDLoc DL(Op);
1128711418

1128811419
EVT Ty = Op.getValueType();
@@ -11349,7 +11480,7 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
1134911480
DAG.getUNDEF(MVT::f32), FVal);
1135011481
}
1135111482

11352-
SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
11483+
SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, HasNoNans, DL, DAG);
1135311484

1135411485
if ((Ty == MVT::f16 || Ty == MVT::bf16) && !Subtarget->hasFullFP16()) {
1135511486
return DAG.getTargetExtractSubreg(AArch64::hsub, DL, Ty, Res);
@@ -15602,47 +15733,6 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
1560215733
llvm_unreachable("unexpected shift opcode");
1560315734
}
1560415735

15605-
static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
15606-
AArch64CC::CondCode CC, bool NoNans, EVT VT,
15607-
const SDLoc &dl, SelectionDAG &DAG) {
15608-
EVT SrcVT = LHS.getValueType();
15609-
assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
15610-
"function only supposed to emit natural comparisons");
15611-
15612-
if (SrcVT.getVectorElementType().isFloatingPoint()) {
15613-
switch (CC) {
15614-
default:
15615-
return SDValue();
15616-
case AArch64CC::NE: {
15617-
SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
15618-
return DAG.getNOT(dl, Fcmeq, VT);
15619-
}
15620-
case AArch64CC::EQ:
15621-
return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
15622-
case AArch64CC::GE:
15623-
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
15624-
case AArch64CC::GT:
15625-
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
15626-
case AArch64CC::LE:
15627-
if (!NoNans)
15628-
return SDValue();
15629-
// If we ignore NaNs then we can use to the LS implementation.
15630-
[[fallthrough]];
15631-
case AArch64CC::LS:
15632-
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
15633-
case AArch64CC::LT:
15634-
if (!NoNans)
15635-
return SDValue();
15636-
// If we ignore NaNs then we can use to the MI implementation.
15637-
[[fallthrough]];
15638-
case AArch64CC::MI:
15639-
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
15640-
}
15641-
}
15642-
15643-
return SDValue();
15644-
}
15645-
1564615736
SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
1564715737
SelectionDAG &DAG) const {
1564815738
if (Op.getValueType().isScalableVector())
@@ -15691,15 +15781,14 @@ SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
1569115781
bool ShouldInvert;
1569215782
changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);
1569315783

15694-
bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || Op->getFlags().hasNoNaNs();
15695-
SDValue Cmp =
15696-
EmitVectorComparison(LHS, RHS, CC1, NoNaNs, CmpVT, dl, DAG);
15784+
bool NoNaNs =
15785+
getTargetMachine().Options.NoNaNsFPMath || Op->getFlags().hasNoNaNs();
15786+
SDValue Cmp = emitVectorComparison(LHS, RHS, CC1, NoNaNs, CmpVT, dl, DAG);
1569715787
if (!Cmp.getNode())
1569815788
return SDValue();
1569915789

1570015790
if (CC2 != AArch64CC::AL) {
15701-
SDValue Cmp2 =
15702-
EmitVectorComparison(LHS, RHS, CC2, NoNaNs, CmpVT, dl, DAG);
15791+
SDValue Cmp2 = emitVectorComparison(LHS, RHS, CC2, NoNaNs, CmpVT, dl, DAG);
1570315792
if (!Cmp2.getNode())
1570415793
return SDValue();
1570515794

@@ -25456,6 +25545,28 @@ static SDValue performDUPCombine(SDNode *N,
2545625545
}
2545725546

2545825547
if (N->getOpcode() == AArch64ISD::DUP) {
25548+
// If the instruction is known to produce a scalar in SIMD registers, we can
25549+
// duplicate it across the vector lanes using DUPLANE instead of moving it
25550+
// to a GPR first. For example, this allows us to handle:
25551+
// v4i32 = DUP (i32 (FCMGT (f32, f32)))
25552+
SDValue Op = N->getOperand(0);
25553+
// FIXME: Ideally, we should be able to handle all instructions that
25554+
// produce a scalar value in FPRs.
25555+
if (Op.getOpcode() == AArch64ISD::FCMEQ ||
25556+
Op.getOpcode() == AArch64ISD::FCMGE ||
25557+
Op.getOpcode() == AArch64ISD::FCMGT) {
25558+
EVT ElemVT = VT.getVectorElementType();
25559+
EVT ExpandedVT = VT;
25560+
// Insert into a 128-bit vector to match DUPLANE's pattern.
25561+
if (VT.getSizeInBits() != 128)
25562+
ExpandedVT = EVT::getVectorVT(*DCI.DAG.getContext(), ElemVT,
25563+
128 / ElemVT.getSizeInBits());
25564+
SDValue Zero = DCI.DAG.getConstant(0, DL, MVT::i64);
25565+
SDValue Vec = DCI.DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ExpandedVT,
25566+
DCI.DAG.getUNDEF(ExpandedVT), Op, Zero);
25567+
return DCI.DAG.getNode(getDUPLANEOp(ElemVT), DL, VT, Vec, Zero);
25568+
}
25569+
2545925570
if (DCI.isAfterLegalizeDAG()) {
2546025571
// If scalar dup's operand is extract_vector_elt, try to combine them into
2546125572
// duplane. For example,

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,8 +643,8 @@ class AArch64TargetLowering : public TargetLowering {
643643
SDValue LowerSELECT(SDValue Op, SelectionDAG &DAG) const;
644644
SDValue LowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const;
645645
SDValue LowerSELECT_CC(ISD::CondCode CC, SDValue LHS, SDValue RHS,
646-
SDValue TVal, SDValue FVal, const SDLoc &dl,
647-
SelectionDAG &DAG) const;
646+
SDValue TVal, SDValue FVal, bool HasNoNans,
647+
const SDLoc &dl, SelectionDAG &DAG) const;
648648
SDValue LowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
649649
SDValue LowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
650650
SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;

llvm/test/CodeGen/AArch64/arm64-neon-v1i1-setcc.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,9 @@ define <1 x i16> @test_select_f16_i16(half %i105, half %in, <1 x i16> %x, <1 x i
174174
; CHECK-LABEL: test_select_f16_i16:
175175
; CHECK: // %bb.0:
176176
; CHECK-NEXT: fcvt s0, h0
177-
; CHECK-NEXT: fcmp s0, s0
178-
; CHECK-NEXT: csetm w8, vs
179-
; CHECK-NEXT: dup v0.4h, w8
177+
; CHECK-NEXT: fcmeq s0, s0, s0
178+
; CHECK-NEXT: mvn v0.16b, v0.16b
179+
; CHECK-NEXT: dup v0.4h, v0.h[0]
180180
; CHECK-NEXT: bsl v0.8b, v2.8b, v3.8b
181181
; CHECK-NEXT: ret
182182
%i179 = fcmp uno half %i105, zeroinitializer

0 commit comments

Comments
 (0)