Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing support for handling negative step values for Onnx.Slice op #824

Closed
pdhirajkumarprasad opened this issue Sep 9, 2024 · 13 comments
Closed
Assignees

Comments

@pdhirajkumarprasad
Copy link

Getting error as ''linalg.conv_2d_nchw_fchw' op inferred input/output operand #1 has shape's dimension #1 to be 4, but found 3' for following IR

module {
  func.func @torch_jit(%arg0: !torch.vtensor<[1,3,240,240],f32>, %arg1: !torch.vtensor<[1],si64> , %arg2:!torch.vtensor<[32,3,3,3],f32>, %arg3: !torch.vtensor<[32],f32>) -> !torch.vtensor<[?,32,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.12.1"} {
    %233 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_onnx__Concat_1461> : tensor<4xsi64>} : () -> !torch.vtensor<[4],si64> 
    %238 = torch.operator "onnx.ConstantOfShape"(%arg1) {torch.onnx.value = dense_resource<_> : tensor<1xsi64>} : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64> 
    %239 = torch.operator "onnx.Concat"(%233, %238) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[8],si64> 
    %240 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__1> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> 
    %241 = torch.operator "onnx.Reshape"(%239, %240) : (!torch.vtensor<[8],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,2],si64> 
    %242 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__2> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %243 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__3> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %244 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__4> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %245 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__5> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %246 = torch.operator "onnx.Slice"(%241, %243, %arg1, %242, %245) : (!torch.vtensor<[4,2],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,2],si64> 
    %247 = torch.operator "onnx.Transpose"(%246) {torch.onnx.perm = [1 : si64, 0 : si64]} : (!torch.vtensor<[4,2],si64>) -> !torch.vtensor<[2,4],si64> 
    %248 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__6> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %249 = torch.operator "onnx.Reshape"(%247, %248) : (!torch.vtensor<[2,4],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[8],si64> 
    %250 = torch.operator "onnx.Cast"(%249) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[8],si64>) -> !torch.vtensor<[8],si64> 
    %251 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__7> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %252 = torch.operator "onnx.Pad"(%arg0, %250, %251) {torch.onnx.mode = "constant"} : (!torch.vtensor<[1,3,240,240],f32>, !torch.vtensor<[8],si64>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,?,?,?],f32> 
    %253 = torch.operator "onnx.Conv"(%252, %arg2, %arg3) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[32,3,3,3],f32>, !torch.vtensor<[32],f32>) -> !torch.vtensor<[?,32,?,?],f32> 
    return %253 : !torch.vtensor<[?,32,?,?],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      _onnx__Concat_1461: "0x080000000000000000000000010000000000000000000000000000000100000000000000",
      _: "0x080000000000000000000000",
      __1: "0x08000000FFFFFFFFFFFFFFFF0200000000000000",
      __2: "0x080000000000000000000000",
      __3: "0x08000000FFFFFFFFFFFFFFFF",
      __4: "0x080000000100000000000080",
      __5: "0x08000000FFFFFFFFFFFFFFFF",
      __6: "0x08000000FFFFFFFFFFFFFFFF",
      __7: "0x0800000000000000"
    }
  }
#-}

command:

iree-compile --iree-hal-target-backends=llvm-cpu model.torch_onnx.mlir

error:

model.torch_onnx.mlir:19:12: error: 'linalg.conv_2d_nchw_fchw' op inferred input/output operand #1 has shape's dimension #1 to be 4, but found 3
    %253 = torch.operator "onnx.Conv"(%252, %arg2, %arg3) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[32,3,3,3],f32>, !torch.vtensor<[32],f32>) -> !torch.vtensor<[?,32,?,?],f32> 
           ^
model.torch_onnx.mlir:19:12: note: see current operation: 
%13 = "linalg.conv_2d_nchw_fchw"(%9, %5, %12) <{dilations = dense<1> : vector<2xi64>, operandSegmentSizes = array<i32: 2, 1>, strides = dense<2> : vector<2xi64>}> ({
^bb0(%arg9: f32, %arg10: f32, %arg11: f32):
  %22 = "arith.mulf"(%arg9, %arg10) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
  %23 = "arith.addf"(%arg11, %22) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
  "linalg.yield"(%23) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 * 2 + d5, d3 * 2 + d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>]} : (tensor<2x4x240x240xf32>, tensor<32x3x3x3xf32>, tensor<2x32x119x119xf32>) -> tensor<2x32x119x119xf32>
@jinchen62
Copy link
Contributor

Created an issue on IREE before iree-org/iree#18387, it fails on iree-compile.

@pdhirajkumarprasad
Copy link
Author

model.onnx.txt

@vivekkhandelwal1
Copy link
Contributor

func.func @torch_jit(%arg0: !torch.vtensor<[1,3,240,240],f32>, %arg1: !torch.vtensor<[1],si64> , %arg2:!torch.vtensor<[32,3,3,3],f32>, %arg3: !torch.vtensor<[32],f32>) -> !torch.vtensor<[?,32,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.12.1"} {

It seems that the issue is with the IR itself, since all the op lowerings are working as expected. Below is the line by line detailed analysis of the IR along with the inputs/outputs:

module {
  func.func @torch_jit(%arg0: !torch.vtensor<[1,3,240,240],f32>, %arg1: !torch.vtensor<[1],si64> , %arg2:!torch.vtensor<[32,3,3,3],f32>, %arg3: !torch.vtensor<[32],f32>) -> !torch.vtensor<[?,32,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.12.1"} {
    %233 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_onnx__Concat_1461> : tensor<4xsi64>} : () -> !torch.vtensor<[4],si64> 
    // %233 = [0, 1, 0, 1]
    %238 = torch.operator "onnx.ConstantOfShape"(%arg1) {torch.onnx.value = dense_resource<_> : tensor<1xsi64>} : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64> 
    // // [0, 0, 0, 0]
    %239 = torch.operator "onnx.Concat"(%233, %238) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[8],si64> 
    // // [1, 0, 1, 0, 0, 0, 0, 0]
    %240 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__1> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> 
    // %240 = [4, 2]
    %241 = torch.operator "onnx.Reshape"(%239, %240) : (!torch.vtensor<[8],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,2],si64> 
    // [[1, 0], 
    //  [1, 0],
    //  [0, 0],
    //  [0, 0],
    // ]
    %242 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__2> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    // %242 = [0]
    %243 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__3> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    // %243 = [-1]
    %244 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__4> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    // %244 = [1]
    %245 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__5> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    // %245 = [-1]
    %246 = torch.operator "onnx.Slice"(%241, %243, %arg1, %242, %245) : (!torch.vtensor<[4,2],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,2],si64> 
    // Data, starts, ends, axes, steps
    // Data = 4x2xi64=[
      // [0, 1],
      // [0, 1],
      // [0, 0],
      // [0, 0],
    // ]
    // Starts = -1
    // ends = 4
    // axes = 0
    // steps = -1
    // %246 = Result = 4x2xi64=[
      // [0, 0],
      // [0, 0],
      // [0, 1],
      // [0, 1],
    // ]
    %247 = torch.operator "onnx.Transpose"(%246) {torch.onnx.perm = [1 : si64, 0 : si64]} : (!torch.vtensor<[4,2],si64>) -> !torch.vtensor<[2,4],si64> 
    // Onnx.Transpose(Result, 1, 0)
    // Input = 4x2xi64, Result = 2x4xi64
    // Transpose_Result = [
    //    [0, 0, 0, 0]
    //    [1, 1, 0, 0]
    // ]
    %248 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__6> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    // %248 = [-1]
    %249 = torch.operator "onnx.Reshape"(%247, %248) : (!torch.vtensor<[2,4],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[8],si64> 
    // Reshape(Transpose_result, %248)
    // Input = 2x4xi64, Result = 8xi64
    // %249 = Result_reshape = [0, 0, 0, 0, 1, 1, 0, 0]
    %250 = torch.operator "onnx.Cast"(%249) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[8],si64>) -> !torch.vtensor<[8],si64> 
    %251 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__7> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    // %251 = [0]
    %252 = torch.operator "onnx.Pad"(%arg0, %250, %251) {torch.onnx.mode = "constant"} : (!torch.vtensor<[1,3,240,240],f32>, !torch.vtensor<[8],si64>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,?,?,?],f32> 
    // %252 = Input = [1x3x240x240xf32], Output = [2x4x240x240xf32]
    %253 = torch.operator "onnx.Conv"(%252, %arg2, %arg3) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[32,3,3,3],f32>, !torch.vtensor<[32],f32>) -> !torch.vtensor<[?,32,?,?],f32> 
    return %253 : !torch.vtensor<[?,32,?,?],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      _onnx__Concat_1461: "0x080000000000000000000000010000000000000000000000000000000100000000000000",
      _: "0x080000000000000000000000",
      __1: "0x08000000FFFFFFFFFFFFFFFF0200000000000000",
      __2: "0x080000000000000000000000",
      __3: "0x08000000FFFFFFFFFFFFFFFF",
      __4: "0x080000000100000000000080",
      __5: "0x08000000FFFFFFFFFFFFFFFF",
      __6: "0x08000000FFFFFFFFFFFFFFFF",
      __7: "0x0800000000000000"
    }
  }
#-}

Link to the gist containing above analysis: https://gist.github.com/vivekkhandelwal1/c581d7c2a09b14f19519d3d6c10f7004

@zjgarvey
Copy link
Collaborator

zjgarvey commented Sep 25, 2024

Oh man, this must have been exported from pytorch.

All of the garbage leading into %250 is literally to re-arrange a constant list of ints, which we have to re-arrange back to the original when lowering torch-onnx to torch.

@zjgarvey
Copy link
Collaborator

Here is a repro of the real issue:

module {
  func.func @torch_jit(%arg0: !torch.vtensor<[1,3,240,240],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.12.1"} {
    %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_onnx__ConstantOfShape_1460> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_onnx__Concat_1461> : tensor<4xsi64>} : () -> !torch.vtensor<[4],si64> 
    %none = torch.constant.none
    %2 = torch.operator "onnx.ConstantOfShape"(%0) {torch.onnx.value = dense_resource<_> : tensor<1xsi64>} : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64> 
    %3 = torch.operator "onnx.Concat"(%1, %2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[8],si64> 
    %4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__1> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> 
    %5 = torch.operator "onnx.Reshape"(%3, %4) : (!torch.vtensor<[8],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,2],si64> 
    %6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__2> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %7 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__3> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__4> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %9 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__5> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %10 = torch.operator "onnx.Slice"(%5, %7, %8, %6, %9) : (!torch.vtensor<[4,2],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,2],si64> 
    %11 = torch.operator "onnx.Transpose"(%10) {torch.onnx.perm = [1 : si64, 0 : si64]} : (!torch.vtensor<[4,2],si64>) -> !torch.vtensor<[2,4],si64> 
    %12 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__6> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %13 = torch.operator "onnx.Reshape"(%11, %12) : (!torch.vtensor<[2,4],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[8],si64> 
    %14 = torch.operator "onnx.Cast"(%13) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[8],si64>) -> !torch.vtensor<[8],si64> 
    %15 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__7> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %16 = torch.operator "onnx.Pad"(%arg0, %14, %15) {torch.onnx.mode = "constant"} : (!torch.vtensor<[1,3,240,240],f32>, !torch.vtensor<[8],si64>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,?,?,?],f32> 
    return %16 : !torch.vtensor<[?,?,?,?],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      _onnx__ConstantOfShape_1460: "0x080000000400000000000000",
      _onnx__Concat_1461: "0x080000000000000000000000010000000000000000000000000000000100000000000000",
      _: "0x080000000000000000000000",
      __1: "0x08000000FFFFFFFFFFFFFFFF0200000000000000",
      __2: "0x080000000000000000000000",
      __3: "0x08000000FFFFFFFFFFFFFFFF",
      __4: "0x080000000100000000000080",
      __5: "0x08000000FFFFFFFFFFFFFFFF",
      __6: "0x08000000FFFFFFFFFFFFFFFF",
      __7: "0x0800000000000000"
    }
  }
#-}

I ran this in the test suite by truncating the model "efficientnet_b1_pruned.in1k" at Pad node 0, then printed out the output and gold output shapes:

shapes : gold: torch.Size([1, 3, 241, 241]) , iree: torch.Size([2, 4, 240, 240])

We need to figure out a better way to handle the insanity generated by pytorch exports of the pad op.

@vivekkhandelwal1
Copy link
Contributor

vivekkhandelwal1 commented Sep 26, 2024

Here is a repro of the real issue:

module {
  func.func @torch_jit(%arg0: !torch.vtensor<[1,3,240,240],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.12.1"} {
    %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_onnx__ConstantOfShape_1460> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_onnx__Concat_1461> : tensor<4xsi64>} : () -> !torch.vtensor<[4],si64> 
    %none = torch.constant.none
    %2 = torch.operator "onnx.ConstantOfShape"(%0) {torch.onnx.value = dense_resource<_> : tensor<1xsi64>} : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64> 
    %3 = torch.operator "onnx.Concat"(%1, %2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[8],si64> 
    %4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__1> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> 
    %5 = torch.operator "onnx.Reshape"(%3, %4) : (!torch.vtensor<[8],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,2],si64> 
    %6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__2> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %7 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__3> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__4> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %9 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__5> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %10 = torch.operator "onnx.Slice"(%5, %7, %8, %6, %9) : (!torch.vtensor<[4,2],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,2],si64> 
    %11 = torch.operator "onnx.Transpose"(%10) {torch.onnx.perm = [1 : si64, 0 : si64]} : (!torch.vtensor<[4,2],si64>) -> !torch.vtensor<[2,4],si64> 
    %12 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__6> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %13 = torch.operator "onnx.Reshape"(%11, %12) : (!torch.vtensor<[2,4],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[8],si64> 
    %14 = torch.operator "onnx.Cast"(%13) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[8],si64>) -> !torch.vtensor<[8],si64> 
    %15 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__7> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %16 = torch.operator "onnx.Pad"(%arg0, %14, %15) {torch.onnx.mode = "constant"} : (!torch.vtensor<[1,3,240,240],f32>, !torch.vtensor<[8],si64>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,?,?,?],f32> 
    return %16 : !torch.vtensor<[?,?,?,?],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      _onnx__ConstantOfShape_1460: "0x080000000400000000000000",
      _onnx__Concat_1461: "0x080000000000000000000000010000000000000000000000000000000100000000000000",
      _: "0x080000000000000000000000",
      __1: "0x08000000FFFFFFFFFFFFFFFF0200000000000000",
      __2: "0x080000000000000000000000",
      __3: "0x08000000FFFFFFFFFFFFFFFF",
      __4: "0x080000000100000000000080",
      __5: "0x08000000FFFFFFFFFFFFFFFF",
      __6: "0x08000000FFFFFFFFFFFFFFFF",
      __7: "0x0800000000000000"
    }
  }
#-}

I ran this in the test suite by truncating the model "efficientnet_b1_pruned.in1k" at Pad node 0, then printed out the output and gold output shapes:

shapes : gold: torch.Size([1, 3, 241, 241]) , iree: torch.Size([2, 4, 240, 240])

We need to figure out a better way to handle the insanity generated by pytorch exports of the pad op.

Hi @zjgarvey, can you please let me know how did you do this? I would like to see the result of onnx.reshape just before the pad op compared with the golden result. I was doing this with a manual script generating the .onnx file and running e2e through alt_e2eshark with the help of @pdhirajkumarprasad.

@zjgarvey
Copy link
Collaborator

@vivekkhandelwal1 This test was generated by the following python code (added to azure_models.py, for example).

from onnx_tests.helper_classes import TruncatedModel, get_trucated_constructor

class UpdateInit(TruncatedModel):
    def __init__(self, n: int, op_type: str, *args, **kwargs):
        super().__init__(n, op_type, *args, **kwargs)
        self.opset_version = 21
        self.update_opset_version_and_overwrite()

const = get_trucated_constructor(UpdateInit, AzureDownloadableModel, "efficientnet_b1_pruned.in1k")

register_test(const(0, "Pad"), f"efficientnet_pad_repro_0")

You could add tests for the first few reshape ops by

for i in range(0,3):
   register_test(const(i,"Reshape"),f'efficientnet_Reshape_{i}')

@vivekkhandelwal1
Copy link
Contributor

@vivekkhandelwal1 This test was generated by the following python code (added to azure_models.py, for example).

from onnx_tests.helper_classes import TruncatedModel, get_trucated_constructor

class UpdateInit(TruncatedModel):
    def __init__(self, n: int, op_type: str, *args, **kwargs):
        super().__init__(n, op_type, *args, **kwargs)
        self.opset_version = 21
        self.update_opset_version_and_overwrite()

const = get_trucated_constructor(UpdateInit, AzureDownloadableModel, "efficientnet_b1_pruned.in1k")

register_test(const(0, "Pad"), f"efficientnet_pad_repro_0")

You could add tests for the first few reshape ops by

for i in range(0,3):
   register_test(const(i,"Reshape"),f'efficientnet_Reshape_{i}')

I followed these steps but I'm getting the following error:

Failed test at stage setup with exception:
Error parsing message with type 'onnx.ModelProto'
Traceback (most recent call last):
  File "/home/azureuser/work/SHARK-TestSuite/alt_e2eshark/./run.py", line 161, in run_tests
    inst = t.model_constructor(t.unique_name, log_dir)
  File "/home/azureuser/work/SHARK-TestSuite/alt_e2eshark/onnx_tests/helper_classes.py", line 161, in <lambda>
    lambda *args, **kwargs: truncated_class(
  File "/home/azureuser/work/SHARK-TestSuite/alt_e2eshark/onnx_tests/models/azure_models.py", line 37, in __init__
    self.update_opset_version_and_overwrite()
  File "/home/azureuser/work/SHARK-TestSuite/alt_e2eshark/e2e_testing/framework.py", line 120, in update_opset_version_and_overwrite
    self.construct_model()
  File "/home/azureuser/work/SHARK-TestSuite/alt_e2eshark/onnx_tests/helper_classes.py", line 129, in construct_model
    og_model = onnx.load(self.sibling_inst.model)
  File "/home/azureuser/work/mlir_venv/lib/python3.10/site-packages/onnx/__init__.py", line 210, in load_model
    model = _get_serializer(format, f).deserialize_proto(_load_bytes(f), ModelProto())
  File "/home/azureuser/work/mlir_venv/lib/python3.10/site-packages/onnx/serialization.py", line 118, in deserialize_proto
    decoded = typing.cast(Optional[int], proto.ParseFromString(serialized))
google.protobuf.message.DecodeError: Error parsing message with type 'onnx.ModelProto'

@vivekkhandelwal1
Copy link
Contributor

So the issue here is that because of the negative step value in the Onnx.Slice op and that negative step value going down in the pipeline generates the following IR:

module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @torch_jit() -> tensor<4x2xi64> {
    %cst = arith.constant dense<[[0, 1], [0, 1], [0, 0], [0, 0]]> : tensor<4x2xi64>
    %extracted_slice = tensor.extract_slice %cst[3, 0] [4, 2] [-1, 1] : tensor<4x2xi64> to tensor<4x2xi64>
    return %extracted_slice : tensor<4x2xi64>
  }
}

Now the issue is that the MLIR doesn't support negative strides(AFAIK), and hence it's giving wrong result(which is also wrong, ideally it should have errored out).

Ideally the result of this IR should be:

[
  [0, 0],
  [0, 0],
  [0, 1],
  [0, 1]
]

but the result generated through iree-compilation and execution is:

[
  [0, 1],
  [0, 1],
  [0, 0],
  [0, 0]
]

and because of this the Onnx.Pad op which takes an input of shape 1x3x240x240 outputs a result tensor of shape 2x4x240x240 instead of 1x3x241x241, hence making this tensor incompatible for the Onnx.Conv op.

Solution: The Onnx.Slice and torch.aten.slice lowering has to be extended to handle the cases of negative steps.

P.S. The reason why torch.aten.slice lowering does not handle the negative step values is because the PyTorch Slice Op does not allow the step value to be negative, and hence there was no need to handle that.

@vivekkhandelwal1 vivekkhandelwal1 changed the title 'linalg.conv_2d_nchw_fchw' op inferred input/output operand #1 has shape's dimension #1 to be 4, but found 3 Missing support for handling negative step values for Onnx.Slice op Sep 27, 2024
@nirvedhmeshram
Copy link
Contributor

nirvedhmeshram commented Sep 28, 2024

@MaheshRavishankar to check that maybe we should support negative strides and that is also a bug in the lowering? I think negative strides (and offsets) should actually be allowed by the semantics of tensor.extract_slice op. But fixing this in the Onnx.Slice and/or torch.aten.slice is also reasonable solution if this is expecting "too much" from the compiler. The concern is that what if the value of stride is determined at runtime and is negative, then the compiler will have to support it.

The other broken thing seems, with an offset of [3,0] the output should just go out of bounds and be garbage if its not understanding the negative stride, so not sure how we got the non flipped output.

@vivekkhandelwal1
Copy link
Contributor

@MaheshRavishankar to check that maybe we should support negative strides and that is also a bug in the lowering? I think negative strides (and offsets) should actually be allowed by the semantics of tensor.extract_slice op. But fixing this in the Onnx.Slice and/or torch.aten.slice is also reasonable solution if this is expecting "too much" from the compiler. The concern is that what if the value of stride is determined at runtime and is negative, then the compiler will have to support it.

The other broken thing seems, with an offset of [3,0] the output should just go out of bounds and be garbage if its not understanding the negative stride, so not sure how we got the non flipped output.

For now, I'm adding the support in Onnx->Torch lowering. For the rest, we may have broader discussions.

@vivekkhandelwal1
Copy link
Contributor

Fixed by llvm/torch-mlir#3763.

CC: @pdhirajkumarprasad

@vivekkhandelwal1
Copy link
Contributor

Closing this since the fix llvm/torch-mlir#3763 is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants