Skip to content

[AArch64] Spare N2I roundtrip when splatting float comparison #141806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

guy-david
Copy link
Contributor

Transform select_cc t1, t2, -1, 0 for floats into a vector comparison which generates a mask, which is later on combined with potential vectorized DUPs.

For GlobalISel, it seems that an equivalent for SELECT_CC does not exist yet?

@llvmbot
Copy link
Member

llvmbot commented May 28, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Guy David (guy-david)

Changes

Transform select_cc t1, t2, -1, 0 for floats into a vector comparison which generates a mask, which is later on combined with potential vectorized DUPs.

For GlobalISel, it seems that an equivalent for SELECT_CC does not exist yet?


Full diff: https://github.com/llvm/llvm-project/pull/141806.diff

4 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+89-44)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+2-2)
  • (modified) llvm/test/CodeGen/AArch64/arm64-neon-v1i1-setcc.ll (+2-3)
  • (added) llvm/test/CodeGen/AArch64/build-vector-dup-simd.ll (+32)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a817ed5f0e917..da5117292e269 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -10906,9 +10906,48 @@ SDValue AArch64TargetLowering::LowerSETCCCARRY(SDValue Op,
                      Cmp.getValue(1));
 }
 
+/// Emit vector comparison for floating-point values, producing a mask.
+static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
+                                    AArch64CC::CondCode CC, bool NoNans, EVT VT,
+                                    const SDLoc &dl, SelectionDAG &DAG) {
+  EVT SrcVT = LHS.getValueType();
+  assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
+         "function only supposed to emit natural comparisons");
+
+  switch (CC) {
+  default:
+    return SDValue();
+  case AArch64CC::NE: {
+    SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
+    return DAG.getNOT(dl, Fcmeq, VT);
+  }
+  case AArch64CC::EQ:
+    return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
+  case AArch64CC::GE:
+    return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
+  case AArch64CC::GT:
+    return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
+  case AArch64CC::LE:
+    if (!NoNans)
+      return SDValue();
+    // If we ignore NaNs then we can use to the LS implementation.
+    [[fallthrough]];
+  case AArch64CC::LS:
+    return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
+  case AArch64CC::LT:
+    if (!NoNans)
+      return SDValue();
+    // If we ignore NaNs then we can use to the MI implementation.
+    [[fallthrough]];
+  case AArch64CC::MI:
+    return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
+  }
+}
+
 SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
                                               SDValue RHS, SDValue TVal,
-                                              SDValue FVal, const SDLoc &dl,
+                                              SDValue FVal, bool HasNoNaNs,
+                                              const SDLoc &dl,
                                               SelectionDAG &DAG) const {
   // Handle f128 first, because it will result in a comparison of some RTLIB
   // call result against zero.
@@ -11092,6 +11131,29 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
          LHS.getValueType() == MVT::f64);
   assert(LHS.getValueType() == RHS.getValueType());
   EVT VT = TVal.getValueType();
+
+  // If the purpose of the comparison is to select between all ones
+  // or all zeros, use a vector comparison because the operands are already
+  // stored in SIMD registers.
+  auto *CTVal = dyn_cast<ConstantSDNode>(TVal);
+  auto *CFVal = dyn_cast<ConstantSDNode>(FVal);
+  if (Subtarget->isNeonAvailable() &&
+      (VT.getSizeInBits() == LHS.getValueType().getSizeInBits()) && CTVal &&
+      CFVal &&
+      ((CTVal->isAllOnes() && CFVal->isZero()) ||
+       ((CTVal->isZero()) && CFVal->isAllOnes()))) {
+    AArch64CC::CondCode CC1;
+    AArch64CC::CondCode CC2;
+    bool ShouldInvert = false;
+    changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);
+    if (CTVal->isZero() ^ ShouldInvert)
+      std::swap(TVal, FVal);
+    bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || HasNoNaNs;
+    SDValue Res = EmitVectorComparison(LHS, RHS, CC1, NoNaNs, VT, dl, DAG);
+    if (Res)
+      return Res;
+  }
+
   SDValue Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
 
   // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
@@ -11178,8 +11240,9 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
   SDValue RHS = Op.getOperand(1);
   SDValue TVal = Op.getOperand(2);
   SDValue FVal = Op.getOperand(3);
+  bool HasNoNans = Op->getFlags().hasNoNaNs();
   SDLoc DL(Op);
-  return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
+  return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, HasNoNans, DL, DAG);
 }
 
 SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
