-
Notifications
You must be signed in to change notification settings - Fork 520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fold slice+copy_ into index_put_ #1901
Conversation
7772040
to
8be8dee
Compare
The t5small_torchscript_0227_transformers4.26.0.mlir : https://gist.github.com/AmosLewis/20895fa06e8b61ce83e9a2b9c2ee8ca6
|
python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py
Outdated
Show resolved
Hide resolved
I think it's failing because this implementation doesn't support negative values in the indices for slice. I'm going to make a separate PR with that so that this one doesn't get too big. |
28f796b
to
73b7720
Compare
I see it failing on LTC. Should be a matter of adding it to that list |
c656a34
to
db56192
Compare
Negative indices fix #1917 |
Okay, this should be good to merge now @ramiro050 |
b2ae471
to
acdfeb5
Compare
Ramiro Two thoughts:
|
fffe445
to
1217d15
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! A few comments
python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py
Outdated
Show resolved
Hide resolved
t5_small_torchscript_test2.mlir
|
@AmosLewis @ramiro050 Since this is green on the tests, can we land this and move the 9223372036854775807 thing to a separate issue? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a small change request and a question
@ramiro050 @gpetters94 Here is the small python script to repeat this INTMAX bug. It looks like a make_fx error https://gist.github.com/AmosLewis/b53dfecbe4918618031cac01c8a88fb9#file-test_intmax-py-L41 |
@gpetters94 @ramiro050 I think the implementation of the indices for the indexput is wrong. Here is the Python code I just used to test.
|
I think this is a Python error |
This Python error is manually created from what recompose pass generates. I find this bug is because I trying to lower the Aten_IndexPutImpl_Op to tosa scatter, and the indices doesnot make sense to me. |
What is the original |
https://gist.github.com/AmosLewis/c6007c2154fedd51081faaee903a1b2c |
/*pin_memory=*/noneVal); | ||
|
||
SmallVector<Value> indicesVector; | ||
for (auto i = 0; i < dim - 1; i++) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be i < dim
not i < dim - 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the correct indices should be (torch.tensor([0, 0, 0]), torch.tensor([1, 2, 3]))
, the first for dim0, the second for dim1, not the (torch.tensor([1, 2, 3]))
. In python code is a[[0, 0, 0], [1, 2, 3]]
. I don't think the convert the dim-1 to dim would fix this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After convert the dim-1 to the dim, get the indices like this
%6 = torch.prim.ListConstruct %none, %5 : (!torch.none, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
compare to the preiveous dim-1
%6 = torch.prim.ListConstruct %5 : (!torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In [1]: import torch
In [2]: torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]),
...: (None, torch.tensor([1, 2, 3]),),
...: torch.tensor([[4, 5, 6]]))
Out[2]: tensor([[0, 4, 5, 6]])
That should work as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it correct.
And as you can see in the previous comment. This is the slice
which is also equal to
|
This folds patterns of
fold
and thencopy_
intoindex_put_
, since the initial pattern doesn't work withMaximizeValueSemantics
.