Skip to content

Unsupported conversion from f8E4M3FN to bf16 with rounding mode rtne #28416

Closed
@appujee

Description

@appujee

Description

    dqt_dtype = jnp.bfloat16
    key = jax.random.key(4352)

    # Random FP8 inputs.
    a = jax.random.normal(key, (32, 64), jnp.float8_e4m3fn)
    b = jax.random.normal(key, (64, 128), jnp.float8_e4m3fn)

    scale_aval = jnp.array(2, dtype=jnp.float32)

    @functools.partial(
        pl.pallas_call,
        #out_shape=jax.ShapeDtypeStruct((), jnp.float8_e4m3fn),
        out_shape=jax.ShapeDtypeStruct((32, 64), dqt_dtype),
        debug=True,
    )
    def dot_kernel(x_ref, y_ref, scale_aval, o_ref):
      o_ref[()] = x_ref[()].astype(dqt_dtype) *scale_aval[()].astype(dqt_dtype)

    result = dot_kernel(a, b, scale_aval)
    print(result)

Resulting jaxpr

{ lambda ; a:MemRef<None>{float8_e4m3fn[32,64]} b:MemRef<None>{float8_e4m3fn[64,128]}
    c:MemRef<None>{float32[]} d:MemRef<None>{bfloat16[32,64]}. let
    e:f8_e4m3fn[32,64] <- a[:,:]
    f:bf16[32,64] = convert_element_type[new_dtype=bfloat16 weak_type=False] e
    g:f32[] <- c[]
    h:bf16[] = convert_element_type[new_dtype=bfloat16 weak_type=False] g
    i:bf16[32,64] = mul f h
    d[:,:] <- i
  in () }

The triton IR:

module @dot_kernel {
  tt.func public @dot_kernel(%arg0: !tt.ptr<f8E4M3FN> {tt.divisibility = 32 : i32} [unknown], %arg1: !tt.ptr<f8E4M3FN> {tt.divisibility = 32 : i32} [unknown], %arg2: !tt.ptr<f32> {tt.divisibility = 32 : i32} [unknown], %arg3: !tt.ptr<bf16> {tt.divisibility = 32 : i32} [unknown]) {
    %c0_i32 = arith.constant 0 : i32 [unknown]
    %c0_i32_0 = arith.constant 0 : i32 [unknown]
    %c32_i32 = arith.constant 32 : i32 [unknown]
    %0 = arith.muli %c0_i32, %c32_i32 : i32 [unknown]
    %c64_i32 = arith.constant 64 : i32 [unknown]
    %1 = arith.muli %c0_i32_0, %c64_i32 : i32 [unknown]
    %c0_i32_1 = arith.constant 0 : i32 [unknown]
    %c0_i32_2 = arith.constant 0 : i32 [unknown]
    %c64_i32_3 = arith.constant 64 : i32 [unknown]
    %2 = arith.muli %c0_i32_1, %c64_i32_3 : i32 [unknown]
    %c128_i32 = arith.constant 128 : i32 [unknown]
    %3 = arith.muli %c0_i32_2, %c128_i32 : i32 [unknown]
    %c0_i32_4 = arith.constant 0 : i32 [unknown]
    %c0_i32_5 = arith.constant 0 : i32 [unknown]
    %c32_i32_6 = arith.constant 32 : i32 [unknown]
    %4 = arith.muli %c0_i32_4, %c32_i32_6 : i32 [unknown]
    %c64_i32_7 = arith.constant 64 : i32 [unknown]
    %5 = arith.muli %c0_i32_5, %c64_i32_7 : i32 [unknown]
    %c0_i32_8 = arith.constant 0 : i32 "/get"(#loc39)
    %6 = tt.splat %c0_i32_8 : i32 -> tensor<32x64xi32> "/get"(#loc39)
    %7 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> "/get"(#loc39)
    %8 = tt.expand_dims %7 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> "/get"(#loc39)
    %9 = tt.broadcast %8 : tensor<32x1xi32> -> tensor<32x64xi32> "/get"(#loc39)
    %10 = tt.splat %0 : i32 -> tensor<32x64xi32> "/get"(#loc39)
    %11 = arith.addi %9, %10 : tensor<32x64xi32> "/get"(#loc39)
    %c64_i32_9 = arith.constant 64 : i32 "/get"(#loc39)
    %12 = tt.splat %c64_i32_9 : i32 -> tensor<32x64xi32> "/get"(#loc39)
    %13 = arith.muli %11, %12 : tensor<32x64xi32> "/get"(#loc39)
    %14 = arith.addi %6, %13 : tensor<32x64xi32> "/get"(#loc39)
    %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> "/get"(#loc39)
    %16 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> "/get"(#loc39)
    %17 = tt.broadcast %16 : tensor<1x64xi32> -> tensor<32x64xi32> "/get"(#loc39)
    %18 = tt.splat %1 : i32 -> tensor<32x64xi32> "/get"(#loc39)
    %19 = arith.addi %17, %18 : tensor<32x64xi32> "/get"(#loc39)
    %c1_i32 = arith.constant 1 : i32 "/get"(#loc39)
    %20 = tt.splat %c1_i32 : i32 -> tensor<32x64xi32> "/get"(#loc39)
    %21 = arith.muli %19, %20 : tensor<32x64xi32> "/get"(#loc39)
    %22 = arith.addi %14, %21 : tensor<32x64xi32> "/get"(#loc39)
    %23 = tt.splat %arg0 : !tt.ptr<f8E4M3FN> -> tensor<32x64x!tt.ptr<f8E4M3FN>> "/get"(#loc39)
    %24 = tt.addptr %23, %22 : tensor<32x64x!tt.ptr<f8E4M3FN>>, tensor<32x64xi32> "/get"(#loc39)
    %25 = tt.load %24 : tensor<32x64x!tt.ptr<f8E4M3FN>> "/get"(#loc39)
    %26 = tt.fp_to_fp %25, rounding = rtne : tensor<32x64xf8E4M3FN> -> tensor<32x64xbf16> "/convert_element_type"(#loc40)
    %c0_i32_10 = arith.constant 0 : i32 "/get"(#loc41)
    %27 = tt.addptr %arg2, %c0_i32_10 : !tt.ptr<f32>, i32 "/get"(#loc41)
    %28 = tt.load %27 : !tt.ptr<f32> "/get"(#loc41)
    %29 = arith.truncf %28 : f32 to bf16 "/convert_element_type"(#loc42)
    %30 = tt.splat %29 : bf16 -> tensor<32x64xbf16> "/mul"(#loc43)
    %31 = arith.mulf %26, %30 : tensor<32x64xbf16> "/mul"(#loc43)
    %c0_i32_11 = arith.constant 0 : i32 "/swap"(#loc44)
    %32 = tt.splat %c0_i32_11 : i32 -> tensor<32x64xi32> "/swap"(#loc44)
    %33 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> "/swap"(#loc44)
    %34 = tt.expand_dims %33 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> "/swap"(#loc44)
    %35 = tt.broadcast %34 : tensor<32x1xi32> -> tensor<32x64xi32> "/swap"(#loc44)
    %36 = tt.splat %4 : i32 -> tensor<32x64xi32> "/swap"(#loc44)
    %37 = arith.addi %35, %36 : tensor<32x64xi32> "/swap"(#loc44)
    %c64_i32_12 = arith.constant 64 : i32 "/swap"(#loc44)
    %38 = tt.splat %c64_i32_12 : i32 -> tensor<32x64xi32> "/swap"(#loc44)
    %39 = arith.muli %37, %38 : tensor<32x64xi32> "/swap"(#loc44)
    %40 = arith.addi %32, %39 : tensor<32x64xi32> "/swap"(#loc44)
    %41 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> "/swap"(#loc44)
    %42 = tt.expand_dims %41 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> "/swap"(#loc44)
    %43 = tt.broadcast %42 : tensor<1x64xi32> -> tensor<32x64xi32> "/swap"(#loc44)
    %44 = tt.splat %5 : i32 -> tensor<32x64xi32> "/swap"(#loc44)
    %45 = arith.addi %43, %44 : tensor<32x64xi32> "/swap"(#loc44)
    %c1_i32_13 = arith.constant 1 : i32 "/swap"(#loc44)
    %46 = tt.splat %c1_i32_13 : i32 -> tensor<32x64xi32> "/swap"(#loc44)
    %47 = arith.muli %45, %46 : tensor<32x64xi32> "/swap"(#loc44)
    %48 = arith.addi %40, %47 : tensor<32x64xi32> "/swap"(#loc44)
    %49 = tt.splat %arg3 : !tt.ptr<bf16> -> tensor<32x64x!tt.ptr<bf16>> "/swap"(#loc44)
    %50 = tt.addptr %49, %48 : tensor<32x64x!tt.ptr<bf16>>, tensor<32x64xi32> "/swap"(#loc44)
    %51 = tt.load %50 : tensor<32x64x!tt.ptr<bf16>> "/swap"(#loc44)
    tt.store %50, %31 : tensor<32x64x!tt.ptr<bf16>> "/swap"(#loc44)
    tt.return [unknown]
  } [unknown]
} [unknown]

And stack trace

Unsupported conversion from f8E4M3FN to bf16 with rounding mode rtne
LLVM ERROR: Unsupported rounding mode for conversion.
PC: @     0x7fadfef0b981  (unknown)  gsignal
    @     0x556dfa268b7b       1904  FailureSignalHandler()
    @     0x7fadff0a3e80  (unknown)  (unknown)
    @     0x556df78d988a        208  llvm::report_fatal_error()
    @     0x556df78d9635         64  llvm::report_fatal_error()
    @     0x556de0d88510       1840  mlir::triton::gpu::ElementwiseOpConversionBase<>::matchAndRewrite()
    @     0x556de0d88a04        256  mlir::ConvertOpToLLVMPattern<>::matchAndRewrite()
    @     0x556de0d8635d        192  mlir::ConvertOpToLLVMPattern<>::matchAndRewrite()
    @     0x556df678e057        240  mlir::ConversionPattern::matchAndRewrite()
    @     0x556df67e273a         64  llvm::function_ref<>::callback_fn<>()
    @     0x556df67dfb09        736  mlir::PatternApplicator::matchAndRewrite()
    @     0x556df678ee1e        320  (anonymous namespace)::OperationLegalizer::legalize()
    @     0x556df678e129        304  mlir::OperationConverter::convert()
    @     0x556df678f228       1104  mlir::OperationConverter::convertOperations()
    @     0x556df6793b54        272  mlir::applyPartialConversion()
    @     0x556de0dcc020       3344  (anonymous namespace)::ConvertTritonGPUToLLVM::runOnOperation()
    @     0x556df69b881c         48  llvm::function_ref<>::callback_fn<>()
    @     0x556df69b24fb        384  mlir::detail::OpToOpPassAdaptor::run()
    @     0x556df69b4999        352  mlir::PassManager::run()

### System info (python version, jaxlib version, accelerator, etc.)

jax version latest.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions