Skip to content

Commit ac63453

Browse files
JamesChestermanNickGuy-Arm
authored andcommitted
[DAGCombiner] Add DAG combine for PARTIAL_REDUCE_MLA when no mul op
Generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA to convert: PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_UMLA(Acc, UnextOp1, TRUNC(Splat(1))) and PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_SMLA(Acc, UnextOp1, TRUNC(Splat(1))).
1 parent e4a8969 commit ac63453

File tree

2 files changed

+73
-83
lines changed

2 files changed

+73
-83
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,8 @@ namespace {
618618
SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
619619
SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
620620
const TargetLowering &TLI);
621+
SDValue foldPartialReduceMLAMulOp(SDNode *N);
622+
SDValue foldPartialReduceMLANoMulOp(SDNode *N);
621623

622624
SDValue CombineExtLoad(SDNode *N);
623625
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -12498,13 +12500,21 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
1249812500
return SDValue();
1249912501
}
1250012502

12503+
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12504+
if (SDValue Res = foldPartialReduceMLAMulOp(N))
12505+
return Res;
12506+
if (SDValue Res = foldPartialReduceMLANoMulOp(N))
12507+
return Res;
12508+
return SDValue();
12509+
}
12510+
1250112511
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
1250212512
// Splat(1)) into
1250312513
// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
1250412514
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
1250512515
// Splat(1)) into
1250612516
// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
12507-
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12517+
SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1250812518
SDLoc DL(N);
1250912519

1251012520
SDValue Acc = N->getOperand(0);
@@ -12550,6 +12560,48 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1255012560
RHSExtOp);
1255112561
}
1255212562

12563+
// Makes PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into
12564+
// PARTIAL_REDUCE_UMLA(Acc, Op, TRUNC(Splat(1)))
12565+
// Makes PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into
12566+
// PARTIAL_REDUCE_SMLA(Acc, Op, TRUNC(Splat(1)))
12567+
SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
12568+
SDLoc DL(N);
12569+
SDValue Acc = N->getOperand(0);
12570+
SDValue Op1 = N->getOperand(1);
12571+
SDValue Op2 = N->getOperand(2);
12572+
12573+
APInt ConstantOne;
12574+
if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
12575+
!ConstantOne.isOne())
12576+
return SDValue();
12577+
12578+
unsigned Op1Opcode = Op1.getOpcode();
12579+
if (!ISD::isExtOpcode(Op1Opcode))
12580+
return SDValue();
12581+
12582+
SDValue UnextOp1 = Op1.getOperand(0);
12583+
EVT UnextOp1VT = UnextOp1.getValueType();
12584+
12585+
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
12586+
return SDValue();
12587+
12588+
SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2);
12589+
12590+
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12591+
12592+
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
12593+
EVT AccElemVT = Acc.getValueType().getVectorElementType();
12594+
if (Op1IsSigned != NodeIsSigned &&
12595+
(Op1.getValueType().getVectorElementType() != AccElemVT ||
12596+
Op2.getValueType().getVectorElementType() != AccElemVT))
12597+
return SDValue();
12598+
12599+
unsigned NewOpcode =
12600+
Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12601+
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
12602+
TruncOp2);
12603+
}
12604+
1255312605
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
1255412606
auto *SLD = cast<VPStridedLoadSDNode>(N);
1255512607
EVT EltVT = SLD->getValueType(0).getVectorElementType();

llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Lines changed: 20 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -662,16 +662,8 @@ define <vscale x 4 x i32> @udot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
662662
;
663663
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op:
664664
; CHECK-NEWLOWERING: // %bb.0:
665-
; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z1.b
666-
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
667-
; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z2.h
668-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z1.h
669-
; CHECK-NEWLOWERING-NEXT: uunpklo z1.s, z1.h
670-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
671-
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
672-
; CHECK-NEWLOWERING-NEXT: add z1.s, z2.s, z1.s
673-
; CHECK-NEWLOWERING-NEXT: add z0.s, z4.s, z0.s
674-
; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
665+
; CHECK-NEWLOWERING-NEXT: mov z2.b, #1 // =0x1
666+
; CHECK-NEWLOWERING-NEXT: udot z0.s, z1.b, z2.b
675667
; CHECK-NEWLOWERING-NEXT: ret
676668
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
677669
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
@@ -687,16 +679,8 @@ define <vscale x 4 x i32> @sdot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
687679
;
688680
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op:
689681
; CHECK-NEWLOWERING: // %bb.0:
690-
; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z1.b
691-
; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
692-
; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z2.h
693-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z1.h
694-
; CHECK-NEWLOWERING-NEXT: sunpklo z1.s, z1.h
695-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
696-
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
697-
; CHECK-NEWLOWERING-NEXT: add z1.s, z2.s, z1.s
698-
; CHECK-NEWLOWERING-NEXT: add z0.s, z4.s, z0.s
699-
; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
682+
; CHECK-NEWLOWERING-NEXT: mov z2.b, #1 // =0x1
683+
; CHECK-NEWLOWERING-NEXT: sdot z0.s, z1.b, z2.b
700684
; CHECK-NEWLOWERING-NEXT: ret
701685
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
702686
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
@@ -712,16 +696,8 @@ define <vscale x 2 x i64> @udot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale
712696
;
713697
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_wide:
714698
; CHECK-NEWLOWERING: // %bb.0: // %entry
715-
; CHECK-NEWLOWERING-NEXT: uunpklo z2.s, z1.h
716-
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
717-
; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z2.s
718-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z1.s
719-
; CHECK-NEWLOWERING-NEXT: uunpklo z1.d, z1.s
720-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
721-
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z3.d
722-
; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
723-
; CHECK-NEWLOWERING-NEXT: add z0.d, z4.d, z0.d
724-
; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
699+
; CHECK-NEWLOWERING-NEXT: mov z2.h, #1 // =0x1
700+
; CHECK-NEWLOWERING-NEXT: udot z0.d, z1.h, z2.h
725701
; CHECK-NEWLOWERING-NEXT: ret
726702
entry:
727703
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
@@ -738,16 +714,8 @@ define <vscale x 2 x i64> @sdot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale
738714
;
739715
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_wide:
740716
; CHECK-NEWLOWERING: // %bb.0: // %entry
741-
; CHECK-NEWLOWERING-NEXT: sunpklo z2.s, z1.h
742-
; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
743-
; CHECK-NEWLOWERING-NEXT: sunpklo z3.d, z2.s
744-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z1.s
745-
; CHECK-NEWLOWERING-NEXT: sunpklo z1.d, z1.s
746-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
747-
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z3.d
748-
; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
749-
; CHECK-NEWLOWERING-NEXT: add z0.d, z4.d, z0.d
750-
; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
717+
; CHECK-NEWLOWERING-NEXT: mov z2.h, #1 // =0x1
718+
; CHECK-NEWLOWERING-NEXT: sdot z0.d, z1.h, z2.h
751719
; CHECK-NEWLOWERING-NEXT: ret
752720
entry:
753721
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
@@ -769,28 +737,13 @@ define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
769737
;
770738
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_8to64:
771739
; CHECK-NEWLOWERING: // %bb.0:
772-
; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z2.b
740+
; CHECK-NEWLOWERING-NEXT: mov z3.b, #1 // =0x1
741+
; CHECK-NEWLOWERING-NEXT: uunpklo z5.h, z2.b
773742
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
774-
; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z3.h
775-
; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z2.h
776-
; CHECK-NEWLOWERING-NEXT: uunpklo z2.s, z2.h
777-
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
778-
; CHECK-NEWLOWERING-NEXT: uunpkhi z6.d, z4.s
779-
; CHECK-NEWLOWERING-NEXT: uunpklo z4.d, z4.s
780-
; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z5.s
781-
; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z2.s
782-
; CHECK-NEWLOWERING-NEXT: uunpklo z25.d, z3.s
783-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
784-
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
785-
; CHECK-NEWLOWERING-NEXT: uunpklo z5.d, z5.s
786-
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z4.d
787-
; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z6.d
788-
; CHECK-NEWLOWERING-NEXT: add z4.d, z25.d, z24.d
789-
; CHECK-NEWLOWERING-NEXT: add z2.d, z3.d, z2.d
790-
; CHECK-NEWLOWERING-NEXT: add z0.d, z5.d, z0.d
791-
; CHECK-NEWLOWERING-NEXT: add z1.d, z7.d, z1.d
792-
; CHECK-NEWLOWERING-NEXT: add z0.d, z4.d, z0.d
793-
; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
743+
; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z3.b
744+
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.h, z3.b
745+
; CHECK-NEWLOWERING-NEXT: udot z0.d, z5.h, z4.h
746+
; CHECK-NEWLOWERING-NEXT: udot z1.d, z2.h, z3.h
794747
; CHECK-NEWLOWERING-NEXT: ret
795748
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
796749
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
@@ -811,28 +764,13 @@ define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
811764
;
812765
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_8to64:
813766
; CHECK-NEWLOWERING: // %bb.0:
814-
; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z2.b
767+
; CHECK-NEWLOWERING-NEXT: mov z3.b, #1 // =0x1
768+
; CHECK-NEWLOWERING-NEXT: sunpklo z5.h, z2.b
815769
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
816-
; CHECK-NEWLOWERING-NEXT: sunpklo z4.s, z3.h
817-
; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z2.h
818-
; CHECK-NEWLOWERING-NEXT: sunpklo z2.s, z2.h
819-
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
820-
; CHECK-NEWLOWERING-NEXT: sunpkhi z6.d, z4.s
821-
; CHECK-NEWLOWERING-NEXT: sunpklo z4.d, z4.s
822-
; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z5.s
823-
; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z2.s
824-
; CHECK-NEWLOWERING-NEXT: sunpklo z25.d, z3.s
825-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
826-
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
827-
; CHECK-NEWLOWERING-NEXT: sunpklo z5.d, z5.s
828-
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z4.d
829-
; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z6.d
830-
; CHECK-NEWLOWERING-NEXT: add z4.d, z25.d, z24.d
831-
; CHECK-NEWLOWERING-NEXT: add z2.d, z3.d, z2.d
832-
; CHECK-NEWLOWERING-NEXT: add z0.d, z5.d, z0.d
833-
; CHECK-NEWLOWERING-NEXT: add z1.d, z7.d, z1.d
834-
; CHECK-NEWLOWERING-NEXT: add z0.d, z4.d, z0.d
835-
; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
770+
; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z3.b
771+
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.h, z3.b
772+
; CHECK-NEWLOWERING-NEXT: sdot z0.d, z5.h, z4.h
773+
; CHECK-NEWLOWERING-NEXT: sdot z1.d, z2.h, z3.h
836774
; CHECK-NEWLOWERING-NEXT: ret
837775
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
838776
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)

0 commit comments

Comments
 (0)