Skip to content

[RISCV] Allow folding vmerge into masked ops when mask is the same #97989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 9, 2024

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Jul 8, 2024

We currently only fold a vmerge into a masked true operand if the vmerge has an all-ones mask, since we end up keeping the mask from the true operand.

But if the masks are the same then we can still fold, because vmerge and true have the same passthru. If an element was masked off in the original vmerge, it will also be masked off in the resulting true, and will have the same passthru value.

The motivation for this is to lower masked VP loads and stores with passthrus to masked RVV instructions. Normally you can express a masked RVV instruction with a mask undisturbed passthru via a combination of a VP op with an all-ones mask and a vp.merge. But for loads and stores you need the same mask on the VP op as well as the vp.merge.

lukel97 added 2 commits July 8, 2024 14:19
We currently only fold a vmerge into a masked true operand if the vmerge has an all-ones mask, since we end up keeping the mask from the true operand.

But if the masks are the same then we can still fold, because vmerge and true have the same passthru. If an element was masked off in the original vmerge, it will also be masked off in the resulting true, and will have the same passthru value.

The motivation for this is to lower masked VP loads and stores with passthrus to masked RVV instructions. Normally you can express a masked RVV instruction with a mask undisturbed passthru via a combination of a VP op with an all-ones mask and a vp.merge. But for loads and stores you need the same mask on the VP op as well as the vp.merge.
@llvmbot
Copy link
Member

llvmbot commented Jul 8, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

Changes

We currently only fold a vmerge into a masked true operand if the vmerge has an all-ones mask, since we end up keeping the mask from the true operand.

But if the masks are the same then we can still fold, because vmerge and true have the same passthru. If an element was masked off in the original vmerge, it will also be masked off in the resulting true, and will have the same passthru value.

The motivation for this is to lower masked VP loads and stores with passthrus to masked RVV instructions. Normally you can express a masked RVV instruction with a mask undisturbed passthru via a combination of a VP op with an all-ones mask and a vp.merge. But for loads and stores you need the same mask on the VP op as well as the vp.merge.


Full diff: https://github.com/llvm/llvm-project/pull/97989.diff

4 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp (+23-8)
  • (modified) llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll (+11)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll (+2-5)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vpload.ll (+19-8)
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index ce6a396e9ced9..ce1e55a80fe62 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -3523,24 +3523,26 @@ bool RISCVDAGToDAGISel::doPeepholeSExtW(SDNode *N) {
   return false;
 }
 
-static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
+// After ISel, a vector pseudo's mask will be copied to V0, and the CopyToReg
+// will be glued to the pseudo. This tries to up the value that was copied to
+// V0.
+static SDValue getMaskSetter(SDValue MaskOp, SDValue GlueOp) {
   // Check that we're using V0 as a mask register.
   if (!isa<RegisterSDNode>(MaskOp) ||
       cast<RegisterSDNode>(MaskOp)->getReg() != RISCV::V0)
-    return false;
+    return SDValue();
 
   // The glued user defines V0.
   const auto *Glued = GlueOp.getNode();
 
   if (!Glued || Glued->getOpcode() != ISD::CopyToReg)
-    return false;
+    return SDValue();
 
   // Check that we're defining V0 as a mask register.
   if (!isa<RegisterSDNode>(Glued->getOperand(1)) ||
       cast<RegisterSDNode>(Glued->getOperand(1))->getReg() != RISCV::V0)
-    return false;
+    return SDValue();
 
-  // Check the instruction defining V0; it needs to be a VMSET pseudo.
   SDValue MaskSetter = Glued->getOperand(2);
 
   // Sometimes the VMSET is wrapped in a COPY_TO_REGCLASS, e.g. if the mask came
@@ -3549,6 +3551,15 @@ static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
       MaskSetter->getMachineOpcode() == RISCV::COPY_TO_REGCLASS)
     MaskSetter = MaskSetter->getOperand(0);
 
+  return MaskSetter;
+}
+
+static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
+  // Check the instruction defining V0; it needs to be a VMSET pseudo.
+  SDValue MaskSetter = getMaskSetter(MaskOp, GlueOp);
+  if (!MaskSetter)
+    return false;
+
   const auto IsVMSet = [](unsigned Opc) {
     return Opc == RISCV::PseudoVMSET_M_B1 || Opc == RISCV::PseudoVMSET_M_B16 ||
            Opc == RISCV::PseudoVMSET_M_B2 || Opc == RISCV::PseudoVMSET_M_B32 ||
@@ -3755,12 +3766,16 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
       return false;
   }
 
-  // If True is masked then the vmerge must have an all 1s mask, since we're
-  // going to keep the mask from True.
+  // If True is masked then the vmerge must have either the same mask or an all
+  // 1s mask, since we're going to keep the mask from True.
   if (IsMasked && Mask) {
     // FIXME: Support mask agnostic True instruction which would have an
     // undef merge operand.
-    if (!usesAllOnesMask(Mask, Glue))
+    SDValue TrueMask =
+        getMaskSetter(True->getOperand(Info->MaskOpIdx),
+                      True->getOperand(True->getNumOperands() - 1));
+    assert(TrueMask);
+    if (!usesAllOnesMask(Mask, Glue) && getMaskSetter(Mask, Glue) != TrueMask)
       return false;
   }
 
diff --git a/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll b/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll
index 033a1d7e297f7..d26fd0ca26c72 100644
--- a/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll
@@ -240,3 +240,14 @@ define <vscale x 2 x i32> @vpmerge_viota(<vscale x 2 x i32> %passthru, <vscale x
   %b = call <vscale x 2 x i32> @llvm.riscv.vmerge.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %passthru, <vscale x 2 x i32> %a, <vscale x 2 x i1> splat (i1 -1), i64 %1)
   ret <vscale x 2 x i32> %b
 }
+
+define <vscale x 2 x i32> @vpmerge_vadd_same_mask(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, <vscale x 2 x i1> %m, i64 %vl) {
+; CHECK-LABEL: vpmerge_vadd_same_mask:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e32, m1, tu, mu
+; CHECK-NEXT:    vadd.vv v8, v9, v10, v0.t
+; CHECK-NEXT:    ret
+  %a = call <vscale x 2 x i32> @llvm.riscv.vadd.mask.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, <vscale x 2 x i1> %m, i64 %vl, i64 1)
+  %b = call <vscale x 2 x i32> @llvm.riscv.vmerge.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %passthru, <vscale x 2 x i32> %a, <vscale x 2 x i1> %m, i64 %vl)
+  ret <vscale x 2 x i32> %b
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll
index f9d992a40299c..f635b61fc3e5b 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll
@@ -85,11 +85,8 @@ define <vscale x 1 x float> @vfmacc_vv_nxv1f32_tu(<vscale x 1 x half> %a, <vscal
 define <vscale x 1 x float> @vfmacc_vv_nxv1f32_masked__tu(<vscale x 1 x half> %a, <vscale x 1 x half> %b, <vscale x 1 x float> %c, <vscale x 1 x i1> %m, i32 zeroext %evl) {
 ; ZVFH-LABEL: vfmacc_vv_nxv1f32_masked__tu:
 ; ZVFH:       # %bb.0:
-; ZVFH-NEXT:    vmv1r.v v11, v10
-; ZVFH-NEXT:    vsetvli zero, a0, e16, mf4, ta, ma
-; ZVFH-NEXT:    vfwmacc.vv v11, v8, v9, v0.t
-; ZVFH-NEXT:    vsetvli zero, zero, e32, mf2, tu, ma
-; ZVFH-NEXT:    vmerge.vvm v10, v10, v11, v0
+; ZVFH-NEXT:    vsetvli zero, a0, e16, mf4, tu, mu
+; ZVFH-NEXT:    vfwmacc.vv v10, v8, v9, v0.t
 ; ZVFH-NEXT:    vmv1r.v v8, v10
 ; ZVFH-NEXT:    ret
 ;
diff --git a/llvm/test/CodeGen/RISCV/rvv/vpload.ll b/llvm/test/CodeGen/RISCV/rvv/vpload.ll
index 1b1e9153a2fd5..c0a210e680c79 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vpload.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vpload.ll
@@ -26,6 +26,17 @@ define <vscale x 1 x i8> @vpload_nxv1i8_allones_mask(ptr %ptr, i32 zeroext %evl)
   ret <vscale x 1 x i8> %load
 }
 
+define <vscale x 1 x i8> @vpload_nxv1i8_passthru(ptr %ptr, <vscale x 1 x i1> %m, <vscale x 1 x i8> %passthru, i32 zeroext %evl) {
+; CHECK-LABEL: vpload_nxv1i8_passthru:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a1, e8, mf8, tu, mu
+; CHECK-NEXT:    vle8.v v8, (a0), v0.t
+; CHECK-NEXT:    ret
+  %load = call <vscale x 1 x i8> @llvm.vp.load.nxv1i8.p0(ptr %ptr, <vscale x 1 x i1> %m, i32 %evl)
+  %merge = call <vscale x 1 x i8> @llvm.vp.merge.nxv1i8(<vscale x 1 x i1> %m, <vscale x 1 x i8> %load, <vscale x 1 x i8> %passthru, i32 %evl)
+  ret <vscale x 1 x i8> %merge
+}
+
 declare <vscale x 2 x i8> @llvm.vp.load.nxv2i8.p0(ptr, <vscale x 2 x i1>, i32)
 
 define <vscale x 2 x i8> @vpload_nxv2i8(ptr %ptr, <vscale x 2 x i1> %m, i32 zeroext %evl) {
@@ -450,10 +461,10 @@ define <vscale x 16 x double> @vpload_nxv16f64(ptr %ptr, <vscale x 16 x i1> %m,
 ; CHECK-NEXT:    add a4, a0, a4
 ; CHECK-NEXT:    vsetvli zero, a3, e64, m8, ta, ma
 ; CHECK-NEXT:    vle64.v v16, (a4), v0.t
-; CHECK-NEXT:    bltu a1, a2, .LBB37_2
+; CHECK-NEXT:    bltu a1, a2, .LBB38_2
 ; CHECK-NEXT:  # %bb.1:
 ; CHECK-NEXT:    mv a1, a2
-; CHECK-NEXT:  .LBB37_2:
+; CHECK-NEXT:  .LBB38_2:
 ; CHECK-NEXT:    vmv1r.v v0, v8
 ; CHECK-NEXT:    vsetvli zero, a1, e64, m8, ta, ma
 ; CHECK-NEXT:    vle64.v v8, (a0), v0.t
@@ -480,10 +491,10 @@ define <vscale x 16 x double> @vpload_nxv17f64(ptr %ptr, ptr %out, <vscale x 17
 ; CHECK-NEXT:    slli a5, a3, 1
 ; CHECK-NEXT:    vmv1r.v v8, v0
 ; CHECK-NEXT:    mv a4, a2
-; CHECK-NEXT:    bltu a2, a5, .LBB38_2
+; CHECK-NEXT:    bltu a2, a5, .LBB39_2
 ; CHECK-NEXT:  # %bb.1:
 ; CHECK-NEXT:    mv a4, a5
-; CHECK-NEXT:  .LBB38_2:
+; CHECK-NEXT:  .LBB39_2:
 ; CHECK-NEXT:    sub a6, a4, a3
 ; CHECK-NEXT:    sltu a7, a4, a6
 ; CHECK-NEXT:    addi a7, a7, -1
@@ -499,10 +510,10 @@ define <vscale x 16 x double> @vpload_nxv17f64(ptr %ptr, ptr %out, <vscale x 17
 ; CHECK-NEXT:    sltu a2, a2, a5
 ; CHECK-NEXT:    addi a2, a2, -1
 ; CHECK-NEXT:    and a2, a2, a5
-; CHECK-NEXT:    bltu a2, a3, .LBB38_4
+; CHECK-NEXT:    bltu a2, a3, .LBB39_4
 ; CHECK-NEXT:  # %bb.3:
 ; CHECK-NEXT:    mv a2, a3
-; CHECK-NEXT:  .LBB38_4:
+; CHECK-NEXT:  .LBB39_4:
 ; CHECK-NEXT:    slli a5, a3, 4
 ; CHECK-NEXT:    srli a6, a3, 2
 ; CHECK-NEXT:    vsetvli a7, zero, e8, mf2, ta, ma
@@ -510,10 +521,10 @@ define <vscale x 16 x double> @vpload_nxv17f64(ptr %ptr, ptr %out, <vscale x 17
 ; CHECK-NEXT:    add a5, a0, a5
 ; CHECK-NEXT:    vsetvli zero, a2, e64, m8, ta, ma
 ; CHECK-NEXT:    vle64.v v24, (a5), v0.t
-; CHECK-NEXT:    bltu a4, a3, .LBB38_6
+; CHECK-NEXT:    bltu a4, a3, .LBB39_6
 ; CHECK-NEXT:  # %bb.5:
 ; CHECK-NEXT:    mv a4, a3
-; CHECK-NEXT:  .LBB38_6:
+; CHECK-NEXT:  .LBB39_6:
 ; CHECK-NEXT:    vmv1r.v v0, v8
 ; CHECK-NEXT:    vsetvli zero, a4, e64, m8, ta, ma
 ; CHECK-NEXT:    vle64.v v8, (a0), v0.t

Comment on lines +3774 to +3776
SDValue TrueMask =
getMaskSetter(True->getOperand(Info->MaskOpIdx),
True->getOperand(True->getNumOperands() - 1));
Copy link
Contributor Author

@lukel97 lukel97 Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit gnarly but I still hope to eventually move all this into RISCVFoldMasks.cpp #71764. Sorry for dropping the ball on that

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lukel97 lukel97 merged commit 3f83a69 into llvm:main Jul 9, 2024
7 checks passed
aaryanshukla pushed a commit to aaryanshukla/llvm-project that referenced this pull request Jul 14, 2024
…lvm#97989)

We currently only fold a vmerge into a masked true operand if the vmerge
has an all-ones mask, since we end up keeping the mask from the true
operand.

But if the masks are the same then we can still fold, because vmerge and
true have the same passthru. If an element was masked off in the
original vmerge, it will also be masked off in the resulting true, and
will have the same passthru value.

The motivation for this is to lower masked VP loads and stores with
passthrus to masked RVV instructions. Normally you can express a masked
RVV instruction with a mask undisturbed passthru via a combination of a
VP op with an all-ones mask and a vp.merge. But for loads and stores you
need the same mask on the VP op as well as the vp.merge.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants