Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shape-legalize-to-stablehlo emit segmentation fault when there is tensor.extract in IR #2488

Open
qingyunqu opened this issue Aug 10, 2024 · 3 comments
Assignees
Labels

Comments

@qingyunqu
Copy link
Contributor

qingyunqu commented Aug 10, 2024

What happened?

Testcase:

module attributes {torch.debug_module_name = "UnsafeViewCollapseDynamicWithAtenSizeIntModule"} {
  func.func @forward(%arg0: tensor<?x?x?x?x?x?xf32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<?x?x?x?x384xf32> {
    %c384_i64 = arith.constant 384 : i64
    %c0 = arith.constant 0 : index
    %c3 = arith.constant 3 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?x?x?x?x?x?xf32>
    %0 = arith.index_cast %dim : index to i64
    %extracted = tensor.extract %arg1[] : tensor<i64>
    %extracted_0 = tensor.extract %arg2[] : tensor<i64>
    %dim_1 = tensor.dim %arg0, %c3 : tensor<?x?x?x?x?x?xf32>
    %1 = arith.index_cast %dim_1 : index to i64
    %from_elements = tensor.from_elements %0, %extracted, %extracted_0, %1, %c384_i64 : tensor<5xi64>
    %2 = stablehlo.dynamic_reshape %arg0, %from_elements : (tensor<?x?x?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x?x?x?x384xf32>
    return %2 : tensor<?x?x?x?x384xf32>
  }
}

Run stablehlo-opt --shape-legalize-to-stablehlo would reproduce the error.
I prefer it to emit pass fail but segv.

Steps to reproduce your issue

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

Version information

No response

@sdasgup3
Copy link
Member

sdasgup3 commented Aug 12, 2024

Thanks @qingyunqu for pointing poining this issue.

IMO, extracting value from scalar tensors, using empty indices, is not supported yet (from cs).

IMO this is a case of data dependent dynamism which is not supported (ref Dynamism-RFC::Out Of Scope[O6]). At least we should error out instead of just crashing. In any case, will get back to you soon.

@sdasgup3 sdasgup3 self-assigned this Aug 12, 2024
@qingyunqu
Copy link
Contributor Author

Hi, I have another case that failed on shape-legalize-to-stablehlo:

module attributes {torch.debug_module_name = "ElementwiseDivTensorUnsignedIntegerModule"} {
  func.func @forward(%arg0: tensor<?x?xui8>, %arg1: tensor<?x?xui8>) -> tensor<?x?xf32> {
    %0 = stablehlo.convert %arg0 : (tensor<?x?xui8>) -> tensor<?x?xf32>
    %1 = stablehlo.convert %arg1 : (tensor<?x?xui8>) -> tensor<?x?xf32>
    %2 = shape.shape_of %0 : tensor<?x?xf32> -> tensor<2xindex>
    %3 = shape.shape_of %1 : tensor<?x?xf32> -> tensor<2xindex>
    %4 = shape.broadcast %2, %3 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex>
    %5 = stablehlo.dynamic_broadcast_in_dim %0, %4, dims = [0, 1] : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
    %6 = stablehlo.dynamic_broadcast_in_dim %1, %4, dims = [0, 1] : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
    %7 = stablehlo.divide %5, %6 : tensor<?x?xf32>
    return %7 : tensor<?x?xf32>
  }
}

The error message is:

note: see current operation: %0 = "stablehlo.custom_call"(<<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>) <{call_target_name = "stablehlo.shape_refinement_operand_wrapper"}> {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<3x4xi8>, tensor<2xi64>) -> tensor<3x4xui8>
note: - use: %2 = "stablehlo.convert"(<<UNKNOWN SSA VALUE>>) : (tensor<3x4xui8>) -> tensor<?x?xf32>

@ghpvnist
Copy link
Member

The <<UNKNOWN SSA VALUE>> looks like a bug in your code. The MLIR class object likely doesn't have visibility into the SSA values.

Here's what I get from running this

// RUN: stablehlo-opt --shape-legalize-to-stablehlo --split-input-file --verify-diagnostics %s

module attributes {torch.debug_module_name = "ElementwiseDivTensorUnsignedIntegerModule"} {
  func.func @forward(%arg0: tensor<?x?xui8>, %arg1: tensor<?x?xui8>) -> tensor<?x?xf32> {
    %0 = stablehlo.convert %arg0 : (tensor<?x?xui8>) -> tensor<?x?xf32>
    %1 = stablehlo.convert %arg1 : (tensor<?x?xui8>) -> tensor<?x?xf32>
    %2 = shape.shape_of %0 : tensor<?x?xf32> -> tensor<2xindex>
    %3 = shape.shape_of %1 : tensor<?x?xf32> -> tensor<2xindex>
    %4 = shape.broadcast %2, %3 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex>
    %5 = stablehlo.dynamic_broadcast_in_dim %0, %4, dims = [0, 1] : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
    %6 = stablehlo.dynamic_broadcast_in_dim %1, %4, dims = [0, 1] : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
    %7 = stablehlo.divide %5, %6 : tensor<?x?xf32>
    return %7 : tensor<?x?xf32>
  }
}

Output

module attributes {torch.debug_module_name = "ElementwiseDivTensorUnsignedIntegerModule"} {
  func.func @forward(%arg0: tensor<?x?xui8>, %arg1: tensor<?x?xui8>) -> tensor<?x?xf32> {
    %0 = stablehlo.convert %arg0 : (tensor<?x?xui8>) -> tensor<?x?xf32>
    %1 = stablehlo.convert %arg1 : (tensor<?x?xui8>) -> tensor<?x?xf32>
    %2 = stablehlo.get_dimension_size %0, dim = 0 : (tensor<?x?xf32>) -> tensor<i32>
    %3 = stablehlo.reshape %2 : (tensor<i32>) -> tensor<1xi32>
    %4 = stablehlo.get_dimension_size %0, dim = 1 : (tensor<?x?xf32>) -> tensor<i32>
    %5 = stablehlo.reshape %4 : (tensor<i32>) -> tensor<1xi32>
    %6 = stablehlo.concatenate %3, %5, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %7 = builtin.unrealized_conversion_cast %6 : tensor<2xi32> to tensor<2xindex>
    %8 = stablehlo.get_dimension_size %1, dim = 0 : (tensor<?x?xf32>) -> tensor<i32>
    %9 = stablehlo.reshape %8 : (tensor<i32>) -> tensor<1xi32>
    %10 = stablehlo.get_dimension_size %1, dim = 1 : (tensor<?x?xf32>) -> tensor<i32>
    %11 = stablehlo.reshape %10 : (tensor<i32>) -> tensor<1xi32>
    %12 = stablehlo.concatenate %9, %11, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %13 = builtin.unrealized_conversion_cast %12 : tensor<2xi32> to tensor<2xindex>
    %14 = builtin.unrealized_conversion_cast %7 : tensor<2xindex> to tensor<2xi32>
    %15 = builtin.unrealized_conversion_cast %13 : tensor<2xindex> to tensor<2xi32>
    %16 = stablehlo.maximum %14, %15 : tensor<2xi32>
    %17 = builtin.unrealized_conversion_cast %16 : tensor<2xi32> to tensor<2xindex>
    %18 = builtin.unrealized_conversion_cast %17 : tensor<2xindex> to tensor<2xi32>
    %19 = stablehlo.dynamic_broadcast_in_dim %0, %18, dims = [0, 1] : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
    %20 = builtin.unrealized_conversion_cast %17 : tensor<2xindex> to tensor<2xi32>
    %21 = stablehlo.dynamic_broadcast_in_dim %1, %20, dims = [0, 1] : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
    %22 = stablehlo.divide %19, %21 : tensor<?x?xf32>
    return %22 : tensor<?x?xf32>
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants