Skip to content

[RISCV] Select zero splats of EEW=64 on RV32 #107205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Sep 4, 2024

A EEW=64 splat of zero on RV32 will have a bitcast in the middle, so it isn't picked up by findVSplat. But we can add a special case for it which allows some .vx and .vi patterns to be picked up.

A EEW=64 splat of zero on RV32 will have a bitcast in the middle, so it isn't picked up by findVSplat. But we can add a special case for it which allows some .vx and .vi patterns to be picked up.
@llvmbot
Copy link
Member

llvmbot commented Sep 4, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

Changes

A EEW=64 splat of zero on RV32 will have a bitcast in the middle, so it isn't picked up by findVSplat. But we can add a special case for it which allows some .vx and .vi patterns to be picked up.


Full diff: https://github.com/llvm/llvm-project/pull/107205.diff

3 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp (+27)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-load-int.ll (+19-37)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-int.ll (+41-81)
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 78006885421603..125edd24be5cab 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -3452,10 +3452,37 @@ bool RISCVDAGToDAGISel::selectVSplat(SDValue N, SDValue &SplatVal) {
   return true;
 }
 
+// Look for splats of zero. On RV32 a EEW=64 there may be a bitcast in between.
+//
+//   t72: nxv16i32 = RISCVISD::VMV_V_X_VL ...
+//   t73: v32i32 = extract_subvector t72, Constant:i32<0>
+//   t21: v16i64 = bitcast t73
+//   t42: nxv8i64 = insert_subvector undef:nxv8i64, t21, Constant:i32<0>
+static bool isZeroSplat(SDValue N) {
+  if (N.getOpcode() == ISD::INSERT_SUBVECTOR && N.getOperand(0).isUndef())
+    N = N.getOperand(1);
+  if (N.getOpcode() == ISD::BITCAST)
+    N = N.getOperand(0);
+  if (N.getOpcode() == ISD::EXTRACT_SUBVECTOR)
+    N = N.getOperand(0);
+
+  return (N.getOpcode() == RISCVISD::VMV_V_X_VL ||
+          N.getOpcode() == RISCVISD::VMV_S_X_VL) &&
+         isa<ConstantSDNode>(N.getOperand(1)) &&
+         N.getConstantOperandVal(1) == 0;
+}
+
 static bool selectVSplatImmHelper(SDValue N, SDValue &SplatVal,
                                   SelectionDAG &DAG,
                                   const RISCVSubtarget &Subtarget,
                                   std::function<bool(int64_t)> ValidateImm) {
+  // Explicitly look for zero splats to handle the EEW=64 on RV32 case.
+  if (isZeroSplat(N) && ValidateImm(0)) {
+    SplatVal =
+        DAG.getConstant(0, SDLoc(N), Subtarget.getXLenVT(), /*isTarget=*/true);
+    return true;
+  }
+
   SDValue Splat = findVSplat(N);
   if (!Splat || !isa<ConstantSDNode>(Splat.getOperand(1)))
     return false;
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-load-int.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-load-int.ll
index ad075e4b4e198c..2f20caa6eb1894 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-load-int.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-load-int.ll
@@ -397,43 +397,22 @@ define void @masked_load_v32i32(ptr %a, ptr %m_ptr, ptr %res_ptr) nounwind {
 declare <32 x i32> @llvm.masked.load.v32i32(ptr, i32, <32 x i1>, <32 x i32>)
 
 define void @masked_load_v32i64(ptr %a, ptr %m_ptr, ptr %res_ptr) nounwind {
-; RV32-LABEL: masked_load_v32i64:
-; RV32:       # %bb.0:
-; RV32-NEXT:    addi a3, a1, 128
-; RV32-NEXT:    vsetivli zero, 16, e64, m8, ta, ma
-; RV32-NEXT:    vle64.v v0, (a1)
-; RV32-NEXT:    vle64.v v24, (a3)
-; RV32-NEXT:    li a1, 32
-; RV32-NEXT:    vsetvli zero, a1, e32, m8, ta, ma
-; RV32-NEXT:    vmv.v.i v16, 0
-; RV32-NEXT:    vsetivli zero, 16, e64, m8, ta, ma
-; RV32-NEXT:    vmseq.vv v8, v0, v16
-; RV32-NEXT:    vmseq.vv v0, v24, v16
-; RV32-NEXT:    addi a1, a0, 128
-; RV32-NEXT:    vle64.v v16, (a1), v0.t
-; RV32-NEXT:    vmv1r.v v0, v8
-; RV32-NEXT:    vle64.v v8, (a0), v0.t
-; RV32-NEXT:    vse64.v v8, (a2)
-; RV32-NEXT:    addi a0, a2, 128
-; RV32-NEXT:    vse64.v v16, (a0)
-; RV32-NEXT:    ret
-;
-; RV64-LABEL: masked_load_v32i64:
-; RV64:       # %bb.0:
-; RV64-NEXT:    addi a3, a1, 128
-; RV64-NEXT:    vsetivli zero, 16, e64, m8, ta, ma
-; RV64-NEXT:    vle64.v v16, (a1)
-; RV64-NEXT:    vle64.v v24, (a3)
-; RV64-NEXT:    vmseq.vi v8, v16, 0
-; RV64-NEXT:    vmseq.vi v0, v24, 0
-; RV64-NEXT:    addi a1, a0, 128
-; RV64-NEXT:    vle64.v v16, (a1), v0.t
-; RV64-NEXT:    vmv1r.v v0, v8
-; RV64-NEXT:    vle64.v v8, (a0), v0.t
-; RV64-NEXT:    vse64.v v8, (a2)
-; RV64-NEXT:    addi a0, a2, 128
-; RV64-NEXT:    vse64.v v16, (a0)
-; RV64-NEXT:    ret
+; CHECK-LABEL: masked_load_v32i64:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a3, a1, 128
+; CHECK-NEXT:    vsetivli zero, 16, e64, m8, ta, ma
+; CHECK-NEXT:    vle64.v v16, (a1)
+; CHECK-NEXT:    vle64.v v24, (a3)
+; CHECK-NEXT:    vmseq.vi v8, v16, 0
+; CHECK-NEXT:    vmseq.vi v0, v24, 0
+; CHECK-NEXT:    addi a1, a0, 128
+; CHECK-NEXT:    vle64.v v16, (a1), v0.t
+; CHECK-NEXT:    vmv1r.v v0, v8
+; CHECK-NEXT:    vle64.v v8, (a0), v0.t
+; CHECK-NEXT:    vse64.v v8, (a2)
+; CHECK-NEXT:    addi a0, a2, 128
+; CHECK-NEXT:    vse64.v v16, (a0)
+; CHECK-NEXT:    ret
   %m = load <32 x i64>, ptr %m_ptr
   %mask = icmp eq <32 x i64> %m, zeroinitializer
   %load = call <32 x i64> @llvm.masked.load.v32i64(ptr %a, i32 8, <32 x i1> %mask, <32 x i64> undef)
@@ -547,3 +526,6 @@ define void @masked_load_v256i8(ptr %a, ptr %m_ptr, ptr %res_ptr) nounwind {
   ret void
 }
 declare <256 x i8> @llvm.masked.load.v256i8(ptr, i32, <256 x i1>, <256 x i8>)
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; RV32: {{.*}}
+; RV64: {{.*}}
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-int.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-int.ll
index 86c28247e97ef1..90690bbc8e2085 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-int.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-int.ll
@@ -397,87 +397,44 @@ define void @masked_store_v32i32(ptr %val_ptr, ptr %a, ptr %m_ptr) nounwind {
 declare void @llvm.masked.store.v32i32.p0(<32 x i32>, ptr, i32, <32 x i1>)
 
 define void @masked_store_v32i64(ptr %val_ptr, ptr %a, ptr %m_ptr) nounwind {
-; RV32-LABEL: masked_store_v32i64:
-; RV32:       # %bb.0:
-; RV32-NEXT:    addi sp, sp, -16
-; RV32-NEXT:    csrr a3, vlenb
-; RV32-NEXT:    slli a3, a3, 4
-; RV32-NEXT:    sub sp, sp, a3
-; RV32-NEXT:    addi a3, a2, 128
-; RV32-NEXT:    vsetivli zero, 16, e64, m8, ta, ma
-; RV32-NEXT:    vle64.v v24, (a2)
-; RV32-NEXT:    vle64.v v8, (a3)
-; RV32-NEXT:    csrr a2, vlenb
-; RV32-NEXT:    slli a2, a2, 3
-; RV32-NEXT:    add a2, sp, a2
-; RV32-NEXT:    addi a2, a2, 16
-; RV32-NEXT:    vs8r.v v8, (a2) # Unknown-size Folded Spill
-; RV32-NEXT:    li a2, 32
-; RV32-NEXT:    vsetvli zero, a2, e32, m8, ta, ma
-; RV32-NEXT:    vmv.v.i v8, 0
-; RV32-NEXT:    vsetivli zero, 16, e64, m8, ta, ma
-; RV32-NEXT:    vmseq.vv v7, v24, v8
-; RV32-NEXT:    addi a2, a0, 128
-; RV32-NEXT:    vle64.v v24, (a2)
-; RV32-NEXT:    vle64.v v16, (a0)
-; RV32-NEXT:    addi a0, sp, 16
-; RV32-NEXT:    vs8r.v v16, (a0) # Unknown-size Folded Spill
-; RV32-NEXT:    csrr a0, vlenb
-; RV32-NEXT:    slli a0, a0, 3
-; RV32-NEXT:    add a0, sp, a0
-; RV32-NEXT:    addi a0, a0, 16
-; RV32-NEXT:    vl8r.v v16, (a0) # Unknown-size Folded Reload
-; RV32-NEXT:    vmseq.vv v0, v16, v8
-; RV32-NEXT:    addi a0, a1, 128
-; RV32-NEXT:    vse64.v v24, (a0), v0.t
-; RV32-NEXT:    vmv1r.v v0, v7
-; RV32-NEXT:    addi a0, sp, 16
-; RV32-NEXT:    vl8r.v v8, (a0) # Unknown-size Folded Reload
-; RV32-NEXT:    vse64.v v8, (a1), v0.t
-; RV32-NEXT:    csrr a0, vlenb
-; RV32-NEXT:    slli a0, a0, 4
-; RV32-NEXT:    add sp, sp, a0
-; RV32-NEXT:    addi sp, sp, 16
-; RV32-NEXT:    ret
-;
-; RV64-LABEL: masked_store_v32i64:
-; RV64:       # %bb.0:
-; RV64-NEXT:    addi sp, sp, -16
-; RV64-NEXT:    csrr a3, vlenb
-; RV64-NEXT:    slli a3, a3, 4
-; RV64-NEXT:    sub sp, sp, a3
-; RV64-NEXT:    vsetivli zero, 16, e64, m8, ta, ma
-; RV64-NEXT:    vle64.v v8, (a2)
-; RV64-NEXT:    addi a2, a2, 128
-; RV64-NEXT:    vle64.v v16, (a2)
-; RV64-NEXT:    csrr a2, vlenb
-; RV64-NEXT:    slli a2, a2, 3
-; RV64-NEXT:    add a2, sp, a2
-; RV64-NEXT:    addi a2, a2, 16
-; RV64-NEXT:    vs8r.v v16, (a2) # Unknown-size Folded Spill
-; RV64-NEXT:    vmseq.vi v0, v8, 0
-; RV64-NEXT:    vle64.v v24, (a0)
-; RV64-NEXT:    addi a0, a0, 128
-; RV64-NEXT:    vle64.v v8, (a0)
-; RV64-NEXT:    addi a0, sp, 16
-; RV64-NEXT:    vs8r.v v8, (a0) # Unknown-size Folded Spill
-; RV64-NEXT:    csrr a0, vlenb
-; RV64-NEXT:    slli a0, a0, 3
-; RV64-NEXT:    add a0, sp, a0
-; RV64-NEXT:    addi a0, a0, 16
-; RV64-NEXT:    vl8r.v v16, (a0) # Unknown-size Folded Reload
-; RV64-NEXT:    vmseq.vi v8, v16, 0
-; RV64-NEXT:    vse64.v v24, (a1), v0.t
-; RV64-NEXT:    addi a0, a1, 128
-; RV64-NEXT:    vmv1r.v v0, v8
-; RV64-NEXT:    addi a1, sp, 16
-; RV64-NEXT:    vl8r.v v8, (a1) # Unknown-size Folded Reload
-; RV64-NEXT:    vse64.v v8, (a0), v0.t
-; RV64-NEXT:    csrr a0, vlenb
-; RV64-NEXT:    slli a0, a0, 4
-; RV64-NEXT:    add sp, sp, a0
-; RV64-NEXT:    addi sp, sp, 16
-; RV64-NEXT:    ret
+; CHECK-LABEL: masked_store_v32i64:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi sp, sp, -16
+; CHECK-NEXT:    csrr a3, vlenb
+; CHECK-NEXT:    slli a3, a3, 4
+; CHECK-NEXT:    sub sp, sp, a3
+; CHECK-NEXT:    vsetivli zero, 16, e64, m8, ta, ma
+; CHECK-NEXT:    vle64.v v8, (a2)
+; CHECK-NEXT:    addi a2, a2, 128
+; CHECK-NEXT:    vle64.v v16, (a2)
+; CHECK-NEXT:    csrr a2, vlenb
+; CHECK-NEXT:    slli a2, a2, 3
+; CHECK-NEXT:    add a2, sp, a2
+; CHECK-NEXT:    addi a2, a2, 16
+; CHECK-NEXT:    vs8r.v v16, (a2) # Unknown-size Folded Spill
+; CHECK-NEXT:    vmseq.vi v0, v8, 0
+; CHECK-NEXT:    vle64.v v24, (a0)
+; CHECK-NEXT:    addi a0, a0, 128
+; CHECK-NEXT:    vle64.v v8, (a0)
+; CHECK-NEXT:    addi a0, sp, 16
+; CHECK-NEXT:    vs8r.v v8, (a0) # Unknown-size Folded Spill
+; CHECK-NEXT:    csrr a0, vlenb
+; CHECK-NEXT:    slli a0, a0, 3
+; CHECK-NEXT:    add a0, sp, a0
+; CHECK-NEXT:    addi a0, a0, 16
+; CHECK-NEXT:    vl8r.v v16, (a0) # Unknown-size Folded Reload
+; CHECK-NEXT:    vmseq.vi v8, v16, 0
+; CHECK-NEXT:    vse64.v v24, (a1), v0.t
+; CHECK-NEXT:    addi a0, a1, 128
+; CHECK-NEXT:    vmv1r.v v0, v8
+; CHECK-NEXT:    addi a1, sp, 16
+; CHECK-NEXT:    vl8r.v v8, (a1) # Unknown-size Folded Reload
+; CHECK-NEXT:    vse64.v v8, (a0), v0.t
+; CHECK-NEXT:    csrr a0, vlenb
+; CHECK-NEXT:    slli a0, a0, 4
+; CHECK-NEXT:    add sp, sp, a0
+; CHECK-NEXT:    addi sp, sp, 16
+; CHECK-NEXT:    ret
   %m = load <32 x i64>, ptr %m_ptr
   %mask = icmp eq <32 x i64> %m, zeroinitializer
   %val = load <32 x i64>, ptr %val_ptr
@@ -683,3 +640,6 @@ define void @masked_store_v256i8(ptr %val_ptr, ptr %a, ptr %m_ptr) nounwind {
   ret void
 }
 declare void @llvm.masked.store.v256i8.p0(<256 x i8>, ptr, i32, <256 x i1>)
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; RV32: {{.*}}
+; RV64: {{.*}}

// t73: v32i32 = extract_subvector t72, Constant:i32<0>
// t21: v16i64 = bitcast t73
// t42: nxv8i64 = insert_subvector undef:nxv8i64, t21, Constant:i32<0>
static bool isZeroSplat(SDValue N) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we do this during lowering?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to have a nxv8i64 VMV_V_X_VL on RV32? I presume the scalar operand needs to be at least as wide as the vector element type if it's like the BUILD_VECTOR/SPLAT_VECTOR nodes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, make sense!
I just worry about the test coverage of this function, do we have cases that for example bitcast doesn't exist?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possible to have a nxv8i64 VMV_V_X_VL on RV32. We model the sign extending behavior.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We make SPLAT_VECTOR of i64 legal on RV32 specifically so DAGCombine will turn BUILD_VECTOR into SPLAT_VECTOR.

        // Make SPLAT_VECTOR Legal so DAGCombine will convert splat vectors to   
        // it before type legalization for i64 vectors on RV32. It will then be  
        // type legalized to SPLAT_VECTOR_PARTS which we need to Custom handle.  
        // FIXME: Use SPLAT_VECTOR for all types? DAGCombine probably needs      
        // improvements first.                                                   
        if (!Subtarget.is64Bit() && VT.getVectorElementType() == MVT::i64) {     
          setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);                      
          setOperationAction(ISD::SPLAT_VECTOR_PARTS, VT, Custom);               
        }  

In the affected tests, the type is larger than a legal type which I guess prevents the DAGCombine from triggering before type legalization. Then type legalization turns it into multiple i32 build_vectors. Maybe we should improve the combine or add a RISC-V combine?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proposed patch #107290

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems like a better approach

static bool selectVSplatImmHelper(SDValue N, SDValue &SplatVal,
SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
std::function<bool(int64_t)> ValidateImm) {
// Explicitly look for zero splats to handle the EEW=64 on RV32 case.
if (isZeroSplat(N) && ValidateImm(0)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to validate imm 0 here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants