@@ -619,7 +619,7 @@ namespace {
619
619
SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
620
620
const TargetLowering &TLI);
621
621
SDValue foldPartialReduceMLAMulOp(SDNode *N);
622
- SDValue foldPartialReduceMLANoMulOp (SDNode *N);
622
+ SDValue foldPartialReduceAdd (SDNode *N);
623
623
624
624
SDValue CombineExtLoad(SDNode *N);
625
625
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -12617,7 +12617,7 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
12617
12617
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12618
12618
if (SDValue Res = foldPartialReduceMLAMulOp(N))
12619
12619
return Res;
12620
- if (SDValue Res = foldPartialReduceMLANoMulOp (N))
12620
+ if (SDValue Res = foldPartialReduceAdd (N))
12621
12621
return Res;
12622
12622
return SDValue();
12623
12623
}
@@ -12635,10 +12635,9 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
12635
12635
SDValue Op1 = N->getOperand(1);
12636
12636
SDValue Op2 = N->getOperand(2);
12637
12637
12638
- APInt ConstantOne ;
12638
+ APInt C ;
12639
12639
if (Op1->getOpcode() != ISD::MUL ||
12640
- !ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
12641
- !ConstantOne.isOne())
12640
+ !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
12642
12641
return SDValue();
12643
12642
12644
12643
SDValue LHS = Op1->getOperand(0);
@@ -12679,11 +12678,11 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
12679
12678
RHSExtOp);
12680
12679
}
12681
12680
12682
- // Makes partial.reduce.umla(acc, zext(op1 ), splat(1)) into
12683
- // partial.reduce.umla(acc, op, splat(trunc(1)))
12684
- // Makes partial.reduce.smla(acc, sext(op1 ), splat(1)) into
12685
- // partial.reduce.smla(acc, op, splat(trunc(1)))
12686
- SDValue DAGCombiner::foldPartialReduceMLANoMulOp (SDNode *N) {
12681
+ // partial.reduce.umla(acc, zext(op ), splat(1))
12682
+ // -> partial.reduce.umla(acc, op, splat(trunc(1)))
12683
+ // partial.reduce.smla(acc, sext(op ), splat(1))
12684
+ // -> partial.reduce.smla(acc, op, splat(trunc(1)))
12685
+ SDValue DAGCombiner::foldPartialReduceAdd (SDNode *N) {
12687
12686
SDLoc DL(N);
12688
12687
SDValue Acc = N->getOperand(0);
12689
12688
SDValue Op1 = N->getOperand(1);
@@ -12700,25 +12699,20 @@ SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
12700
12699
12701
12700
SDValue UnextOp1 = Op1.getOperand(0);
12702
12701
EVT UnextOp1VT = UnextOp1.getValueType();
12703
-
12704
12702
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
12705
12703
return SDValue();
12706
12704
12707
- SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2);
12708
-
12709
12705
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12710
-
12711
12706
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
12712
12707
EVT AccElemVT = Acc.getValueType().getVectorElementType();
12713
12708
if (Op1IsSigned != NodeIsSigned &&
12714
- (Op1.getValueType().getVectorElementType() != AccElemVT ||
12715
- Op2.getValueType().getVectorElementType() != AccElemVT))
12709
+ Op1.getValueType().getVectorElementType() != AccElemVT)
12716
12710
return SDValue();
12717
12711
12718
12712
unsigned NewOpcode =
12719
12713
Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12720
12714
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
12721
- TruncOp2 );
12715
+ DAG.getConstant(1, DL, UnextOp1VT) );
12722
12716
}
12723
12717
12724
12718
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
0 commit comments