Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOSA] Add aten.Index.Tensor support #1771

Merged
merged 1 commit into from
Jan 20, 2023
Merged

Conversation

AmosLewis
Copy link
Collaborator

@AmosLewis AmosLewis commented Jan 3, 2023

Find this error in nod-ai/SHARK-Studio#494
torch.ops.aten.index(input, (index, ))

  %530 = torch_c.from_builtin_tensor %24 : tensor<1xi64> -> !torch.vtensor<[1],si64>
  %531 = "tosa.slice"(%529) {size = [1, 1, 2], start = [0, 127, 0]} : (tensor<1x128x2xf32>) -> tensor<1x1x2xf32>
  %532 = "tosa.reshape"(%531) {new_shape = [1, 2]} : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
  %533 = torch_c.from_builtin_tensor %532 : tensor<1x2xf32> -> !torch.vtensor<[1,2],f32>
  %534 = torch.prim.ListConstruct %530 : (!torch.vtensor<[1],si64>) -> !torch.list<vtensor>
  %535 = torch.aten.index.Tensor %533, %534 : !torch.vtensor<[1,2],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,2],f32>
  %536 = torch_c.to_builtin_tensor %535 : !torch.vtensor<[1,2],f32> -> tensor<1x2xf32>
func.func @torch.aten.index.Tensor(%arg0: !torch.vtensor<[1],si64>, %arg1: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,2],f32> {
  %0 = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[1],si64>) -> !torch.list<vtensor>
  %1 = torch.aten.index.Tensor %arg1, %0 : !torch.vtensor<[1,2],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,2],f32>
  return %1 : !torch.vtensor<[1,2],f32>
}

We could use the tf gather_nd(which is already implemented in tosa) to get the torch.index. The trick is to expand/unsqueeze the last dim of the index.
python code for algorithm explanation:

import torch

t = torch.tensor([ [1, 2, 3, 4, 5], [6,7,8,9,10],[11,12,13,14,15],[16,17,18,19,20]]) # 4*5
i = torch.tensor([[1,2,3], [3,2,1]]) # 2*3

o = t[i] 
  = torch.tensor([ [[ 6,  7,  8,  9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]],  
[[16, 17, 18, 19, 20], [11, 12, 13, 14, 15], [ 6,  7,  8,  9, 10]]]) #  2*3*5
  

########## same function by tensorflow
import tensorflow as tf

t = tf.constant([[1, 2, 3, 4, 5], [6,7,8,9,10],[11,12,13,14,15],[16,17,18,19,20]]) # 4*5

i =  tf.constant([[1,2,3], [3,2,1]]) # 2*3

i_expand = tf.expand_dims(i,axis=2)
<tf.Tensor: shape=(2, 3, 1), dtype=int32, numpy=
array([[[1],[2],[3]],[[3],[2],[1]]], dtype=int32)>

io=tf.gather_nd(t,tf.expand_dims(i,axis=2))
<tf.Tensor: shape=(2, 3, 5), dtype=int32, numpy=
array([[[ 6,  7,  8,  9, 10],[11, 12, 13, 14, 15],[16, 17, 18, 19, 20]],
       [[16, 17, 18, 19, 20], [11, 12, 13, 14, 15],[ 6,  7,  8,  9, 10]]], dtype=int32)>

Final result should looks like this:

// -----
// CHECK-LABEL:   func.func @torch.aten.index.Tensor(
// CHECK-SAME:                                       %[[VAL_0:.*]]: !torch.vtensor<[1],si64>,
// CHECK-SAME:                                       %[[VAL_1:.*]]: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,2],f32> {
// CHECK:           %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,2],f32> -> tensor<1x2xf32>
// CHECK:           %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_0]] : (!torch.vtensor<[1],si64>) -> !torch.list<vtensor>
// CHECK:           %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1],si64> -> tensor<1xi64>
// CHECK:           %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_4]]) : (tensor<1xi64>) -> tensor<1xi32>
// CHECK:           %[[VAL_6:.*]] = "tosa.reshape"(%[[VAL_5]]) {new_shape = array<i64: 1, 1>} : (tensor<1xi32>) -> tensor<1x1xi32>
// CHECK:           %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = array<i64: 1, 1, 2>} : (tensor<1x2xf32>) -> tensor<1x1x2xf32>
// CHECK:           %[[VAL_8:.*]] = "tosa.reshape"(%[[VAL_6]]) {new_shape = array<i64: 1, 1>} : (tensor<1x1xi32>) -> tensor<1x1xi32>
// CHECK:           %[[VAL_9:.*]] = "tosa.const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK:           %[[VAL_10:.*]] = "tosa.mul"(%[[VAL_8]], %[[VAL_9]]) {shift = 0 : i32} : (tensor<1x1xi32>, tensor<1xi32>) -> tensor<1x1xi32>
// CHECK:           %[[VAL_11:.*]] = "tosa.reduce_sum"(%[[VAL_10]]) {axis = 1 : i64} : (tensor<1x1xi32>) -> tensor<1x1xi32>
// CHECK:           %[[VAL_12:.*]] = "tosa.reshape"(%[[VAL_11]]) {new_shape = array<i64: 1, 1>} : (tensor<1x1xi32>) -> tensor<1x1xi32>
// CHECK:           %[[VAL_13:.*]] = "tosa.gather"(%[[VAL_7]], %[[VAL_12]]) : (tensor<1x1x2xf32>, tensor<1x1xi32>) -> tensor<1x1x2xf32>
// CHECK:           %[[VAL_14:.*]] = "tosa.reshape"(%[[VAL_13]]) {new_shape = array<i64: 1, 2>} : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
// CHECK:           %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<1x2xf32> -> !torch.vtensor<[1,2],f32>
// CHECK:           return %[[VAL_15]] : !torch.vtensor<[1,2],f32>
// CHECK:         }
func.func @torch.aten.index.Tensor(%arg0: !torch.vtensor<[1],si64>, %arg1: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,2],f32> {
  %0 = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[1],si64>) -> !torch.list<vtensor>
  %1 = torch.aten.index.Tensor %arg1, %0 : !torch.vtensor<[1,2],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,2],f32>
  return %1 : !torch.vtensor<[1,2],f32>
}

@AmosLewis AmosLewis self-assigned this Jan 3, 2023
@AmosLewis AmosLewis force-pushed the index_tensor branch 2 times, most recently from 6321ff5 to 4b57b24 Compare January 3, 2023 22:29
@AmosLewis AmosLewis marked this pull request as ready for review January 3, 2023 22:29
@AmosLewis
Copy link
Collaborator Author

e2e fail, working on fix it.

@AmosLewis AmosLewis marked this pull request as draft January 5, 2023 17:27
@AmosLewis AmosLewis force-pushed the index_tensor branch 2 times, most recently from fda98b3 to e8f8b7c Compare January 16, 2023 05:15
@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Jan 16, 2023

The e2e bug is fixed.
The new implementation is explained here:
https://gist.github.com/AmosLewis/c90c1148a96291db93408b3fa39f9ae2

@AmosLewis AmosLewis marked this pull request as ready for review January 16, 2023 05:24
@AmosLewis AmosLewis force-pushed the index_tensor branch 2 times, most recently from 40e46b4 to 6ec4143 Compare January 17, 2023 07:37
@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Jan 17, 2023

Add a check to fix IndexTensorMultiInputContiguousCenter e2e test segfault
torch.ops.aten.index(x, (None, index1, index2, None))

// TODO figure out why the index is empty for IndexTensorMultiInputContiguousCenter e2e test
  if (!index.getImpl())
    return rewriter.notifyMatchFailure(
        op, "Only list ranked tensor types index are supported");

@AmosLewis AmosLewis force-pushed the index_tensor branch 2 times, most recently from 90b3c34 to 0b89ed1 Compare January 17, 2023 08:22
@AmosLewis AmosLewis requested a review from ramiro050 January 18, 2023 20:16
@ramiro050 ramiro050 removed their request for review January 18, 2023 20:26
@ramiro050
Copy link
Collaborator

Removing myself from reviewers. I will let @eric-k256 continue leading the review of this patch (happy to help if questions/issues arise)

@AmosLewis AmosLewis merged commit 2587b3f into llvm:main Jan 20, 2023
@AmosLewis AmosLewis deleted the index_tensor branch January 20, 2023 08:11
gpetters94 pushed a commit to gpetters94/mlir-npcomp that referenced this pull request May 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants