Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support aten._trilinear and improve einsum decomposition #3784

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

stbaione
Copy link
Contributor

@stbaione stbaione commented Oct 11, 2024

Tracking

Issue
TorchToLinalg Op Support

Description

Aten_TrilinearOp is an implementation of a "trilinear einstein sum". Essentially, just an einsum across 3 tensors.

There are a few inputs:

Tensor Inputs

  • i1, i2, i3 - The three input tensors for the _trilinear op.

Expands

These inputs allow you to unsqueeze an input tensor at the specified dims as a pre-processing step to make the shapes compatible for the rest of the op:

  • expand1: List[int], expand2: List[int], expand3: List[int]

sumdim

  • sumdim: List[int] - After applying element wise multiplication, the values in sumdim denote where to collapse a dimension by summing over it

unroll_dim

  • unroll_dim: int - In the PyTorch implementation, this specifies a dimension where you could slice the input tensors, multiply and sum them, then concatenate the results in an output tensor. This complicates the implementation significantly, but doesn't change the result, so I opted against it. Along with that, a previously accepted path for solving this involved reusing the AtenEinsumOp, which also would also ignore this input.

Solution

After trying a bunch of more complicated approaches for it, this op actually ended up being quite simple: See _trilinear

_trilinear = (i1.unsqueeze(expand1) * i2.unsqueeze(expand2) * i3.unsqueeze(expand3)).sum(sumdim)

Wish I saw this earlier, but watcha gonna do: 🙃

Not Reusing AtenEinsumOp

Frankly, I found multiple cases where valid inputs would have numerical mismatches for EinsumOp, even when running tests against EinsumOp directly. I think it has something to do with the singleton dimensions. Will need to look into this further, but once I realized the simplified approach, it appeared to be more reliable and much simpler.

Either way (credit to @zjgarvey), there are improvements to the einsum op here. When I was originally trying to use the op, intermediate tensors were being flattened properly, but then its 0th dimension was being cast from a static dim to a dynamic dim due to integers not folding correctly in the MLIR. Figured it's worth keeping these improvements for future reusers of EinsumOp.

The zero'd out dim "bug"

For some reason, if you specify a dimension in all expands,

[expand1=[0], expand2=[0], expand3=[0]],
[expand1=[1], expand2=[1], expand3=[1]]

The _trilinear op would specify 0 for that dimension in the output shape, unless it was also included in sumdim. This goes against the implementation of torch.einsum:

>>> a, b, c = [torch.rand(1, 3, 3, 3) for i in range(3)] # Simulate expand at dim=0 for all input tensors
>>> torch.einsum('abcd,abcd,abcd->abcd', a, b, c).shape
torch.Size([1, 3, 3, 3])

And is just straight up incorrect mathematically. I considered "replacing" singleton dims with zeroed out dims, but that seemed like carrying over a bug. Instead, I included a test for the case, verified that the singleton dimensions were handled the way that torch.einsum handles it, instead of torch._trilinear, and xfailed it with a note as to why.

…ts a trilinear einstein sum.

WIP, it currently builds, but fails at lowering to linalg
Lowers to torch backend, but unable to lower to linalg
There's a discrepancy between the way that _trilinear and einsum op handles the second test case (in torch python). Troubleshooting this discrepancy to try and figure out why/where the two ops differ.
Add more test cases,
Add PyTorch _trilinear "bug" to xfail set
@stbaione stbaione changed the title Implementation of torch.ops.aten._trilinear [TorchToLinalg] Implementation of torch.ops.aten._trilinear Oct 17, 2024
@stbaione stbaione marked this pull request as ready for review October 18, 2024 00:34
Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Glad you got something working Stephen! The major review points here:

  1. We need to fail the conversion in the cases we don't support. It's not sufficient to xfail the tests for unsupported cases, because a downstream user of the tool isn't going to run a big model and say "Oh this random shape is messed up, it must be that one esoteric e2e test for this one op that I saw one day". We need to report a match failure so that the op actually doesn't get converted, so we don't have model support people spending days debugging a silently failing conversion.
  2. Related to 1. What does unroll dim do? It needs to be included, or if unrolldim !=0 we also need to report an "unimplemented" match failure.
  3. Not really major, but glad we don't have to use einsum. I think the einsum changes are generally good, but it might be better to move them into a different patch. I'm fine leaving them in here, but the commit messaging will seem odd if anyone wants to trace back the history.

@stbaione
Copy link
Contributor Author

stbaione commented Oct 18, 2024

Glad you got something working Stephen! The major review points here:

  1. We need to fail the conversion in the cases we don't support. It's not sufficient to xfail the tests for unsupported cases, because a downstream user of the tool isn't going to run a big model and say "Oh this random shape is messed up, it must be that one esoteric e2e test for this one op that I saw one day". We need to report a match failure so that the op actually doesn't get converted, so we don't have model support people spending days debugging a silently failing conversion.
  2. Related to 1. What does unroll dim do? It needs to be included, or if unrolldim !=0 we also need to report an "unimplemented" match failure.
  3. Not really major, but glad we don't have to use einsum. I think the einsum changes are generally good, but it might be better to move them into a different patch. I'm fine leaving them in here, but the commit messaging will seem odd if anyone wants to trace back the history.

@zjgarvey

  1. Is this referring to: "What about the case where something lies in the triple intersecton of the expand sets? I thought we were going to handle that case."? If so, I can go ahead and add a match failure. Explained the reasoning for why I left it as-is above, but makes sense to be more explicit about that case as to not cause downstream confusion.
  2. Copying reply from above to make it easier to track responses:
The unrollDim allows slicing along a dimension across all tensors. 
Then you can do (slice1 * slice2 * slice3).sum(sumdim), and concat the result to the output tensor. 
It doesn't change the output of the function, and wasn't used in the EinsumOp approach, 
but its intent is to save space by processing the tensors in batches instead of the entire tensors at once.

I can look into extending the solution to use this
  1. I agree, it's way more straightforward doing it this way. Maybe I should edit my PR title to include the changes for einsum op? After merging main, the changes actually seem to have fixed 5 tests that were xfailed in fx_importer_stablehlo pipeline.

@zjgarvey
Copy link
Collaborator

1. Is this referring to: "What about the case where something lies in the triple intersecton of the expand sets? I thought we were going to handle that case."? If so, I can go ahead and add a match failure. Explained the reasoning for why I left it as-is above, but makes sense to be more explicit about that case as to not cause downstream confusion.

If it is a genuine bug, let's at least file an issue in pytorch and emit a warning.
If it is not a genuine bug, then we need to mimic the pytorch behavior.

2. Copying reply from above to make it easier to track responses:
The unrollDim allows slicing along a dimension across all tensors. 
Then you can do (slice1 * slice2 * slice3).sum(sumdim), and concat the result to the output tensor. 
It doesn't change the output of the function, and wasn't used in the EinsumOp approach, 
but its intent is to save space by processing the tensors in batches instead of the entire tensors at once.

I can look into extending the solution to use this

I don't think we will need to implement this, but no point in reporting a match failure if the unrollDim is non-constant. Just make a comment somewhere in the conversion that the unrollDim does not change the result of the operation, so we do not use it in the conversion.

3. I agree, it's way more straightforward doing it this way. Maybe I should edit my PR title to include the changes for einsum op? After merging main, the changes actually seem to have fixed 5 tests that were xfailed in `fx_importer_stablehlo` pipeline.

Ah, good that it resolves some failing tests. We should rename the title to something like "support aten._trilinear and improve einsum decomposition". No TorchToLinalg flag, since this is a decomposition that affects other backends too.

@stbaione stbaione changed the title [TorchToLinalg] Implementation of torch.ops.aten._trilinear support aten._trilinear and improve einsum decomposition Oct 18, 2024
…s not included in sumDim,

Add note in func description that `unrollDim` is unused
@stbaione
Copy link
Contributor Author

1. Is this referring to: "What about the case where something lies in the triple intersecton of the expand sets? I thought we were going to handle that case."? If so, I can go ahead and add a match failure. Explained the reasoning for why I left it as-is above, but makes sense to be more explicit about that case as to not cause downstream confusion.

If it is a genuine bug, let's at least file an issue in pytorch and emit a warning. If it is not a genuine bug, then we need to mimic the pytorch behavior.

PyTorch bug filed here and emitWarning added.

2. Copying reply from above to make it easier to track responses:
The unrollDim allows slicing along a dimension across all tensors. 
Then you can do (slice1 * slice2 * slice3).sum(sumdim), and concat the result to the output tensor. 
It doesn't change the output of the function, and wasn't used in the EinsumOp approach, 
but its intent is to save space by processing the tensors in batches instead of the entire tensors at once.

I can look into extending the solution to use this

I don't think we will need to implement this, but no point in reporting a match failure if the unrollDim is non-constant. Just make a comment somewhere in the conversion that the unrollDim does not change the result of the operation, so we do not use it in the conversion.

Comment that unrollDim does not impact output and is unused included in function description.

3. I agree, it's way more straightforward doing it this way. Maybe I should edit my PR title to include the changes for einsum op? After merging main, the changes actually seem to have fixed 5 tests that were xfailed in `fx_importer_stablehlo` pipeline.

Ah, good that it resolves some failing tests. We should rename the title to something like "support aten._trilinear and improve einsum decomposition". No TorchToLinalg flag, since this is a decomposition that affects other backends too.

Updated title of PR to: "support aten._trilinear and improve einsum decomposition"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants