Fold slice+copy_ into index_put_#1901
Conversation
7772040 to
8be8dee
Compare
|
The t5small_torchscript_0227_transformers4.26.0.mlir : https://gist.github.com/AmosLewis/20895fa06e8b61ce83e9a2b9c2ee8ca6 |
I think it's failing because this implementation doesn't support negative values in the indices for slice. I'm going to make a separate PR with that so that this one doesn't get too big. |
28f796b to
73b7720
Compare
I see it failing on LTC. Should be a matter of adding it to that list |
c656a34 to
db56192
Compare
|
Negative indices fix #1917 |
|
Okay, this should be good to merge now @ramiro050 |
b2ae471 to
acdfeb5
Compare
|
Ramiro Two thoughts:
|
fffe445 to
1217d15
Compare
ramiro050
left a comment
There was a problem hiding this comment.
Looking good! A few comments
|
t5_small_torchscript_test2.mlir |
|
@AmosLewis @ramiro050 Since this is green on the tests, can we land this and move the 9223372036854775807 thing to a separate issue? |
ramiro050
left a comment
There was a problem hiding this comment.
LGTM, just a small change request and a question
@ramiro050 @gpetters94 Here is the small python script to repeat this INTMAX bug. It looks like a make_fx error https://gist.github.com/AmosLewis/b53dfecbe4918618031cac01c8a88fb9#file-test_intmax-py-L41 |
|
@gpetters94 @ramiro050 I think the implementation of the indices for the indexput is wrong. Here is the Python code I just used to test. |
I think this is a Python error |
This Python error is manually created from what recompose pass generates. I find this bug is because I trying to lower the Aten_IndexPutImpl_Op to tosa scatter, and the indices doesnot make sense to me. |
|
What is the original |
https://gist.github.com/AmosLewis/c6007c2154fedd51081faaee903a1b2c |
| /*pin_memory=*/noneVal); | ||
|
|
||
| SmallVector<Value> indicesVector; | ||
| for (auto i = 0; i < dim - 1; i++) |
There was a problem hiding this comment.
this should be i < dim not i < dim - 1
There was a problem hiding this comment.
I think the correct indices should be (torch.tensor([0, 0, 0]), torch.tensor([1, 2, 3])), the first for dim0, the second for dim1, not the (torch.tensor([1, 2, 3])). In python code is a[[0, 0, 0], [1, 2, 3]]. I don't think the convert the dim-1 to dim would fix this issue.
There was a problem hiding this comment.
After convert the dim-1 to the dim, get the indices like this
%6 = torch.prim.ListConstruct %none, %5 : (!torch.none, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
compare to the preiveous dim-1
%6 = torch.prim.ListConstruct %5 : (!torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
There was a problem hiding this comment.
In [1]: import torch
In [2]: torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]),
...: (None, torch.tensor([1, 2, 3]),),
...: torch.tensor([[4, 5, 6]]))
Out[2]: tensor([[0, 4, 5, 6]])That should work as expected.
And as you can see in the previous comment. This is the slice which is also equal to |
This folds patterns of
foldand thencopy_intoindex_put_, since the initial pattern doesn't work withMaximizeValueSemantics.