Skip to content

Assertion Error on CustomCall with input/output aliasing #21312

Open
@wsmoses

Description

module {
  func.func private @main_hlo_call_205bd746421f59dc(%arg0: tensor<17xf64>, %arg1: tensor<16xf64>, %arg2: tensor<1x34x34xf64>) -> tensor<f64> {
    %c = stablehlo.constant dense<0> : tensor<i64>
    %c_0 = stablehlo.constant dense<20> : tensor<i64>
    %c_1 = stablehlo.constant dense<1> : tensor<i64>
    %0:2 = stablehlo.custom_call @enzymexla_compile_gpu(%arg2, %arg0) {api_version = 4 : i32, backend_config = {attr = "P0\C1P\B3q\00\00\000\C1P\B3q\00\00"}, output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [1], operand_index = 1, operand_tuple_indices = []>]} : (tensor<1x34x34xf64>, tensor<17xf64>) -> (tensor<1x34x34xf64>, tensor<17xf64>)
    %1 = stablehlo.slice %0#1 [0:1] : (tensor<17xf64>) -> tensor<1xf64>
    %2 = stablehlo.reshape %1 : (tensor<1xf64>) -> tensor<f64>
    return %2 : tensor<f64>
  }
  func.func @main(%arg0: tensor<17xf64>, %arg1: tensor<16xf64>, %arg2: tensor<34x34x1xf64>) -> tensor<f64> {
    %0 = stablehlo.transpose %arg2, dims = [2, 1, 0] : (tensor<34x34x1xf64>) -> tensor<1x34x34xf64>
    %1 = call @main_hlo_call_205bd746421f59dc(%arg0, %arg1, %0) : (tensor<17xf64>, tensor<16xf64>, tensor<1x34x34xf64>) -> tensor<f64>
    return %1 : tensor<f64>
  }
}

Compiling the preceeding mlir module creates an assertion error with the following stacktrace.

        xla::status_macros::MakeErrorStream::Impl::GetStatus()
        xla::ShapeVerifier::HandleCustomCall(xla::HloInstruction*)
        absl::lts_20230802::Status xla::HloInstruction::Visit<xla::HloInstruction*>(xla::DfsHloVisitorBase<xla::HloInstruction*>*)

        absl::lts_20230802::Status xla::HloInstruction::Accept<xla::HloInstruction*>(xla::DfsHloVisitorBase<xla::HloInstruction*>*, bool, bool, bool)
        absl::lts_20230802::Status xla::HloComputation::Accept<xla::HloInstruction*>(xla::DfsHloVisitorBase<xla::HloInstruction*>*) const
        xla::HloVerifier::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char> >, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char> > > > const&)
        absl::lts_20230802::StatusOr<bool> xla::HloPassPipeline::RunPassesInternal<xla::HloModule>(xla::HloModule*, xla::DebugOptions const&, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char> >, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char> > > > const&)
        xla::HloPassPipeline::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char> >, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char> > > > const&)
        xla::HloPassInterface::Run(xla::HloModule*)
        xla::gpu::GpuCompiler::OptimizeHloPostLayoutAssignment(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&, tsl::thread::ThreadPool*)
        xla::gpu::NVPTXCompiler::OptimizeHloPostLayoutAssignment(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&, tsl::thread::ThreadPool*)
        xla::gpu::GpuCompiler::OptimizeHloModule(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&)
        xla::gpu::GpuCompiler::RunHloPasses(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule> >, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&)
        xla::Service::BuildExecutable(xla::HloModuleProto const&, std::unique_ptr<xla::HloModuleConfig, std::default_delete<xla::HloModuleConfig> >, xla::Backend*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, bool)
        xla::LocalService::CompileExecutables(xla::XlaComputation const&, absl::lts_20230802::Span<xla::Shape const* const>, xla::ExecutableBuildOptions const&)
        xla::LocalClient::Compile(xla::XlaComputation const&, absl::lts_20230802::Span<xla::Shape const* const>, xla::ExecutableBuildOptions const&)
        xla::PjRtStreamExecutorClient::CompileInternal(xla::XlaComputation const&, std::vector<xla::Shape const*, std::allocator<xla::Shape const*> > const&, std::function<absl::lts_20230802::StatusOr<std::pair<std::vector<xla::Shape, std::allocator<xla::Shape> >, xla::Shape> > (xla::HloModule const&)>, xla::CompileOptions)
        xla::PjRtStreamExecutorClient::Compile(mlir::ModuleOp, xla::CompileOptions)

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions