-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[SDAG] Split the partial reduce legalize table by opcode [nfc] #141970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
On it's own, this change should be non-functional. This is a preparatory change for llvm#141267 which adds a new form of PARTIAL_REDUCE_*MLA. As noted in the discussion on that review, AArch64 needs a different set of legal and custom types for the PARTIAL_REDUCE_SUMLA variant than the currently existing PARTIAL_REDUCE_UMLA/SMLA.
@llvm/pr-subscribers-backend-risc-v @llvm/pr-subscribers-backend-aarch64 Author: Philip Reames (preames) ChangesOn it's own, this change should be non-functional. This is a preparatory change for #141267 which adds a new form of PARTIAL_REDUCE_*MLA. As noted in the discussion on that review, AArch64 needs a different set of legal and custom types for the PARTIAL_REDUCE_SUMLA variant than the currently existing PARTIAL_REDUCE_UMLA/SMLA. Full diff: https://github.com/llvm/llvm-project/pull/141970.diff 5 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index b818f4768c2c3..9c453f51e129d 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1659,17 +1659,20 @@ class LLVM_ABI TargetLoweringBase {
/// InputVT should be treated. Either it's legal, needs to be promoted to a
/// larger size, needs to be expanded to some other code sequence, or the
/// target has a custom expander for it.
- LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
- PartialReduceActionTypes TypePair = {AccVT.getSimpleVT().SimpleTy,
- InputVT.getSimpleVT().SimpleTy};
- auto It = PartialReduceMLAActions.find(TypePair);
+ LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
+ EVT InputVT) const {
+ assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
+ PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
+ InputVT.getSimpleVT().SimpleTy};
+ auto It = PartialReduceMLAActions.find(Key);
return It != PartialReduceMLAActions.end() ? It->second : Expand;
}
/// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
/// legal or custom for this target.
- bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
- LegalizeAction Action = getPartialReduceMLAAction(AccVT, InputVT);
+ bool isPartialReduceMLALegalOrCustom(unsigned Opc, EVT AccVT,
+ EVT InputVT) const {
+ LegalizeAction Action = getPartialReduceMLAAction(Opc, AccVT, InputVT);
return Action == Legal || Action == Custom;
}
@@ -2754,12 +2757,18 @@ class LLVM_ABI TargetLoweringBase {
/// type InputVT should be treated by the target. Either it's legal, needs to
/// be promoted to a larger size, needs to be expanded to some other code
/// sequence, or the target has a custom expander for it.
- void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
+ void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
LegalizeAction Action) {
+ assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
assert(AccVT.isValid() && InputVT.isValid() &&
"setPartialReduceMLAAction types aren't valid");
- PartialReduceActionTypes TypePair = {AccVT.SimpleTy, InputVT.SimpleTy};
- PartialReduceMLAActions[TypePair] = Action;
+ PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};
+ PartialReduceMLAActions[Key] = Action;
+ }
+ void setPartialReduceMLAAction(ArrayRef<unsigned> Opcodes, MVT AccVT,
+ MVT InputVT, LegalizeAction Action) {
+ for (unsigned Opc : Opcodes)
+ setPartialReduceMLAAction(Opc, AccVT, InputVT, Action);
}
/// If Opc/OrigVT is specified as being promoted, the promotion code defaults
@@ -3751,10 +3760,10 @@ class LLVM_ABI TargetLoweringBase {
uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];
using PartialReduceActionTypes =
- std::pair<MVT::SimpleValueType, MVT::SimpleValueType>;
- /// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
- /// nodes, keep a LegalizeAction which indicates how instruction selection
- /// should deal with this operation.
+ std::tuple<unsigned, MVT::SimpleValueType, MVT::SimpleValueType>;
+ /// For each partial reduce opcode, result type and input type combination,
+ /// keep a LegalizeAction which indicates how instruction selection should
+ /// deal with this operation.
DenseMap<PartialReduceActionTypes, LegalizeAction> PartialReduceMLAActions;
ValueTypeActionImpl ValueTypeActions;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index e05f85ea3bd8e..be2209a2f8faf 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12673,17 +12673,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue LHSExtOp = LHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();
+ bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
+ unsigned NewOpcode =
+ ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+
// Only perform these combines if the target supports folding
// the extends into the operation.
if (!TLI.isPartialReduceMLALegalOrCustom(
- TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
+ NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
- bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
- unsigned NewOpcode =
- ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
-
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
@@ -12737,14 +12737,6 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
if (!ISD::isExtOpcode(Op1Opcode))
return SDValue();
- SDValue UnextOp1 = Op1.getOperand(0);
- EVT UnextOp1VT = UnextOp1.getValueType();
- auto *Context = DAG.getContext();
- if (!TLI.isPartialReduceMLALegalOrCustom(
- TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
- TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
- return SDValue();
-
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
EVT AccElemVT = Acc.getValueType().getVectorElementType();
@@ -12754,6 +12746,15 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
unsigned NewOpcode =
Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+
+ SDValue UnextOp1 = Op1.getOperand(0);
+ EVT UnextOp1VT = UnextOp1.getValueType();
+ auto *Context = DAG.getContext();
+ if (!TLI.isPartialReduceMLALegalOrCustom(
+ NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
+ TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
+ return SDValue();
+
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
DAG.getConstant(1, DL, UnextOp1VT));
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index affcd78ea61b0..910a40e5b5141 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -530,8 +530,9 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
}
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
- Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
- Node->getOperand(1).getValueType());
+ Action =
+ TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
+ Node->getOperand(1).getValueType());
break;
#define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...) \
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a07afea963e20..f18d325148742 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1458,9 +1458,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FADD, VT, Custom);
if (EnablePartialReduceNodes && Subtarget->hasDotProd()) {
- setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
- setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal);
- setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
+ ISD::PARTIAL_REDUCE_UMLA};
+
+ setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal);
+ setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal);
+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
}
} else /* !isNeonAvailable */ {
@@ -1881,16 +1884,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
if (EnablePartialReduceNodes && Subtarget->isSVEorStreamingSVEAvailable()) {
// Mark known legal pairs as 'Legal' (these will expand to UDOT or SDOT).
// Other pairs will default to 'Expand'.
- setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
- setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
+ ISD::PARTIAL_REDUCE_UMLA};
+ setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv8i16, Legal);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Legal);
- setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv16i8, Custom);
// Wide add types
if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
- setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Legal);
- setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Legal);
- setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Legal);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv4i32, Legal);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv8i16, Legal);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv8i16, MVT::nxv16i8, Legal);
}
}
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 43c81b97a0e05..567f4c5b47d30 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1573,11 +1573,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
// zve32x is broken for partial_reduce_umla, but let's not make it worse.
if (Subtarget.hasStdExtZvqdotq() && Subtarget.getELen() >= 64) {
- setPartialReduceMLAAction(MVT::nxv1i32, MVT::nxv4i8, Custom);
- setPartialReduceMLAAction(MVT::nxv2i32, MVT::nxv8i8, Custom);
- setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
- setPartialReduceMLAAction(MVT::nxv8i32, MVT::nxv32i8, Custom);
- setPartialReduceMLAAction(MVT::nxv16i32, MVT::nxv64i8, Custom);
+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
+ ISD::PARTIAL_REDUCE_UMLA};
+ setPartialReduceMLAAction(MLAOps, MVT::nxv1i32, MVT::nxv4i8, Custom);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv2i32, MVT::nxv8i8, Custom);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Custom);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv8i32, MVT::nxv32i8, Custom);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv16i32, MVT::nxv64i8, Custom);
if (Subtarget.useRVVForFixedLengthVectors()) {
for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) {
@@ -1586,7 +1588,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
continue;
ElementCount EC = VT.getVectorElementCount();
MVT ArgVT = MVT::getVectorVT(MVT::i8, EC.multiplyCoefficientBy(4));
- setPartialReduceMLAAction(VT, ArgVT, Custom);
+ setPartialReduceMLAAction(MLAOps, VT, ArgVT, Custom);
}
}
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…141970) On it's own, this change should be non-functional. This is a preparatory change for llvm#141267 which adds a new form of PARTIAL_REDUCE_*MLA. As noted in the discussion on that review, AArch64 needs a different set of legal and custom types for the PARTIAL_REDUCE_SUMLA variant than the currently existing PARTIAL_REDUCE_UMLA/SMLA.
…141970) On it's own, this change should be non-functional. This is a preparatory change for llvm#141267 which adds a new form of PARTIAL_REDUCE_*MLA. As noted in the discussion on that review, AArch64 needs a different set of legal and custom types for the PARTIAL_REDUCE_SUMLA variant than the currently existing PARTIAL_REDUCE_UMLA/SMLA.
On it's own, this change should be non-functional. This is a preparatory change for #141267 which adds a new form of PARTIAL_REDUCE_*MLA. As noted in the discussion on that review, AArch64 needs a different set of legal and custom types for the PARTIAL_REDUCE_SUMLA variant than the currently existing PARTIAL_REDUCE_UMLA/SMLA.