Skip to content

[DAGCombiner] Add DAG combine for PARTIAL_REDUCE_MLA when no mul op #131326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,8 @@ namespace {
SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
const TargetLowering &TLI);
SDValue foldPartialReduceMLAMulOp(SDNode *N);
SDValue foldPartialReduceAdd(SDNode *N);

SDValue CombineExtLoad(SDNode *N);
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
Expand Down Expand Up @@ -12601,12 +12603,20 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
return SDValue();
}

SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
if (SDValue Res = foldPartialReduceMLAMulOp(N))
return Res;
if (SDValue Res = foldPartialReduceAdd(N))
return Res;
return SDValue();
}

// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
// -> partial_reduce_*mla(acc, a, b)
//
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDLoc DL(N);
auto *Context = DAG.getContext();
SDValue Acc = N->getOperand(0);
Expand Down Expand Up @@ -12672,6 +12682,43 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
RHSExtOp);
}

// partial.reduce.umla(acc, zext(op), splat(1))
// -> partial.reduce.umla(acc, op, splat(trunc(1)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe rename this to foldPartialReduceAdd?

// partial.reduce.smla(acc, sext(op), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);

APInt ConstantOne;
if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
!ConstantOne.isOne())
return SDValue();

unsigned Op1Opcode = Op1.getOpcode();
if (!ISD::isExtOpcode(Op1Opcode))
return SDValue();

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove newline.

SDValue UnextOp1 = Op1.getOperand(0);
EVT UnextOp1VT = UnextOp1.getValueType();
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
return SDValue();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can just be DAG.getConstant(1, DL, UnextOp1VT) (and better to just inline that in the use below).


bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove newline.

EVT AccElemVT = Acc.getValueType().getVectorElementType();
if (Op1IsSigned != NodeIsSigned &&
Op1.getValueType().getVectorElementType() != AccElemVT)
return SDValue();

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to test the Op2 case here, because the type of Op1 must match that of Op2.

unsigned NewOpcode =
Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
DAG.getConstant(1, DL, UnextOp1VT));
}

SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
auto *SLD = cast<VPStridedLoadSDNode>(N);
EVT EltVT = SLD->getValueType(0).getVectorElementType();
Expand Down
48 changes: 8 additions & 40 deletions llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
Original file line number Diff line number Diff line change
Expand Up @@ -516,16 +516,8 @@ define <vscale x 4 x i32> @udot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
;
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op:
; CHECK-NEWLOWERING: // %bb.0:
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z1.b
; CHECK-NEWLOWERING-NEXT: uunpklo z1.h, z1.b
; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z2.h
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z1.h
; CHECK-NEWLOWERING-NEXT: uunpklo z1.s, z1.h
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
; CHECK-NEWLOWERING-NEXT: add z1.s, z4.s, z3.s
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
; CHECK-NEWLOWERING-NEXT: mov z2.b, #1 // =0x1
; CHECK-NEWLOWERING-NEXT: udot z0.s, z1.b, z2.b
; CHECK-NEWLOWERING-NEXT: ret
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
Expand All @@ -541,16 +533,8 @@ define <vscale x 4 x i32> @sdot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
;
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op:
; CHECK-NEWLOWERING: // %bb.0:
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z1.b
; CHECK-NEWLOWERING-NEXT: sunpklo z1.h, z1.b
; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z2.h
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z1.h
; CHECK-NEWLOWERING-NEXT: sunpklo z1.s, z1.h
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
; CHECK-NEWLOWERING-NEXT: add z1.s, z4.s, z3.s
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
; CHECK-NEWLOWERING-NEXT: mov z2.b, #1 // =0x1
; CHECK-NEWLOWERING-NEXT: sdot z0.s, z1.b, z2.b
; CHECK-NEWLOWERING-NEXT: ret
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
Expand All @@ -566,16 +550,8 @@ define <vscale x 2 x i64> @udot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale
;
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_wide:
; CHECK-NEWLOWERING: // %bb.0: // %entry
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z1.h
; CHECK-NEWLOWERING-NEXT: uunpklo z1.s, z1.h
; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z2.s
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z1.s
; CHECK-NEWLOWERING-NEXT: uunpklo z1.d, z1.s
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z1.d
; CHECK-NEWLOWERING-NEXT: add z1.d, z4.d, z3.d
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z1.d
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
; CHECK-NEWLOWERING-NEXT: mov z2.h, #1 // =0x1
; CHECK-NEWLOWERING-NEXT: udot z0.d, z1.h, z2.h
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
Expand All @@ -592,16 +568,8 @@ define <vscale x 2 x i64> @sdot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale
;
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_wide:
; CHECK-NEWLOWERING: // %bb.0: // %entry
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z1.h
; CHECK-NEWLOWERING-NEXT: sunpklo z1.s, z1.h
; CHECK-NEWLOWERING-NEXT: sunpklo z3.d, z2.s
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z1.s
; CHECK-NEWLOWERING-NEXT: sunpklo z1.d, z1.s
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z1.d
; CHECK-NEWLOWERING-NEXT: add z1.d, z4.d, z3.d
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z1.d
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
; CHECK-NEWLOWERING-NEXT: mov z2.h, #1 // =0x1
; CHECK-NEWLOWERING-NEXT: sdot z0.d, z1.h, z2.h
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
Expand Down