Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] F4_E3M0 Reference Enable #2463

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion src/gpu/intel/ocl/ocl_math_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ int rnd_down(int a, unsigned int b) {
#define MATH_UTILS_DECLARE_F4_E2M1 1
#endif

#if DT_F4_E3M0 || SRC_DT_F4_E3M0 || WEI_DT_F4_E3M0 || DST_DT_F4_E3M0 \
|| BIA_DT_F4_E3M0 || A_DT_F4_E3M0 || A_DT_F4_E3M0 || B_DT_F4_E3M0 \
|| C_DT_F4_E3M0 || DATA_DT_F4_E3M0 || POST_OP_USING_F4_E3M0
#define MATH_UTILS_DECLARE_F4_E3M0 1
#endif

#if DT_S4 || SRC_DT_S4 || WEI_DT_S4 || DST_DT_S4 || BIA_DT_S4 || A_DT_S4 \
|| B_DT_S4 || C_DT_S4 || DATA_DT_S4 || WEI_ZP_DT_S4 || SRC_ZP_DT_S4
#define MATH_UTILS_DECLARE_S4 1
Expand Down Expand Up @@ -690,9 +696,51 @@ float __attribute__((overloadable)) cvt_f4_e2m1_to_f32(uchar a) {
return as_float((sign << 28) | (exp << 23) | (mant << 22));
}

#endif
#if MATH_UTILS_DECLARE_F4_E3M0

// OCL translation of common fp4 methods.
uchar __attribute__((overloadable)) cvt_f32_to_f4_e3m0(float f) {
uint f_raw = as_uint(f);
uint sign = f_raw & 0x80000000;

// There is no NaN or infinity in e3m0, we just return maxval
uint naninf_mask = 0x7f800000;
if ((f_raw & naninf_mask) == naninf_mask) return 0x7;

// we convert with naive closest value computation out of 8
float e3m0_val_table[8] = {0.0f, .25f, .5f, 1.0f, 2.0f, 4.0f, 8.0f, 16.0f};

float abs_f = as_float(f_raw ^ sign);

int idx = 0;
float min_diff = fabs(e3m0_val_table[idx] - abs_f);
uchar raw_bits = idx;
for (++idx; idx < 8; ++idx) {
float diff = fabs(e3m0_val_table[idx] - abs_f);
if (diff < min_diff) {
min_diff = diff;
raw_bits = idx;
}
// Special case for midpoint, we round to even (so even index)
if ((diff == min_diff) && !(idx & 1)) raw_bits = idx;
}
// reapply sign
if (sign) raw_bits = raw_bits | 0x08;
return raw_bits;
}

float __attribute__((overloadable)) cvt_f4_e3m0_to_f32(uchar a) {
// List of e3m0 values. The index of each value maps to its encoding.
const float e3m0_table[16] = {0.0f, .25f, .5f, 1.0f, 2.0f, 4.0f, 8.0f,
16.0f, -0.0f, -.25f, -.5f, -1.0f, -2.0f, -4.0f, -8.0f, -16.0f};
return e3m0_table[a];
}

#endif

#if MATH_UTILS_DECLARE_S4 || MATH_UTILS_DECLARE_U4 || MATH_UTILS_DECLARE_F4_E2M1
#if MATH_UTILS_DECLARE_S4 || MATH_UTILS_DECLARE_U4 \
|| MATH_UTILS_DECLARE_F4_E2M1 || MATH_UTILS_DECLARE_F4_E3M0
#define GET_HALF_BYTE(x, y) get_half_byte(x, y)

uchar __attribute__((overloadable)) get_half_byte(__global uchar *x, off_t y) {
Expand Down
120 changes: 118 additions & 2 deletions src/gpu/intel/ocl/ocl_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,68 @@
#define FLT_ACC_DATA_T float
#define TO_FLT_ACC_DATA_T(v) convert_float(cvt_f8_e4m3_to_hf(v))

#elif DT_F4_E3M0 == 1
#define DATA_T uchar
#define DATA2_T uchar2
#define DATA4_T uchar4
#define DATA8_T uchar8
#define DATA16_T uchar16
#define DATA_MAX (uchar)0x07
#define DATA_MIN (uchar)0x08
#define DATA_ZERO (uchar)0x00
#define DATA_ONE (uchar)0x03
#define DEF_ACC_DATA_T float
#define DEF_ACC_DATA2_T float2
#define DEF_ACC_DATA4_T float4
#define DEF_ACC_DATA8_T float8
#define POST_OP_DATA_T float
#define TO_DATA_T(v) cvt_f32_to_f4_e3m0((v))
#define TO_DEF_ACC_DATA_T(v) (cvt_f4_e3m0_to_f32(v))
#define TO_DEF_ACC_DATA2_T(v) (cvt_f4_e3m0_to_f32(v))
#define TO_DEF_ACC_DATA4_T(v) (cvt_f4_e3m0_to_f32(v))
#define TO_DEF_ACC_DATA8_T(v) (cvt_f4_e3m0_to_f32(v))
#define DATA_TO_REF(v) (cvt_f4_e3m0_to_f32(v))
#define CONVERT_DATA_T(v) cvt_f32_to_f4_e3m0(v)
#define CONVERT_DATA2_T(v) cvt_f32_to_f4_e3m0(v)
#define CONVERT_DATA4_T(v) cvt_f32_to_f4_e3m0(v)
#define CONVERT_DATA8_T(v) cvt_f32_to_f4_e3m0(v)
#define CONVERT_FLOAT_T(v) cvt_f4_e3m0_to_f32(v)
#define CONVERT_FLOAT2_T(v) cvt_f4_e3m0_to_f32(v)
#define CONVERT_FLOAT4_T(v) cvt_f4_e3m0_to_f32(v)
#define CONVERT_FLOAT8_T(v) cvt_f4_e3m0_to_f32(v)

#define BLOCK_READ intel_sub_group_block_read_uc
#define BLOCK_WRITE intel_sub_group_block_write_uc
#define BLOCK_READ2 intel_sub_group_block_read_uc2
#define BLOCK_READ4 intel_sub_group_block_read_uc4
#define BLOCK_READ8 intel_sub_group_block_read_uc8
#define BLOCK_WRITE2 intel_sub_group_block_write_uc2
#define BLOCK_WRITE4 intel_sub_group_block_write_uc4
#define BLOCK_WRITE8 intel_sub_group_block_write_uc8
#define AS_DATA_T as_uchar
#define AS_DATA2_T as_uchar2
#define AS_DATA4_T as_uchar4
#define AS_DATA8_T as_uchar8
#define AS_DATA16_T as_uchar16

#define AS_UINT_T as_uchar
#define AS_UINT2_T as_uchar2
#define AS_UINT4_T as_uchar4
#define AS_UINT8_T as_uchar8
#define AS_INT8_T as_uint8

#define BLOCK_DATA_T uchar
#define BLOCK_DATA2_T uchar2
#define BLOCK_DATA4_T uchar4
#define BLOCK_DATA8_T uchar8
#define AS_BLOCK_DATA_T as_uchar
#define AS_BLOCK_DATA2_T as_uchar2
#define AS_BLOCK_DATA4_T as_uchar4
#define AS_BLOCK_DATA8_T as_uchar8

#define FLT_ACC_DATA_T float
#define TO_FLT_ACC_DATA_T(v) (cvt_f4_e3m0_to_f32(v))

#elif DT_F4_E2M1 == 1
#define DATA_T uchar
#define DATA2_T uchar2
Expand Down Expand Up @@ -866,6 +928,10 @@
#define SRC_TO_REF(x) cvt_f4_e2m1_to_f32(x)
#define SRC_TO_REF8(x) cvt_f4_e2m1_to_f32(x)
#define REF_TO_SRC(x) cvt_f32_to_f4_e2m1(x)
#elif SRC_DT_F4_E3M0
#define SRC_TO_REF(x) cvt_f4_e3m0_to_f32(x)
#define SRC_TO_REF8(x) cvt_f4_e3m0_to_f32(x)
#define REF_TO_SRC(x) cvt_f32_to_f4_e3m0(x)
#elif SRC_DT_U4
#define SRC_TO_REF(x) convert_float(x)
#elif SRC_DT_S4
Expand All @@ -885,6 +951,8 @@
#define TO_SRC(x) cvt_hf_to_f8_e4m3(convert_half(x))
#elif SRC_DT_F4_E2M1
#define TO_SRC(x) cvt_f32_to_f4_e2m1(x)
#elif SRC_DT_F4_E3M0
#define TO_SRC(x) cvt_f32_to_f4_e3m0(x)
#elif SRC_DT_U8
#define TO_SRC(x) convert_uchar_sat_rte(x)
#elif SRC_DT_S8
Expand Down Expand Up @@ -913,6 +981,10 @@
#define A_TO_REF(x) cvt_f4_e2m1_to_f32(x)
#define A_TO_REF8(x) cvt_f4_e2m1_to_f32(x)
#define REF_TO_A(x) cvt_f32_to_f4_e2m1(x)
#elif A_DT_F4_E3M0
#define A_TO_REF(x) cvt_f4_e3m0_to_f32(x)
#define A_TO_REF8(x) cvt_f4_e3m0_to_f32(x)
#define REF_TO_A(x) cvt_f32_to_f4_e3m0(x)
#else
#define A_TO_REF(x) (x)
#define A_TO_REF8(x) (x)
Expand All @@ -926,6 +998,8 @@
#define TO_A(x) cvt_hf_to_f8_e4m3(convert_half(x))
#elif A_DT_F4_E2M1
#define TO_A(x) cvt_f32_to_f4_e2m1(x)
#elif A_DT_F4_E3M0
#define TO_A(x) cvt_f32_to_f4_e3m0(x)
#elif A_DT_U8
#define TO_A(x) convert_uchar_sat_rte(x)
#elif A_DT_S8
Expand All @@ -950,6 +1024,9 @@
#elif WEI_DT_F4_E2M1
#define WEI_TO_REF(x) cvt_f4_e2m1_to_f32(x)
#define REF_TO_WEI(x) cvt_f32_to_f4_e2m1(x)
#elif WEI_DT_F4_E3M0
#define WEI_TO_REF(x) cvt_f4_e3m0_to_f32(x)
#define REF_TO_WEI(x) cvt_f32_to_f4_e3m0(x)
#elif WEI_DT_S8
#define WEI_TO_REF(x) convert_int_sat_rte(x)
#define REF_TO_WEI(x) convert_char_sat_rte(x)
Expand All @@ -972,6 +1049,8 @@
#define TO_WEI(x) cvt_hf_to_f8_e4m3(convert_half(x))
#elif WEI_DT_F4_E2M1
#define TO_WEI(x) cvt_f32_to_f4_e2m1(x)
#elif WEI_DT_F4_E3M0
#define TO_WEI(x) cvt_f32_to_f4_e3m0(x)
#elif WEI_DT_U8
#define TO_WEI(x) convert_uchar_sat_rte(x)
#elif WEI_DT_S8
Expand Down Expand Up @@ -1076,6 +1155,10 @@
#define B_TO_REF(x) cvt_f4_e2m1_to_f32(x)
#define REF_TO_B(x) cvt_f32_to_f4_e2m1(x)
#define TO_B(x) cvt_f32_to_f4_e2m1(x)
#elif B_DT_F4_E3M0
#define B_TO_REF(x) cvt_f4_e3m0_to_f32(x)
#define REF_TO_B(x) cvt_f32_to_f4_e3m0(x)
#define TO_B(x) cvt_f32_to_f4_e3m0(x)
#elif B_DT_U8
#define B_TO_REF(x) (x)
#define REF_TO_B(x) (x)
Expand Down Expand Up @@ -1109,6 +1192,9 @@
#elif BIA_DT_F4_E2M1
#define BIA_TO_REF(x) cvt_f4_e2m1_to_f32(x)
#define REF_TO_BIA(x) cvt_f32_to_f4_e2m1(x)
#elif BIA_DT_F4_E3M0
#define BIA_TO_REF(x) cvt_f4_e3m0_to_f32(x)
#define REF_TO_BIA(x) cvt_f32_to_f4_e3m0(x)
#else
#define BIA_TO_REF(x) (x)
#define REF_TO_BIA(x) (x)
Expand All @@ -1122,6 +1208,8 @@
#define TO_BIA(x) cvt_hf_to_f8_e4m3(convert_half(x))
#elif BIA_DT_F4_E2M1
#define TO_BIA(x) cvt_f32_to_f4_e2m1(x)
#elif BIA_DT_F4_E3M0
#define TO_BIA(x) cvt_f32_to_f4_e3m0(x)
#elif BIA_DT_U8
#define TO_BIA(x) convert_uchar_sat_rte(x)
#elif BIA_DT_S8
Expand Down Expand Up @@ -1280,8 +1368,16 @@
#define DST_TO_REF8(x) cvt_f4_e2m1_to_f32(x)
#define REF_TO_DST(x) cvt_f32_to_f4_e2m1(x)
#define REF_TO_DST8(x) cvt_f32_to_f4_e2m1(x)
#define DST_DATA_MAX (uchar)0x7B
#define DST_DATA_MIN (uchar)0xFB
#define DST_DATA_MAX (uchar)0x07
#define DST_DATA_MIN (uchar)0x01
#elif DST_DT_F4_E2M1
#define DST_TO_REF(x) cvt_f4_e3m0_to_f32(x)
#define DST_TO_REF2(x) cvt_f4_e3m0_to_f32(x)
#define DST_TO_REF8(x) cvt_f4_e3m0_to_f32(x)
#define REF_TO_DST(x) cvt_f32_to_f4_e3m0(x)
#define REF_TO_DST8(x) cvt_f32_to_f4_e3m0(x)
#define DST_DATA_MAX (uchar)0x07
#define DST_DATA_MIN (uchar)0x08
#elif DST_DT_F16
#define REF_TO_DST(x) convert_half(x)
#define DST_TO_REF(x) convert_float(x)
Expand Down Expand Up @@ -1371,6 +1467,16 @@
#define DST_DATA_FMAX 6.0
#define DST_DATA_FMIN 1.0
#define DST_DATA_FLOW -6.0
#elif DST_DT_F4_E3M0
#define SET_DOUBLE_HALF_BYTE(x, y, z) set_double_half_byte(x, y, z)
#define TO_DST(x) cvt_f32_to_f4_e3m0(convert_float(x))
#define TO_DST2(x) cvt_f32_to_f4_e3m0(convert_float2(x))
#define TO_DST4(x) cvt_f32_to_f4_e3m0(convert_float4(x))
#define TO_DST8(x) cvt_f32_to_f4_e3m0(convert_float8(x))
#define TO_DST16(x) cvt_f32_to_f4_e3m0(convert_float16(x))
#define DST_DATA_FMAX 16.0
#define DST_DATA_FMIN 0.25
#define DST_DATA_FLOW -16.0
#elif DST_DT_U8
#define TO_DST(x) convert_uchar_sat_rte(x)
#define TO_DST2(x) convert_uchar2_sat_rte(x)
Expand Down Expand Up @@ -1440,6 +1546,11 @@
#define C_TO_REF8(x) cvt_f4_e2m1_to_f32(x)
#define REF_TO_C(x) cvt_f32_to_f4_e2m1(x)
#define REF_TO_C8(x) cvt_f32_to_f4_e2m1(x)
#elif C_DT_F4_E2M1
#define C_TO_REF(x) cvt_f4_e3m0_to_f32(x)
#define C_TO_REF8(x) cvt_f4_e3m0_to_f32(x)
#define REF_TO_C(x) cvt_f32_to_f4_e3m0(x)
#define REF_TO_C8(x) cvt_f32_to_f4_e3m0(x)
#else
#define C_TO_REF(x) (x)
#define C_TO_REF8(x) (x)
Expand All @@ -1458,6 +1569,9 @@
#elif C_DT_F4_E2M1
#define TO_C(x) cvt_f32_to_f4_e2m1(x)
#define TO_C8(x) cvt_f32_to_f4_e2m1(x)
#elif C_DT_F4_E2M1
#define TO_C(x) cvt_f32_to_f4_e3m0(x)
#define TO_C8(x) cvt_f32_to_f4_e3m0(x)
#elif C_DT_F16
#define TO_C(x) convert_half(x)
#define TO_C8(x) convert_half8(x)
Expand Down Expand Up @@ -1517,6 +1631,8 @@
#define SUM_TO_REF(x) convert_float(cvt_f8_e4m3_to_hf(x))
#elif SUM_DT_F4_E2M1
#define SUM_TO_REF(x) cvt_f4_e2m1_to_f32(x)
#elif SUM_DT_F4_E3M0
#define SUM_TO_REF(x) cvt_f4_e3m0_to_f32(x)
#else
#define SUM_TO_REF
#endif
Expand Down
4 changes: 2 additions & 2 deletions src/gpu/intel/ocl/ref_matmul.cl
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,13 @@ __kernel void ref_matmul(__global SRC_DATA_T *A, __global WEI_DATA_T *B,
+ src_zp_stride_m * m;
src_zp = SRC_ZP_TO_REF(a0, src_zp_off);
#endif
#if SRC_DT_F4_E2M1
#if SRC_DT_F4_E2M1 || SRC_DT_F4_E3M0
ACC_DATA_T s = TO_ACC(
SRC_TO_REF(GET_HALF_BYTE(A, src_off)) - src_zp);
#else
ACC_DATA_T s = TO_ACC(SRC_TO_REF(A[src_off]) - src_zp);
#endif
#if WEI_DT_S4 || WEI_DT_U4 || WEI_DT_F4_E2M1
#if WEI_DT_S4 || WEI_DT_U4 || WEI_DT_F4_E2M1 || WEI_DT_F4_E3M0
ACC_DATA_T w_raw = WEI_TO_REF(GET_HALF_BYTE(B, wei_off));
#else
ACC_DATA_T w_raw = WEI_TO_REF(B[wei_off]);
Expand Down
11 changes: 7 additions & 4 deletions src/gpu/intel/ocl/ref_matmul.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2024 Intel Corporation
* Copyright 2019-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -84,8 +84,10 @@ struct ref_matmul_t : public gpu_primitive_t {
= (utils::one_of(src_dt_, f8_e5m2, f8_e4m3)
|| utils::one_of(wei_dt_, f8_e5m2, f8_e4m3))
&& utils::one_of(dst_dt_, f32, bf16, f16, src_dt_);
const bool is_f4 = (utils::everyone_is(f4_e2m1, src_dt_, wei_dt_)
&& utils::one_of(dst_dt_, f32, bf16, f16, src_dt_));
const bool is_f4
= ((utils::one_of(src_dt_, f4_e2m1, f4_e3m0)
|| utils::everyone_is(wei_dt_, f4_e2m1, f4_e3m0))
&& utils::one_of(dst_dt_, f32, bf16, f16, src_dt_));
const bool is_bf16 = src_dt_ == bf16
&& utils::one_of(wei_dt_, bf16, s8, u8, s4, u4)
&& utils::one_of(dst_dt_, bf16, f32);
Expand Down Expand Up @@ -116,7 +118,8 @@ struct ref_matmul_t : public gpu_primitive_t {
IMPLICATION(utils::one_of(f64, src_dt_, wei_dt_, dst_dt_),
dev_info_->has_native(f64)),
VERBOSE_UNSUPPORTED_DT);
subbyte_pack_ = (dst_dt_ == data_type::f4_e2m1);
subbyte_pack_ = utils::one_of(
dst_dt_, data_type::f4_e2m1, data_type::f4_e3m0);
if (subbyte_pack_) {
using namespace dnnl::impl::memory_tracking::names;
const memory_desc_wrapper dst_mdw(dst_md(0));
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/intel/ocl/ref_reorder.cl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ __kernel void ref_reorder(__global SRC_DATA_T *restrict src,
#if WITH_DST_SCALE
dst_scale = dst_scales[SCALE_OFF(DST, d0, d1, d2, d3, d4, d5)];
#endif
#if FROM_I4 || SRC_DT_F4_E2M1
#if FROM_I4 || SRC_DT_F4_E2M1 || SRC_DT_F4_E3M0
SRC_DATA_T src_value = GET_HALF_BYTE(src, src_off);
#else
SRC_DATA_T src_value = src[src_off];
Expand Down
3 changes: 2 additions & 1 deletion src/gpu/intel/ocl/ref_reorder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ status_t ref_reorder_t::pd_t::init_conf(impl::engine_t *engine) {

auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
conf.dispatch = compute_engine->create_dispatch(dst_mdw.md_);
conf.subbyte_pack = utils::one_of(dst_mdw.data_type(), u4, s4, f4_e2m1);
conf.subbyte_pack
= utils::one_of(dst_mdw.data_type(), u4, s4, f4_e2m1, f4_e3m0);

dim_t blocks[MAX_NDIMS] = {1, 1, 0, 0, 0, 0};
for (int i = 0; i < MAX_NDIMS; ++i) {
Expand Down
14 changes: 8 additions & 6 deletions src/gpu/intel/ocl/ref_reorder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,24 @@ struct ref_reorder_t : public gpu_primitive_t {

VDISPATCH_REORDER(
utils::one_of(sdt, f32, f16, bf16, f8_e5m2, f8_e4m3,
f4_e2m1, s32, s8, u8, s4, u4, f64),
f4_e2m1, f4_e3m0, s32, s8, u8, s4, u4, f64),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_REORDER(
utils::one_of(ddt, f32, f16, bf16, f8_e5m2, f8_e4m3,
f4_e2m1, s32, s8, u8, s4, u4, f64),
f4_e2m1, f4_e3m0, s32, s8, u8, s4, u4, f64),
VERBOSE_UNSUPPORTED_DT);

VDISPATCH_REORDER(
IMPLICATION(utils::one_of(ddt, f8_e4m3, f8_e5m2, f4_e2m1),
IMPLICATION(utils::one_of(ddt, f8_e4m3, f8_e5m2, f4_e2m1,
f4_e3m0),
utils::one_of(sdt, f64, f32, f16, bf16, f8_e5m2,
f8_e4m3, f4_e2m1, ddt)),
f8_e4m3, f4_e2m1, f4_e3m0, ddt)),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_REORDER(
IMPLICATION(utils::one_of(sdt, f8_e4m3, f8_e5m2, f4_e2m1),
IMPLICATION(utils::one_of(sdt, f8_e4m3, f8_e5m2, f4_e2m1,
f4_e3m0),
utils::one_of(ddt, f64, f32, f16, bf16, f8_e5m2,
f8_e4m3, f4_e2m1, sdt)),
f8_e4m3, f4_e2m1, f4_e3m0, sdt)),
VERBOSE_UNSUPPORTED_DT);

auto *compute_engine = utils::downcast<compute::compute_engine_t *>(
Expand Down
Loading
Loading