Skip to content

[AArch64][SVE] Add lowering for PARTIAL_REDUCE_U/SMLA to USDOT #131327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -924,8 +924,19 @@ SDValue DAGTypeLegalizer::CreateStackStoreLoad(SDValue Op,
/// illegal ResNo in that case.
bool DAGTypeLegalizer::CustomLowerNode(SDNode *N, EVT VT, bool LegalizeResult) {
// See if the target wants to custom lower this node.
if (TLI.getOperationAction(N->getOpcode(), VT) != TargetLowering::Custom)
return false;
unsigned Opcode = N->getOpcode();
bool IsPRMLAOpcode =
Opcode == ISD::PARTIAL_REDUCE_UMLA || Opcode == ISD::PARTIAL_REDUCE_SMLA;

if (IsPRMLAOpcode) {
if (TLI.getPartialReduceMLAAction(N->getValueType(0),
N->getOperand(1).getValueType()) !=
TargetLowering::Custom)
return false;
} else {
if (TLI.getOperationAction(Opcode, VT) != TargetLowering::Custom)
return false;
}
Comment on lines +927 to +939
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't LegalizeVectorOps handle this? getPartialReduceMLAAction() is already hooked up there and should be able to call into the custom lowering?

Copy link
Member

@MacDue MacDue Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this done to bypass type legalization for the usdot_8to64 case? Could we handle that instead by adding a combine that reduces accumulators of <vscale x 4 x i64> to <vscale x 4 x i32> followed by a extend (for i8 inputs)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NickGuy-Arm I suspect you did this to work around type legalisation? At the point of doing Custom lowering, all the types must be legal. If the extends would be the same, then as @MacDue says it would be handled in LegalizeVectorOps. It's just that the operands (to be sign/zero-extended) have not been folded into the operation yet, because UMLA/SMLA doesn't support mixed extends, hence why the types can't be legalised the normal way.

The way to handle this case is to either:
(1) Implement this mapping to an AArch64ISD node with an AArch64 DAG combine that runs before type legalisation.
or:
(2)Create a separate PARTIAL_REDUCE_USMLA node, which would go through the regular flow of type legalisation.

The downside of (1) is that we don't get any type-legalisation, so any unsupported types would need to be handled in that particular DAG combine basically requiring it to do type-legalisation. I think (2) can piggy-back on most of the type legalisation added for UMLA/SMLA, with some small changes.

Copy link
Contributor Author

@NickGuy-Arm NickGuy-Arm May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type coming in via VT is the type of the operand of the partial_reduce_umla node, which is an extend, so it effectively hides the actual operand type at this stage. We need to use the pre-extended type to figure out whether USDOT is valid to emit, and the type legalization step obscures this type by splitting across multiple sets of partial_reduce_umla and extract_subvector nodes, meaning we'd have to check significantly more nodes/paths to verify the validity.
I don't think the pre-legalization DAG combine would work for the reasons you pointed out, but in trying to implement the separate node, I encountered the exact same issues as we hit without the above call to getPartialReduceMLAAction.
I've added an operation action for ISD::PARTIAL_REDUCE_UMLA with nxv16i32, which is the post-extended type of nxv16i8, and we can have the existing validation within LowerPARTIAL_REDUCE_MLAToUSDOT decide whether it can actually be lowered to USDOT (falling back to unpacks and mlas if not).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reimplemented this check, as I believe it is the simplest solution to this problem. For USDOT lowering to function, it needs to happen pre-legalization because it deals with illegal intermediate types (which are then flattened out by replacing the nodes with the USDOT ISD node).
As the partialReduceMLA LegalizeActions are handled differently from the standard operation actions, we need to check the relevant action to take.

This check is simply the required plumbing to have the legalizer respect when a target says that it has custom lowering for a given partial reduction. If we try to pack the information into the operation actions, we lose the ability to filter based on what the partial reduction is reducing from. And trying to move the check to post-legalization we lose direct access to the pre-extend type, as the nodes required to legalize the type obscure it through multiple extends or AArch64ISD unpack nodes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the pre-legalization DAG combine would work for the reasons you pointed out

What reasons were you referring to here? I would expect this pre-type-legalization DAG combine to recognise the pattern partial.reduce.add(a, mul(sext(b), zext(c)), splat(1)) -> AArch64ISD::sudot(a, b, c). At this point, there shouldn't be any uunpklo/hi instructions yet.

but in trying to implement the separate node, I encountered the exact same issues as we hit without the above call to getPartialReduceMLAAction.

Are you talking about option (2), create a new ISD::PARTIAL_REDUCE_USMLA node? If so, can you elaborate on the issues you encountered? (I'd expect it to function roughly the same as the PARTIAL_REDUCE_UMLA node for example)


SmallVector<SDValue, 8> Results;
if (LegalizeResult)
Expand All @@ -946,7 +957,6 @@ bool DAGTypeLegalizer::CustomLowerNode(SDNode *N, EVT VT, bool LegalizeResult) {
return true;
}


/// Widen the node's results with custom code provided by the target and return
/// "true", or do nothing and return "false".
bool DAGTypeLegalizer::CustomWidenLowerNode(SDNode *N, EVT VT) {
Expand Down
102 changes: 94 additions & 8 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1869,7 +1869,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);

// 8to64
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);

// USDOT
if (Subtarget->hasMatMulInt8())
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i32, Custom);
}

// Handle operations that are only available in non-streaming SVE mode.
Expand Down Expand Up @@ -27533,6 +27538,10 @@ void AArch64TargetLowering::ReplaceNodeResults(
if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
Results.push_back(Res);
return;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
Results.push_back(LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG));
return;
case ISD::ADD:
case ISD::FADD:
ReplaceAddWithADDP(N, Results, DAG, Subtarget);
Expand Down Expand Up @@ -29481,21 +29490,24 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
return Scatter;
}

