Skip to content

[AArch64] Spare N2I roundtrip when splatting float comparison #141806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 173 additions & 52 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11002,10 +11002,126 @@ SDValue AArch64TargetLowering::LowerSETCCCARRY(SDValue Op,
Cmp.getValue(1));
}

SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
SDValue RHS, SDValue TVal,
SDValue FVal, const SDLoc &dl,
SelectionDAG &DAG) const {
/// Emit vector comparison for floating-point values, producing a mask.
static SDValue emitVectorComparison(SDValue LHS, SDValue RHS,
AArch64CC::CondCode CC, bool NoNans, EVT VT,
const SDLoc &DL, SelectionDAG &DAG) {
EVT SrcVT = LHS.getValueType();
assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
"function only supposed to emit natural comparisons");

switch (CC) {
default:
return SDValue();
case AArch64CC::NE: {
SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, DL, VT, LHS, RHS);
// Use vector semantics for the inversion to potentially save a copy between
// SIMD and regular registers.
if (!LHS.getValueType().isVector()) {
EVT VecVT =
EVT::getVectorVT(*DAG.getContext(), VT, 128 / VT.getSizeInBits());
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
SDValue MaskVec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT,
DAG.getUNDEF(VecVT), Fcmeq, Zero);
SDValue InvertedMask = DAG.getNOT(DL, MaskVec, VecVT);
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, InvertedMask, Zero);
}
return DAG.getNOT(DL, Fcmeq, VT);
}
case AArch64CC::EQ:
return DAG.getNode(AArch64ISD::FCMEQ, DL, VT, LHS, RHS);
case AArch64CC::GE:
return DAG.getNode(AArch64ISD::FCMGE, DL, VT, LHS, RHS);
case AArch64CC::GT:
return DAG.getNode(AArch64ISD::FCMGT, DL, VT, LHS, RHS);
case AArch64CC::LE:
if (!NoNans)
return SDValue();
// If we ignore NaNs then we can use to the LS implementation.
[[fallthrough]];
case AArch64CC::LS:
return DAG.getNode(AArch64ISD::FCMGE, DL, VT, RHS, LHS);
case AArch64CC::LT:
if (!NoNans)
return SDValue();
// If we ignore NaNs then we can use to the MI implementation.
[[fallthrough]];
case AArch64CC::MI:
return DAG.getNode(AArch64ISD::FCMGT, DL, VT, RHS, LHS);
}
}

/// For SELECT_CC, when the true/false values are (-1, 0) and the compared
/// values are scalars, try to emit a mask generating vector instruction.
static SDValue emitFloatCompareMask(SDValue LHS, SDValue RHS, SDValue TVal,
SDValue FVal, ISD::CondCode CC, bool NoNaNs,
const SDLoc &DL, SelectionDAG &DAG) {
assert(!LHS.getValueType().isVector());
assert(!RHS.getValueType().isVector());

auto *CTVal = dyn_cast<ConstantSDNode>(TVal);
auto *CFVal = dyn_cast<ConstantSDNode>(FVal);
if (!CTVal || !CFVal)
return {};
if (!(CTVal->isAllOnes() && CFVal->isZero()) &&
!(CTVal->isZero() && CFVal->isAllOnes()))
return {};

if (CTVal->isZero())
CC = ISD::getSetCCInverse(CC, LHS.getValueType());

EVT VT = TVal.getValueType();
if (VT.getSizeInBits() != LHS.getValueType().getSizeInBits())
return {};

if (!NoNaNs && (CC == ISD::SETUO || CC == ISD::SETO)) {
bool OneNaN = false;
if (LHS == RHS) {
OneNaN = true;
} else if (DAG.isKnownNeverNaN(RHS)) {
OneNaN = true;
RHS = LHS;
} else if (DAG.isKnownNeverNaN(LHS)) {
OneNaN = true;
LHS = RHS;
}
if (OneNaN)
CC = (CC == ISD::SETUO) ? ISD::SETUNE : ISD::SETOEQ;
}

AArch64CC::CondCode CC1;
AArch64CC::CondCode CC2;
bool ShouldInvert = false;
changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);
SDValue Cmp = emitVectorComparison(LHS, RHS, CC1, NoNaNs, VT, DL, DAG);
SDValue Cmp2;
if (CC2 != AArch64CC::AL) {
Cmp2 = emitVectorComparison(LHS, RHS, CC2, NoNaNs, VT, DL, DAG);
if (!Cmp2)
return {};
}
if (!Cmp2 && !ShouldInvert)
return Cmp;

