-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[RISCV] Fix a bug in partial.reduce lowering for zvqdotq .vx forms #142185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I'd missed a bitcast in the lowering. Unfortunately, that bitcast happens to be semantically required here as the partial_reduce_* source expects an i8 element type, but the pseudos and patterns expect an i32 element type. This appears to only influence the .vx matching from the cases I've found so far, and LV does not yet generate anything which will exercise this. The reduce path (instead of the partial.reduce one) used by SLP currently manually constructs the i32 value, and then goes directly to the pseudo's with their i32 arguments, not the partial_reduce nodes. We're basically loosing the .vx matching on this path until we teach splat matching to be able to manually splat the i8 value into an i32 via LUI/ADDI.
@llvm/pr-subscribers-backend-risc-v Author: Philip Reames (preames) ChangesI'd missed a bitcast in the lowering. Unfortunately, that bitcast happens to be semantically required here as the partial_reduce_* source expects an i8 element type, but the pseudos and patterns expect an i32 element type. This appears to only influence the .vx matching from the cases I've found so far, and LV does not yet generate anything which will exercise this. The reduce path (instead of the partial.reduce one) used by SLP currently manually constructs the i32 value, and then goes directly to the pseudo's with their i32 arguments, not the partial_reduce nodes. We're basically loosing the .vx matching on this path until we teach splat matching to be able to manually splat the i8 value into an i32 via LUI/ADDI. Full diff: https://github.com/llvm/llvm-project/pull/142185.diff 3 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f2311e94252e9..b7fd0c93fa93f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -8412,13 +8412,18 @@ SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
assert(ArgVT == B.getSimpleValueType() &&
ArgVT.getVectorElementType() == MVT::i8);
+ // The zvqdotq pseudos are defined with sources and destination both
+ // being i32. This cast is needed for correctness to avoid incorrect
+ // .vx matching of i8 splats.
+ A = DAG.getBitcast(VT, A);
+ B = DAG.getBitcast(VT, B);
+
MVT ContainerVT = VT;
if (VT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(VT);
Accum = convertToScalableVector(ContainerVT, Accum, DAG, Subtarget);
- MVT ArgContainerVT = getContainerForFixedLengthVector(ArgVT);
- A = convertToScalableVector(ArgContainerVT, A, DAG, Subtarget);
- B = convertToScalableVector(ArgContainerVT, B, DAG, Subtarget);
+ A = convertToScalableVector(ContainerVT, A, DAG, Subtarget);
+ B = convertToScalableVector(ContainerVT, B, DAG, Subtarget);
}
bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 5e9bbe6c1ebce..0237faea9efb7 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -598,7 +598,6 @@ entry:
ret <1 x i32> %res
}
-; FIXME: This case is wrong. We should be splatting 128 to each i8 lane!
define <1 x i32> @vqdotu_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
; NODOT-LABEL: vqdotu_vx_partial_reduce:
; NODOT: # %bb.0: # %entry
@@ -618,10 +617,13 @@ define <1 x i32> @vqdotu_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
;
; DOT-LABEL: vqdotu_vx_partial_reduce:
; DOT: # %bb.0: # %entry
-; DOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
+; DOT-NEXT: vsetivli zero, 1, e32, m1, ta, ma
; DOT-NEXT: vmv.s.x v9, zero
; DOT-NEXT: li a0, 128
-; DOT-NEXT: vqdotu.vx v9, v8, a0
+; DOT-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
+; DOT-NEXT: vmv.v.x v10, a0
+; DOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
+; DOT-NEXT: vqdotu.vv v9, v8, v10
; DOT-NEXT: vmv1r.v v8, v9
; DOT-NEXT: ret
entry:
@@ -631,7 +633,6 @@ entry:
ret <1 x i32> %res
}
-; FIXME: This case is wrong. We should be splatting 128 to each i8 lane!
define <1 x i32> @vqdot_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
; NODOT-LABEL: vqdot_vx_partial_reduce:
; NODOT: # %bb.0: # %entry
@@ -652,10 +653,13 @@ define <1 x i32> @vqdot_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
;
; DOT-LABEL: vqdot_vx_partial_reduce:
; DOT: # %bb.0: # %entry
-; DOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
+; DOT-NEXT: vsetivli zero, 1, e32, m1, ta, ma
; DOT-NEXT: vmv.s.x v9, zero
; DOT-NEXT: li a0, 128
-; DOT-NEXT: vqdot.vx v9, v8, a0
+; DOT-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
+; DOT-NEXT: vmv.v.x v10, a0
+; DOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
+; DOT-NEXT: vqdot.vv v9, v8, v10
; DOT-NEXT: vmv1r.v v8, v9
; DOT-NEXT: ret
entry:
@@ -1372,7 +1376,6 @@ entry:
}
-; FIXME: This case is wrong. We should be splatting 128 to each i8 lane!
define <4 x i32> @partial_of_sext(<16 x i8> %a) {
; NODOT-LABEL: partial_of_sext:
; NODOT: # %bb.0: # %entry
@@ -1393,10 +1396,11 @@ define <4 x i32> @partial_of_sext(<16 x i8> %a) {
;
; DOT-LABEL: partial_of_sext:
; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 1
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
; DOT-NEXT: vmv.v.i v9, 0
-; DOT-NEXT: li a0, 1
-; DOT-NEXT: vqdot.vx v9, v8, a0
+; DOT-NEXT: vqdot.vv v9, v8, v10
; DOT-NEXT: vmv.v.v v8, v9
; DOT-NEXT: ret
entry:
@@ -1405,7 +1409,6 @@ entry:
ret <4 x i32> %res
}
-; FIXME: This case is wrong. We should be splatting 128 to each i8 lane!
define <4 x i32> @partial_of_zext(<16 x i8> %a) {
; NODOT-LABEL: partial_of_zext:
; NODOT: # %bb.0: # %entry
@@ -1426,10 +1429,11 @@ define <4 x i32> @partial_of_zext(<16 x i8> %a) {
;
; DOT-LABEL: partial_of_zext:
; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 1
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
; DOT-NEXT: vmv.v.i v9, 0
-; DOT-NEXT: li a0, 1
-; DOT-NEXT: vqdotu.vx v9, v8, a0
+; DOT-NEXT: vqdotu.vv v9, v8, v10
; DOT-NEXT: vmv.v.v v8, v9
; DOT-NEXT: ret
entry:
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
index 2bd2ef2878fd5..d0fc915a0d07e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -957,3 +957,56 @@ entry:
%res = call <vscale x 1 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 1 x i32> zeroinitializer, <vscale x 4 x i32> %mul)
ret <vscale x 1 x i32> %res
}
+
+
+define <vscale x 4 x i32> @partial_of_sext(<vscale x 16 x i8> %a) {
+; NODOT-LABEL: partial_of_sext:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
+; NODOT-NEXT: vsext.vf4 v16, v8
+; NODOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; NODOT-NEXT: vadd.vv v8, v22, v16
+; NODOT-NEXT: vadd.vv v10, v18, v20
+; NODOT-NEXT: vadd.vv v8, v10, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: partial_of_sext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 1
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdot.vv v10, v8, v12
+; DOT-NEXT: vmv.v.v v8, v10
+; DOT-NEXT: ret
+entry:
+ %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %res = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %a.ext)
+ ret <vscale x 4 x i32> %res
+}
+
+define <vscale x 4 x i32> @partial_of_zext(<vscale x 16 x i8> %a) {
+; NODOT-LABEL: partial_of_zext:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
+; NODOT-NEXT: vzext.vf4 v16, v8
+; NODOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; NODOT-NEXT: vadd.vv v8, v22, v16
+; NODOT-NEXT: vadd.vv v10, v18, v20
+; NODOT-NEXT: vadd.vv v8, v10, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: partial_of_zext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 1
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotu.vv v10, v8, v12
+; DOT-NEXT: vmv.v.v v8, v10
+; DOT-NEXT: ret
+entry:
+ %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %res = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %a.ext)
+ ret <vscale x 4 x i32> %res
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…tNode. After seeing the bug that llvm#142185 fixed, I thought it might be a good idea to start verifying that nodes are formed correctly. This patch introduces the verifyTargetNode function and adds these opcodes. More opcodes can be added in later patches.
@preames @topperc I've landed 4a7b53f to fix a warning from this PR. Now, your PR removed the last use of
|
I'd missed a bitcast in the lowering. Unfortunately, that bitcast happens to be semantically required here as the partial_reduce_* source expects an i8 element type, but the pseudos and patterns expect an i32 element type.
This appears to only influence the .vx matching from the cases I've found so far, and LV does not yet generate anything which will exercise this. The reduce path (instead of the partial.reduce one) used by SLP currently manually constructs the i32 value, and then goes directly to the pseudo's with their i32 arguments, not the partial_reduce nodes.
We're basically loosing the .vx matching on this path until we teach splat matching to be able to manually splat the i8 value into an i32 via LUI/ADDI.