/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can
/// however still make use of the dot product instruction by instead
/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
SDValue
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
if (SDValue UsdotNode = LowerPARTIAL_REDUCE_MLAToUSDOT(Op, DAG))
return UsdotNode;

SDValue Acc = Op.getOperand(0);
SDValue LHS = Op.getOperand(1);
SDValue RHS = Op.getOperand(2);
EVT ResultVT = Op.getValueType();
assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type
/// pairing of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We
/// can however still make use of the dot product instruction by instead
/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
if (ResultVT != MVT::nxv2i64 || LHS.getValueType() != MVT::nxv16i8)
return SDValue();

SDLoc DL(Op);
SDValue Acc = Op.getOperand(0);
SDValue RHS = Op.getOperand(2);
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);

Expand All @@ -29515,6 +29527,80 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
}

// partial.reduce.umla(acc, mul(zext(mulOpLHS), sext(mulOpRHS)), splat(1))
// -> USDOT(acc, mulOpLHS, mulOpRHS)
// partial.reduce.smla(acc, mul(sext(mulOpLHS), zext(mulOpRHS)), splat(1))
// -> USDOT(acc, mulOpRHS, mulOpLHS)
SDValue
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLAToUSDOT(SDValue Op,
SelectionDAG &DAG) const {
bool Scalable = Op.getValueType().isScalableVector();
auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
if (Scalable && !Subtarget.isSVEorStreamingSVEAvailable())
return SDValue();
if (!Scalable && (!Subtarget.isNeonAvailable() || !Subtarget.hasDotProd()))
return SDValue();
if (!Subtarget.hasMatMulInt8())
return SDValue();
SDLoc DL(Op);

if (Op.getOperand(1).getOpcode() != ISD::MUL)
return SDValue();

SDValue Acc = Op.getOperand(0);
SDValue Mul = Op.getOperand(1);

APInt ConstantOne;
if (!ISD::isConstantSplatVector(Op.getOperand(2).getNode(), ConstantOne) ||
!ConstantOne.isOne())
return SDValue();

SDValue ExtMulOpLHS = Mul.getOperand(0);
SDValue ExtMulOpRHS = Mul.getOperand(1);
unsigned ExtMulOpLHSOpcode = ExtMulOpLHS.getOpcode();
unsigned ExtMulOpRHSOpcode = ExtMulOpRHS.getOpcode();
if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
!ISD::isExtOpcode(ExtMulOpRHSOpcode))
return SDValue();

SDValue MulOpLHS = ExtMulOpLHS.getOperand(0);
SDValue MulOpRHS = ExtMulOpRHS.getOperand(0);
EVT MulOpLHSVT = MulOpLHS.getValueType();
if (MulOpLHSVT != MulOpRHS.getValueType())
return SDValue();

bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
if (LHSIsSigned == RHSIsSigned)
return SDValue();

EVT AccVT = Acc.getValueType();
// There is no nxv2i64 version of usdot
if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64)
return SDValue();

// USDOT expects the signed operand to be last
if (!RHSIsSigned)
std::swap(MulOpLHS, MulOpRHS);

unsigned Opcode = AArch64ISD::USDOT;
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
// product followed by a zero / sign extension
// Don't want this to be split because there is no nxv2i64 version of usdot
if ((AccVT == MVT::nxv4i64 && MulOpLHSVT == MVT::nxv16i8) ||
(AccVT == MVT::v4i64 && MulOpLHSVT == MVT::v16i8)) {
EVT AccVTI32 = AccVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;

SDValue DotI32 =
DAG.getNode(Opcode, DL, AccVTI32, DAG.getConstant(0, DL, AccVTI32),
MulOpLHS, MulOpRHS);
SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, AccVT);
return DAG.getNode(ISD::ADD, DL, AccVT, Acc, Extended);
}

return DAG.getNode(Opcode, DL, AccVT, Acc, MulOpLHS, MulOpRHS);
}

SDValue
AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
SelectionDAG &DAG) const {
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,7 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerPARTIAL_REDUCE_MLAToUSDOT(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const;
Expand Down
38 changes: 3 additions & 35 deletions llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
; RUN: llc -mtriple=aarch64 -mattr=+sve,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE
; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE2
; RUN: llc -mtriple=aarch64 -mattr=+sme -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME
; RUN: llc -mtriple=aarch64 -mattr=+sve,+sme,+i8mm -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME

define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; CHECK-LABEL: udot:
Expand Down Expand Up @@ -106,23 +106,7 @@ define <vscale x 4 x i32> @usdot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
;
; CHECK-NEWLOWERING-LABEL: usdot:
; CHECK-NEWLOWERING: // %bb.0: // %entry
; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z1.b
; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
; CHECK-NEWLOWERING-NEXT: ptrue p0.s
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z3.h
; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z1.h
; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z2.h
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z3.s, z4.s
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
; CHECK-NEWLOWERING-NEXT: usdot z0.s, z1.b, z2.b
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
Expand Down Expand Up @@ -161,23 +145,7 @@ define <vscale x 4 x i32> @sudot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
;
; CHECK-NEWLOWERING-LABEL: sudot:
; CHECK-NEWLOWERING: // %bb.0: // %entry
; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z1.b
; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z2.b
; CHECK-NEWLOWERING-NEXT: ptrue p0.s
; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z3.h
; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z1.h
; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z2.h
; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z3.s, z4.s
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
; CHECK-NEWLOWERING-NEXT: usdot z0.s, z2.b, z1.b
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
Expand Down