-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64][NEON][SVE] Lower i8 to i64 partial reduction to a dot product #110220
Conversation
An i8 to i64 partial reduction can instead be done with an i8 to i32 dot product followed by a sign extension.
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-backend-aarch64 Author: James Chesterman (JamesChesterman) ChangesAn i8 to i64 partial reduction can instead be done with an i8 to i32 dot product followed by a sign extension. Full diff: https://github.com/llvm/llvm-project/pull/110220.diff 3 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4166d9bd22bc01..af66b6b0e43b81 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1996,8 +1996,8 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
return true;
EVT VT = EVT::getEVT(I->getType());
- return VT != MVT::nxv4i32 && VT != MVT::nxv2i64 && VT != MVT::v4i32 &&
- VT != MVT::v2i32;
+ return VT != MVT::nxv4i64 && VT != MVT::nxv4i32 && VT != MVT::nxv2i64 &&
+ VT != MVT::v4i64 && VT != MVT::v4i32 && VT != MVT::v2i32;
}
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21916,8 +21916,10 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
- if (!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
+ if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
+ !(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
!(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
+ !(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
!(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
!(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
return SDValue();
@@ -21930,7 +21932,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
bool Scalable = N->getValueType(0).isScalableVT();
// There's no nxv2i64 version of usdot
- if (Scalable && ReducedType != MVT::nxv4i32)
+ if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
return SDValue();
Opcode = AArch64ISD::USDOT;
@@ -21942,6 +21944,20 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
else
Opcode = AArch64ISD::UDOT;
+ // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
+ // product followed by a zero / sign extension
+ if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
+ (ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
+ EVT ReducedTypeHalved = (ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+
+ auto Doti32 =
+ DAG.getNode(Opcode, DL, ReducedTypeHalved,
+ DAG.getConstant(0, DL, ReducedTypeHalved), A, B);
+ auto Extended = DAG.getSExtOrTrunc(Doti32, DL, ReducedType);
+ return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(),
+ {NarrowOp, Extended});
+ }
+
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
}
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 841da1f8ea57c1..c1b9a4c9dbb797 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -211,6 +211,162 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
ret <2 x i32> %partial.reduce
}
+define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-DOT-LABEL: udot_8to64:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
+; CHECK-DOT-NEXT: udot v4.4s, v2.16b, v3.16b
+; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: udot_8to64:
+; CHECK-NODOT: // %bb.0: // %entry
+; CHECK-NODOT-NEXT: umull v4.8h, v2.8b, v3.8b
+; CHECK-NODOT-NEXT: umull2 v2.8h, v2.16b, v3.16b
+; CHECK-NODOT-NEXT: ushll v3.4s, v4.4h, #0
+; CHECK-NODOT-NEXT: ushll v5.4s, v2.4h, #0
+; CHECK-NODOT-NEXT: ushll2 v4.4s, v4.8h, #0
+; CHECK-NODOT-NEXT: ushll2 v2.4s, v2.8h, #0
+; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v3.4s
+; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v3.2s
+; CHECK-NODOT-NEXT: uaddl2 v3.2d, v4.4s, v5.4s
+; CHECK-NODOT-NEXT: uaddl v4.2d, v4.2s, v5.2s
+; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v2.2s
+; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
+; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
+; CHECK-NODOT-NEXT: ret
+entry:
+ %a.wide = zext <16 x i8> %a to <16 x i64>
+ %b.wide = zext <16 x i8> %b to <16 x i64>
+ %mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
+ <4 x i64> %acc, <16 x i64> %mult)
+ ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
+; CHECK-DOT-LABEL: sdot_8to64:
+; CHECK-DOT: // %bb.0: // %entry
+; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
+; CHECK-DOT-NEXT: sdot v4.4s, v2.16b, v3.16b
+; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: sdot_8to64:
+; CHECK-NODOT: // %bb.0: // %entry
+; CHECK-NODOT-NEXT: smull v4.8h, v2.8b, v3.8b
+; CHECK-NODOT-NEXT: smull2 v2.8h, v2.16b, v3.16b
+; CHECK-NODOT-NEXT: sshll v3.4s, v4.4h, #0
+; CHECK-NODOT-NEXT: sshll v5.4s, v2.4h, #0
+; CHECK-NODOT-NEXT: sshll2 v4.4s, v4.8h, #0
+; CHECK-NODOT-NEXT: sshll2 v2.4s, v2.8h, #0
+; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v3.4s
+; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v3.2s
+; CHECK-NODOT-NEXT: saddl2 v3.2d, v4.4s, v5.4s
+; CHECK-NODOT-NEXT: saddl v4.2d, v4.2s, v5.2s
+; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v2.2s
+; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
+; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
+; CHECK-NODOT-NEXT: ret
+entry:
+ %a.wide = sext <16 x i8> %a to <16 x i64>
+ %b.wide = sext <16 x i8> %b to <16 x i64>
+ %mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
+ <4 x i64> %acc, <16 x i64> %mult)
+ ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @usdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
+; CHECK-NOI8MM-LABEL: usdot_8to64:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: ushll v4.8h, v2.8b, #0
+; CHECK-NOI8MM-NEXT: sshll v5.8h, v3.8b, #0
+; CHECK-NOI8MM-NEXT: ushll2 v2.8h, v2.16b, #0
+; CHECK-NOI8MM-NEXT: sshll2 v3.8h, v3.16b, #0
+; CHECK-NOI8MM-NEXT: ushll v6.4s, v4.4h, #0
+; CHECK-NOI8MM-NEXT: sshll v7.4s, v5.4h, #0
+; CHECK-NOI8MM-NEXT: ushll2 v4.4s, v4.8h, #0
+; CHECK-NOI8MM-NEXT: sshll2 v5.4s, v5.8h, #0
+; CHECK-NOI8MM-NEXT: ushll2 v16.4s, v2.8h, #0
+; CHECK-NOI8MM-NEXT: sshll2 v17.4s, v3.8h, #0
+; CHECK-NOI8MM-NEXT: ushll v2.4s, v2.4h, #0
+; CHECK-NOI8MM-NEXT: sshll v3.4s, v3.4h, #0
+; CHECK-NOI8MM-NEXT: smlal2 v1.2d, v6.4s, v7.4s
+; CHECK-NOI8MM-NEXT: smlal v0.2d, v6.2s, v7.2s
+; CHECK-NOI8MM-NEXT: smull v18.2d, v4.2s, v5.2s
+; CHECK-NOI8MM-NEXT: smull2 v4.2d, v4.4s, v5.4s
+; CHECK-NOI8MM-NEXT: smlal2 v1.2d, v16.4s, v17.4s
+; CHECK-NOI8MM-NEXT: smlal v0.2d, v16.2s, v17.2s
+; CHECK-NOI8MM-NEXT: smlal2 v4.2d, v2.4s, v3.4s
+; CHECK-NOI8MM-NEXT: smlal v18.2d, v2.2s, v3.2s
+; CHECK-NOI8MM-NEXT: add v1.2d, v4.2d, v1.2d
+; CHECK-NOI8MM-NEXT: add v0.2d, v18.2d, v0.2d
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-I8MM-LABEL: usdot_8to64:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: movi v4.2d, #0000000000000000
+; CHECK-I8MM-NEXT: usdot v4.4s, v2.16b, v3.16b
+; CHECK-I8MM-NEXT: saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-I8MM-NEXT: ret
+entry:
+ %a.wide = zext <16 x i8> %a to <16 x i64>
+ %b.wide = sext <16 x i8> %b to <16 x i64>
+ %mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
+ <4 x i64> %acc, <16 x i64> %mult)
+ ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @sudot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-NOI8MM-LABEL: sudot_8to64:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: sshll v4.8h, v2.8b, #0
+; CHECK-NOI8MM-NEXT: ushll v5.8h, v3.8b, #0
+; CHECK-NOI8MM-NEXT: sshll2 v2.8h, v2.16b, #0
+; CHECK-NOI8MM-NEXT: ushll2 v3.8h, v3.16b, #0
+; CHECK-NOI8MM-NEXT: sshll v6.4s, v4.4h, #0
+; CHECK-NOI8MM-NEXT: ushll v7.4s, v5.4h, #0
+; CHECK-NOI8MM-NEXT: sshll2 v4.4s, v4.8h, #0
+; CHECK-NOI8MM-NEXT: ushll2 v5.4s, v5.8h, #0
+; CHECK-NOI8MM-NEXT: sshll2 v16.4s, v2.8h, #0
+; CHECK-NOI8MM-NEXT: ushll2 v17.4s, v3.8h, #0
+; CHECK-NOI8MM-NEXT: sshll v2.4s, v2.4h, #0
+; CHECK-NOI8MM-NEXT: ushll v3.4s, v3.4h, #0
+; CHECK-NOI8MM-NEXT: smlal2 v1.2d, v6.4s, v7.4s
+; CHECK-NOI8MM-NEXT: smlal v0.2d, v6.2s, v7.2s
+; CHECK-NOI8MM-NEXT: smull v18.2d, v4.2s, v5.2s
+; CHECK-NOI8MM-NEXT: smull2 v4.2d, v4.4s, v5.4s
+; CHECK-NOI8MM-NEXT: smlal2 v1.2d, v16.4s, v17.4s
+; CHECK-NOI8MM-NEXT: smlal v0.2d, v16.2s, v17.2s
+; CHECK-NOI8MM-NEXT: smlal2 v4.2d, v2.4s, v3.4s
+; CHECK-NOI8MM-NEXT: smlal v18.2d, v2.2s, v3.2s
+; CHECK-NOI8MM-NEXT: add v1.2d, v4.2d, v1.2d
+; CHECK-NOI8MM-NEXT: add v0.2d, v18.2d, v0.2d
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-I8MM-LABEL: sudot_8to64:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: movi v4.2d, #0000000000000000
+; CHECK-I8MM-NEXT: usdot v4.4s, v3.16b, v2.16b
+; CHECK-I8MM-NEXT: saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-I8MM-NEXT: ret
+entry:
+ %a.wide = sext <16 x i8> %a to <16 x i64>
+ %b.wide = zext <16 x i8> %b to <16 x i64>
+ %mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
+ <4 x i64> %acc, <16 x i64> %mult)
+ ret <4 x i64> %partial.reduce
+}
+
define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-LABEL: not_udot:
; CHECK: // %bb.0:
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 00e5ac479d02c9..66d6e0388bbf94 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -126,6 +126,196 @@ entry:
ret <vscale x 4 x i32> %partial.reduce
}
+define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: udot_8to64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEXT: udot z4.s, z2.b, z3.b
+; CHECK-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+ %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(
+ <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b){
+; CHECK-LABEL: sdot_8to64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEXT: sdot z4.s, z2.b, z3.b
+; CHECK-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+ %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(
+ <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @usdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b){
+; CHECK-I8MM-LABEL: usdot_8to64:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: mov z4.s, #0 // =0x0
+; CHECK-I8MM-NEXT: usdot z4.s, z2.b, z3.b
+; CHECK-I8MM-NEXT: sunpklo z2.d, z4.s
+; CHECK-I8MM-NEXT: sunpkhi z3.d, z4.s
+; CHECK-I8MM-NEXT: add z0.d, z0.d, z2.d
+; CHECK-I8MM-NEXT: add z1.d, z1.d, z3.d
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: usdot_8to64:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NOI8MM-NEXT: addvl sp, sp, #-2
+; CHECK-NOI8MM-NEXT: str z9, [sp] // 16-byte Folded Spill
+; CHECK-NOI8MM-NEXT: str z8, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NOI8MM-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
+; CHECK-NOI8MM-NEXT: .cfi_offset w29, -16
+; CHECK-NOI8MM-NEXT: .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NOI8MM-NEXT: .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
+; CHECK-NOI8MM-NEXT: uunpklo z4.h, z2.b
+; CHECK-NOI8MM-NEXT: sunpklo z5.h, z3.b
+; CHECK-NOI8MM-NEXT: uunpkhi z2.h, z2.b
+; CHECK-NOI8MM-NEXT: sunpkhi z3.h, z3.b
+; CHECK-NOI8MM-NEXT: ptrue p0.d
+; CHECK-NOI8MM-NEXT: uunpklo z6.s, z4.h
+; CHECK-NOI8MM-NEXT: uunpkhi z4.s, z4.h
+; CHECK-NOI8MM-NEXT: sunpklo z7.s, z5.h
+; CHECK-NOI8MM-NEXT: sunpkhi z5.s, z5.h
+; CHECK-NOI8MM-NEXT: uunpklo z24.s, z2.h
+; CHECK-NOI8MM-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NOI8MM-NEXT: sunpklo z25.s, z3.h
+; CHECK-NOI8MM-NEXT: sunpkhi z3.s, z3.h
+; CHECK-NOI8MM-NEXT: uunpkhi z26.d, z6.s
+; CHECK-NOI8MM-NEXT: uunpklo z6.d, z6.s
+; CHECK-NOI8MM-NEXT: uunpklo z27.d, z4.s
+; CHECK-NOI8MM-NEXT: sunpklo z28.d, z7.s
+; CHECK-NOI8MM-NEXT: sunpklo z29.d, z5.s
+; CHECK-NOI8MM-NEXT: uunpkhi z4.d, z4.s
+; CHECK-NOI8MM-NEXT: sunpkhi z7.d, z7.s
+; CHECK-NOI8MM-NEXT: sunpkhi z5.d, z5.s
+; CHECK-NOI8MM-NEXT: uunpkhi z30.d, z24.s
+; CHECK-NOI8MM-NEXT: uunpkhi z31.d, z2.s
+; CHECK-NOI8MM-NEXT: uunpklo z24.d, z24.s
+; CHECK-NOI8MM-NEXT: uunpklo z2.d, z2.s
+; CHECK-NOI8MM-NEXT: sunpkhi z8.d, z25.s
+; CHECK-NOI8MM-NEXT: sunpklo z25.d, z25.s
+; CHECK-NOI8MM-NEXT: sunpklo z9.d, z3.s
+; CHECK-NOI8MM-NEXT: mul z27.d, z27.d, z29.d
+; CHECK-NOI8MM-NEXT: mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NOI8MM-NEXT: sunpkhi z3.d, z3.s
+; CHECK-NOI8MM-NEXT: mul z4.d, z4.d, z5.d
+; CHECK-NOI8MM-NEXT: mla z1.d, p0/m, z26.d, z7.d
+; CHECK-NOI8MM-NEXT: mla z0.d, p0/m, z2.d, z9.d
+; CHECK-NOI8MM-NEXT: movprfx z2, z27
+; CHECK-NOI8MM-NEXT: mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NOI8MM-NEXT: ldr z9, [sp] // 16-byte Folded Reload
+; CHECK-NOI8MM-NEXT: mla z1.d, p0/m, z31.d, z3.d
+; CHECK-NOI8MM-NEXT: movprfx z3, z4
+; CHECK-NOI8MM-NEXT: mla z3.d, p0/m, z30.d, z8.d
+; CHECK-NOI8MM-NEXT: ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NOI8MM-NEXT: add z0.d, z2.d, z0.d
+; CHECK-NOI8MM-NEXT: add z1.d, z3.d, z1.d
+; CHECK-NOI8MM-NEXT: addvl sp, sp, #2
+; CHECK-NOI8MM-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NOI8MM-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+ %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(
+ <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-I8MM-LABEL: sudot_8to64:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: mov z4.s, #0 // =0x0
+; CHECK-I8MM-NEXT: usdot z4.s, z3.b, z2.b
+; CHECK-I8MM-NEXT: sunpklo z2.d, z4.s
+; CHECK-I8MM-NEXT: sunpkhi z3.d, z4.s
+; CHECK-I8MM-NEXT: add z0.d, z0.d, z2.d
+; CHECK-I8MM-NEXT: add z1.d, z1.d, z3.d
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: sudot_8to64:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NOI8MM-NEXT: addvl sp, sp, #-2
+; CHECK-NOI8MM-NEXT: str z9, [sp] // 16-byte Folded Spill
+; CHECK-NOI8MM-NEXT: str z8, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NOI8MM-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
+; CHECK-NOI8MM-NEXT: .cfi_offset w29, -16
+; CHECK-NOI8MM-NEXT: .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NOI8MM-NEXT: .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
+; CHECK-NOI8MM-NEXT: sunpklo z4.h, z2.b
+; CHECK-NOI8MM-NEXT: uunpklo z5.h, z3.b
+; CHECK-NOI8MM-NEXT: sunpkhi z2.h, z2.b
+; CHECK-NOI8MM-NEXT: uunpkhi z3.h, z3.b
+; CHECK-NOI8MM-NEXT: ptrue p0.d
+; CHECK-NOI8MM-NEXT: sunpklo z6.s, z4.h
+; CHECK-NOI8MM-NEXT: sunpkhi z4.s, z4.h
+; CHECK-NOI8MM-NEXT: uunpklo z7.s, z5.h
+; CHECK-NOI8MM-NEXT: uunpkhi z5.s, z5.h
+; CHECK-NOI8MM-NEXT: sunpklo z24.s, z2.h
+; CHECK-NOI8MM-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NOI8MM-NEXT: uunpklo z25.s, z3.h
+; CHECK-NOI8MM-NEXT: uunpkhi z3.s, z3.h
+; CHECK-NOI8MM-NEXT: sunpkhi z26.d, z6.s
+; CHECK-NOI8MM-NEXT: sunpklo z6.d, z6.s
+; CHECK-NOI8MM-NEXT: sunpklo z27.d, z4.s
+; CHECK-NOI8MM-NEXT: uunpklo z28.d, z7.s
+; CHECK-NOI8MM-NEXT: uunpklo z29.d, z5.s
+; CHECK-NOI8MM-NEXT: sunpkhi z4.d, z4.s
+; CHECK-NOI8MM-NEXT: uunpkhi z7.d, z7.s
+; CHECK-NOI8MM-NEXT: uunpkhi z5.d, z5.s
+; CHECK-NOI8MM-NEXT: sunpkhi z30.d, z24.s
+; CHECK-NOI8MM-NEXT: sunpkhi z31.d, z2.s
+; CHECK-NOI8MM-NEXT: sunpklo z24.d, z24.s
+; CHECK-NOI8MM-NEXT: sunpklo z2.d, z2.s
+; CHECK-NOI8MM-NEXT: uunpkhi z8.d, z25.s
+; CHECK-NOI8MM-NEXT: uunpklo z25.d, z25.s
+; CHECK-NOI8MM-NEXT: uunpklo z9.d, z3.s
+; CHECK-NOI8MM-NEXT: mul z27.d, z27.d, z29.d
+; CHECK-NOI8MM-NEXT: mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NOI8MM-NEXT: uunpkhi z3.d, z3.s
+; CHECK-NOI8MM-NEXT: mul z4.d, z4.d, z5.d
+; CHECK-NOI8MM-NEXT: mla z1.d, p0/m, z26.d, z7.d
+; CHECK-NOI8MM-NEXT: mla z0.d, p0/m, z2.d, z9.d
+; CHECK-NOI8MM-NEXT: movprfx z2, z27
+; CHECK-NOI8MM-NEXT: mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NOI8MM-NEXT: ldr z9, [sp] // 16-byte Folded Reload
+; CHECK-NOI8MM-NEXT: mla z1.d, p0/m, z31.d, z3.d
+; CHECK-NOI8MM-NEXT: movprfx z3, z4
+; CHECK-NOI8MM-NEXT: mla z3.d, p0/m, z30.d, z8.d
+; CHECK-NOI8MM-NEXT: ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NOI8MM-NEXT: add z0.d, z2.d, z0.d
+; CHECK-NOI8MM-NEXT: add z1.d, z3.d, z1.d
+; CHECK-NOI8MM-NEXT: addvl sp, sp, #2
+; CHECK-NOI8MM-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NOI8MM-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+ %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(
+ <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
; CHECK-LABEL: not_udot:
; CHECK: // %bb.0: // %entry
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the fixup 👍
P.s. You may want to rebase (since I think CI is failing due to unrelated changes)
@JamesChesterman Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
…ct (llvm#110220) An i8 to i64 partial reduction can instead be done with an i8 to i32 dot product followed by a sign extension.
An i8 to i64 partial reduction can instead be done with an i8 to i32 dot product followed by a sign extension.