Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gather tests #40

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

Gather tests #40

wants to merge 1 commit into from

Conversation

ddilbazTT
Copy link
Contributor

Gather test is failing in SHLO, which seems to be related to constant op. I wanted to share the progress and if constant op seems to be the reason of failure, triage that.

Error I receive (with gdb):

module @jit_take_fn attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<32000x1024xf32>, %arg1: tensor<1x32xi16>) -> (tensor<1x32x1024xf32> {jax.result_info = ""}) {
    %0 = call @_take(%arg0, %arg1) : (tensor<32000x1024xf32>, tensor<1x32xi16>) -> tensor<1x32x1024xf32>
    return %0 : tensor<1x32x1024xf32>
  }
  func.func private @_take(%arg0: tensor<32000x1024xf32>, %arg1: tensor<1x32xi16>) -> tensor<1x32x1024xf32> {
    %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32>
    %c = stablehlo.constant dense<true> : tensor<i1>
    %c_0 = stablehlo.constant dense<1> : tensor<i32>
    %c_1 = stablehlo.constant dense<2> : tensor<i32>
    %c_2 = stablehlo.constant dense<0> : tensor<i32>
    %c_3 = stablehlo.constant dense<1024> : tensor<i32>
    %c_4 = stablehlo.constant dense<32000> : tensor<i32>
    %c_5 = stablehlo.constant dense<0> : tensor<1xi32>
    %c_6 = stablehlo.constant dense<32000> : tensor<i16>
    %c_7 = stablehlo.constant dense<0> : tensor<i16>
    %0 = stablehlo.broadcast_in_dim %c_7, dims = [] : (tensor<i16>) -> tensor<1x32xi16>
    %1 = stablehlo.compare  LT, %arg1, %0,  SIGNED : (tensor<1x32xi16>, tensor<1x32xi16>) -> tensor<1x32xi1>
    %2 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor<i16>) -> tensor<1x32xi16>
    %3 = stablehlo.add %arg1, %2 : tensor<1x32xi16>
    %4 = call @_where(%1, %3, %arg1) : (tensor<1x32xi1>, tensor<1x32xi16>, tensor<1x32xi16>) -> tensor<1x32xi16>
    %5 = stablehlo.broadcast_in_dim %4, dims = [0, 1] : (tensor<1x32xi16>) -> tensor<1x32x1xi16>
    %6 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor<i32>) -> tensor<1xi32>
    %7 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor<i32>) -> tensor<1xi32>
    %8 = stablehlo.concatenate %6, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %9 = stablehlo.convert %5 : (tensor<1x32x1xi16>) -> tensor<1x32x1xi32>
    %10 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i32>) -> tensor<1xi32>
    %11 = stablehlo.compare  LT, %c_5, %10,  SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
    %12 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<i32>) -> tensor<1xi32>
    %13 = stablehlo.add %c_5, %12 : tensor<1xi32>
    %14 = stablehlo.select %11, %13, %c_5 : tensor<1xi1>, tensor<1xi32>
    %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32>
    %16 = "stablehlo.gather"(%8, %15) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, slice_sizes = array<i64: 1>}> : (tensor<2xi32>, tensor<1x1xi32>) -> tensor<1xi32>
    %17 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<1xi32>
    %18 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor<i32>) -> tensor<1xi32>
    %19 = stablehlo.concatenate %17, %18, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>                                                                                                                                                                    
    %20 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i32>) -> tensor<1xi32>
    %21 = stablehlo.compare  LT, %c_5, %20,  SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
    %22 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<i32>) -> tensor<1xi32>
    %23 = stablehlo.add %c_5, %22 : tensor<1xi32>
    %24 = stablehlo.select %21, %23, %c_5 : tensor<1xi1>, tensor<1xi32>
    %25 = stablehlo.broadcast_in_dim %24, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32>
    %26 = "stablehlo.gather"(%19, %25) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, slice_sizes = array<i64: 1>}> : (tensor<2xi32>, tensor<1x1xi32>) -> tensor<1xi32>
    %27 = stablehlo.subtract %16, %26 : tensor<1xi32>
    %28 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i32>) -> tensor<1x32x1xi32>
    %29 = stablehlo.compare  GE, %9, %28,  SIGNED : (tensor<1x32x1xi32>, tensor<1x32x1xi32>) -> tensor<1x32x1xi1>
    %30 = stablehlo.broadcast_in_dim %27, dims = [2] : (tensor<1xi32>) -> tensor<1x1x1xi32>
    %31 = stablehlo.broadcast_in_dim %30, dims = [0, 1, 2] : (tensor<1x1x1xi32>) -> tensor<1x32x1xi32>
    %32 = stablehlo.compare  LE, %9, %31,  SIGNED : (tensor<1x32x1xi32>, tensor<1x32x1xi32>) -> tensor<1x32x1xi1>
    %33 = stablehlo.and %29, %32 : tensor<1x32x1xi1>
    %34 = stablehlo.reduce(%33 init: %c) applies stablehlo.and across dimensions = [2] : (tensor<1x32x1xi1>, tensor<i1>) -> tensor<1x32xi1>
    %35 = "stablehlo.gather"(%arg0, %9) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, slice_sizes = array<i64: 1, 1024>}> : (tensor<32000x1024xf32>, tensor<1x32x1xi32>) -> tensor<1x32x1
024xf32>
    %36 = stablehlo.broadcast_in_dim %34, dims = [0, 1] : (tensor<1x32xi1>) -> tensor<1x32x1024xi1>
    %37 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<1x32x1024xf32>
    %38 = stablehlo.select %36, %35, %37 : tensor<1x32x1024xi1>, tensor<1x32x1024xf32>
    return %38 : tensor<1x32x1024xf32>
  }
  func.func private @_where(%arg0: tensor<1x32xi1>, %arg1: tensor<1x32xi16>, %arg2: tensor<1x32xi16>) -> tensor<1x32xi16> {
    %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor<1x32xi1>, tensor<1x32xi16>
    return %0 : tensor<1x32xi16>
  }
}
ElementsAttr does not provide iteration facilities for type `int`, see attribute: dense<true> : tensor<i1>
invalid `T` for ElementsAttr::getValues
UNREACHABLE executed at /opt/ttmlir-toolchain/include/mlir/IR/BuiltinAttributeInterfaces.h:306!

Thread 1 "python" received signal SIGABRT, Aborted.
0x00007ffff7ced9fc in pthread_kill () from /lib/x86_64-linux-gnu/libc.so.6
(gdb) bt
#0  0x00007ffff7ced9fc in pthread_kill () from /lib/x86_64-linux-gnu/libc.so.6
#1  0x00007ffff7c99476 in raise () from /lib/x86_64-linux-gnu/libc.so.6
#2  0x00007ffff7c7f7f3 in abort () from /lib/x86_64-linux-gnu/libc.so.6
#3  0x00007fff405bef99 in llvm::llvm_unreachable_internal(char const*, char const*, unsigned int) () from /opt/ttmlir-toolchain/lib/libLLVM.so.20.0git
#4  0x00007fff481d43d3 in std::enable_if<std::is_same<mlir::Attribute, int>::value||(!std::is_base_of<mlir::Attribute, int>::value), mlir::detail::ElementsAttrIterator<int> >::type mlir::ElementsAttr::value_begin<int>() const ()
   from /localdev/ddilbaz/tt-xla/tests/TTIR/../../build/src/tt/pjrt_plugin_tt.so
#5  0x00007fff481d3997 in (anonymous namespace)::StableHLOToTTIRConstantOpConversionPattern::matchAndRewrite(mlir::stablehlo::ConstantOp, mlir::stablehlo::ConstantOpAdaptor, mlir::ConversionPatternRewriter&) const ()
   from /localdev/ddilbaz/tt-xla/tests/TTIR/../../build/src/tt/pjrt_plugin_tt.so
#6  0x00007fff481d35fd in mlir::OpConversionPattern<mlir::stablehlo::ConstantOp>::matchAndRewrite(mlir::Operation*, llvm::ArrayRef<mlir::Value>, mlir::ConversionPatternRewriter&) const () from /localdev/ddilbaz/tt-xla/tests/TTIR/../../build/src/tt/pjrt_plugin_tt.so

tt-xla changes done:

  • update tt-mlir version to match ddilbaz/gather branch
  • add PJRT_Buffer_Type_S16 support
  • add createTTIRGatherPatternMatch pass
  • add test_gather_op.py which will eventually migrate to test_basic_ops.py after debugging

Locally, adjusted llvm_project and stablehlo versions (third_party/tt-mlir/src/tt-mlir/env/CMakeLists.txt):

set(LLVM_PROJECT_VERSION "e813750354bbc08551cf23ff559a54b4a9ea1f29")
set(STABLEHLO_VERSION "d40285ef3db0687e3f1e2bb0d716d748485a9739")

Will update these in tt-mlir if ok.

Work in progress. Failing in SHLO
@@ -96,11 +96,15 @@ void ModuleBuilder::BuildModule(std::string_view code, std::string_view format,
{
throw std::runtime_error("Failed to run MLIR compiler pass pipeline.");
}
DLOG_F(LOG_DEBUG, "TTIR Module");
shlo_pm.addPass(mlir::tt::ttir::createTTIRGatherPatternMatch());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding pass manually, can you just invoke the pipeline, like we do for TTIRToTTNN below?

@mmanzoorTT
Copy link
Contributor

mmanzoorTT commented Oct 24, 2024

I can see three issues with this graph.

  1. boolean scalar constant; which are not handled with current constant op conversion. I am working on a fix. I already have handled boolean constant tensors.
  2. function call; @LPanosTT has implemented an inliner pass. It may handle this scenario (I haven't tested it).
  3. select op; @uazizTT has PR for adding support for select op. You can cherry-pick his commit and test your gather op.

Found another issue.
4. reduce op for logical_and is not yet supported. I am looking into it.

@LPanosTT
Copy link
Contributor

LPanosTT commented Oct 24, 2024

I can see three issues with this graph.

  1. boolean scalar constant; which are not handled with current constant op conversion. I am working on a fix. I already have handled boolean constant tensors.
  2. function call; @LPanosTT has implemented an inliner pass. It may handle this scenario (I haven't tested it).
  3. select op; @uazizTT has PR for adding support for select op. You can cherry-pick his commit and test your gather op.

The inliner pass runs as one of the very first passes in the TTIR to ttnn pipeline so you wouldn’t see its effect unless you followed the stablehlo to TTIR with the TTIR to TTNN pipeline. Its merged now so you can try

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants