-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[DAGCombiner] Add DAG combine for PARTIAL_REDUCE_MLA when no mul op #131326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -618,6 +618,8 @@ namespace { | |
SDValue CombineConsecutiveLoads(SDNode *N, EVT VT); | ||
SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG, | ||
const TargetLowering &TLI); | ||
SDValue foldPartialReduceMLAMulOp(SDNode *N); | ||
SDValue foldPartialReduceAdd(SDNode *N); | ||
|
||
SDValue CombineExtLoad(SDNode *N); | ||
SDValue CombineZExtLogicopShiftLoad(SDNode *N); | ||
|
@@ -12601,12 +12603,20 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) { | |
return SDValue(); | ||
} | ||
|
||
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) { | ||
if (SDValue Res = foldPartialReduceMLAMulOp(N)) | ||
return Res; | ||
if (SDValue Res = foldPartialReduceAdd(N)) | ||
return Res; | ||
return SDValue(); | ||
} | ||
|
||
// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1)) | ||
// -> partial_reduce_*mla(acc, a, b) | ||
// | ||
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1)) | ||
// -> partial_reduce_*mla(acc, x, C) | ||
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) { | ||
SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { | ||
SDLoc DL(N); | ||
auto *Context = DAG.getContext(); | ||
SDValue Acc = N->getOperand(0); | ||
|
@@ -12672,6 +12682,43 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) { | |
RHSExtOp); | ||
} | ||
|
||
// partial.reduce.umla(acc, zext(op), splat(1)) | ||
// -> partial.reduce.umla(acc, op, splat(trunc(1))) | ||
// partial.reduce.smla(acc, sext(op), splat(1)) | ||
// -> partial.reduce.smla(acc, op, splat(trunc(1))) | ||
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) { | ||
SDLoc DL(N); | ||
SDValue Acc = N->getOperand(0); | ||
SDValue Op1 = N->getOperand(1); | ||
SDValue Op2 = N->getOperand(2); | ||
|
||
APInt ConstantOne; | ||
if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) || | ||
!ConstantOne.isOne()) | ||
return SDValue(); | ||
|
||
unsigned Op1Opcode = Op1.getOpcode(); | ||
if (!ISD::isExtOpcode(Op1Opcode)) | ||
return SDValue(); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: remove newline. |
||
SDValue UnextOp1 = Op1.getOperand(0); | ||
EVT UnextOp1VT = UnextOp1.getValueType(); | ||
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT)) | ||
return SDValue(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can just be |
||
|
||
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND; | ||
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: remove newline. |
||
EVT AccElemVT = Acc.getValueType().getVectorElementType(); | ||
if (Op1IsSigned != NodeIsSigned && | ||
Op1.getValueType().getVectorElementType() != AccElemVT) | ||
return SDValue(); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think you need to test the |
||
unsigned NewOpcode = | ||
Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA; | ||
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1, | ||
DAG.getConstant(1, DL, UnextOp1VT)); | ||
} | ||
|
||
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) { | ||
auto *SLD = cast<VPStridedLoadSDNode>(N); | ||
EVT EltVT = SLD->getValueType(0).getVectorElementType(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe rename this to
foldPartialReduceAdd
?