Skip to content

Commit 406041f

Browse files
Split the DAG combine into two.
Also make sure the DAG combine is only done when the action for partial reductions have a type combination which is either Legal or Custom. This ensures that the combines are not performed only for the resulting DAG to be expanded, as this leads to worse Code Gen.
1 parent 5053db6 commit 406041f

File tree

6 files changed

+165
-112
lines changed

6 files changed

+165
-112
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1639,6 +1639,25 @@ class TargetLoweringBase {
16391639
getCondCodeAction(CC, VT) == Custom;
16401640
}
16411641

1642+
/// Return how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input type
1643+
/// InputVT should be treated. Either it's legal, needs to be promoted to a
1644+
/// larger size, needs to be expanded to some other code sequence, or the
1645+
/// target has a custom expander for it.
1646+
LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
1647+
unsigned AccI = (unsigned)AccVT.getSimpleVT().SimpleTy;
1648+
unsigned InputI = (unsigned)InputVT.getSimpleVT().SimpleTy;
1649+
assert(AccI < MVT::VALUETYPE_SIZE && InputI < MVT::VALUETYPE_SIZE &&
1650+
"Table isn't big enough!");
1651+
return PartialReduceMLAActions[AccI][InputI];
1652+
}
1653+
1654+
/// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
1655+
/// legal or custom for this target.
1656+
bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
1657+
return getPartialReduceMLAAction(AccVT, InputVT) == Legal ||
1658+
getPartialReduceMLAAction(AccVT, InputVT) == Custom;
1659+
}
1660+
16421661
/// If the action for this operation is to promote, this method returns the
16431662
/// ValueType to promote to.
16441663
MVT getTypeToPromoteTo(unsigned Op, MVT VT) const {
@@ -2704,6 +2723,16 @@ class TargetLoweringBase {
27042723
setCondCodeAction(CCs, VT, Action);
27052724
}
27062725

2726+
/// Indicate how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input
2727+
/// type InputVT should be treated by the target. Either it's legal, needs to
2728+
/// be promoted to a larger size, needs to be expanded to some other code
2729+
/// sequence, or the target has a custom expander for it.
2730+
void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
2731+
LegalizeAction Action) {
2732+
assert(AccVT.isValid() && InputVT.isValid() && "Table isn't big enough!");
2733+
PartialReduceMLAActions[AccVT.SimpleTy][InputVT.SimpleTy] = Action;
2734+
}
2735+
27072736
/// If Opc/OrigVT is specified as being promoted, the promotion code defaults
27082737
/// to trying a larger integer/fp until it can find one that works. If that
27092738
/// default is insufficient, this method can be used by the target to override
@@ -3650,6 +3679,12 @@ class TargetLoweringBase {
36503679
/// up the MVT::VALUETYPE_SIZE value to the next multiple of 8.
36513680
uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];
36523681

3682+
/// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
3683+
/// nodes, keep a LegalizeAction which indicates how instruction selection
3684+
/// should deal with this operation.
3685+
LegalizeAction PartialReduceMLAActions[MVT::VALUETYPE_SIZE]
3686+
[MVT::VALUETYPE_SIZE];
3687+
36533688
ValueTypeActionImpl ValueTypeActions;
36543689

36553690
private:

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,8 @@ namespace {
622622
SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
623623
SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
624624
const TargetLowering &TLI);
625+
SDValue foldMulPARTIAL_REDUCE_MLA(SDNode *N);
626+
SDValue foldExtendPARTIAL_REDUCE_MLA(SDNode *N);
625627

626628
SDValue CombineExtLoad(SDNode *N);
627629
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -12502,18 +12504,45 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
1250212504
}
1250312505

1250412506
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12505-
// Makes PARTIAL_REDUCE_MLA(Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), Splat(1))
12506-
// into PARTIAL_REDUCE_MLA(Acc, MulOpLHS, MulOpRHS)
12507+
// Only perform the DAG combine if there is custom lowering provided by the
12508+
// target.
12509+
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0),
12510+
N->getOperand(1).getValueType()))
12511+
return SDValue();
12512+
12513+
if (SDValue Res = foldMulPARTIAL_REDUCE_MLA(N))
12514+
return Res;
12515+
if (SDValue Res = foldExtendPARTIAL_REDUCE_MLA(N))
12516+
return Res;
12517+
return SDValue();
12518+
}
12519+
12520+
SDValue DAGCombiner::foldMulPARTIAL_REDUCE_MLA(SDNode *N) {
12521+
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(MulOpLHS, MulOpRHS), Splat(1)) into
12522+
// PARTIAL_REDUCE_*MLA(Acc, MulOpLHS, MulOpRHS)
1250712523
SDLoc DL(N);
12508-
SDValue Op0 = N->getOperand(0);
12509-
SDValue Op1 = N->getOperand(1);
12510-
SDValue Op2 = N->getOperand(2);
1251112524

12525+
SDValue Op1 = N->getOperand(1);
1251212526
if (Op1->getOpcode() != ISD::MUL)
1251312527
return SDValue();
1251412528

12515-
SDValue ExtMulOpLHS = Op1->getOperand(0);
12516-
SDValue ExtMulOpRHS = Op1->getOperand(1);
12529+
APInt ConstantOne;
12530+
if (!ISD::isConstantSplatVector(N->getOperand(2).getNode(), ConstantOne) ||
12531+
!ConstantOne.isOne())
12532+
return SDValue();
12533+
12534+
return DAG.getNode(N->getOpcode(), DL, N->getValueType(0), N->getOperand(0),
12535+
Op1->getOperand(0), Op1->getOperand(1));
12536+
}
12537+
12538+
SDValue DAGCombiner::foldExtendPARTIAL_REDUCE_MLA(SDNode *N) {
12539+
// Makes PARTIAL_REDUCE_*MLA(Acc, ZEXT(MulOpLHS), ZEXT(MulOpRHS)) into
12540+
// PARTIAL_REDUCE_UMLA(Acc, MulOpLHS, MulOpRHS) and
12541+
// PARTIAL_REDUCE_*MLA(Acc, SEXT(MulOpLHS), SEXT(MulOpRHS)) into
12542+
// PARTIAL_REDUCE_SMLA(Acc, MulOpLHS, MulOpRHS)
12543+
SDLoc DL(N);
12544+
SDValue ExtMulOpLHS = N->getOperand(1);
12545+
SDValue ExtMulOpRHS = N->getOperand(2);
1251712546
unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
1251812547
unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
1251912548
if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
@@ -12526,23 +12555,15 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1252612555
if (MulOpLHSVT != MulOpRHS.getValueType())
1252712556
return SDValue();
1252812557

12529-
if (!TLI.isTypeLegal(MulOpLHSVT) || !TLI.isTypeLegal(N->getValueType(0)))
12530-
return SDValue();
12531-
12532-
APInt ConstantOne;
12533-
if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
12534-
!ConstantOne.isOne())
12535-
return SDValue();
12536-
1253712558
bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
1253812559
bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
1253912560
if (LHSIsSigned != RHSIsSigned)
1254012561
return SDValue();
1254112562

1254212563
unsigned NewOpcode =
1254312564
LHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12544-
return DAG.getNode(NewOpcode, DL, Op0->getValueType(0), Op0, MulOpLHS,
12545-
MulOpRHS);
12565+
return DAG.getNode(NewOpcode, DL, N->getValueType(0), N->getOperand(0),
12566+
MulOpLHS, MulOpRHS);
1254612567
}
1254712568

1254812569
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,8 +469,6 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
469469
case ISD::VECTOR_COMPRESS:
470470
case ISD::SCMP:
471471
case ISD::UCMP:
472-
case ISD::PARTIAL_REDUCE_UMLA:
473-
case ISD::PARTIAL_REDUCE_SMLA:
474472
Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
475473
break;
476474
case ISD::SMULFIX:
@@ -524,6 +522,11 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
524522
Action = TLI.getOperationAction(Node->getOpcode(), OpVT);
525523
break;
526524
}
525+
case ISD::PARTIAL_REDUCE_UMLA:
526+
case ISD::PARTIAL_REDUCE_SMLA:
527+
Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
528+
Node->getOperand(1).getValueType());
529+
break;
527530

