Closed
Description
Description
dqt_dtype = jnp.bfloat16
key = jax.random.key(4352)
# Random FP8 inputs.
a = jax.random.normal(key, (32, 64), jnp.float8_e4m3fn)
b = jax.random.normal(key, (64, 128), jnp.float8_e4m3fn)
scale_aval = jnp.array(2, dtype=jnp.float32)
@functools.partial(
pl.pallas_call,
#out_shape=jax.ShapeDtypeStruct((), jnp.float8_e4m3fn),
out_shape=jax.ShapeDtypeStruct((32, 64), dqt_dtype),
debug=True,
)
def dot_kernel(x_ref, y_ref, scale_aval, o_ref):
o_ref[()] = x_ref[()].astype(dqt_dtype) *scale_aval[()].astype(dqt_dtype)
result = dot_kernel(a, b, scale_aval)
print(result)
Resulting jaxpr
{ lambda ; a:MemRef<None>{float8_e4m3fn[32,64]} b:MemRef<None>{float8_e4m3fn[64,128]}
c:MemRef<None>{float32[]} d:MemRef<None>{bfloat16[32,64]}. let
e:f8_e4m3fn[32,64] <- a[:,:]
f:bf16[32,64] = convert_element_type[new_dtype=bfloat16 weak_type=False] e
g:f32[] <- c[]
h:bf16[] = convert_element_type[new_dtype=bfloat16 weak_type=False] g
i:bf16[32,64] = mul f h
d[:,:] <- i
in () }
The triton IR:
module @dot_kernel {
tt.func public @dot_kernel(%arg0: !tt.ptr<f8E4M3FN> {tt.divisibility = 32 : i32} [unknown], %arg1: !tt.ptr<f8E4M3FN> {tt.divisibility = 32 : i32} [unknown], %arg2: !tt.ptr<f32> {tt.divisibility = 32 : i32} [unknown], %arg3: !tt.ptr<bf16> {tt.divisibility = 32 : i32} [unknown]) {
%c0_i32 = arith.constant 0 : i32 [unknown]
%c0_i32_0 = arith.constant 0 : i32 [unknown]
%c32_i32 = arith.constant 32 : i32 [unknown]
%0 = arith.muli %c0_i32, %c32_i32 : i32 [unknown]
%c64_i32 = arith.constant 64 : i32 [unknown]
%1 = arith.muli %c0_i32_0, %c64_i32 : i32 [unknown]
%c0_i32_1 = arith.constant 0 : i32 [unknown]
%c0_i32_2 = arith.constant 0 : i32 [unknown]
%c64_i32_3 = arith.constant 64 : i32 [unknown]
%2 = arith.muli %c0_i32_1, %c64_i32_3 : i32 [unknown]
%c128_i32 = arith.constant 128 : i32 [unknown]
%3 = arith.muli %c0_i32_2, %c128_i32 : i32 [unknown]
%c0_i32_4 = arith.constant 0 : i32 [unknown]
%c0_i32_5 = arith.constant 0 : i32 [unknown]
%c32_i32_6 = arith.constant 32 : i32 [unknown]
%4 = arith.muli %c0_i32_4, %c32_i32_6 : i32 [unknown]
%c64_i32_7 = arith.constant 64 : i32 [unknown]
%5 = arith.muli %c0_i32_5, %c64_i32_7 : i32 [unknown]
%c0_i32_8 = arith.constant 0 : i32 "/get"(#loc39)
%6 = tt.splat %c0_i32_8 : i32 -> tensor<32x64xi32> "/get"(#loc39)
%7 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> "/get"(#loc39)
%8 = tt.expand_dims %7 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> "/get"(#loc39)
%9 = tt.broadcast %8 : tensor<32x1xi32> -> tensor<32x64xi32> "/get"(#loc39)
%10 = tt.splat %0 : i32 -> tensor<32x64xi32> "/get"(#loc39)
%11 = arith.addi %9, %10 : tensor<32x64xi32> "/get"(#loc39)
%c64_i32_9 = arith.constant 64 : i32 "/get"(#loc39)
%12 = tt.splat %c64_i32_9 : i32 -> tensor<32x64xi32> "/get"(#loc39)
%13 = arith.muli %11, %12 : tensor<32x64xi32> "/get"(#loc39)
%14 = arith.addi %6, %13 : tensor<32x64xi32> "/get"(#loc39)
%15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> "/get"(#loc39)
%16 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> "/get"(#loc39)
%17 = tt.broadcast %16 : tensor<1x64xi32> -> tensor<32x64xi32> "/get"(#loc39)
%18 = tt.splat %1 : i32 -> tensor<32x64xi32> "/get"(#loc39)
%19 = arith.addi %17, %18 : tensor<32x64xi32> "/get"(#loc39)
%c1_i32 = arith.constant 1 : i32 "/get"(#loc39)
%20 = tt.splat %c1_i32 : i32 -> tensor<32x64xi32> "/get"(#loc39)
%21 = arith.muli %19, %20 : tensor<32x64xi32> "/get"(#loc39)
%22 = arith.addi %14, %21 : tensor<32x64xi32> "/get"(#loc39)
%23 = tt.splat %arg0 : !tt.ptr<f8E4M3FN> -> tensor<32x64x!tt.ptr<f8E4M3FN>> "/get"(#loc39)
%24 = tt.addptr %23, %22 : tensor<32x64x!tt.ptr<f8E4M3FN>>, tensor<32x64xi32> "/get"(#loc39)
%25 = tt.load %24 : tensor<32x64x!tt.ptr<f8E4M3FN>> "/get"(#loc39)
%26 = tt.fp_to_fp %25, rounding = rtne : tensor<32x64xf8E4M3FN> -> tensor<32x64xbf16> "/convert_element_type"(#loc40)
%c0_i32_10 = arith.constant 0 : i32 "/get"(#loc41)
%27 = tt.addptr %arg2, %c0_i32_10 : !tt.ptr<f32>, i32 "/get"(#loc41)
%28 = tt.load %27 : !tt.ptr<f32> "/get"(#loc41)
%29 = arith.truncf %28 : f32 to bf16 "/convert_element_type"(#loc42)
%30 = tt.splat %29 : bf16 -> tensor<32x64xbf16> "/mul"(#loc43)
%31 = arith.mulf %26, %30 : tensor<32x64xbf16> "/mul"(#loc43)
%c0_i32_11 = arith.constant 0 : i32 "/swap"(#loc44)
%32 = tt.splat %c0_i32_11 : i32 -> tensor<32x64xi32> "/swap"(#loc44)
%33 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> "/swap"(#loc44)
%34 = tt.expand_dims %33 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> "/swap"(#loc44)
%35 = tt.broadcast %34 : tensor<32x1xi32> -> tensor<32x64xi32> "/swap"(#loc44)
%36 = tt.splat %4 : i32 -> tensor<32x64xi32> "/swap"(#loc44)
%37 = arith.addi %35, %36 : tensor<32x64xi32> "/swap"(#loc44)
%c64_i32_12 = arith.constant 64 : i32 "/swap"(#loc44)
%38 = tt.splat %c64_i32_12 : i32 -> tensor<32x64xi32> "/swap"(#loc44)
%39 = arith.muli %37, %38 : tensor<32x64xi32> "/swap"(#loc44)
%40 = arith.addi %32, %39 : tensor<32x64xi32> "/swap"(#loc44)
%41 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> "/swap"(#loc44)
%42 = tt.expand_dims %41 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> "/swap"(#loc44)
%43 = tt.broadcast %42 : tensor<1x64xi32> -> tensor<32x64xi32> "/swap"(#loc44)
%44 = tt.splat %5 : i32 -> tensor<32x64xi32> "/swap"(#loc44)
%45 = arith.addi %43, %44 : tensor<32x64xi32> "/swap"(#loc44)
%c1_i32_13 = arith.constant 1 : i32 "/swap"(#loc44)
%46 = tt.splat %c1_i32_13 : i32 -> tensor<32x64xi32> "/swap"(#loc44)
%47 = arith.muli %45, %46 : tensor<32x64xi32> "/swap"(#loc44)
%48 = arith.addi %40, %47 : tensor<32x64xi32> "/swap"(#loc44)
%49 = tt.splat %arg3 : !tt.ptr<bf16> -> tensor<32x64x!tt.ptr<bf16>> "/swap"(#loc44)
%50 = tt.addptr %49, %48 : tensor<32x64x!tt.ptr<bf16>>, tensor<32x64xi32> "/swap"(#loc44)
%51 = tt.load %50 : tensor<32x64x!tt.ptr<bf16>> "/swap"(#loc44)
tt.store %50, %31 : tensor<32x64x!tt.ptr<bf16>> "/swap"(#loc44)
tt.return [unknown]
} [unknown]
} [unknown]
And stack trace
Unsupported conversion from f8E4M3FN to bf16 with rounding mode rtne
LLVM ERROR: Unsupported rounding mode for conversion.
PC: @ 0x7fadfef0b981 (unknown) gsignal
@ 0x556dfa268b7b 1904 FailureSignalHandler()
@ 0x7fadff0a3e80 (unknown) (unknown)
@ 0x556df78d988a 208 llvm::report_fatal_error()
@ 0x556df78d9635 64 llvm::report_fatal_error()
@ 0x556de0d88510 1840 mlir::triton::gpu::ElementwiseOpConversionBase<>::matchAndRewrite()
@ 0x556de0d88a04 256 mlir::ConvertOpToLLVMPattern<>::matchAndRewrite()
@ 0x556de0d8635d 192 mlir::ConvertOpToLLVMPattern<>::matchAndRewrite()
@ 0x556df678e057 240 mlir::ConversionPattern::matchAndRewrite()
@ 0x556df67e273a 64 llvm::function_ref<>::callback_fn<>()
@ 0x556df67dfb09 736 mlir::PatternApplicator::matchAndRewrite()
@ 0x556df678ee1e 320 (anonymous namespace)::OperationLegalizer::legalize()
@ 0x556df678e129 304 mlir::OperationConverter::convert()
@ 0x556df678f228 1104 mlir::OperationConverter::convertOperations()
@ 0x556df6793b54 272 mlir::applyPartialConversion()
@ 0x556de0dcc020 3344 (anonymous namespace)::ConvertTritonGPUToLLVM::runOnOperation()
@ 0x556df69b881c 48 llvm::function_ref<>::callback_fn<>()
@ 0x556df69b24fb 384 mlir::detail::OpToOpPassAdaptor::run()
@ 0x556df69b4999 352 mlir::PassManager::run()
### System info (python version, jaxlib version, accelerator, etc.)
jax version latest.