-
Notifications
You must be signed in to change notification settings - Fork 13.9k
[DAGCombiner] Add generic DAG combine for ISD::PARTIAL_REDUCE_MLA #127083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Add generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA nodes. Transforms the DAG from: PARTIAL_REDUCE_MLA(Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), Splat(1)) to PARTIAL_REDUCE_MLA(Acc, MulOpLHS, MulOpRHS).
Also make sure the DAG combine is only done when the action for partial reductions have a type combination which is either Legal or Custom. This ensures that the combines are not performed only for the resulting DAG to be expanded, as this leads to worse Code Gen.
17fe9bd
to
406041f
Compare
@llvm/pr-subscribers-backend-aarch64 @llvm/pr-subscribers-llvm-selectiondag Author: James Chesterman (JamesChesterman) ChangesAdd generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA nodes. Transforms the DAG from: Full diff: https://github.com/llvm/llvm-project/pull/127083.diff 4 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index a4c3d042fe3a4..52e57365dceab 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1639,6 +1639,25 @@ class TargetLoweringBase {
getCondCodeAction(CC, VT) == Custom;
}
+ /// Return how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input type
+ /// InputVT should be treated. Either it's legal, needs to be promoted to a
+ /// larger size, needs to be expanded to some other code sequence, or the
+ /// target has a custom expander for it.
+ LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
+ unsigned AccI = (unsigned)AccVT.getSimpleVT().SimpleTy;
+ unsigned InputI = (unsigned)InputVT.getSimpleVT().SimpleTy;
+ assert(AccI < MVT::VALUETYPE_SIZE && InputI < MVT::VALUETYPE_SIZE &&
+ "Table isn't big enough!");
+ return PartialReduceMLAActions[AccI][InputI];
+ }
+
+ /// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
+ /// legal or custom for this target.
+ bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
+ return getPartialReduceMLAAction(AccVT, InputVT) == Legal ||
+ getPartialReduceMLAAction(AccVT, InputVT) == Custom;
+ }
+
/// If the action for this operation is to promote, this method returns the
/// ValueType to promote to.
MVT getTypeToPromoteTo(unsigned Op, MVT VT) const {
@@ -2704,6 +2723,16 @@ class TargetLoweringBase {
setCondCodeAction(CCs, VT, Action);
}
+ /// Indicate how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input
+ /// type InputVT should be treated by the target. Either it's legal, needs to
+ /// be promoted to a larger size, needs to be expanded to some other code
+ /// sequence, or the target has a custom expander for it.
+ void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
+ LegalizeAction Action) {
+ assert(AccVT.isValid() && InputVT.isValid() && "Table isn't big enough!");
+ PartialReduceMLAActions[AccVT.SimpleTy][InputVT.SimpleTy] = Action;
+ }
+
/// If Opc/OrigVT is specified as being promoted, the promotion code defaults
/// to trying a larger integer/fp until it can find one that works. If that
/// default is insufficient, this method can be used by the target to override
@@ -3650,6 +3679,12 @@ class TargetLoweringBase {
/// up the MVT::VALUETYPE_SIZE value to the next multiple of 8.
uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];
+ /// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
+ /// nodes, keep a LegalizeAction which indicates how instruction selection
+ /// should deal with this operation.
+ LegalizeAction PartialReduceMLAActions[MVT::VALUETYPE_SIZE]
+ [MVT::VALUETYPE_SIZE];
+
ValueTypeActionImpl ValueTypeActions;
private:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index bc7cdf38dbc2a..223260c43a38e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -545,6 +545,7 @@ namespace {
SDValue visitMGATHER(SDNode *N);
SDValue visitMSCATTER(SDNode *N);
SDValue visitMHISTOGRAM(SDNode *N);
+ SDValue visitPARTIAL_REDUCE_MLA(SDNode *N);
SDValue visitVPGATHER(SDNode *N);
SDValue visitVPSCATTER(SDNode *N);
SDValue visitVP_STRIDED_LOAD(SDNode *N);
@@ -621,6 +622,8 @@ namespace {
SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
const TargetLowering &TLI);
+ SDValue foldMulPARTIAL_REDUCE_MLA(SDNode *N);
+ SDValue foldExtendPARTIAL_REDUCE_MLA(SDNode *N);
SDValue CombineExtLoad(SDNode *N);
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -1972,6 +1975,9 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::MSCATTER: return visitMSCATTER(N);
case ISD::MSTORE: return visitMSTORE(N);
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
+ case ISD::PARTIAL_REDUCE_SMLA:
+ case ISD::PARTIAL_REDUCE_UMLA:
+ return visitPARTIAL_REDUCE_MLA(N);
case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
case ISD::LIFETIME_END: return visitLIFETIME_END(N);
case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
@@ -12497,6 +12503,69 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
+ // Only perform the DAG combine if there is custom lowering provided by the
+ // target.
+ if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0),
+ N->getOperand(1).getValueType()))
+ return SDValue();
+
+ if (SDValue Res = foldMulPARTIAL_REDUCE_MLA(N))
+ return Res;
+ if (SDValue Res = foldExtendPARTIAL_REDUCE_MLA(N))
+ return Res;
+ return SDValue();
+}
+
+SDValue DAGCombiner::foldMulPARTIAL_REDUCE_MLA(SDNode *N) {
+ // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(MulOpLHS, MulOpRHS), Splat(1)) into
+ // PARTIAL_REDUCE_*MLA(Acc, MulOpLHS, MulOpRHS)
+ SDLoc DL(N);
+
+ SDValue Op1 = N->getOperand(1);
+ if (Op1->getOpcode() != ISD::MUL)
+ return SDValue();
+
+ APInt ConstantOne;
+ if (!ISD::isConstantSplatVector(N->getOperand(2).getNode(), ConstantOne) ||
+ !ConstantOne.isOne())
+ return SDValue();
+
+ return DAG.getNode(N->getOpcode(), DL, N->getValueType(0), N->getOperand(0),
+ Op1->getOperand(0), Op1->getOperand(1));
+}
+
+SDValue DAGCombiner::foldExtendPARTIAL_REDUCE_MLA(SDNode *N) {
+ // Makes PARTIAL_REDUCE_*MLA(Acc, ZEXT(MulOpLHS), ZEXT(MulOpRHS)) into
+ // PARTIAL_REDUCE_UMLA(Acc, MulOpLHS, MulOpRHS) and
+ // PARTIAL_REDUCE_*MLA(Acc, SEXT(MulOpLHS), SEXT(MulOpRHS)) into
+ // PARTIAL_REDUCE_SMLA(Acc, MulOpLHS, MulOpRHS)
+ SDLoc DL(N);
+ SDValue ExtMulOpLHS = N->getOperand(1);
+ SDValue ExtMulOpRHS = N->getOperand(2);
+ unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
+ unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
+ if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+ !ISD::isExtOpcode(ExtMulOpRHSOpcode))
+ return SDValue();
+
+ SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
+ SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
+ EVT MulOpLHSVT = MulOpLHS.getValueType();
+ if (MulOpLHSVT != MulOpRHS.getValueType())
+ return SDValue();
+
+ bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
+ bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
+ if (LHSIsSigned != RHSIsSigned)
+ return SDValue();
+
+ unsigned NewOpcode =
+ LHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(NewOpcode, DL, N->getValueType(0), N->getOperand(0),
+ MulOpLHS, MulOpRHS);
+}
+
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
auto *SLD = cast<VPStridedLoadSDNode>(N);
EVT EltVT = SLD->getValueType(0).getVectorElementType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index de4447fb0cf1a..e43b14a47e565 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -469,8 +469,6 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::VECTOR_COMPRESS:
case ISD::SCMP:
case ISD::UCMP:
- case ISD::PARTIAL_REDUCE_UMLA:
- case ISD::PARTIAL_REDUCE_SMLA:
Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
break;
case ISD::SMULFIX:
@@ -524,6 +522,11 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
Action = TLI.getOperationAction(Node->getOpcode(), OpVT);
break;
}
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
+ Node->getOperand(1).getValueType());
+ break;
#define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...) \
case ISD::VPID: { \
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index f5ea3c0b47d6a..af97ce20fdb10 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -836,9 +836,8 @@ void TargetLoweringBase::initActions() {
setOperationAction(ISD::SET_FPENV, VT, Expand);
setOperationAction(ISD::RESET_FPENV, VT, Expand);
- // PartialReduceMLA operations default to expand.
- setOperationAction({ISD::PARTIAL_REDUCE_UMLA, ISD::PARTIAL_REDUCE_SMLA}, VT,
- Expand);
+ for (MVT InputVT : MVT::all_valuetypes())
+ setPartialReduceMLAAction(VT, InputVT, Expand);
}
// Most targets ignore the @llvm.prefetch intrinsic.
|
This is so the MUL fold does not happen unless the extend fold can be performed. As otherwise a lot of code would need to be repeated to check that it can happen.
This makes it so the changes are reflected in the tests, so that we can tell the DAG combine is actually happening. It has been replaced with a FIXME note saying to potentially add it back in when the rest of the implementation is complete.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks 👍 I think the regressions should be fixed after the lowering changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still LGTM, but wait for @sdesmalen-arm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few more small comments, other than that it looks fine.
; CHECK-NODOT-NEXT: smull2 v2.8h, v2.16b, v3.16b | ||
; CHECK-NODOT-NEXT: sshll v3.4s, v4.4h, #0 | ||
; CHECK-NODOT-NEXT: sshll v5.4s, v2.4h, #0 | ||
; CHECK-NODOT-NEXT: sshll v4.8h, v3.8b, #0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know these are regressions, but they'll be addressed by follow-up patches that further improve this code-gen.
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/190/builds/15664 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/137/builds/14406 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/65/builds/13161 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/185/builds/14141 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/175/builds/14217 Here is the relevant piece of the build log for the reference
|
@JamesChesterman I've reverted this PR due to a failure in |
…TIAL_REDUCE_MLA (#127083)" This reverts commit 2bef21f. Multiple builtbot failures have been reported: llvm/llvm-project#127083
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/76/builds/7459 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/63/builds/4415 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/33/builds/12363 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/145/builds/5417 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/153/builds/24591 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/146/builds/2407 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/60/builds/21027 Here is the relevant piece of the build log for the reference
|
Relanded this PR successfully: e3c8e17 |
@JamesChesterman Thank you for fixing this! |
…vm#127083) Add generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA nodes. Transforms the DAG from: PARTIAL_REDUCE_MLA(Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), Splat(1)) to PARTIAL_REDUCE_MLA(Acc, MulOpLHS, MulOpRHS).
…_MLA (llvm#127083)" This reverts commit 2bef21f. Multiple builtbot failures have been reported: llvm#127083
…_MLA (llvm#127083)" This relands commit 7a06681.
Add generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA nodes. Transforms the DAG from:
PARTIAL_REDUCE_MLA(Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), Splat(1)) to
PARTIAL_REDUCE_MLA(Acc, MulOpLHS, MulOpRHS).