528531
#define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...) \
529532
case ISD::VPID: { \

llvm/lib/CodeGen/TargetLoweringBase.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -836,9 +836,8 @@ void TargetLoweringBase::initActions() {
836836
setOperationAction(ISD::SET_FPENV, VT, Expand);
837837
setOperationAction(ISD::RESET_FPENV, VT, Expand);
838838

839-
// PartialReduceMLA operations default to expand.
840-
setOperationAction({ISD::PARTIAL_REDUCE_UMLA, ISD::PARTIAL_REDUCE_SMLA}, VT,
841-
Expand);
839+
for (MVT InputVT : MVT::all_valuetypes())
840+
setPartialReduceMLAAction(VT, InputVT, Expand);
842841
}
843842

844843
// Most targets ignore the @llvm.prefetch intrinsic.

llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll

Lines changed: 33 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
1212
;
1313
; CHECK-NODOT-LABEL: udot:
1414
; CHECK-NODOT: // %bb.0:
15-
; CHECK-NODOT-NEXT: ushll v3.8h, v1.8b, #0
16-
; CHECK-NODOT-NEXT: ushll v4.8h, v2.8b, #0
17-
; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0
18-
; CHECK-NODOT-NEXT: ushll2 v2.8h, v2.16b, #0
19-
; CHECK-NODOT-NEXT: umlal v0.4s, v4.4h, v3.4h
20-
; CHECK-NODOT-NEXT: umull v5.4s, v2.4h, v1.4h
21-
; CHECK-NODOT-NEXT: umlal2 v0.4s, v2.8h, v1.8h
22-
; CHECK-NODOT-NEXT: umlal2 v5.4s, v4.8h, v3.8h
23-
; CHECK-NODOT-NEXT: add v0.4s, v5.4s, v0.4s
15+
; CHECK-NODOT-NEXT: umull v3.8h, v2.8b, v1.8b
16+
; CHECK-NODOT-NEXT: umull2 v1.8h, v2.16b, v1.16b
17+
; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
18+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v3.4h
19+
; CHECK-NODOT-NEXT: uaddw2 v2.4s, v2.4s, v3.8h
20+
; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
21+
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
2422
; CHECK-NODOT-NEXT: ret
2523
%u.wide = zext <16 x i8> %u to <16 x i32>
2624
%s.wide = zext <16 x i8> %s to <16 x i32>
@@ -37,19 +35,17 @@ define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
3735
;
3836
; CHECK-NODOT-LABEL: udot_narrow:
3937
; CHECK-NODOT: // %bb.0:
40-
; CHECK-NODOT-NEXT: ushll v1.8h, v1.8b, #0
41-
; CHECK-NODOT-NEXT: ushll v2.8h, v2.8b, #0
38+
; CHECK-NODOT-NEXT: umull v1.8h, v2.8b, v1.8b
4239
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
43-
; CHECK-NODOT-NEXT: umull v3.4s, v2.4h, v1.4h
44-
; CHECK-NODOT-NEXT: umull2 v4.4s, v2.8h, v1.8h
45-
; CHECK-NODOT-NEXT: ext v5.16b, v1.16b, v1.16b, #8
46-
; CHECK-NODOT-NEXT: ext v6.16b, v2.16b, v2.16b, #8
47-
; CHECK-NODOT-NEXT: umlal v0.4s, v2.4h, v1.4h
40+
; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
41+
; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0
42+
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
43+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
4844
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
49-
; CHECK-NODOT-NEXT: ext v1.16b, v4.16b, v4.16b, #8
50-
; CHECK-NODOT-NEXT: umlal v3.4s, v6.4h, v5.4h
51-
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
45+
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
5246
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
47+
; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
48+
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
5349
; CHECK-NODOT-NEXT: ret
5450
%u.wide = zext <8 x i8> %u to <8 x i32>
5551
%s.wide = zext <8 x i8> %s to <8 x i32>
@@ -66,15 +62,13 @@ define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
6662
;
6763
; CHECK-NODOT-LABEL: sdot:
6864
; CHECK-NODOT: // %bb.0:
69-
; CHECK-NODOT-NEXT: sshll v3.8h, v1.8b, #0
70-
; CHECK-NODOT-NEXT: sshll v4.8h, v2.8b, #0
71-
; CHECK-NODOT-NEXT: sshll2 v1.8h, v1.16b, #0
72-
; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0
73-
; CHECK-NODOT-NEXT: smlal v0.4s, v4.4h, v3.4h
74-
; CHECK-NODOT-NEXT: smull v5.4s, v2.4h, v1.4h
75-
; CHECK-NODOT-NEXT: smlal2 v0.4s, v2.8h, v1.8h
76-
; CHECK-NODOT-NEXT: smlal2 v5.4s, v4.8h, v3.8h
77-
; CHECK-NODOT-NEXT: add v0.4s, v5.4s, v0.4s
65+
; CHECK-NODOT-NEXT: smull v3.8h, v2.8b, v1.8b
66+
; CHECK-NODOT-NEXT: smull2 v1.8h, v2.16b, v1.16b
67+
; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
68+
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v3.4h
69+
; CHECK-NODOT-NEXT: saddw2 v2.4s, v2.4s, v3.8h
70+
; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
71+
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
7872
; CHECK-NODOT-NEXT: ret
7973
%u.wide = sext <16 x i8> %u to <16 x i32>
8074
%s.wide = sext <16 x i8> %s to <16 x i32>
@@ -91,19 +85,17 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
9185
;
9286
; CHECK-NODOT-LABEL: sdot_narrow:
9387
; CHECK-NODOT: // %bb.0:
94-
; CHECK-NODOT-NEXT: sshll v1.8h, v1.8b, #0
95-
; CHECK-NODOT-NEXT: sshll v2.8h, v2.8b, #0
88+
; CHECK-NODOT-NEXT: smull v1.8h, v2.8b, v1.8b
9689
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
97-
; CHECK-NODOT-NEXT: smull v3.4s, v2.4h, v1.4h
98-
; CHECK-NODOT-NEXT: smull2 v4.4s, v2.8h, v1.8h
99-
; CHECK-NODOT-NEXT: ext v5.16b, v1.16b, v1.16b, #8
100-
; CHECK-NODOT-NEXT: ext v6.16b, v2.16b, v2.16b, #8
101-
; CHECK-NODOT-NEXT: smlal v0.4s, v2.4h, v1.4h
90+
; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
91+
; CHECK-NODOT-NEXT: sshll2 v3.4s, v1.8h, #0
92+
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
93+
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v1.4h
10294
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
103-
; CHECK-NODOT-NEXT: ext v1.16b, v4.16b, v4.16b, #8
104-
; CHECK-NODOT-NEXT: smlal v3.4s, v6.4h, v5.4h
105-
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
95+
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
10696
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
97+
; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
98+
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
10799
; CHECK-NODOT-NEXT: ret
108100
%u.wide = sext <8 x i8> %u to <8 x i32>
109101
%s.wide = sext <8 x i8> %s to <8 x i32>
@@ -539,10 +531,9 @@ define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
539531
define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
540532
; CHECK-LABEL: not_udot:
541533
; CHECK: // %bb.0:
542-
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
543-
; CHECK-NEXT: ushll v2.8h, v2.8b, #0
544-
; CHECK-NEXT: umlal v0.4s, v2.4h, v1.4h
545-
; CHECK-NEXT: umlal2 v0.4s, v2.8h, v1.8h
534+
; CHECK-NEXT: umull v1.8h, v2.8b, v1.8b
535+
; CHECK-NEXT: uaddw v0.4s, v0.4s, v1.4h
536+
; CHECK-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
546537
; CHECK-NEXT: ret
547538
%u.wide = zext <8 x i8> %u to <8 x i32>
548539
%s.wide = zext <8 x i8> %s to <8 x i32>

0 commit comments

Comments
 (0)