@@ -619,7 +619,7 @@ namespace {
619
619
SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
620
620
const TargetLowering &TLI);
621
621
SDValue foldPartialReduceMLAMulOp(SDNode *N);
622
- SDValue foldPartialReduceMLANoMulOp (SDNode *N);
622
+ SDValue foldPartialReduceAdd (SDNode *N);
623
623
624
624
SDValue CombineExtLoad(SDNode *N);
625
625
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -12606,7 +12606,7 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
12606
12606
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12607
12607
if (SDValue Res = foldPartialReduceMLAMulOp(N))
12608
12608
return Res;
12609
- if (SDValue Res = foldPartialReduceMLANoMulOp (N))
12609
+ if (SDValue Res = foldPartialReduceAdd (N))
12610
12610
return Res;
12611
12611
return SDValue();
12612
12612
}
@@ -12682,11 +12682,11 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
12682
12682
RHSExtOp);
12683
12683
}
12684
12684
12685
- // Makes partial.reduce.umla(acc, zext(op1 ), splat(1)) into
12686
- // partial.reduce.umla(acc, op, splat(trunc(1)))
12687
- // Makes partial.reduce.smla(acc, sext(op1 ), splat(1)) into
12688
- // partial.reduce.smla(acc, op, splat(trunc(1)))
12689
- SDValue DAGCombiner::foldPartialReduceMLANoMulOp (SDNode *N) {
12685
+ // partial.reduce.umla(acc, zext(op ), splat(1))
12686
+ // -> partial.reduce.umla(acc, op, splat(trunc(1)))
12687
+ // partial.reduce.smla(acc, sext(op ), splat(1))
12688
+ // -> partial.reduce.smla(acc, op, splat(trunc(1)))
12689
+ SDValue DAGCombiner::foldPartialReduceAdd (SDNode *N) {
12690
12690
SDLoc DL(N);
12691
12691
SDValue Acc = N->getOperand(0);
12692
12692
SDValue Op1 = N->getOperand(1);
@@ -12703,25 +12703,20 @@ SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
12703
12703
12704
12704
SDValue UnextOp1 = Op1.getOperand(0);
12705
12705
EVT UnextOp1VT = UnextOp1.getValueType();
12706
-
12707
12706
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
12708
12707
return SDValue();
12709
12708
12710
- SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2);
12711
-
12712
12709
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12713
-
12714
12710
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
12715
12711
EVT AccElemVT = Acc.getValueType().getVectorElementType();
12716
12712
if (Op1IsSigned != NodeIsSigned &&
12717
- (Op1.getValueType().getVectorElementType() != AccElemVT ||
12718
- Op2.getValueType().getVectorElementType() != AccElemVT))
12713
+ Op1.getValueType().getVectorElementType() != AccElemVT)
12719
12714
return SDValue();
12720
12715
12721
12716
unsigned NewOpcode =
12722
12717
Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12723
12718
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
12724
- TruncOp2 );
12719
+ DAG.getConstant(1, DL, UnextOp1VT) );
12725
12720
}
12726
12721
12727
12722
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
0 commit comments