@@ -11187,6 +11250,7 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
   SDValue CCVal = Op->getOperand(0);
   SDValue TVal = Op->getOperand(1);
   SDValue FVal = Op->getOperand(2);
+  bool HasNoNans = Op->getFlags().hasNoNaNs();
   SDLoc DL(Op);
 
   EVT Ty = Op.getValueType();
@@ -11253,7 +11317,7 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
                                      DAG.getUNDEF(MVT::f32), FVal);
   }
 
-  SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
+  SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, HasNoNans, DL, DAG);
 
   if ((Ty == MVT::f16 || Ty == MVT::bf16) && !Subtarget->hasFullFP16()) {
     return DAG.getTargetExtractSubreg(AArch64::hsub, DL, Ty, Res);
@@ -15506,47 +15570,6 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
   llvm_unreachable("unexpected shift opcode");
 }
 
-static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
-                                    AArch64CC::CondCode CC, bool NoNans, EVT VT,
-                                    const SDLoc &dl, SelectionDAG &DAG) {
-  EVT SrcVT = LHS.getValueType();
-  assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
-         "function only supposed to emit natural comparisons");
-
-  if (SrcVT.getVectorElementType().isFloatingPoint()) {
-    switch (CC) {
-    default:
-      return SDValue();
-    case AArch64CC::NE: {
-      SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
-      return DAG.getNOT(dl, Fcmeq, VT);
-    }
-    case AArch64CC::EQ:
-      return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
-    case AArch64CC::GE:
-      return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
-    case AArch64CC::GT:
-      return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
-    case AArch64CC::LE:
-      if (!NoNans)
-        return SDValue();
-      // If we ignore NaNs then we can use to the LS implementation.
-      [[fallthrough]];
-    case AArch64CC::LS:
-      return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
-    case AArch64CC::LT:
-      if (!NoNans)
-        return SDValue();
-      // If we ignore NaNs then we can use to the MI implementation.
-      [[fallthrough]];
-    case AArch64CC::MI:
-      return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
-    }
-  }
-
-  return SDValue();
-}
-
 SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
                                            SelectionDAG &DAG) const {
   if (Op.getValueType().isScalableVector())
@@ -25365,6 +25388,28 @@ static SDValue performDUPCombine(SDNode *N,
   }
 
   if (N->getOpcode() == AArch64ISD::DUP) {
+    // If the instruction is known to produce a scalar in SIMD registers, we can
+    // can duplicate it across the vector lanes using DUPLANE instead of moving
+    // it to a GPR first. For example, this allows us to handle:
+    //   v4i32 = DUP (i32 (FCMGT (f32, f32)))
+    SDValue Op = N->getOperand(0);
+    // FIXME: Ideally, we should be able to handle all instructions that
+    // produce a scalar value in FPRs.
+    if (Op.getOpcode() == AArch64ISD::FCMEQ ||
+        Op.getOpcode() == AArch64ISD::FCMGE ||
+        Op.getOpcode() == AArch64ISD::FCMGT) {
+      EVT ElemVT = VT.getVectorElementType();
+      EVT ExpandedVT = VT;
+      // Insert into a 128-bit vector to match DUPLANE's pattern.
+      if (VT.getSizeInBits() != 128)
+        ExpandedVT = EVT::getVectorVT(*DCI.DAG.getContext(), ElemVT,
+                                      128 / ElemVT.getSizeInBits());
+      SDValue Zero = DCI.DAG.getConstant(0, DL, MVT::i64);
+      SDValue Vec = DCI.DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ExpandedVT,
+                                    DCI.DAG.getUNDEF(ExpandedVT), Op, Zero);
+      return DCI.DAG.getNode(getDUPLANEOp(ElemVT), DL, VT, Vec, Zero);
+    }
+
     if (DCI.isAfterLegalizeDAG()) {
       // If scalar dup's operand is extract_vector_elt, try to combine them into
       // duplane. For example,
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 1924d20f67f49..e2e2150133e82 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -645,8 +645,8 @@ class AArch64TargetLowering : public TargetLowering {
   SDValue LowerSELECT(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerSELECT_CC(ISD::CondCode CC, SDValue LHS, SDValue RHS,
-                         SDValue TVal, SDValue FVal, const SDLoc &dl,
-                         SelectionDAG &DAG) const;
+                         SDValue TVal, SDValue FVal, bool HasNoNans,
+                         const SDLoc &dl, SelectionDAG &DAG) const;
   SDValue LowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/AArch64/arm64-neon-v1i1-setcc.ll b/llvm/test/CodeGen/AArch64/arm64-neon-v1i1-setcc.ll
index 6c70d19a977a5..05178c1dc291c 100644
--- a/llvm/test/CodeGen/AArch64/arm64-neon-v1i1-setcc.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-neon-v1i1-setcc.ll
@@ -174,9 +174,8 @@ define <1 x i16> @test_select_f16_i16(half %i105, half %in, <1 x i16> %x, <1 x i
 ; CHECK-LABEL: test_select_f16_i16:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    fcvt s0, h0
-; CHECK-NEXT:    fcmp s0, s0
-; CHECK-NEXT:    csetm w8, vs
-; CHECK-NEXT:    dup v0.4h, w8
+; CHECK-NEXT:    fcmgt s0, s0, s0
+; CHECK-NEXT:    dup v0.4h, v0.h[0]
 ; CHECK-NEXT:    bsl v0.8b, v2.8b, v3.8b
 ; CHECK-NEXT:    ret
   %i179 = fcmp uno half %i105, zeroinitializer
diff --git a/llvm/test/CodeGen/AArch64/build-vector-dup-simd.ll b/llvm/test/CodeGen/AArch64/build-vector-dup-simd.ll
new file mode 100644
index 0000000000000..c52b8817ab6f8
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/build-vector-dup-simd.ll
@@ -0,0 +1,32 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=aarch64 | FileCheck %s
+
+define <4 x float> @dup32(float %a, float %b) {
+; CHECK-LABEL: dup32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fcmgt s0, s0, s1
+; CHECK-NEXT:    dup v0.4s, v0.s[0]
+; CHECK-NEXT:    ret
+entry:
+  %0 = fcmp ogt float %a, %b
+  %vcmpd.i = sext i1 %0 to i32
+  %vecinit.i = insertelement <4 x i32> poison, i32 %vcmpd.i, i64 0
+  %1 = bitcast <4 x i32> %vecinit.i to <4 x float>
+  %2 = shufflevector <4 x float> %1, <4 x float> poison, <4 x i32> zeroinitializer
+  ret <4 x float> %2
+}
+
+define <2 x double> @dup64(double %a, double %b) {
+; CHECK-LABEL: dup64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fcmgt d0, d0, d1
+; CHECK-NEXT:    dup v0.2d, v0.d[0]
+; CHECK-NEXT:    ret
+entry:
+  %0 = fcmp ogt double %a, %b
+  %vcmpd.i = sext i1 %0 to i64
+  %vecinit.i = insertelement <2 x i64> poison, i64 %vcmpd.i, i64 0
+  %1 = bitcast <2 x i64> %vecinit.i to <2 x double>
+  %2 = shufflevector <2 x double> %1, <2 x double> poison, <2 x i32> zeroinitializer
+  ret <2 x double> %2
+}

@guy-david guy-david force-pushed the users/guy-david/aarch64-dup-lane-simd branch from 1472468 to 8623154 Compare May 29, 2025 14:06
@guy-david guy-david force-pushed the users/guy-david/aarch64-dup-lane-simd branch from 8623154 to 1bc6d26 Compare May 31, 2025 20:15
Copy link

github-actions bot commented May 31, 2025

✅ With the latest revision this PR passed the undef deprecator.

@guy-david guy-david force-pushed the users/guy-david/aarch64-dup-lane-simd branch from 1bc6d26 to e79057c Compare May 31, 2025 20:20
@guy-david guy-david requested a review from davemgreen June 4, 2025 13:18
@guy-david guy-david force-pushed the users/guy-david/aarch64-dup-lane-simd branch from e79057c to 1beafba Compare June 5, 2025 13:40
@guy-david guy-david requested a review from davemgreen June 5, 2025 13:41
@guy-david guy-david force-pushed the users/guy-david/aarch64-dup-lane-simd branch from 1beafba to eb7addd Compare June 5, 2025 22:39
Transform `select_cc t1, t2, -1, 0` for floats into a vector comparison which
generates a mask, which is later on combined with potential vectorized DUPs.
@guy-david guy-david force-pushed the users/guy-david/aarch64-dup-lane-simd branch from eb7addd to 8fa8f13 Compare June 6, 2025 12:45
@guy-david guy-david requested a review from davemgreen June 6, 2025 12:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants