Skip to content

Commit ffd894f

Browse files
committed
Address nits
1 parent a708b96 commit ffd894f

File tree

1 file changed

+11
-17
lines changed

1 file changed

+11
-17
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ namespace {
619619
SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
620620
const TargetLowering &TLI);
621621
SDValue foldPartialReduceMLAMulOp(SDNode *N);
622-
SDValue foldPartialReduceMLANoMulOp(SDNode *N);
622+
SDValue foldPartialReduceAdd(SDNode *N);
623623

624624
SDValue CombineExtLoad(SDNode *N);
625625
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -12617,7 +12617,7 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
1261712617
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1261812618
if (SDValue Res = foldPartialReduceMLAMulOp(N))
1261912619
return Res;
12620-
if (SDValue Res = foldPartialReduceMLANoMulOp(N))
12620+
if (SDValue Res = foldPartialReduceAdd(N))
1262112621
return Res;
1262212622
return SDValue();
1262312623
}
@@ -12635,10 +12635,9 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1263512635
SDValue Op1 = N->getOperand(1);
1263612636
SDValue Op2 = N->getOperand(2);
1263712637

12638-
APInt ConstantOne;
12638+
APInt C;
1263912639
if (Op1->getOpcode() != ISD::MUL ||
12640-
!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
12641-
!ConstantOne.isOne())
12640+
!ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
1264212641
return SDValue();
1264312642

1264412643
SDValue LHS = Op1->getOperand(0);
@@ -12679,11 +12678,11 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1267912678
RHSExtOp);
1268012679
}
1268112680

12682-
// Makes partial.reduce.umla(acc, zext(op1), splat(1)) into
12683-
// partial.reduce.umla(acc, op, splat(trunc(1)))
12684-
// Makes partial.reduce.smla(acc, sext(op1), splat(1)) into
12685-
// partial.reduce.smla(acc, op, splat(trunc(1)))
12686-
SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
12681+
// partial.reduce.umla(acc, zext(op), splat(1))
12682+
// -> partial.reduce.umla(acc, op, splat(trunc(1)))
12683+
// partial.reduce.smla(acc, sext(op), splat(1))
12684+
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
12685+
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
1268712686
SDLoc DL(N);
1268812687
SDValue Acc = N->getOperand(0);
1268912688
SDValue Op1 = N->getOperand(1);
@@ -12700,25 +12699,20 @@ SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
1270012699

1270112700
SDValue UnextOp1 = Op1.getOperand(0);
1270212701
EVT UnextOp1VT = UnextOp1.getValueType();
12703-
1270412702
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
1270512703
return SDValue();
1270612704

12707-
SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2);
12708-
1270912705
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12710-
1271112706
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
1271212707
EVT AccElemVT = Acc.getValueType().getVectorElementType();
1271312708
if (Op1IsSigned != NodeIsSigned &&
12714-
(Op1.getValueType().getVectorElementType() != AccElemVT ||
12715-
Op2.getValueType().getVectorElementType() != AccElemVT))
12709+
Op1.getValueType().getVectorElementType() != AccElemVT)
1271612710
return SDValue();
1271712711

1271812712
unsigned NewOpcode =
1271912713
Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
1272012714
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
12721-
TruncOp2);
12715+
DAG.getConstant(1, DL, UnextOp1VT));
1272212716
}
1272312717

1272412718
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {

0 commit comments

Comments
 (0)