-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[AArch64] Add custom lowering of nxv32i1 get.active.lane.mask nodes #141969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
performActiveLaneMaskCombine already tries to combine a single get.active.lane.mask where the low and high halves of the result are extracted into a single whilelo which operates on a predicate pair. If the get.active.lane.mask node requires splitting, multiple nodes are created with saturating adds to increment the starting index. We cannot combine these into a single whilelo_x2 at this point unless we know the add will not overflow. This patch adds custom lowering for the node if the return type is nxv32xi1, as this can be replaced with a whilelo_x2 using legal types. Anything wider than nxv32i1 will still require splitting first.
@llvm/pr-subscribers-backend-aarch64 Author: Kerry McLaughlin (kmclaughlin-arm) ChangesperformActiveLaneMaskCombine already tries to combine a single If the get.active.lane.mask node requires splitting, multiple nodes are This patch adds custom lowering for the node if the return type is Full diff: https://github.com/llvm/llvm-project/pull/141969.diff 3 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f0a703be35207..4eb49b9fe025e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1501,6 +1501,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Legal);
}
+ setOperationAction(ISD::GET_ACTIVE_LANE_MASK, MVT::nxv32i1, Custom);
+
for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32})
setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom);
}
@@ -27328,6 +27330,29 @@ void AArch64TargetLowering::ReplaceExtractSubVectorResults(
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Half));
}
+void AArch64TargetLowering::ReplaceGetActiveLaneMaskResults(
+ SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
+ if (!Subtarget->hasSVE2p1())
+ return;
+
+ SDLoc DL(N);
+ SDValue Idx = N->getOperand(0);
+ SDValue TC = N->getOperand(1);
+ if (Idx.getValueType() != MVT::i64) {
+ Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
+ TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+ }
+
+ SDValue ID =
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+ EVT HalfVT = N->getValueType(0).getHalfNumVectorElementsVT(*DAG.getContext());
+ auto WideMask =
+ DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {HalfVT, HalfVT}, {ID, Idx, TC});
+
+ Results.push_back(DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0),
+ {WideMask.getValue(0), WideMask.getValue(1)}));
+}
+
// Create an even/odd pair of X registers holding integer value V.
static SDValue createGPRPairNode(SelectionDAG &DAG, SDValue V) {
SDLoc dl(V.getNode());
@@ -27714,6 +27739,9 @@ void AArch64TargetLowering::ReplaceNodeResults(
// CONCAT_VECTORS -- but delegate to common code for result type
// legalisation
return;
+ case ISD::GET_ACTIVE_LANE_MASK:
+ ReplaceGetActiveLaneMaskResults(N, Results, DAG);
+ return;
case ISD::INTRINSIC_WO_CHAIN: {
EVT VT = N->getValueType(0);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index b59526bf01888..4c6358034af02 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1318,6 +1318,9 @@ class AArch64TargetLowering : public TargetLowering {
void ReplaceExtractSubVectorResults(SDNode *N,
SmallVectorImpl<SDValue> &Results,
SelectionDAG &DAG) const;
+ void ReplaceGetActiveLaneMaskResults(SDNode *N,
+ SmallVectorImpl<SDValue> &Results,
+ SelectionDAG &DAG) const;
bool shouldNormalizeToSelectSequence(LLVMContext &, EVT) const override;
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index 2d84a69f3144e..0b78dd963cbb0 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -111,7 +111,7 @@ define void @test_partial_extract(i64 %i, i64 %n) #0 {
ret void
}
-;; Negative test for when extracting a fixed-length vector.
+; Negative test for when extracting a fixed-length vector.
define void @test_fixed_extract(i64 %i, i64 %n) #0 {
; CHECK-SVE-LABEL: test_fixed_extract:
; CHECK-SVE: // %bb.0:
@@ -151,6 +151,155 @@ define void @test_fixed_extract(i64 %i, i64 %n) #0 {
ret void
}
+; Illegal Types
+
+define void @test_2x16bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 {
+; CHECK-SVE-LABEL: test_2x16bit_mask_with_32bit_index_and_trip_count:
+; CHECK-SVE: // %bb.0:
+; CHECK-SVE-NEXT: rdvl x8, #1
+; CHECK-SVE-NEXT: adds w8, w0, w8
+; CHECK-SVE-NEXT: csinv w8, w8, wzr, lo
+; CHECK-SVE-NEXT: whilelo p0.b, w0, w1
+; CHECK-SVE-NEXT: whilelo p1.b, w8, w1
+; CHECK-SVE-NEXT: b use
+;
+; CHECK-SVE2p1-LABEL: test_2x16bit_mask_with_32bit_index_and_trip_count:
+; CHECK-SVE2p1: // %bb.0:
+; CHECK-SVE2p1-NEXT: mov w8, w1
+; CHECK-SVE2p1-NEXT: mov w9, w0
+; CHECK-SVE2p1-NEXT: whilelo { p0.b, p1.b }, x9, x8
+; CHECK-SVE2p1-NEXT: b use
+ %r = call <vscale x 32 x i1> @llvm.get.active.lane.mask.nxv32i1.i32(i32 %i, i32 %n)
+ %v0 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 0)
+ %v1 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 16)
+ tail call void @use(<vscale x 16 x i1> %v0, <vscale x 16 x i1> %v1)
+ ret void
+}
+
+define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 {
+; CHECK-SVE-LABEL: test_2x32bit_mask_with_32bit_index_and_trip_count:
+; CHECK-SVE: // %bb.0:
+; CHECK-SVE-NEXT: rdvl x8, #2
+; CHECK-SVE-NEXT: rdvl x9, #1
+; CHECK-SVE-NEXT: adds w8, w0, w8
+; CHECK-SVE-NEXT: csinv w8, w8, wzr, lo
+; CHECK-SVE-NEXT: adds w10, w8, w9
+; CHECK-SVE-NEXT: csinv w10, w10, wzr, lo
+; CHECK-SVE-NEXT: whilelo p3.b, w10, w1
+; CHECK-SVE-NEXT: adds w9, w0, w9
+; CHECK-SVE-NEXT: csinv w9, w9, wzr, lo
+; CHECK-SVE-NEXT: whilelo p0.b, w0, w1
+; CHECK-SVE-NEXT: whilelo p1.b, w9, w1
+; CHECK-SVE-NEXT: whilelo p2.b, w8, w1
+; CHECK-SVE-NEXT: b use
+;
+; CHECK-SVE2p1-LABEL: test_2x32bit_mask_with_32bit_index_and_trip_count:
+; CHECK-SVE2p1: // %bb.0:
+; CHECK-SVE2p1-NEXT: rdvl x8, #2
+; CHECK-SVE2p1-NEXT: mov w9, w1
+; CHECK-SVE2p1-NEXT: mov w10, w0
+; CHECK-SVE2p1-NEXT: adds w8, w0, w8
+; CHECK-SVE2p1-NEXT: csinv w8, w8, wzr, lo
+; CHECK-SVE2p1-NEXT: whilelo { p0.b, p1.b }, x10, x9
+; CHECK-SVE2p1-NEXT: whilelo { p2.b, p3.b }, x8, x9
+; CHECK-SVE2p1-NEXT: b use
+ %r = call <vscale x 64 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i32 %i, i32 %n)
+ %v0 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 0)
+ %v1 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 16)
+ %v2 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 32)
+ %v3 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 48)
+ tail call void @use(<vscale x 16 x i1> %v0, <vscale x 16 x i1> %v1, <vscale x 16 x i1> %v2, <vscale x 16 x i1> %v3)
+ ret void
+}
+
+define void @test_2x16bit_mask_with_32bit_index_and_trip_count_ext8(i32 %i, i32 %n) #0 {
+; CHECK-SVE-LABEL: test_2x16bit_mask_with_32bit_index_and_trip_count_ext8:
+; CHECK-SVE: // %bb.0:
+; CHECK-SVE-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-SVE-NEXT: rdvl x8, #1
+; CHECK-SVE-NEXT: adds w8, w0, w8
+; CHECK-SVE-NEXT: csinv w8, w8, wzr, lo
+; CHECK-SVE-NEXT: whilelo p0.b, w0, w1
+; CHECK-SVE-NEXT: whilelo p4.b, w8, w1
+; CHECK-SVE-NEXT: punpklo p1.h, p0.b
+; CHECK-SVE-NEXT: punpkhi p3.h, p0.b
+; CHECK-SVE-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT: punpklo p2.h, p3.b
+; CHECK-SVE-NEXT: punpkhi p3.h, p3.b
+; CHECK-SVE-NEXT: bl use
+; CHECK-SVE-NEXT: punpklo p1.h, p4.b
+; CHECK-SVE-NEXT: punpkhi p3.h, p4.b
+; CHECK-SVE-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT: punpklo p2.h, p3.b
+; CHECK-SVE-NEXT: punpkhi p3.h, p3.b
+; CHECK-SVE-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-SVE-NEXT: b use
+;
+; CHECK-SVE2p1-LABEL: test_2x16bit_mask_with_32bit_index_and_trip_count_ext8:
+; CHECK-SVE2p1: // %bb.0:
+; CHECK-SVE2p1-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-SVE2p1-NEXT: mov w8, w1
+; CHECK-SVE2p1-NEXT: mov w9, w0
+; CHECK-SVE2p1-NEXT: whilelo { p4.b, p5.b }, x9, x8
+; CHECK-SVE2p1-NEXT: punpklo p1.h, p4.b
+; CHECK-SVE2p1-NEXT: punpkhi p3.h, p4.b
+; CHECK-SVE2p1-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE2p1-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE2p1-NEXT: punpklo p2.h, p3.b
+; CHECK-SVE2p1-NEXT: punpkhi p3.h, p3.b
+; CHECK-SVE2p1-NEXT: bl use
+; CHECK-SVE2p1-NEXT: punpklo p1.h, p5.b
+; CHECK-SVE2p1-NEXT: punpkhi p3.h, p5.b
+; CHECK-SVE2p1-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE2p1-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE2p1-NEXT: punpklo p2.h, p3.b
+; CHECK-SVE2p1-NEXT: punpkhi p3.h, p3.b
+; CHECK-SVE2p1-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-SVE2p1-NEXT: b use
+ %r = call <vscale x 32 x i1> @llvm.get.active.lane.mask.nxv32i1.i32(i32 %i, i32 %n)
+ %v0 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 0)
+ %v1 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 4)
+ %v2 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 8)
+ %v3 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 12)
+ tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1, <vscale x 4 x i1> %v2, <vscale x 4 x i1> %v3)
+ %v4 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 16)
+ %v5 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 20)
+ %v6 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 24)
+ %v7 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 28)
+ tail call void @use(<vscale x 4 x i1> %v4, <vscale x 4 x i1> %v5, <vscale x 4 x i1> %v6, <vscale x 4 x i1> %v7)
+ ret void
+}
+
+; Negative test for when not extracting exactly two halves of the source vector
+define void @test_illegal_type_with_partial_extracts(i32 %i, i32 %n) #0 {
+; CHECK-SVE-LABEL: test_illegal_type_with_partial_extracts:
+; CHECK-SVE: // %bb.0:
+; CHECK-SVE-NEXT: rdvl x8, #1
+; CHECK-SVE-NEXT: adds w8, w0, w8
+; CHECK-SVE-NEXT: csinv w8, w8, wzr, lo
+; CHECK-SVE-NEXT: whilelo p0.b, w0, w1
+; CHECK-SVE-NEXT: whilelo p1.b, w8, w1
+; CHECK-SVE-NEXT: punpkhi p0.h, p0.b
+; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT: b use
+;
+; CHECK-SVE2p1-LABEL: test_illegal_type_with_partial_extracts:
+; CHECK-SVE2p1: // %bb.0:
+; CHECK-SVE2p1-NEXT: mov w8, w1
+; CHECK-SVE2p1-NEXT: mov w9, w0
+; CHECK-SVE2p1-NEXT: whilelo { p2.b, p3.b }, x9, x8
+; CHECK-SVE2p1-NEXT: punpkhi p0.h, p2.b
+; CHECK-SVE2p1-NEXT: punpkhi p1.h, p3.b
+; CHECK-SVE2p1-NEXT: b use
+ %r = call <vscale x 32 x i1> @llvm.get.active.lane.mask.nxv32i1.i32(i32 %i, i32 %n)
+ %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 8)
+ %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 24)
+ tail call void @use(<vscale x 8 x i1> %v0, <vscale x 8 x i1> %v1)
+ ret void
+}
+
declare void @use(...)
attributes #0 = { nounwind }
|
TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64); | ||
} | ||
|
||
SDValue ID = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's look like this code is expecting the result type to <vscale x 32 x i1>. Is it worth adding an assert for this?
SDValue Idx = N->getOperand(0); | ||
SDValue TC = N->getOperand(1); | ||
if (Idx.getValueType() != MVT::i64) { | ||
Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth having a test with a i128 index type to test the trunc case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this again, I think there should also be an assert for the operand types here. The reason is that shouldExpandGetActiveLaneMask
returns true when the operand types are bigger than i64, so we should never reach this function if Idx/TC are something like i128.
ret void | ||
} | ||
|
||
; Negative test for when not extracting exactly two halves of the source vector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This actually doesn't look like a negative test, since we do generate the while-pair instruction if SVE2.1 is available, right? Couldn't this test just be named @test_2x16bit_mask_with_32bit_index_and_trip_count_part_ext
…askResults - Rename new test with partial extracts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with minor nit!
; CHECK-SVE2p1-NEXT: whilelo { p0.b, p1.b }, x10, x9 | ||
; CHECK-SVE2p1-NEXT: whilelo { p2.b, p3.b }, x8, x9 | ||
; CHECK-SVE2p1-NEXT: b use | ||
%r = call <vscale x 64 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i32 %i, i32 %n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Perhaps change the intrinsic to @llvm.get.active.lane.mask.nxv64i1.i32
?
ret void | ||
} | ||
|
||
define void @test_2x16bit_mask_with_32bit_index_and_trip_count_ext8(i32 %i, i32 %n) #0 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does test_2x16bit_mask_with_32bit_index_and_trip_count_ext8
and test_2x16bit_mask_with_32bit_index_and_trip_count_part_extracts
test that's not already covered by the previous two tests?
To me they look more like llvm.vector.extract
tests which should already be covered elsewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed these two tests. The first was added as there wasn't a test with more than two extracts, but this is already covered by test_2x32bit_mask_with_32bit_index_and_trip_count
and the second is already covered by the existing test_partial_extract
.
@@ -27328,6 +27330,36 @@ void AArch64TargetLowering::ReplaceExtractSubVectorResults( | |||
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Half)); | |||
} | |||
|
|||
void AArch64TargetLowering::ReplaceGetActiveLaneMaskResults( | |||
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const { | |||
if (!Subtarget->hasSVE2p1()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should protect the setOperationAction()
call and then be an assert here, that way the function is only called when the necessary instructions are available.
Do you mind also extending the PR to cover Subtarget.hasSME2() && Subtarget.isStreaming()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A final question and it looks like you've missed Dave's "incorrectly named intrinsic" comment, but otherwise this looks good to me.
@@ -1,6 +1,7 @@ | |||
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4 | |||
; RUN: llc -mattr=+sve < %s | FileCheck %s -check-prefix CHECK-SVE | |||
; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s -check-prefix CHECK-SVE2p1 | |||
; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s -check-prefix CHECK-SVE2p1-SME2 -check-prefix CHECK-SVE2p1 | |||
; RUN: llc -mattr=+sve -mattr=+sme2 -force-streaming < %s | FileCheck %s -check-prefix CHECK-SVE2p1-SME2 -check-prefix CHECK-SME2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is -mattr=+sve
required for the SME2 test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It shouldn't be needed for the get_active_lane_mask, but without SVE the tests fail with Don't know how to legalize this scalable vector type
because of the extract_subvectors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To follow up on this, the assert I mentioned above was only happening when I did not also pass the -force-streaming
flag. Dropping only -mattr=+sve
just results in poor codgen because the get.active.lane.mask is expanded.
performActiveLaneMaskCombine already tries to combine a single
get.active.lane.mask where the low and high halves of the result are
extracted into a single whilelo which operates on a predicate pair.
If the get.active.lane.mask node requires splitting, multiple nodes are
created with saturating adds to increment the starting index. We cannot
combine these into a single whilelo_x2 at this point unless we know
the add will not overflow.
This patch adds custom lowering for the node if the return type is
nxv32xi1, as this can be replaced with a whilelo_x2 using legal types.
Anything wider than nxv32i1 will still require splitting first.