Skip to content

Commit bc54324

Browse files
committed
Address nits
1 parent 505eda7 commit bc54324

File tree

1 file changed

+9
-14
lines changed

1 file changed

+9
-14
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ namespace {
619619
SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
620620
const TargetLowering &TLI);
621621
SDValue foldPartialReduceMLAMulOp(SDNode *N);
622-
SDValue foldPartialReduceMLANoMulOp(SDNode *N);
622+
SDValue foldPartialReduceAdd(SDNode *N);
623623

624624
SDValue CombineExtLoad(SDNode *N);
625625
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -12606,7 +12606,7 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
1260612606
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1260712607
if (SDValue Res = foldPartialReduceMLAMulOp(N))
1260812608
return Res;
12609-
if (SDValue Res = foldPartialReduceMLANoMulOp(N))
12609+
if (SDValue Res = foldPartialReduceAdd(N))
1261012610
return Res;
1261112611
return SDValue();
1261212612
}
@@ -12682,11 +12682,11 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1268212682
RHSExtOp);
1268312683
}
1268412684

12685-
// Makes partial.reduce.umla(acc, zext(op1), splat(1)) into
12686-
// partial.reduce.umla(acc, op, splat(trunc(1)))
12687-
// Makes partial.reduce.smla(acc, sext(op1), splat(1)) into
12688-
// partial.reduce.smla(acc, op, splat(trunc(1)))
12689-
SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
12685+
// partial.reduce.umla(acc, zext(op), splat(1))
12686+
// -> partial.reduce.umla(acc, op, splat(trunc(1)))
12687+
// partial.reduce.smla(acc, sext(op), splat(1))
12688+
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
12689+
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
1269012690
SDLoc DL(N);
1269112691
SDValue Acc = N->getOperand(0);
1269212692
SDValue Op1 = N->getOperand(1);
@@ -12703,25 +12703,20 @@ SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
1270312703

1270412704
SDValue UnextOp1 = Op1.getOperand(0);
1270512705
EVT UnextOp1VT = UnextOp1.getValueType();
12706-
1270712706
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
1270812707
return SDValue();
1270912708

12710-
SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2);
12711-
1271212709
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12713-
1271412710
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
1271512711
EVT AccElemVT = Acc.getValueType().getVectorElementType();
1271612712
if (Op1IsSigned != NodeIsSigned &&
12717-
(Op1.getValueType().getVectorElementType() != AccElemVT ||
12718-
Op2.getValueType().getVectorElementType() != AccElemVT))
12713+
Op1.getValueType().getVectorElementType() != AccElemVT)
1271912714
return SDValue();
1272012715

1272112716
unsigned NewOpcode =
1272212717
Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
1272312718
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
12724-
TruncOp2);
12719+
DAG.getConstant(1, DL, UnextOp1VT));
1272512720
}
1272612721

1272712722
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {

0 commit comments

Comments
 (0)