-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[AArch64][SVE] Add lowering for PARTIAL_REDUCE_U/SMLA to USDOT #131327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
NickGuy-Arm
wants to merge
5
commits into
llvm:main
Choose a base branch
from
NickGuy-Arm:JamesChesterman/legal-partial-reduction/4
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
a20bced
[AArch64][SVE] Add lowering for PARTIAL_REDUCE_U/SMLA to USDOT
NickGuy-Arm 7a62406
Add calls to setPartialReduceMLAAction.
NickGuy-Arm 7868964
Rebase and update tests
NickGuy-Arm d40773d
Adjust how usdot cases are lowered
NickGuy-Arm 22636ac
Adjust how usdot cases are lowered
NickGuy-Arm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't
LegalizeVectorOps
handle this?getPartialReduceMLAAction()
is already hooked up there and should be able to call into the custom lowering?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this done to bypass type legalization for the
usdot_8to64
case? Could we handle that instead by adding a combine that reduces accumulators of<vscale x 4 x i64>
to<vscale x 4 x i32>
followed by a extend (for i8 inputs)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NickGuy-Arm I suspect you did this to work around type legalisation? At the point of doing Custom lowering, all the types must be legal. If the extends would be the same, then as @MacDue says it would be handled in LegalizeVectorOps. It's just that the operands (to be sign/zero-extended) have not been folded into the operation yet, because UMLA/SMLA doesn't support mixed extends, hence why the types can't be legalised the normal way.
The way to handle this case is to either:
(1) Implement this mapping to an AArch64ISD node with an AArch64 DAG combine that runs before type legalisation.
or:
(2)Create a separate
PARTIAL_REDUCE_USMLA
node, which would go through the regular flow of type legalisation.The downside of (1) is that we don't get any type-legalisation, so any unsupported types would need to be handled in that particular DAG combine basically requiring it to do type-legalisation. I think (2) can piggy-back on most of the type legalisation added for UMLA/SMLA, with some small changes.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type coming in via
VT
is the type of the operand of thepartial_reduce_umla
node, which is an extend, so it effectively hides the actual operand type at this stage. We need to use the pre-extended type to figure out whetherUSDOT
is valid to emit, and the type legalization step obscures this type by splitting across multiple sets ofpartial_reduce_umla
andextract_subvector
nodes, meaning we'd have to check significantly more nodes/paths to verify the validity.I don't think the pre-legalization DAG combine would work for the reasons you pointed out, but in trying to implement the separate node, I encountered the exact same issues as we hit without the above call to
getPartialReduceMLAAction
.I've added an operation action for
ISD::PARTIAL_REDUCE_UMLA
withnxv16i32
, which is the post-extended type ofnxv16i8
, and we can have the existing validation withinLowerPARTIAL_REDUCE_MLAToUSDOT
decide whether it can actually be lowered to USDOT (falling back to unpacks andmla
s if not).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've reimplemented this check, as I believe it is the simplest solution to this problem. For
USDOT
lowering to function, it needs to happen pre-legalization because it deals with illegal intermediate types (which are then flattened out by replacing the nodes with theUSDOT
ISD node).As the partialReduceMLA LegalizeActions are handled differently from the standard operation actions, we need to check the relevant action to take.
This check is simply the required plumbing to have the legalizer respect when a target says that it has custom lowering for a given partial reduction. If we try to pack the information into the operation actions, we lose the ability to filter based on what the partial reduction is reducing from. And trying to move the check to post-legalization we lose direct access to the pre-extend type, as the nodes required to legalize the type obscure it through multiple extends or AArch64ISD unpack nodes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What reasons were you referring to here? I would expect this pre-type-legalization DAG combine to recognise the pattern
partial.reduce.add(a, mul(sext(b), zext(c)), splat(1)) -> AArch64ISD::sudot(a, b, c)
. At this point, there shouldn't be any uunpklo/hi instructions yet.Are you talking about option (2), create a new ISD::PARTIAL_REDUCE_USMLA node? If so, can you elaborate on the issues you encountered? (I'd expect it to function roughly the same as the PARTIAL_REDUCE_UMLA node for example)