Skip to content

[RISCV] Fix a bug in partial.reduce lowering for zvqdotq .vx forms #142185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2025

Conversation

preames
Copy link
Collaborator

@preames preames commented May 30, 2025

I'd missed a bitcast in the lowering. Unfortunately, that bitcast happens to be semantically required here as the partial_reduce_* source expects an i8 element type, but the pseudos and patterns expect an i32 element type.

This appears to only influence the .vx matching from the cases I've found so far, and LV does not yet generate anything which will exercise this. The reduce path (instead of the partial.reduce one) used by SLP currently manually constructs the i32 value, and then goes directly to the pseudo's with their i32 arguments, not the partial_reduce nodes.

We're basically loosing the .vx matching on this path until we teach splat matching to be able to manually splat the i8 value into an i32 via LUI/ADDI.

I'd missed a bitcast in the lowering.  Unfortunately, that bitcast
happens to be semantically required here as the partial_reduce_*
source expects an i8 element type, but the pseudos and patterns
expect an i32 element type.

This appears to only influence the .vx matching from the cases I've
found so far, and LV does not yet generate anything which will
exercise this.  The reduce path (instead of the partial.reduce one)
used by SLP currently manually constructs the i32 value, and then goes
directly to the pseudo's with their i32 arguments, not the
partial_reduce nodes.

We're basically loosing the .vx matching on this path until we teach
splat matching to be able to manually splat the i8 value into an i32
via LUI/ADDI.
@llvmbot
Copy link
Member

llvmbot commented May 30, 2025

@llvm/pr-subscribers-backend-risc-v

Author: Philip Reames (preames)

Changes

I'd missed a bitcast in the lowering. Unfortunately, that bitcast happens to be semantically required here as the partial_reduce_* source expects an i8 element type, but the pseudos and patterns expect an i32 element type.

This appears to only influence the .vx matching from the cases I've found so far, and LV does not yet generate anything which will exercise this. The reduce path (instead of the partial.reduce one) used by SLP currently manually constructs the i32 value, and then goes directly to the pseudo's with their i32 arguments, not the partial_reduce nodes.

We're basically loosing the .vx matching on this path until we teach splat matching to be able to manually splat the i8 value into an i32 via LUI/ADDI.


Full diff: https://github.com/llvm/llvm-project/pull/142185.diff

3 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+8-3)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll (+16-12)
  • (modified) llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll (+53)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f2311e94252e9..b7fd0c93fa93f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -8412,13 +8412,18 @@ SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
   assert(ArgVT == B.getSimpleValueType() &&
          ArgVT.getVectorElementType() == MVT::i8);
 
+  // The zvqdotq pseudos are defined with sources and destination both
+  // being i32.  This cast is needed for correctness to avoid incorrect
+  // .vx matching of i8 splats.
+  A = DAG.getBitcast(VT, A);
+  B = DAG.getBitcast(VT, B);
+
   MVT ContainerVT = VT;
   if (VT.isFixedLengthVector()) {
     ContainerVT = getContainerForFixedLengthVector(VT);
     Accum = convertToScalableVector(ContainerVT, Accum, DAG, Subtarget);
-    MVT ArgContainerVT = getContainerForFixedLengthVector(ArgVT);
-    A = convertToScalableVector(ArgContainerVT, A, DAG, Subtarget);
-    B = convertToScalableVector(ArgContainerVT, B, DAG, Subtarget);
+    A = convertToScalableVector(ContainerVT, A, DAG, Subtarget);
+    B = convertToScalableVector(ContainerVT, B, DAG, Subtarget);
   }
 
   bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 5e9bbe6c1ebce..0237faea9efb7 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -598,7 +598,6 @@ entry:
   ret <1 x i32> %res
 }
 
-; FIXME: This case is wrong.  We should be splatting 128 to each i8 lane!
 define <1 x i32> @vqdotu_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
 ; NODOT-LABEL: vqdotu_vx_partial_reduce:
 ; NODOT:       # %bb.0: # %entry
@@ -618,10 +617,13 @@ define <1 x i32> @vqdotu_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
 ;
 ; DOT-LABEL: vqdotu_vx_partial_reduce:
 ; DOT:       # %bb.0: # %entry
-; DOT-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
+; DOT-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
 ; DOT-NEXT:    vmv.s.x v9, zero
 ; DOT-NEXT:    li a0, 128
-; DOT-NEXT:    vqdotu.vx v9, v8, a0
+; DOT-NEXT:    vsetivli zero, 4, e8, mf4, ta, ma
+; DOT-NEXT:    vmv.v.x v10, a0
+; DOT-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
+; DOT-NEXT:    vqdotu.vv v9, v8, v10
 ; DOT-NEXT:    vmv1r.v v8, v9
 ; DOT-NEXT:    ret
 entry:
@@ -631,7 +633,6 @@ entry:
   ret <1 x i32> %res
 }
 
-; FIXME: This case is wrong.  We should be splatting 128 to each i8 lane!
 define <1 x i32> @vqdot_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
 ; NODOT-LABEL: vqdot_vx_partial_reduce:
 ; NODOT:       # %bb.0: # %entry
@@ -652,10 +653,13 @@ define <1 x i32> @vqdot_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
 ;
 ; DOT-LABEL: vqdot_vx_partial_reduce:
 ; DOT:       # %bb.0: # %entry
-; DOT-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
+; DOT-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
 ; DOT-NEXT:    vmv.s.x v9, zero
 ; DOT-NEXT:    li a0, 128
-; DOT-NEXT:    vqdot.vx v9, v8, a0
+; DOT-NEXT:    vsetivli zero, 4, e8, mf4, ta, ma
+; DOT-NEXT:    vmv.v.x v10, a0
+; DOT-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
+; DOT-NEXT:    vqdot.vv v9, v8, v10
 ; DOT-NEXT:    vmv1r.v v8, v9
 ; DOT-NEXT:    ret
 entry:
@@ -1372,7 +1376,6 @@ entry:
 }
 
 
-; FIXME: This case is wrong.  We should be splatting 128 to each i8 lane!
 define <4 x i32> @partial_of_sext(<16 x i8> %a) {
 ; NODOT-LABEL: partial_of_sext:
 ; NODOT:       # %bb.0: # %entry
@@ -1393,10 +1396,11 @@ define <4 x i32> @partial_of_sext(<16 x i8> %a) {
 ;
 ; DOT-LABEL: partial_of_sext:
 ; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 1
 ; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
 ; DOT-NEXT:    vmv.v.i v9, 0
-; DOT-NEXT:    li a0, 1
-; DOT-NEXT:    vqdot.vx v9, v8, a0
+; DOT-NEXT:    vqdot.vv v9, v8, v10
 ; DOT-NEXT:    vmv.v.v v8, v9
 ; DOT-NEXT:    ret
 entry:
@@ -1405,7 +1409,6 @@ entry:
   ret <4 x i32> %res
 }
 
-; FIXME: This case is wrong.  We should be splatting 128 to each i8 lane!
 define <4 x i32> @partial_of_zext(<16 x i8> %a) {
 ; NODOT-LABEL: partial_of_zext:
 ; NODOT:       # %bb.0: # %entry
@@ -1426,10 +1429,11 @@ define <4 x i32> @partial_of_zext(<16 x i8> %a) {
 ;
 ; DOT-LABEL: partial_of_zext:
 ; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 1
 ; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
 ; DOT-NEXT:    vmv.v.i v9, 0
-; DOT-NEXT:    li a0, 1
-; DOT-NEXT:    vqdotu.vx v9, v8, a0
+; DOT-NEXT:    vqdotu.vv v9, v8, v10
 ; DOT-NEXT:    vmv.v.v v8, v9
 ; DOT-NEXT:    ret
 entry:
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
index 2bd2ef2878fd5..d0fc915a0d07e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -957,3 +957,56 @@ entry:
   %res = call <vscale x 1 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 1 x i32> zeroinitializer, <vscale x 4 x i32> %mul)
   ret <vscale x 1 x i32> %res
 }
+
+
+define <vscale x 4 x i32> @partial_of_sext(<vscale x 16 x i8> %a) {
+; NODOT-LABEL: partial_of_sext:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e32, m8, ta, ma
+; NODOT-NEXT:    vsext.vf4 v16, v8
+; NODOT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; NODOT-NEXT:    vadd.vv v8, v22, v16
+; NODOT-NEXT:    vadd.vv v10, v18, v20
+; NODOT-NEXT:    vadd.vv v8, v10, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: partial_of_sext:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v12, 1
+; DOT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdot.vv v10, v8, v12
+; DOT-NEXT:    vmv.v.v v8, v10
+; DOT-NEXT:    ret
+entry:
+  %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %res = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %a.ext)
+  ret <vscale x 4 x i32> %res
+}
+
+define <vscale x 4 x i32> @partial_of_zext(<vscale x 16 x i8> %a) {
+; NODOT-LABEL: partial_of_zext:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e32, m8, ta, ma
+; NODOT-NEXT:    vzext.vf4 v16, v8
+; NODOT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; NODOT-NEXT:    vadd.vv v8, v22, v16
+; NODOT-NEXT:    vadd.vv v10, v18, v20
+; NODOT-NEXT:    vadd.vv v8, v10, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: partial_of_zext:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v12, 1
+; DOT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdotu.vv v10, v8, v12
+; DOT-NEXT:    vmv.v.v v8, v10
+; DOT-NEXT:    ret
+entry:
+  %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %res = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %a.ext)
+  ret <vscale x 4 x i32> %res
+}

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@preames preames merged commit 443cdd0 into llvm:main May 30, 2025
10 of 12 checks passed
@preames preames deleted the pr-zvqdotq-vx-partial-reduce-fix branch May 30, 2025 18:05
topperc added a commit to topperc/llvm-project that referenced this pull request May 30, 2025
…tNode.

After seeing the bug that llvm#142185 fixed, I thought it might be a
good idea to start verifying that nodes are formed correctly.

This patch introduces the verifyTargetNode function and adds these
opcodes. More opcodes can be added in later patches.
@kazutakahirata
Copy link
Contributor

@preames @topperc I've landed 4a7b53f to fix a warning from this PR.

Now, your PR removed the last use of ArgVT aside from the one in assert. Should we remove the assert? Or should we keep it? Thanks!

  MVT ArgVT = A.getSimpleValueType();
  assert(ArgVT == B.getSimpleValueType() &&
         ArgVT.getVectorElementType() == MVT::i8);

topperc added a commit that referenced this pull request May 31, 2025
…tNode. (#142202)

After seeing the bug that #142185 fixed, I thought it might be a good
idea to start verifying that nodes are formed correctly.

This patch introduces the verifyTargetNode function and adds these
opcodes. More opcodes can be added in later patches.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants