Skip to content

[SDAG] Add partial_reduce_sumla node #141267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions llvm/include/llvm/CodeGen/ISDOpcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1493,8 +1493,9 @@ enum NodeType {
VECREDUCE_UMIN,

// PARTIAL_REDUCE_[U|S]MLA(Accumulator, Input1, Input2)
// The partial reduction nodes sign or zero extend Input1 and Input2 to the
// element type of Accumulator before multiplying their results.
// The partial reduction nodes sign or zero extend Input1 and Input2
// (with the extension kind noted below) to the element type of
// Accumulator before multiplying their results.
// This result is concatenated to the Accumulator, and this is then reduced,
// using addition, to the result type.
// The output is only expected to either be given to another partial reduction
Expand All @@ -1506,8 +1507,9 @@ enum NodeType {
// multiple of the number of elements in the Accumulator / output type.
// Input1 and Input2 must have an element type which is the same as or smaller
// than the element type of the Accumulator and output.
PARTIAL_REDUCE_SMLA,
PARTIAL_REDUCE_UMLA,
PARTIAL_REDUCE_SMLA, // sext, sext
PARTIAL_REDUCE_UMLA, // zext, zext
PARTIAL_REDUCE_SUMLA, // sext, zext

// The `llvm.experimental.stackmap` intrinsic.
// Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
Expand Down
6 changes: 4 additions & 2 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1661,7 +1661,8 @@ class LLVM_ABI TargetLoweringBase {
/// target has a custom expander for it.
LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
EVT InputVT) const {
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
Opc == ISD::PARTIAL_REDUCE_SUMLA);
PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
InputVT.getSimpleVT().SimpleTy};
auto It = PartialReduceMLAActions.find(Key);
Expand Down Expand Up @@ -2759,7 +2760,8 @@ class LLVM_ABI TargetLoweringBase {
/// sequence, or the target has a custom expander for it.
void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
LegalizeAction Action) {
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
Opc == ISD::PARTIAL_REDUCE_SUMLA);
assert(AccVT.isValid() && InputVT.isValid() &&
"setPartialReduceMLAAction types aren't valid");
PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};
Expand Down
66 changes: 44 additions & 22 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1992,6 +1992,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
return visitPARTIAL_REDUCE_MLA(N);
case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
case ISD::LIFETIME_END: return visitLIFETIME_END(N);
Expand Down Expand Up @@ -12680,26 +12681,27 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue LHSExtOp = LHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();

bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
unsigned NewOpcode =
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;

// Only perform these combines if the target supports folding
// the extends into the operation.
if (!TLI.isPartialReduceMLALegalOrCustom(
NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();

// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
// TODO: Make use of partial_reduce_sumla here
APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
(LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
return SDValue();

unsigned NewOpcode = LHSOpcode == ISD::SIGN_EXTEND
? ISD::PARTIAL_REDUCE_SMLA
: ISD::PARTIAL_REDUCE_UMLA;

// Only perform these combines if the target supports folding
// the extends into the operation.
if (!TLI.isPartialReduceMLALegalOrCustom(
NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();

return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
DAG.getConstant(CTrunc, DL, LHSExtOpVT));
}
Expand All @@ -12709,26 +12711,46 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
return SDValue();

SDValue RHSExtOp = RHS->getOperand(0);
if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
if (LHSExtOpVT != RHSExtOp.getValueType())
return SDValue();

unsigned NewOpc;
if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND)
NewOpc = ISD::PARTIAL_REDUCE_SMLA;
else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
NewOpc = ISD::PARTIAL_REDUCE_UMLA;
else if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on what you mean by this TODO? I'm not sure I follow why we'd want to handle a zext as a sext in this case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A zext nonneg is a zext for which the high bit is known to be zero, and thus is equivalent to a sext. We canonicalize such cases to zext nonneg. As such, handling zext nonneg would allow us to recognize more parial_reduce_smla cases which we'd currently miss. Note that for partial_reduce_sumla enabled targets, this might not matter since we'd just chose an alternate instruction.

std::swap(LHSExtOp, RHSExtOp);
} else
return SDValue();

// For a 2-stage extend the signedness of both of the extends must be the
// same. This is so the node can be folded into only a signed or unsigned
// node.
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
// For a 2-stage extend the signedness of both of the extends must match
// If the mul has the same type, there is no outer extend, and thus we
// can simply use the inner extends to pick the result node.
// TODO: extend to handle nonneg zext as sext
EVT AccElemVT = Acc.getValueType().getVectorElementType();
if (ExtIsSigned != NodeIsSigned &&
Op1.getValueType().getVectorElementType() != AccElemVT)
if (Op1.getValueType().getVectorElementType() != AccElemVT &&
NewOpc != N->getOpcode())
return SDValue();

return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
RHSExtOp);
// Only perform these combines if the target supports folding
// the extends into the operation.
if (!TLI.isPartialReduceMLALegalOrCustom(
NewOpc, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();

return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
}

// partial.reduce.umla(acc, zext(op), splat(1))
// -> partial.reduce.umla(acc, op, splat(trunc(1)))
// partial.reduce.smla(acc, sext(op), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
// partial.reduce.sumla(acc, sext(op), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
Expand All @@ -12745,7 +12767,7 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
return SDValue();

bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
EVT AccElemVT = Acc.getValueType().getVectorElementType();
if (Op1IsSigned != NodeIsSigned &&
Op1.getValueType().getVectorElementType() != AccElemVT)
Expand Down
15 changes: 13 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {

case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
Res = PromoteIntRes_PARTIAL_REDUCE_MLA(N);
break;

Expand Down Expand Up @@ -2093,6 +2094,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
break;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N);
break;
}
Expand Down Expand Up @@ -2886,12 +2888,21 @@ SDValue DAGTypeLegalizer::PromoteIntOp_GET_ACTIVE_LANE_MASK(SDNode *N) {

SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N) {
SmallVector<SDValue, 1> NewOps(N->ops());
if (N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA) {
switch (N->getOpcode()) {
case ISD::PARTIAL_REDUCE_SMLA:
NewOps[1] = SExtPromotedInteger(N->getOperand(1));
NewOps[2] = SExtPromotedInteger(N->getOperand(2));
} else {
break;
case ISD::PARTIAL_REDUCE_UMLA:
NewOps[1] = ZExtPromotedInteger(N->getOperand(1));
NewOps[2] = ZExtPromotedInteger(N->getOperand(2));
break;
case ISD::PARTIAL_REDUCE_SUMLA:
NewOps[1] = SExtPromotedInteger(N->getOperand(1));
NewOps[2] = ZExtPromotedInteger(N->getOperand(2));
break;
default:
llvm_unreachable("unexpected opcode");
}
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
}
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
}
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
Action =
TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
Node->getOperand(1).getValueType());
Expand Down Expand Up @@ -1211,6 +1212,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
return;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
return;
case ISD::VECREDUCE_SEQ_FADD:
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
break;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
break;
case ISD::GET_ACTIVE_LANE_MASK:
Expand Down Expand Up @@ -3473,6 +3474,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
break;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
break;
}
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8041,7 +8041,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
break;
}
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA: {
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA: {
[[maybe_unused]] EVT AccVT = N1.getValueType();
[[maybe_unused]] EVT Input1VT = N2.getValueType();
[[maybe_unused]] EVT Input2VT = N3.getValueType();
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
return "partial_reduce_umla";
case ISD::PARTIAL_REDUCE_SMLA:
return "partial_reduce_smla";
case ISD::PARTIAL_REDUCE_SUMLA:
return "partial_reduce_sumla";

// Vector Predication
#define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...) \
Expand Down
14 changes: 9 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11893,13 +11893,17 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
EVT ExtMulOpVT =
EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
MulOpVT.getVectorElementCount());
unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
? ISD::SIGN_EXTEND
: ISD::ZERO_EXTEND;

unsigned ExtOpcLHS = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA
? ISD::ZERO_EXTEND
: ISD::SIGN_EXTEND;
unsigned ExtOpcRHS = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
? ISD::SIGN_EXTEND
: ISD::ZERO_EXTEND;

if (ExtMulOpVT != MulOpVT) {
MulLHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulLHS);
MulRHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulRHS);
MulLHS = DAG.getNode(ExtOpcLHS, DL, ExtMulOpVT, MulLHS);
MulRHS = DAG.getNode(ExtOpcRHS, DL, ExtMulOpVT, MulRHS);
}
SDValue Input = MulLHS;
APInt ConstantOne;
Expand Down
20 changes: 17 additions & 3 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1574,7 +1574,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
// zve32x is broken for partial_reduce_umla, but let's not make it worse.
if (Subtarget.hasStdExtZvqdotq() && Subtarget.getELen() >= 64) {
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
ISD::PARTIAL_REDUCE_UMLA};
ISD::PARTIAL_REDUCE_UMLA,
ISD::PARTIAL_REDUCE_SUMLA};
setPartialReduceMLAAction(MLAOps, MVT::nxv1i32, MVT::nxv4i8, Custom);
setPartialReduceMLAAction(MLAOps, MVT::nxv2i32, MVT::nxv8i8, Custom);
setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Custom);
Expand Down Expand Up @@ -8318,6 +8319,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return lowerADJUST_TRAMPOLINE(Op, DAG);
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
return lowerPARTIAL_REDUCE_MLA(Op, DAG);
}
}
Expand Down Expand Up @@ -8486,8 +8488,20 @@ SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
B = convertToScalableVector(ContainerVT, B, DAG, Subtarget);
}

bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
unsigned Opc;
switch (Op.getOpcode()) {
case ISD::PARTIAL_REDUCE_SMLA:
Opc = RISCVISD::VQDOT_VL;
break;
case ISD::PARTIAL_REDUCE_UMLA:
Opc = RISCVISD::VQDOTU_VL;
break;
case ISD::PARTIAL_REDUCE_SUMLA:
Opc = RISCVISD::VQDOTSU_VL;
break;
default:
llvm_unreachable("Unexpected opcode");
}
auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
SDValue Res = DAG.getNode(Opc, DL, ContainerVT, {A, B, Accum, Mask, VL});
if (VT.isFixedLengthVector())
Expand Down
Loading
Loading