EVT VecVT = EVT::getVectorVT(*DAG.getContext(), VT, 128 / VT.getSizeInBits());
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
Cmp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT, DAG.getUNDEF(VecVT), Cmp,
Zero);
if (Cmp2) {
Cmp2 = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT, DAG.getUNDEF(VecVT),
Cmp2, Zero);
Cmp = DAG.getNode(ISD::OR, DL, VecVT, Cmp, Cmp2);
}
if (ShouldInvert)
Cmp = DAG.getNOT(DL, Cmp, VecVT);
Cmp = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Cmp, Zero);
return Cmp;
}

SDValue AArch64TargetLowering::LowerSELECT_CC(
ISD::CondCode CC, SDValue LHS, SDValue RHS, SDValue TVal, SDValue FVal,
iterator_range<SDNode::user_iterator> Users, bool HasNoNaNs,
const SDLoc &dl, SelectionDAG &DAG) const {
// Handle f128 first, because it will result in a comparison of some RTLIB
// call result against zero.
if (LHS.getValueType() == MVT::f128) {
Expand Down Expand Up @@ -11188,6 +11304,27 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
LHS.getValueType() == MVT::f64);
assert(LHS.getValueType() == RHS.getValueType());
EVT VT = TVal.getValueType();

// If the purpose of the comparison is to select between all ones
// or all zeros, try to use a vector comparison because the operands are
// already stored in SIMD registers.
if (Subtarget->isNeonAvailable() && all_of(Users, [](const SDNode *U) {
switch (U->getOpcode()) {
default:
return false;
case ISD::INSERT_VECTOR_ELT:
case ISD::SCALAR_TO_VECTOR:
case AArch64ISD::DUP:
return true;
}
})) {
bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || HasNoNaNs;
SDValue VectorCmp =
emitFloatCompareMask(LHS, RHS, TVal, FVal, CC, NoNaNs, dl, DAG);
if (VectorCmp)
return VectorCmp;
}

SDValue Cmp = emitComparison(LHS, RHS, CC, dl, DAG);

// Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
Expand Down Expand Up @@ -11274,15 +11411,18 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
SDValue RHS = Op.getOperand(1);
SDValue TVal = Op.getOperand(2);
SDValue FVal = Op.getOperand(3);
bool HasNoNans = Op->getFlags().hasNoNaNs();
SDLoc DL(Op);
return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, Op->users(), HasNoNans, DL,
DAG);
}

SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
SelectionDAG &DAG) const {
SDValue CCVal = Op->getOperand(0);
SDValue TVal = Op->getOperand(1);
SDValue FVal = Op->getOperand(2);
bool HasNoNans = Op->getFlags().hasNoNaNs();
SDLoc DL(Op);

EVT Ty = Op.getValueType();
Expand Down Expand Up @@ -11349,7 +11489,8 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
DAG.getUNDEF(MVT::f32), FVal);
}

SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
SDValue Res =
LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, Op->users(), HasNoNans, DL, DAG);

if ((Ty == MVT::f16 || Ty == MVT::bf16) && !Subtarget->hasFullFP16()) {
return DAG.getTargetExtractSubreg(AArch64::hsub, DL, Ty, Res);
Expand Down Expand Up @@ -15602,47 +15743,6 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
llvm_unreachable("unexpected shift opcode");
}

static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
AArch64CC::CondCode CC, bool NoNans, EVT VT,
const SDLoc &dl, SelectionDAG &DAG) {
EVT SrcVT = LHS.getValueType();
assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
"function only supposed to emit natural comparisons");

if (SrcVT.getVectorElementType().isFloatingPoint()) {
switch (CC) {
default:
return SDValue();
case AArch64CC::NE: {
SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
return DAG.getNOT(dl, Fcmeq, VT);
}
case AArch64CC::EQ:
return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
case AArch64CC::GE:
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
case AArch64CC::GT:
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
case AArch64CC::LE:
if (!NoNans)
return SDValue();
// If we ignore NaNs then we can use to the LS implementation.
[[fallthrough]];
case AArch64CC::LS:
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
case AArch64CC::LT:
if (!NoNans)
return SDValue();
// If we ignore NaNs then we can use to the MI implementation.
[[fallthrough]];
case AArch64CC::MI:
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
}
}

return SDValue();
}

SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
SelectionDAG &DAG) const {
if (Op.getValueType().isScalableVector())
Expand Down Expand Up @@ -15691,15 +15791,14 @@ SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
bool ShouldInvert;
changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);

bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || Op->getFlags().hasNoNaNs();
SDValue Cmp =
EmitVectorComparison(LHS, RHS, CC1, NoNaNs, CmpVT, dl, DAG);
bool NoNaNs =
getTargetMachine().Options.NoNaNsFPMath || Op->getFlags().hasNoNaNs();
SDValue Cmp = emitVectorComparison(LHS, RHS, CC1, NoNaNs, CmpVT, dl, DAG);
if (!Cmp.getNode())
return SDValue();

if (CC2 != AArch64CC::AL) {
SDValue Cmp2 =
EmitVectorComparison(LHS, RHS, CC2, NoNaNs, CmpVT, dl, DAG);
SDValue Cmp2 = emitVectorComparison(LHS, RHS, CC2, NoNaNs, CmpVT, dl, DAG);
if (!Cmp2.getNode())
return SDValue();

Expand Down Expand Up @@ -25456,6 +25555,28 @@ static SDValue performDUPCombine(SDNode *N,
}

if (N->getOpcode() == AArch64ISD::DUP) {
// If the instruction is known to produce a scalar in SIMD registers, we can
// duplicate it across the vector lanes using DUPLANE instead of moving it
// to a GPR first. For example, this allows us to handle:
// v4i32 = DUP (i32 (FCMGT (f32, f32)))
SDValue Op = N->getOperand(0);
// FIXME: Ideally, we should be able to handle all instructions that
// produce a scalar value in FPRs.
if (Op.getOpcode() == AArch64ISD::FCMEQ ||
Op.getOpcode() == AArch64ISD::FCMGE ||
Op.getOpcode() == AArch64ISD::FCMGT) {
EVT ElemVT = VT.getVectorElementType();
EVT ExpandedVT = VT;
// Insert into a 128-bit vector to match DUPLANE's pattern.
if (VT.getSizeInBits() != 128)
ExpandedVT = EVT::getVectorVT(*DCI.DAG.getContext(), ElemVT,
128 / ElemVT.getSizeInBits());
SDValue Zero = DCI.DAG.getConstant(0, DL, MVT::i64);
SDValue Vec = DCI.DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ExpandedVT,
DCI.DAG.getUNDEF(ExpandedVT), Op, Zero);
return DCI.DAG.getNode(getDUPLANEOp(ElemVT), DL, VT, Vec, Zero);
}

if (DCI.isAfterLegalizeDAG()) {
// If scalar dup's operand is extract_vector_elt, try to combine them into
// duplane. For example,
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,9 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerSELECT(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSELECT_CC(ISD::CondCode CC, SDValue LHS, SDValue RHS,
SDValue TVal, SDValue FVal, const SDLoc &dl,
SDValue TVal, SDValue FVal,
iterator_range<SDNode::user_iterator> Users,
bool HasNoNans, const SDLoc &dl,
SelectionDAG &DAG) const;
SDValue LowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
Expand Down
6 changes: 3 additions & 3 deletions llvm/test/CodeGen/AArch64/arm64-neon-v1i1-setcc.ll
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ define <1 x i16> @test_select_f16_i16(half %i105, half %in, <1 x i16> %x, <1 x i
; CHECK-LABEL: test_select_f16_i16:
; CHECK: // %bb.0:
; CHECK-NEXT: fcvt s0, h0
; CHECK-NEXT: fcmp s0, s0
; CHECK-NEXT: csetm w8, vs
; CHECK-NEXT: dup v0.4h, w8
; CHECK-NEXT: fcmeq s0, s0, s0
; CHECK-NEXT: mvn v0.16b, v0.16b
; CHECK-NEXT: dup v0.4h, v0.h[0]
; CHECK-NEXT: bsl v0.8b, v2.8b, v3.8b
; CHECK-NEXT: ret
%i179 = fcmp uno half %i105, zeroinitializer
Expand Down
Loading