Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request i1 type support in Interpreter #2487

Open
qingyunqu opened this issue Aug 10, 2024 · 3 comments
Open

Request i1 type support in Interpreter #2487

qingyunqu opened this issue Aug 10, 2024 · 3 comments
Assignees

Comments

@qingyunqu
Copy link
Contributor

Request description

It failed when I run some case like:

module {
  func.func @forward(%arg0: tensor<4xf32>) -> tensor<4xi1> {
    %0 = stablehlo.convert %arg0 : (tensor<4xf32>) -> tensor<4xi1>
    return %0 : tensor<4xi1>
  }
}
@sdasgup3
Copy link
Member

sdasgup3 commented Aug 12, 2024

Can you please elaborate on the error message?

Actually, the i1 type is supported in the interpreter. For example,

$ cat test.mlir
module {
  func.func @forward() -> tensor<4xi1> {
     %cst = stablehlo.constant dense<3.0> : tensor<4xf32>
    %0 = stablehlo.convert %cst : (tensor<4xf32>) -> tensor<4xi1>
    return %0 : tensor<4xi1>
  }
}
$ stablehlo-translate --interpret test.mlir
tensor<4xi1> {
  [true, true, true, true]
}

Note: I have modified the above program to specify some input to convert op.

If I run the above mentioned program as is we will get something like (from cs)

error: incorrect number of arguments specified, provided 0 inputs but function expected 1

which means that we have not provided constant inputs equal to the number of program arguments.
Please refer to

auto results = evalModule(*module, {inputValue1, inputValue2}, config);
and
actual = np.array(stablehlo.eval_module(m, args)[0])
on how to use the interpreter programatcally using c+ and python APIs resp.

Please let me know if this helps.

@sdasgup3 sdasgup3 self-assigned this Aug 12, 2024
@qingyunqu
Copy link
Contributor Author

Can you please elaborate on the error message?

Actually, the i1 type is supported in the interpreter. For example,

$ cat test.mlir
module {
  func.func @forward() -> tensor<4xi1> {
     %cst = stablehlo.constant dense<3.0> : tensor<4xf32>
    %0 = stablehlo.convert %cst : (tensor<4xf32>) -> tensor<4xi1>
    return %0 : tensor<4xi1>
  }
}
$ stablehlo-translate --interpret test.mlir
tensor<4xi1> {
  [true, true, true, true]
}

Note: I have modified the above program to specify some input to convert op.

If I run the above mentioned program as is we will get something like (from cs)

error: incorrect number of arguments specified, provided 0 inputs but function expected 1

which means that we have not provided constant inputs equal to the number of program arguments. Please refer to

auto results = evalModule(*module, {inputValue1, inputValue2}, config);

and

actual = np.array(stablehlo.eval_module(m, args)[0])

on how to use the interpreter programatcally using c+ and python APIs resp.
Please let me know if this helps.

Hi, I have got the failed example:

module attributes {torch.debug_module_name = "ElementwiseAtenLogicalAndOpPromoteBroadcastModule"} {
  func.func @forward(%arg0: tensor<?xf32>, %arg1: tensor<?x?xi64>) -> tensor<?x?xi1> {
    %c = stablehlo.constant dense<1> : tensor<1xi32>
    %0 = stablehlo.convert %arg0 : (tensor<?xf32>) -> tensor<?xi1>
    %1 = stablehlo.convert %arg1 : (tensor<?x?xi64>) -> tensor<?x?xi1>
    %2 = stablehlo.get_dimension_size %0, dim = 0 : (tensor<?xi1>) -> tensor<i32>
    %3 = stablehlo.reshape %2 : (tensor<i32>) -> tensor<1xi32>
    %4 = stablehlo.concatenate %3, dim = 0 : (tensor<1xi32>) -> tensor<1xi32>
    %5 = stablehlo.get_dimension_size %1, dim = 0 : (tensor<?x?xi1>) -> tensor<i32>
    %6 = stablehlo.reshape %5 : (tensor<i32>) -> tensor<1xi32>
    %7 = stablehlo.get_dimension_size %1, dim = 1 : (tensor<?x?xi1>) -> tensor<i32>
    %8 = stablehlo.reshape %7 : (tensor<i32>) -> tensor<1xi32>
    %9 = stablehlo.concatenate %6, %8, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %10 = stablehlo.concatenate %c, %4, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %11 = stablehlo.maximum %10, %9 : tensor<2xi32>
    %12 = stablehlo.dynamic_broadcast_in_dim %0, %11, dims = [1] : (tensor<?xi1>, tensor<2xi32>) -> tensor<?x?xi1>
    %13 = stablehlo.dynamic_broadcast_in_dim %1, %11, dims = [0, 1] : (tensor<?x?xi1>, tensor<2xi32>) -> tensor<?x?xi1>
    %14 = stablehlo.and %12, %13 : tensor<?x?xi1>
    return %14 : tensor<?x?xi1>
  }
}

The error message is LLVM ERROR: Element is not an integer.

@ghpvnist
Copy link
Member

Hi, I'm able to interpret this fine on my machine (Note the 1.0 for floats. Maybe this is the issue?)

// RUN: stablehlo-translate %s --interpret --args="[dense<1.0> : tensor<4xf32>, dense<1> : tensor<4x4xi64>]"

module attributes {torch.debug_module_name = "ElementwiseAtenLogicalAndOpPromoteBroadcastModule"} {
  func.func @forward(%arg0: tensor<?xf32>, %arg1: tensor<?x?xi64>) -> tensor<?x?xi1> {
    %c = stablehlo.constant dense<1> : tensor<1xi32>
    %0 = stablehlo.convert %arg0 : (tensor<?xf32>) -> tensor<?xi1>
    %1 = stablehlo.convert %arg1 : (tensor<?x?xi64>) -> tensor<?x?xi1>
    %2 = stablehlo.get_dimension_size %0, dim = 0 : (tensor<?xi1>) -> tensor<i32>
    %3 = stablehlo.reshape %2 : (tensor<i32>) -> tensor<1xi32>
    %4 = stablehlo.concatenate %3, dim = 0 : (tensor<1xi32>) -> tensor<1xi32>
    %5 = stablehlo.get_dimension_size %1, dim = 0 : (tensor<?x?xi1>) -> tensor<i32>
    %6 = stablehlo.reshape %5 : (tensor<i32>) -> tensor<1xi32>
    %7 = stablehlo.get_dimension_size %1, dim = 1 : (tensor<?x?xi1>) -> tensor<i32>
    %8 = stablehlo.reshape %7 : (tensor<i32>) -> tensor<1xi32>
    %9 = stablehlo.concatenate %6, %8, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %10 = stablehlo.concatenate %c, %4, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %11 = stablehlo.maximum %10, %9 : tensor<2xi32>
    %12 = stablehlo.dynamic_broadcast_in_dim %0, %11, dims = [1] : (tensor<?xi1>, tensor<2xi32>) -> tensor<?x?xi1>
    %13 = stablehlo.dynamic_broadcast_in_dim %1, %11, dims = [0, 1] : (tensor<?x?xi1>, tensor<2xi32>) -> tensor<?x?xi1>
    %14 = stablehlo.and %12, %13 : tensor<?x?xi1>
    return %14 : tensor<?x?xi1>
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants