Skip to content

Commit c18fc6f

Browse files
committed
[DAG] Fold nested add(add(reduce(a), b), add(reduce(c), d))
This patch reassociates add(add(vecreduce(a), b), add(vecreduce(c), d)) into add(vecreduce(add(a, c)), add(b, d)), to combine the reductions into a single node. This comes up after unrolling vectorized loops. There is another small change to move reassociateReduction inside fadd outside of a AllowNewConst block, as new constants will not be created and it should be OK to perform the combine later after legalization.
1 parent debfd7b commit c18fc6f

File tree

3 files changed

+144
-208
lines changed

3 files changed

+144
-208
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1329,6 +1329,28 @@ SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
13291329
DAG.getNode(Opc, DL, N0.getOperand(0).getValueType(),
13301330
N0.getOperand(0), N1.getOperand(0)));
13311331
}
1332+
1333+
// Reassociate op(op(vecreduce(a), b), op(vecreduce(c), d)) into
1334+
// op(vecreduce(op(a, c)), op(b, d)), to combine the reductions into a
1335+
// single node.
1336+
SDValue A, B, C, D;
1337+
if (sd_match(N0,
1338+
m_OneUse(m_c_BinOp(Opc, m_OneUse(m_UnaryOp(RedOpc, m_Value(A))),
1339+
m_Value(B)))) &&
1340+
sd_match(N1,
1341+
m_OneUse(m_c_BinOp(Opc, m_OneUse(m_UnaryOp(RedOpc, m_Value(C))),
1342+
m_Value(D)))) &&
1343+
!sd_match(B, m_UnaryOp(RedOpc, m_Value())) &&
1344+
!sd_match(D, m_UnaryOp(RedOpc, m_Value())) &&
1345+
A.getValueType() == C.getValueType() &&
1346+
hasOperation(Opc, A.getValueType()) &&
1347+
TLI.shouldReassociateReduction(RedOpc, VT)) {
1348+
SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
1349+
SDValue Op = DAG.getNode(Opc, DL, A.getValueType(), A, C);
1350+
SDValue Red = DAG.getNode(RedOpc, DL, VT, Op);
1351+
SDValue Op2 = DAG.getNode(Opc, DL, VT, B, D);
1352+
return DAG.getNode(Opc, DL, VT, Red, Op2);
1353+
}
13321354
return SDValue();
13331355
}
13341356

@@ -17107,12 +17129,15 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
1710717129
DAG.getConstantFP(4.0, DL, VT));
1710817130
}
1710917131
}
17132+
} // enable-unsafe-fp-math && AllowNewConst
1711017133

17134+
if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
17135+
(Flags.hasAllowReassociation() && Flags.hasNoSignedZeros()))) {
1711117136
// Fold fadd(vecreduce(x), vecreduce(y)) -> vecreduce(fadd(x, y))
1711217137
if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FADD, ISD::FADD, DL,
1711317138
VT, N0, N1, Flags))
1711417139
return SD;
17115-
} // enable-unsafe-fp-math
17140+
}
1711617141

1711717142
// FADD -> FMA combines:
1711817143
if (SDValue Fused = visitFADDForFMACombine<EmptyMatchContext>(N)) {

llvm/test/CodeGen/AArch64/double_reduct.ll

Lines changed: 52 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -288,13 +288,11 @@ define i32 @smax_i32(<8 x i32> %a, <4 x i32> %b) {
288288
define float @nested_fadd_f32(<4 x float> %a, <4 x float> %b, float %c, float %d) {
289289
; CHECK-LABEL: nested_fadd_f32:
290290
; CHECK: // %bb.0:
291-
; CHECK-NEXT: faddp v1.4s, v1.4s, v1.4s
291+
; CHECK-NEXT: fadd v0.4s, v0.4s, v1.4s
292+
; CHECK-NEXT: fadd s2, s2, s3
292293
; CHECK-NEXT: faddp v0.4s, v0.4s, v0.4s
293-
; CHECK-NEXT: faddp s1, v1.2s
294294
; CHECK-NEXT: faddp s0, v0.2s
295-
; CHECK-NEXT: fadd s1, s1, s3
296295
; CHECK-NEXT: fadd s0, s0, s2
297-
; CHECK-NEXT: fadd s0, s0, s1
298296
; CHECK-NEXT: ret
299297
%r1 = call fast float @llvm.vector.reduce.fadd.f32.v4f32(float -0.0, <4 x float> %a)
300298
%a1 = fadd fast float %r1, %c
@@ -332,15 +330,12 @@ define float @nested_fadd_f32_slow(<4 x float> %a, <4 x float> %b, float %c, flo
332330
define float @nested_mul_f32(<4 x float> %a, <4 x float> %b, float %c, float %d) {
333331
; CHECK-LABEL: nested_mul_f32:
334332
; CHECK: // %bb.0:
335-
; CHECK-NEXT: ext v4.16b, v1.16b, v1.16b, #8
336-
; CHECK-NEXT: ext v5.16b, v0.16b, v0.16b, #8
337-
; CHECK-NEXT: fmul v1.2s, v1.2s, v4.2s
338-
; CHECK-NEXT: fmul v0.2s, v0.2s, v5.2s
339-
; CHECK-NEXT: fmul s1, s1, v1.s[1]
333+
; CHECK-NEXT: fmul v0.4s, v0.4s, v1.4s
334+
; CHECK-NEXT: fmul s2, s2, s3
335+
; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8
336+
; CHECK-NEXT: fmul v0.2s, v0.2s, v1.2s
340337
; CHECK-NEXT: fmul s0, s0, v0.s[1]
341-
; CHECK-NEXT: fmul s1, s1, s3
342338
; CHECK-NEXT: fmul s0, s0, s2
343-
; CHECK-NEXT: fmul s0, s0, s1
344339
; CHECK-NEXT: ret
345340
%r1 = call fast float @llvm.vector.reduce.fmul.f32.v4f32(float 1.0, <4 x float> %a)
346341
%a1 = fmul fast float %r1, %c
@@ -353,12 +348,10 @@ define float @nested_mul_f32(<4 x float> %a, <4 x float> %b, float %c, float %d)
353348
define i32 @nested_add_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
354349
; CHECK-LABEL: nested_add_i32:
355350
; CHECK: // %bb.0:
356-
; CHECK-NEXT: addv s1, v1.4s
351+
; CHECK-NEXT: add v0.4s, v0.4s, v1.4s
352+
; CHECK-NEXT: add w8, w0, w1
357353
; CHECK-NEXT: addv s0, v0.4s
358-
; CHECK-NEXT: fmov w8, s1
359354
; CHECK-NEXT: fmov w9, s0
360-
; CHECK-NEXT: add w9, w9, w0
361-
; CHECK-NEXT: add w8, w8, w1
362355
; CHECK-NEXT: add w0, w9, w8
363356
; CHECK-NEXT: ret
364357
%r1 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
@@ -372,12 +365,10 @@ define i32 @nested_add_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
372365
define i32 @nested_add_c1_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
373366
; CHECK-LABEL: nested_add_c1_i32:
374367
; CHECK: // %bb.0:
375-
; CHECK-NEXT: addv s1, v1.4s
368+
; CHECK-NEXT: add v0.4s, v0.4s, v1.4s
369+
; CHECK-NEXT: add w8, w0, w1
376370
; CHECK-NEXT: addv s0, v0.4s
377-
; CHECK-NEXT: fmov w8, s1
378371
; CHECK-NEXT: fmov w9, s0
379-
; CHECK-NEXT: add w9, w0, w9
380-
; CHECK-NEXT: add w8, w8, w1
381372
; CHECK-NEXT: add w0, w9, w8
382373
; CHECK-NEXT: ret
383374
%r1 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
@@ -391,12 +382,10 @@ define i32 @nested_add_c1_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
391382
define i32 @nested_add_c2_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
392383
; CHECK-LABEL: nested_add_c2_i32:
393384
; CHECK: // %bb.0:
394-
; CHECK-NEXT: addv s1, v1.4s
385+
; CHECK-NEXT: add v0.4s, v0.4s, v1.4s
386+
; CHECK-NEXT: add w8, w0, w1
395387
; CHECK-NEXT: addv s0, v0.4s
396-
; CHECK-NEXT: fmov w8, s1
397388
; CHECK-NEXT: fmov w9, s0
398-
; CHECK-NEXT: add w9, w9, w0
399-
; CHECK-NEXT: add w8, w1, w8
400389
; CHECK-NEXT: add w0, w9, w8
401390
; CHECK-NEXT: ret
402391
%r1 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
@@ -429,19 +418,14 @@ define i32 @nested_add_manyreduct_i32(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c,
429418
define i32 @nested_mul_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
430419
; CHECK-LABEL: nested_mul_i32:
431420
; CHECK: // %bb.0:
432-
; CHECK-NEXT: ext v3.16b, v0.16b, v0.16b, #8
433-
; CHECK-NEXT: ext v2.16b, v1.16b, v1.16b, #8
434-
; CHECK-NEXT: mul v0.2s, v0.2s, v3.2s
435-
; CHECK-NEXT: mul v1.2s, v1.2s, v2.2s
436-
; CHECK-NEXT: mov w8, v0.s[1]
421+
; CHECK-NEXT: mul v0.4s, v0.4s, v1.4s
422+
; CHECK-NEXT: mul w8, w0, w1
423+
; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8
424+
; CHECK-NEXT: mul v0.2s, v0.2s, v1.2s
425+
; CHECK-NEXT: mov w9, v0.s[1]
437426
; CHECK-NEXT: fmov w10, s0
438-
; CHECK-NEXT: mov w9, v1.s[1]
439-
; CHECK-NEXT: mul w8, w10, w8
440-
; CHECK-NEXT: fmov w10, s1
441427
; CHECK-NEXT: mul w9, w10, w9
442-
; CHECK-NEXT: mul w8, w8, w0
443-
; CHECK-NEXT: mul w9, w9, w1
444-
; CHECK-NEXT: mul w0, w8, w9
428+
; CHECK-NEXT: mul w0, w9, w8
445429
; CHECK-NEXT: ret
446430
%r1 = call i32 @llvm.vector.reduce.mul.v4i32(<4 x i32> %a)
447431
%a1 = mul i32 %r1, %c
@@ -454,19 +438,14 @@ define i32 @nested_mul_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
454438
define i32 @nested_and_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
455439
; CHECK-LABEL: nested_and_i32:
456440
; CHECK: // %bb.0:
457-
; CHECK-NEXT: ext v2.16b, v1.16b, v1.16b, #8
458-
; CHECK-NEXT: ext v3.16b, v0.16b, v0.16b, #8
459-
; CHECK-NEXT: and v1.8b, v1.8b, v2.8b
460-
; CHECK-NEXT: and v0.8b, v0.8b, v3.8b
461-
; CHECK-NEXT: fmov x8, d1
441+
; CHECK-NEXT: and v0.16b, v0.16b, v1.16b
442+
; CHECK-NEXT: and w8, w0, w1
443+
; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8
444+
; CHECK-NEXT: and v0.8b, v0.8b, v1.8b
462445
; CHECK-NEXT: fmov x9, d0
463446
; CHECK-NEXT: lsr x10, x9, #32
464-
; CHECK-NEXT: lsr x11, x8, #32
465-
; CHECK-NEXT: and w9, w9, w0
466-
; CHECK-NEXT: and w8, w8, w1
467-
; CHECK-NEXT: and w9, w9, w10
468-
; CHECK-NEXT: and w8, w8, w11
469-
; CHECK-NEXT: and w0, w9, w8
447+
; CHECK-NEXT: and w8, w9, w8
448+
; CHECK-NEXT: and w0, w8, w10
470449
; CHECK-NEXT: ret
471450
%r1 = call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %a)
472451
%a1 = and i32 %r1, %c
@@ -479,19 +458,14 @@ define i32 @nested_and_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
479458
define i32 @nested_or_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
480459
; CHECK-LABEL: nested_or_i32:
481460
; CHECK: // %bb.0:
482-
; CHECK-NEXT: ext v2.16b, v1.16b, v1.16b, #8
483-
; CHECK-NEXT: ext v3.16b, v0.16b, v0.16b, #8
484-
; CHECK-NEXT: orr v1.8b, v1.8b, v2.8b
485-
; CHECK-NEXT: orr v0.8b, v0.8b, v3.8b
486-
; CHECK-NEXT: fmov x8, d1
461+
; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b
462+
; CHECK-NEXT: orr w8, w0, w1
463+
; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8
464+
; CHECK-NEXT: orr v0.8b, v0.8b, v1.8b
487465
; CHECK-NEXT: fmov x9, d0
488466
; CHECK-NEXT: lsr x10, x9, #32
489-
; CHECK-NEXT: lsr x11, x8, #32
490-
; CHECK-NEXT: orr w9, w9, w0
491-
; CHECK-NEXT: orr w8, w8, w1
492-
; CHECK-NEXT: orr w9, w9, w10
493-
; CHECK-NEXT: orr w8, w8, w11
494-
; CHECK-NEXT: orr w0, w9, w8
467+
; CHECK-NEXT: orr w8, w9, w8
468+
; CHECK-NEXT: orr w0, w8, w10
495469
; CHECK-NEXT: ret
496470
%r1 = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> %a)
497471
%a1 = or i32 %r1, %c
@@ -504,19 +478,14 @@ define i32 @nested_or_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
504478
define i32 @nested_xor_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
505479
; CHECK-LABEL: nested_xor_i32:
506480
; CHECK: // %bb.0:
507-
; CHECK-NEXT: ext v2.16b, v1.16b, v1.16b, #8
508-
; CHECK-NEXT: ext v3.16b, v0.16b, v0.16b, #8
509-
; CHECK-NEXT: eor v1.8b, v1.8b, v2.8b
510-
; CHECK-NEXT: eor v0.8b, v0.8b, v3.8b
511-
; CHECK-NEXT: fmov x8, d1
481+
; CHECK-NEXT: eor v0.16b, v0.16b, v1.16b
482+
; CHECK-NEXT: eor w8, w0, w1
483+
; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8
484+
; CHECK-NEXT: eor v0.8b, v0.8b, v1.8b
512485
; CHECK-NEXT: fmov x9, d0
513486
; CHECK-NEXT: lsr x10, x9, #32
514-
; CHECK-NEXT: lsr x11, x8, #32
515-
; CHECK-NEXT: eor w9, w9, w0
516-
; CHECK-NEXT: eor w8, w8, w1
517-
; CHECK-NEXT: eor w9, w9, w10
518-
; CHECK-NEXT: eor w8, w8, w11
519-
; CHECK-NEXT: eor w0, w9, w8
487+
; CHECK-NEXT: eor w8, w9, w8
488+
; CHECK-NEXT: eor w0, w8, w10
520489
; CHECK-NEXT: ret
521490
%r1 = call i32 @llvm.vector.reduce.xor.v4i32(<4 x i32> %a)
522491
%a1 = xor i32 %r1, %c
@@ -529,14 +498,11 @@ define i32 @nested_xor_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
529498
define i32 @nested_smin_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
530499
; CHECK-LABEL: nested_smin_i32:
531500
; CHECK: // %bb.0:
501+
; CHECK-NEXT: smin v0.4s, v0.4s, v1.4s
502+
; CHECK-NEXT: cmp w0, w1
503+
; CHECK-NEXT: csel w8, w0, w1, lt
532504
; CHECK-NEXT: sminv s0, v0.4s
533-
; CHECK-NEXT: sminv s1, v1.4s
534505
; CHECK-NEXT: fmov w9, s0
535-
; CHECK-NEXT: fmov w8, s1
536-
; CHECK-NEXT: cmp w9, w0
537-
; CHECK-NEXT: csel w9, w9, w0, lt
538-
; CHECK-NEXT: cmp w8, w1
539-
; CHECK-NEXT: csel w8, w8, w1, lt
540506
; CHECK-NEXT: cmp w9, w8
541507
; CHECK-NEXT: csel w0, w9, w8, lt
542508
; CHECK-NEXT: ret
@@ -551,14 +517,11 @@ define i32 @nested_smin_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
551517
define i32 @nested_smax_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
552518
; CHECK-LABEL: nested_smax_i32:
553519
; CHECK: // %bb.0:
520+
; CHECK-NEXT: smax v0.4s, v0.4s, v1.4s
521+
; CHECK-NEXT: cmp w0, w1
522+
; CHECK-NEXT: csel w8, w0, w1, gt
554523
; CHECK-NEXT: smaxv s0, v0.4s
555-
; CHECK-NEXT: smaxv s1, v1.4s
556524
; CHECK-NEXT: fmov w9, s0
557-
; CHECK-NEXT: fmov w8, s1
558-
; CHECK-NEXT: cmp w9, w0
559-
; CHECK-NEXT: csel w9, w9, w0, gt
560-
; CHECK-NEXT: cmp w8, w1
561-
; CHECK-NEXT: csel w8, w8, w1, gt
562525
; CHECK-NEXT: cmp w9, w8
563526
; CHECK-NEXT: csel w0, w9, w8, gt
564527
; CHECK-NEXT: ret
@@ -573,14 +536,11 @@ define i32 @nested_smax_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
573536
define i32 @nested_umin_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
574537
; CHECK-LABEL: nested_umin_i32:
575538
; CHECK: // %bb.0:
539+
; CHECK-NEXT: umin v0.4s, v0.4s, v1.4s
540+
; CHECK-NEXT: cmp w0, w1
541+
; CHECK-NEXT: csel w8, w0, w1, lo
576542
; CHECK-NEXT: uminv s0, v0.4s
577-
; CHECK-NEXT: uminv s1, v1.4s
578543
; CHECK-NEXT: fmov w9, s0
579-
; CHECK-NEXT: fmov w8, s1
580-
; CHECK-NEXT: cmp w9, w0
581-
; CHECK-NEXT: csel w9, w9, w0, lo
582-
; CHECK-NEXT: cmp w8, w1
583-
; CHECK-NEXT: csel w8, w8, w1, lo
584544
; CHECK-NEXT: cmp w9, w8
585545
; CHECK-NEXT: csel w0, w9, w8, lo
586546
; CHECK-NEXT: ret
@@ -595,14 +555,11 @@ define i32 @nested_umin_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
595555
define i32 @nested_umax_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
596556
; CHECK-LABEL: nested_umax_i32:
597557
; CHECK: // %bb.0:
558+
; CHECK-NEXT: umax v0.4s, v0.4s, v1.4s
559+
; CHECK-NEXT: cmp w0, w1
560+
; CHECK-NEXT: csel w8, w0, w1, hi
598561
; CHECK-NEXT: umaxv s0, v0.4s
599-
; CHECK-NEXT: umaxv s1, v1.4s
600562
; CHECK-NEXT: fmov w9, s0
601-
; CHECK-NEXT: fmov w8, s1
602-
; CHECK-NEXT: cmp w9, w0
603-
; CHECK-NEXT: csel w9, w9, w0, hi
604-
; CHECK-NEXT: cmp w8, w1
605-
; CHECK-NEXT: csel w8, w8, w1, hi
606563
; CHECK-NEXT: cmp w9, w8
607564
; CHECK-NEXT: csel w0, w9, w8, hi
608565
; CHECK-NEXT: ret
@@ -617,11 +574,10 @@ define i32 @nested_umax_i32(<4 x i32> %a, <4 x i32> %b, i32 %c, i32 %d) {
617574
define float @nested_fmin_float(<4 x float> %a, <4 x float> %b, float %c, float %d) {
618575
; CHECK-LABEL: nested_fmin_float:
619576
; CHECK: // %bb.0:
620-
; CHECK-NEXT: fminnmv s1, v1.4s
577+
; CHECK-NEXT: fminnm v0.4s, v0.4s, v1.4s
578+
; CHECK-NEXT: fminnm s2, s2, s3
621579
; CHECK-NEXT: fminnmv s0, v0.4s
622-
; CHECK-NEXT: fminnm s1, s1, s3
623580
; CHECK-NEXT: fminnm s0, s0, s2
624-
; CHECK-NEXT: fminnm s0, s0, s1
625581
; CHECK-NEXT: ret
626582
%r1 = call float @llvm.vector.reduce.fmin.v4f32(<4 x float> %a)
627583
%a1 = call float @llvm.minnum.f32(float %r1, float %c)
@@ -634,11 +590,10 @@ define float @nested_fmin_float(<4 x float> %a, <4 x float> %b, float %c, float
634590
define float @nested_fmax_float(<4 x float> %a, <4 x float> %b, float %c, float %d) {
635591
; CHECK-LABEL: nested_fmax_float:
636592
; CHECK: // %bb.0:
637-
; CHECK-NEXT: fmaxnmv s1, v1.4s
593+
; CHECK-NEXT: fmaxnm v0.4s, v0.4s, v1.4s
594+
; CHECK-NEXT: fmaxnm s2, s2, s3
638595
; CHECK-NEXT: fmaxnmv s0, v0.4s
639-
; CHECK-NEXT: fmaxnm s1, s1, s3
640596
; CHECK-NEXT: fmaxnm s0, s0, s2
641-
; CHECK-NEXT: fmaxnm s0, s0, s1
642597
; CHECK-NEXT: ret
643598
%r1 = call float @llvm.vector.reduce.fmax.v4f32(<4 x float> %a)
644599
%a1 = call float @llvm.maxnum.f32(float %r1, float %c)

0 commit comments

Comments
 (0)