Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add e2e test for onnx.ScatterElements / torch.scatter.reduce #363

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

AmosLewis
Copy link
Collaborator

@AmosLewis AmosLewis commented Oct 4, 2024

test torch.scatter.reduce linalg lowering llvm/torch-mlir#3754

torch.scatter.reduce step by step example:

src = [1, 2, 3, 4, 5, 6]
index = [0, 1, 0, 1, 2, 1]
self = [1, 2, 3, 4]
Step 0:
self[index[0]] += src[0]
self[0] += 1  = 1+1 = 2
1+1 = 2
self = [2, 2, 3, 4])

Step 1:
self[index[1]] += src[1]
self[1] += 2  = 2+2 = 4
self = [2, 4, 3, 4])

Step 2:
self[index[2]] += src[2]
self[0] += 3  = 2+3 = 5
self = [5, 4, 3, 4])

Step 3:
self[index[3]] += src[3]
self[1] += 4  = 4+4 = 8
self = [5, 8, 3, 4])

Step 4:
self[index[4]] += src[4]
self[2] += 5  = 3+5 = 8
self = [5, 8, 8, 4])

Step 5:
self[index[5]] += src[5]
self[1] += 6  = 8+6 = 14
self = [5, 14, 8, 4])

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Oct 4, 2024

python -m torch_mlir.tools.import_onnx --opset-version=21 model.onnx -o ScatterElements.default.torch-onnx.mlir ScatterElements.default.torch-onnx.mlir

module {
  func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>} : () -> !torch.vtensor<[6],si64> 
    %1 = torch.operator "onnx.ScatterElements"(%arg0, %0, %arg1) {torch.onnx.axis = 0 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[4],f32>, !torch.vtensor<[6],si64>, !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> 
    return %1 : !torch.vtensor<[4],f32>
  }
}

torch-mlir-opt -pass-pipeline='builtin.module(func.func(convert-torch-onnx-to-torch),torch-lower-to-backend-contract,func.func(cse,canonicalize))' ScatterElements.default.torch-onnx.mlir > ScatterElements.default.onnx.torch.mlir ScatterElements.default.onnx.torch.mlir

module {
  func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %true = torch.constant.bool true
    %str = torch.constant.str "sum"
    %0 = torch.vtensor.literal(dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64>
    %int0 = torch.constant.int 0
    %1 = torch.aten.scatter_reduce.two %arg0, %int0, %0, %arg1, %str, %true : !torch.vtensor<[4],f32>, !torch.int, !torch.vtensor<[6],si64>, !torch.vtensor<[6],f32>, !torch.str, !torch.bool -> !torch.vtensor<[4],f32>
    return %1 : !torch.vtensor<[4],f32>
  }
}

torch-mlir-opt --convert-torch-onnx-to-torch --torch-decompose-complex-ops --cse --canonicalize --convert-torch-to-linalg ScatterElements.default.onnx.torch.mlir > linalg.mlir linalg.mlir

#map = affine_map<(d0) -> (d0, 0)>
#map1 = affine_map<(d0) -> (d0)>
module {
  func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %0 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[6],f32> -> tensor<6xf32>
    %1 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[4],f32> -> tensor<4xf32>
    %true = torch.constant.bool true
    %str = torch.constant.str "sum"
    %2 = torch.vtensor.literal(dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64>
    %3 = torch_c.to_builtin_tensor %2 : !torch.vtensor<[6],si64> -> tensor<6xi64>
    %int0 = torch.constant.int 0
    %c0 = arith.constant 0 : index
    %c6 = arith.constant 6 : index
    %c1 = arith.constant 1 : index
    %4 = arith.muli %c1, %c6 : index
    %5 = arith.index_cast %4 : index to i64
    %6 = arith.index_cast %5 : i64 to index
    %c0_0 = arith.constant 0 : index
    %c6_1 = arith.constant 6 : index
    %c1_2 = arith.constant 1 : index
    %7 = tensor.empty(%6) : tensor<?x1xi32>
    %c0_i32 = arith.constant 0 : i32
    %8 = linalg.fill ins(%c0_i32 : i32) outs(%7 : tensor<?x1xi32>) -> tensor<?x1xi32>
    %9 = tensor.empty(%6) : tensor<?xf32>
    %cst = arith.constant 0.000000e+00 : f32
    %10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<?xf32>) -> tensor<?xf32>
    %11:2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} outs(%8, %10 : tensor<?x1xi32>, tensor<?xf32>) {
    ^bb0(%out: i32, %out_13: f32):
      %16 = linalg.index 0 : index
      %17 = arith.remsi %16, %c6_1 : index
      %18 = arith.divsi %16, %c6_1 : index
      %extracted = tensor.extract %3[%17] : tensor<6xi64>
      %extracted_14 = tensor.extract %0[%17] : tensor<6xf32>
      %19 = arith.index_cast %17 : index to i64
      %20 = arith.trunci %19 : i64 to i32
      %21 = arith.trunci %extracted : i64 to i32
      linalg.yield %21, %extracted_14 : i32, f32
    } -> (tensor<?x1xi32>, tensor<?xf32>)
    %c0_3 = arith.constant 0 : index
    %c0_4 = arith.constant 0 : index
    %c1_5 = arith.constant 1 : index
    %c1_6 = arith.constant 1 : index
    %c1_7 = arith.constant 1 : index
    %12 = tensor.empty(%6) : tensor<?x1xi32>
    %c0_i32_8 = arith.constant 0 : i32
    %13 = linalg.fill ins(%c0_i32_8 : i32) outs(%12 : tensor<?x1xi32>) -> tensor<?x1xi32>
    %c0_9 = arith.constant 0 : index
    %dim = tensor.dim %11#0, %c0_9 : tensor<?x1xi32>
    %c1_10 = arith.constant 1 : index
    %c1_11 = arith.constant 1 : index
    %inserted_slice = tensor.insert_slice %11#0 into %13[0, 0] [%dim, 1] [1, 1] : tensor<?x1xi32> into tensor<?x1xi32>
    %c1_12 = arith.constant 1 : index
    %14 = tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(false) ins(%11#1, %inserted_slice : tensor<?xf32>, tensor<?x1xi32>) outs(%1 : tensor<4xf32>) {
    ^bb0(%arg2: f32, %arg3: f32):
      %16 = arith.addf %arg2, %arg3 : f32
      tm_tensor.yield %16 : f32
    } -> tensor<4xf32>
    %cast = tensor.cast %14 : tensor<4xf32> to tensor<4xf32>
    %15 = torch_c.from_builtin_tensor %cast : tensor<4xf32> -> !torch.vtensor<[4],f32>
    return %15 : !torch.vtensor<[4],f32>
  }
}

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Oct 4, 2024

Pass by most recent patch

Status report for run: test-run using mode:onnx todtype:default backend:llvm-cpu

| tests                          | model-run   | onnx-import   | torch-mlir   | iree-compile   | inference   |
|:-------------------------------|:------------|:--------------|:-------------|:---------------|:------------|
| onnx/operators/ScatterElements | passed      | passed        | passed       | passed         | passed      |

AmosLewis added a commit to llvm/torch-mlir that referenced this pull request Oct 6, 2024
…to tm_tensor/linalg_ext dialect (#3754)

- To fix issue onnx.ScatterElements: nod-ai/SHARK-ModelDev#823
- E2E test: nod-ai/SHARK-TestSuite#363
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant