Skip to content

Commit 7a9dfcb

Browse files
committed
[AArch64] Extend usage of XAR instruction for fixed-length operations
Resolves #139229 In #137162, support for `v2i64` was implemented for vector rotate transformation, although types like `v4i32`, `v8i16` and `v16i8` do not have Neon SHA3, we can use SVE operations if sve2-sha3 is available.
1 parent ee91f9b commit 7a9dfcb

File tree

2 files changed

+231
-23
lines changed

2 files changed

+231
-23
lines changed

llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4606,7 +4606,36 @@ bool AArch64DAGToDAGISel::trySelectXAR(SDNode *N) {
46064606
return false;
46074607
}
46084608

4609-
if (!Subtarget->hasSHA3())
4609+
// We have Neon SHA3 XAR operation for v2i64 but for types
4610+
// v4i32, v8i16, v16i8 we can use SVE operations when SVE2-SHA3
4611+
// is available.
4612+
EVT SVT;
4613+
switch (VT.getSimpleVT().SimpleTy) {
4614+
case MVT::v4i32:
4615+
SVT = MVT::nxv4i32;
4616+
break;
4617+
case MVT::v8i16:
4618+
SVT = MVT::nxv8i16;
4619+
break;
4620+
case MVT::v16i8:
4621+
SVT = MVT::nxv16i8;
4622+
break;
4623+
case MVT::v1i64:
4624+
case MVT::v2i32:
4625+
case MVT::v4i16:
4626+
case MVT::v8i8:
4627+
// Widen type to v2i64.
4628+
SVT = MVT::v2i64;
4629+
break;
4630+
default:
4631+
if (VT != MVT::v2i64)
4632+
return false;
4633+
SVT = MVT::v2i64;
4634+
break;
4635+
}
4636+
4637+
if ((!SVT.isScalableVector() && !Subtarget->hasSHA3()) ||
4638+
(SVT.isScalableVector() && !Subtarget->hasSVE2()))
46104639
return false;
46114640

46124641
if (N0->getOpcode() != AArch64ISD::VSHL ||
@@ -4632,41 +4661,68 @@ bool AArch64DAGToDAGISel::trySelectXAR(SDNode *N) {
46324661
SDValue Imm = CurDAG->getTargetConstant(
46334662
ShAmt, DL, N0.getOperand(1).getValueType(), false);
46344663

4635-
if (ShAmt + HsAmt != 64)
4664+
unsigned VTSizeInBits = VT.getScalarSizeInBits();
4665+
if (ShAmt + HsAmt != VTSizeInBits)
46364666
return false;
46374667

46384668
if (!IsXOROperand) {
46394669
SDValue Zero = CurDAG->getTargetConstant(0, DL, MVT::i64);
4640-
SDNode *MOV =
4641-
CurDAG->getMachineNode(AArch64::MOVIv2d_ns, DL, MVT::v2i64, Zero);
4670+
SDNode *MOV = CurDAG->getMachineNode(AArch64::MOVIv2d_ns, DL, SVT, Zero);
46424671
SDValue MOVIV = SDValue(MOV, 0);
4672+
46434673
R1 = N1->getOperand(0);
46444674
R2 = MOVIV;
46454675
}
46464676

4647-
// If the input is a v1i64, widen to a v2i64 to use XAR.
4648-
assert((VT == MVT::v1i64 || VT == MVT::v2i64) && "Unexpected XAR type!");
4649-
if (VT == MVT::v1i64) {
4650-
EVT SVT = MVT::v2i64;
4677+
if (SVT.isScalableVector()) {
4678+
SDValue Undef =
4679+
SDValue(CurDAG->getMachineNode(TargetOpcode::IMPLICIT_DEF, DL, SVT), 0);
4680+
SDValue ZSub = CurDAG->getTargetConstant(AArch64::zsub, DL, MVT::i32);
4681+
4682+
R1 = SDValue(CurDAG->getMachineNode(AArch64::INSERT_SUBREG, DL, SVT, Undef,
4683+
R1, ZSub),
4684+
0);
4685+
R2 = SDValue(CurDAG->getMachineNode(AArch64::INSERT_SUBREG, DL, SVT, Undef,
4686+
R2, ZSub),
4687+
0);
4688+
}
4689+
4690+
if (!SVT.isScalableVector() && SVT != VT) {
46514691
SDValue Undef =
46524692
SDValue(CurDAG->getMachineNode(AArch64::IMPLICIT_DEF, DL, SVT), 0);
46534693
SDValue DSub = CurDAG->getTargetConstant(AArch64::dsub, DL, MVT::i32);
4694+
46544695
R1 = SDValue(CurDAG->getMachineNode(AArch64::INSERT_SUBREG, DL, SVT, Undef,
46554696
R1, DSub),
46564697
0);
4657-
if (R2.getValueType() == MVT::v1i64)
4698+
if (R2.getValueType() != SVT)
46584699
R2 = SDValue(CurDAG->getMachineNode(AArch64::INSERT_SUBREG, DL, SVT,
46594700
Undef, R2, DSub),
46604701
0);
46614702
}
46624703

46634704
SDValue Ops[] = {R1, R2, Imm};
4664-
SDNode *XAR = CurDAG->getMachineNode(AArch64::XAR, DL, MVT::v2i64, Ops);
4705+
SDNode *XAR = nullptr;
46654706

4666-
if (VT == MVT::v1i64) {
4707+
if (SVT.isScalableVector()) {
4708+
if (auto Opc = SelectOpcodeFromVT<SelectTypeKind::Int>(
4709+
SVT, {AArch64::XAR_ZZZI_B, AArch64::XAR_ZZZI_H, AArch64::XAR_ZZZI_S,
4710+
AArch64::XAR_ZZZI_D}))
4711+
XAR = CurDAG->getMachineNode(Opc, DL, VT, Ops);
4712+
} else {
4713+
XAR = CurDAG->getMachineNode(AArch64::XAR, DL, SVT, Ops);
4714+
}
4715+
4716+
assert(XAR && "Unexpected NULL value for XAR instruction in DAG");
4717+
4718+
if (!SVT.isScalableVector() && SVT != VT) {
46674719
SDValue DSub = CurDAG->getTargetConstant(AArch64::dsub, DL, MVT::i32);
46684720
XAR = CurDAG->getMachineNode(AArch64::EXTRACT_SUBREG, DL, VT,
46694721
SDValue(XAR, 0), DSub);
4722+
} else if (SVT.isScalableVector()) {
4723+
SDValue ZSub = CurDAG->getTargetConstant(AArch64::zsub, DL, MVT::i32);
4724+
XAR = CurDAG->getMachineNode(AArch64::EXTRACT_SUBREG, DL, VT,
4725+
SDValue(XAR, 0), ZSub);
46704726
}
46714727
ReplaceNode(N, XAR);
46724728
return true;

llvm/test/CodeGen/AArch64/xar.ll

Lines changed: 164 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
22
; RUN: llc -mtriple=aarch64 -mattr=+sha3 < %s | FileCheck --check-prefix=SHA3 %s
33
; RUN: llc -mtriple=aarch64 -mattr=-sha3 < %s | FileCheck --check-prefix=NOSHA3 %s
4+
; RUN: llc -mtriple=aarch64 -mattr=+sve2 < %s | FileCheck --check-prefix=SVE2SHA3 %s
5+
6+
/* 128-bit vectors */
47

58
define <2 x i64> @xar(<2 x i64> %x, <2 x i64> %y) {
69
; SHA3-LABEL: xar:
@@ -39,14 +42,14 @@ define <1 x i64> @xar_v1i64(<1 x i64> %a, <1 x i64> %b) {
3942
ret <1 x i64> %fshl
4043
}
4144

42-
define <2 x i64> @xar_instead_of_or1(<2 x i64> %r) {
43-
; SHA3-LABEL: xar_instead_of_or1:
45+
define <2 x i64> @xar_instead_of_or_v2i64(<2 x i64> %r) {
46+
; SHA3-LABEL: xar_instead_of_or_v2i64:
4447
; SHA3: // %bb.0: // %entry
4548
; SHA3-NEXT: movi v1.2d, #0000000000000000
4649
; SHA3-NEXT: xar v0.2d, v0.2d, v1.2d, #39
4750
; SHA3-NEXT: ret
4851
;
49-
; NOSHA3-LABEL: xar_instead_of_or1:
52+
; NOSHA3-LABEL: xar_instead_of_or_v2i64:
5053
; NOSHA3: // %bb.0: // %entry
5154
; NOSHA3-NEXT: shl v1.2d, v0.2d, #25
5255
; NOSHA3-NEXT: usra v1.2d, v0.2d, #39
@@ -76,63 +79,212 @@ define <1 x i64> @xar_instead_of_or_v1i64(<1 x i64> %v.val) {
7679
ret <1 x i64> %fshl
7780
}
7881

79-
define <4 x i32> @xar_instead_of_or2(<4 x i32> %r) {
80-
; SHA3-LABEL: xar_instead_of_or2:
82+
define <4 x i32> @xar_instead_of_or_v4i32(<4 x i32> %r) {
83+
; SHA3-LABEL: xar_instead_of_or_v4i32:
8184
; SHA3: // %bb.0: // %entry
8285
; SHA3-NEXT: shl v1.4s, v0.4s, #25
8386
; SHA3-NEXT: usra v1.4s, v0.4s, #7
8487
; SHA3-NEXT: mov v0.16b, v1.16b
8588
; SHA3-NEXT: ret
8689
;
87-
; NOSHA3-LABEL: xar_instead_of_or2:
90+
; NOSHA3-LABEL: xar_instead_of_or_v4i32:
8891
; NOSHA3: // %bb.0: // %entry
8992
; NOSHA3-NEXT: shl v1.4s, v0.4s, #25
9093
; NOSHA3-NEXT: usra v1.4s, v0.4s, #7
9194
; NOSHA3-NEXT: mov v0.16b, v1.16b
9295
; NOSHA3-NEXT: ret
96+
;
97+
; SVE2SHA3-LABEL: xar_instead_of_or_v4i32:
98+
; SVE2SHA3: // %bb.0: // %entry
99+
; SVE2SHA3-NEXT: movi v1.2d, #0000000000000000
100+
; SVE2SHA3-NEXT: // kill: def $q0 killed $q0 def $z0
101+
; SVE2SHA3-NEXT: xar z0.s, z0.s, z1.s, #7
102+
; SVE2SHA3-NEXT: // kill: def $q0 killed $q0 killed $z0
103+
; SVE2SHA3-NEXT: ret
93104
entry:
94105
%or = call <4 x i32> @llvm.fshl.v2i32(<4 x i32> %r, <4 x i32> %r, <4 x i32> splat (i32 25))
95106
ret <4 x i32> %or
96107
}
97108

98-
define <8 x i16> @xar_instead_of_or3(<8 x i16> %r) {
99-
; SHA3-LABEL: xar_instead_of_or3:
109+
define <8 x i16> @xar_instead_of_or_v8i16(<8 x i16> %r) {
110+
; SHA3-LABEL: xar_instead_of_or_v8i16:
100111
; SHA3: // %bb.0: // %entry
101112
; SHA3-NEXT: shl v1.8h, v0.8h, #9
102113
; SHA3-NEXT: usra v1.8h, v0.8h, #7
103114
; SHA3-NEXT: mov v0.16b, v1.16b
104115
; SHA3-NEXT: ret
105116
;
106-
; NOSHA3-LABEL: xar_instead_of_or3:
117+
; NOSHA3-LABEL: xar_instead_of_or_v8i16:
107118
; NOSHA3: // %bb.0: // %entry
108119
; NOSHA3-NEXT: shl v1.8h, v0.8h, #9
109120
; NOSHA3-NEXT: usra v1.8h, v0.8h, #7
110121
; NOSHA3-NEXT: mov v0.16b, v1.16b
111122
; NOSHA3-NEXT: ret
123+
;
124+
; SVE2SHA3-LABEL: xar_instead_of_or_v8i16:
125+
; SVE2SHA3: // %bb.0: // %entry
126+
; SVE2SHA3-NEXT: movi v1.2d, #0000000000000000
127+
; SVE2SHA3-NEXT: // kill: def $q0 killed $q0 def $z0
128+
; SVE2SHA3-NEXT: xar z0.h, z0.h, z1.h, #7
129+
; SVE2SHA3-NEXT: // kill: def $q0 killed $q0 killed $z0
130+
; SVE2SHA3-NEXT: ret
112131
entry:
113132
%or = call <8 x i16> @llvm.fshl.v2i16(<8 x i16> %r, <8 x i16> %r, <8 x i16> splat (i16 25))
114133
ret <8 x i16> %or
115134
}
116135

117-
define <16 x i8> @xar_instead_of_or4(<16 x i8> %r) {
118-
; SHA3-LABEL: xar_instead_of_or4:
136+
define <16 x i8> @xar_instead_of_or_v16i8(<16 x i8> %r) {
137+
; SHA3-LABEL: xar_instead_of_or_v16i8:
119138
; SHA3: // %bb.0: // %entry
120139
; SHA3-NEXT: add v1.16b, v0.16b, v0.16b
121140
; SHA3-NEXT: usra v1.16b, v0.16b, #7
122141
; SHA3-NEXT: mov v0.16b, v1.16b
123142
; SHA3-NEXT: ret
124143
;
125-
; NOSHA3-LABEL: xar_instead_of_or4:
144+
; NOSHA3-LABEL: xar_instead_of_or_v16i8:
126145
; NOSHA3: // %bb.0: // %entry
127146
; NOSHA3-NEXT: add v1.16b, v0.16b, v0.16b
128147
; NOSHA3-NEXT: usra v1.16b, v0.16b, #7
129148
; NOSHA3-NEXT: mov v0.16b, v1.16b
130149
; NOSHA3-NEXT: ret
150+
;
151+
; SVE2SHA3-LABEL: xar_instead_of_or_v16i8:
152+
; SVE2SHA3: // %bb.0: // %entry
153+
; SVE2SHA3-NEXT: movi v1.2d, #0000000000000000
154+
; SVE2SHA3-NEXT: // kill: def $q0 killed $q0 def $z0
155+
; SVE2SHA3-NEXT: xar z0.b, z0.b, z1.b, #7
156+
; SVE2SHA3-NEXT: // kill: def $q0 killed $q0 killed $z0
157+
; SVE2SHA3-NEXT: ret
131158
entry:
132159
%or = call <16 x i8> @llvm.fshl.v2i8(<16 x i8> %r, <16 x i8> %r, <16 x i8> splat (i8 25))
133160
ret <16 x i8> %or
134161
}
135162

163+
/* 64 bit vectors */
164+
165+
define <2 x i32> @xar_v2i32(<2 x i32> %x, <2 x i32> %y) {
166+
; SHA3-LABEL: xar_v2i32:
167+
; SHA3: // %bb.0: // %entry
168+
; SHA3-NEXT: // kill: def $d0 killed $d0 def $q0
169+
; SHA3-NEXT: // kill: def $d1 killed $d1 def $q1
170+
; SHA3-NEXT: xar v0.2d, v0.2d, v1.2d, #7
171+
; SHA3-NEXT: // kill: def $d0 killed $d0 killed $q0
172+
; SHA3-NEXT: ret
173+
;
174+
; NOSHA3-LABEL: xar_v2i32:
175+
; NOSHA3: // %bb.0: // %entry
176+
; NOSHA3-NEXT: eor v1.8b, v0.8b, v1.8b
177+
; NOSHA3-NEXT: shl v0.2s, v1.2s, #25
178+
; NOSHA3-NEXT: usra v0.2s, v1.2s, #7
179+
; NOSHA3-NEXT: ret
180+
entry:
181+
%a = xor <2 x i32> %x, %y
182+
%b = call <2 x i32> @llvm.fshl(<2 x i32> %a, <2 x i32> %a, <2 x i32> <i32 25, i32 25>)
183+
ret <2 x i32> %b
184+
}
185+
186+
define <2 x i32> @xar_instead_of_or_v2i32(<2 x i32> %r) {
187+
; SHA3-LABEL: xar_instead_of_or_v2i32:
188+
; SHA3: // %bb.0: // %entry
189+
; SHA3-NEXT: movi v1.2d, #0000000000000000
190+
; SHA3-NEXT: // kill: def $d0 killed $d0 def $q0
191+
; SHA3-NEXT: xar v0.2d, v0.2d, v1.2d, #7
192+
; SHA3-NEXT: // kill: def $d0 killed $d0 killed $q0
193+
; SHA3-NEXT: ret
194+
;
195+
; NOSHA3-LABEL: xar_instead_of_or_v2i32:
196+
; NOSHA3: // %bb.0: // %entry
197+
; NOSHA3-NEXT: shl v1.2s, v0.2s, #25
198+
; NOSHA3-NEXT: usra v1.2s, v0.2s, #7
199+
; NOSHA3-NEXT: fmov d0, d1
200+
; NOSHA3-NEXT: ret
201+
entry:
202+
%or = call <2 x i32> @llvm.fshl(<2 x i32> %r, <2 x i32> %r, <2 x i32> splat (i32 25))
203+
ret <2 x i32> %or
204+
}
205+
206+
define <4 x i16> @xar_v4i16(<4 x i16> %x, <4 x i16> %y) {
207+
; SHA3-LABEL: xar_v4i16:
208+
; SHA3: // %bb.0: // %entry
209+
; SHA3-NEXT: // kill: def $d0 killed $d0 def $q0
210+
; SHA3-NEXT: // kill: def $d1 killed $d1 def $q1
211+
; SHA3-NEXT: xar v0.2d, v0.2d, v1.2d, #7
212+
; SHA3-NEXT: // kill: def $d0 killed $d0 killed $q0
213+
; SHA3-NEXT: ret
214+
;
215+
; NOSHA3-LABEL: xar_v4i16:
216+
; NOSHA3: // %bb.0: // %entry
217+
; NOSHA3-NEXT: eor v1.8b, v0.8b, v1.8b
218+
; NOSHA3-NEXT: shl v0.4h, v1.4h, #9
219+
; NOSHA3-NEXT: usra v0.4h, v1.4h, #7
220+
; NOSHA3-NEXT: ret
221+
entry:
222+
%a = xor <4 x i16> %x, %y
223+
%b = call <4 x i16> @llvm.fshl(<4 x i16> %a, <4 x i16> %a, <4 x i16> splat (i16 25))
224+
ret <4 x i16> %b
225+
}
226+
227+
define <4 x i16> @xar_instead_of_or_v4i16(<4 x i16> %r) {
228+
; SHA3-LABEL: xar_instead_of_or_v4i16:
229+
; SHA3: // %bb.0: // %entry
230+
; SHA3-NEXT: movi v1.2d, #0000000000000000
231+
; SHA3-NEXT: // kill: def $d0 killed $d0 def $q0
232+
; SHA3-NEXT: xar v0.2d, v0.2d, v1.2d, #7
233+
; SHA3-NEXT: // kill: def $d0 killed $d0 killed $q0
234+
; SHA3-NEXT: ret
235+
;
236+
; NOSHA3-LABEL: xar_instead_of_or_v4i16:
237+
; NOSHA3: // %bb.0: // %entry
238+
; NOSHA3-NEXT: shl v1.4h, v0.4h, #9
239+
; NOSHA3-NEXT: usra v1.4h, v0.4h, #7
240+
; NOSHA3-NEXT: fmov d0, d1
241+
; NOSHA3-NEXT: ret
242+
entry:
243+
%or = call <4 x i16> @llvm.fshl(<4 x i16> %r, <4 x i16> %r, <4 x i16> splat (i16 25))
244+
ret <4 x i16> %or
245+
}
246+
247+
define <8 x i8> @xar_v8i8(<8 x i8> %x, <8 x i8> %y) {
248+
; SHA3-LABEL: xar_v8i8:
249+
; SHA3: // %bb.0: // %entry
250+
; SHA3-NEXT: // kill: def $d0 killed $d0 def $q0
251+
; SHA3-NEXT: // kill: def $d1 killed $d1 def $q1
252+
; SHA3-NEXT: xar v0.2d, v0.2d, v1.2d, #7
253+
; SHA3-NEXT: // kill: def $d0 killed $d0 killed $q0
254+
; SHA3-NEXT: ret
255+
;
256+
; NOSHA3-LABEL: xar_v8i8:
257+
; NOSHA3: // %bb.0: // %entry
258+
; NOSHA3-NEXT: eor v1.8b, v0.8b, v1.8b
259+
; NOSHA3-NEXT: add v0.8b, v1.8b, v1.8b
260+
; NOSHA3-NEXT: usra v0.8b, v1.8b, #7
261+
; NOSHA3-NEXT: ret
262+
entry:
263+
%a = xor <8 x i8> %x, %y
264+
%b = call <8 x i8> @llvm.fshl(<8 x i8> %a, <8 x i8> %a, <8 x i8> splat (i8 25))
265+
ret <8 x i8> %b
266+
}
267+
268+
define <8 x i8> @xar_instead_of_or_v8i8(<8 x i8> %r) {
269+
; SHA3-LABEL: xar_instead_of_or_v8i8:
270+
; SHA3: // %bb.0: // %entry
271+
; SHA3-NEXT: movi v1.2d, #0000000000000000
272+
; SHA3-NEXT: // kill: def $d0 killed $d0 def $q0
273+
; SHA3-NEXT: xar v0.2d, v0.2d, v1.2d, #7
274+
; SHA3-NEXT: // kill: def $d0 killed $d0 killed $q0
275+
; SHA3-NEXT: ret
276+
;
277+
; NOSHA3-LABEL: xar_instead_of_or_v8i8:
278+
; NOSHA3: // %bb.0: // %entry
279+
; NOSHA3-NEXT: add v1.8b, v0.8b, v0.8b
280+
; NOSHA3-NEXT: usra v1.8b, v0.8b, #7
281+
; NOSHA3-NEXT: fmov d0, d1
282+
; NOSHA3-NEXT: ret
283+
entry:
284+
%or = call <8 x i8> @llvm.fshl(<8 x i8> %r, <8 x i8> %r, <8 x i8> splat (i8 25))
285+
ret <8 x i8> %or
286+
}
287+
136288
declare <2 x i64> @llvm.fshl.v2i64(<2 x i64>, <2 x i64>, <2 x i64>)
137289
declare <4 x i32> @llvm.fshl.v4i32(<4 x i32>, <4 x i32>, <4 x i32>)
138290
declare <8 x i16> @llvm.fshl.v8i16(<8 x i16>, <8 x i16>, <8 x i16>)

0 commit comments

Comments
 (0)