Skip to content

Commit

Permalink
xe: reorder: support f4_e2m1
Browse files Browse the repository at this point in the history
  • Loading branch information
atkassen committed Sep 25, 2024
1 parent 60a0cae commit 1692a13
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 56 deletions.
35 changes: 12 additions & 23 deletions src/gpu/intel/ocl/custom_reorder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,46 +60,35 @@ struct custom_reorder_t : public gpu_primitive_t {
auto *compute_engine = utils::downcast<compute::compute_engine_t *>(
dst_engine->kind() == engine_kind::gpu ? dst_engine
: src_engine);
using namespace data_type;
auto sdt = src_md()->data_type;
auto ddt = dst_md()->data_type;
VDISPATCH_REORDER(
IMPLICATION(utils::one_of(dst_md()->data_type,
data_type::f8_e4m3, data_type::f8_e5m2),
utils::one_of(src_md()->data_type, data_type::f32,
data_type::f16, data_type::bf16,
data_type::f64))
&& IMPLICATION(utils::one_of(src_md()->data_type,
data_type::f8_e4m3,
data_type::f8_e5m2),
utils::one_of(dst_md()->data_type,
data_type::f32, data_type::f16,
data_type::bf16, data_type::f64)),
utils::one_of(sdt, f32, f16, f8_e5m2, f8_e4m3, s32, s8, u8),
VERBOSE_UNSUPPORTED_DT);

VDISPATCH_REORDER(
!(utils::one_of(data_type::s4, dst_md()->data_type,
src_md()->data_type)
|| utils::one_of(data_type::u4, dst_md()->data_type,
src_md()->data_type)),
utils::one_of(ddt, f32, f16, f8_e5m2, f8_e4m3, s32, s8, u8),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_REORDER(IMPLICATION(utils::one_of(ddt, f8_e4m3, f8_e5m2),
utils::one_of(sdt, f32, f16, bf16)),
VERBOSE_UNSUPPORTED_DT_CFG);
VDISPATCH_REORDER(IMPLICATION(utils::one_of(sdt, f8_e4m3, f8_e5m2),
utils::one_of(ddt, f32, f16, bf16)),
VERBOSE_UNSUPPORTED_DT_CFG);

VDISPATCH_REORDER(!memory_desc_ndims_ok(src_md(), dst_md()),
VERBOSE_INCONSISTENT_NDIMS, "src", "dst");
VDISPATCH_REORDER(compute_engine->mayiuse(
compute::device_ext_t::intel_subgroups),
VERBOSE_UNSUPPORTED_DEVICE_FEATURE, "subgroups");
VDISPATCH_REORDER(
IMPLICATION(
utils::one_of(data_type::f16, src_md()->data_type,
dst_md()->data_type),
IMPLICATION(utils::one_of(f16, sdt, ddt),
compute_engine->mayiuse(
compute::device_ext_t::khr_fp16)
&& compute_engine->mayiuse(
compute::device_ext_t::
intel_subgroups_short)),
VERBOSE_UNSUPPORTED_DT_CFG);
VDISPATCH_REORDER(
(!utils::one_of(data_type::f64, src_md()->data_type,
dst_md()->data_type)),
VERBOSE_UNSUPPORTED_DT);

VDISPATCH_REORDER_SC(init_conf(engine),
VERBOSE_PRIMITIVE_CREATION_FAIL, "reorder");
Expand Down
43 changes: 43 additions & 0 deletions src/gpu/intel/ocl/ocl_math_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ int rnd_down(int a, unsigned int b) {
#define MATH_UTILS_DECLARE_HF8 1
#endif

#if DT_F4_E2M1 || SRC_DT_F4_E2M1 || WEI_DT_F4_E2M1 || DST_DT_F4_E2M1 \
|| BIA_DT_F4_E2M1 || A_DT_F4_E2M1 || A_DT_F4_E2M1 || B_DT_F4_E2M1 \
|| C_DT_F4_E2M1 || DATA_DT_F4_E2M1 || POST_OP_USING_F4_E2M1
#define MATH_UTILS_DECLARE_F4_E2M1 1
#endif

#if DT_S4 || SRC_DT_S4 || WEI_DT_S4 || DST_DT_S4 || BIA_DT_S4 || A_DT_S4 \
|| B_DT_S4 || C_DT_S4 || DATA_DT_S4 || WEI_ZP_DT_S4 || SRC_ZP_DT_S4
#define MATH_UTILS_DECLARE_S4 1
Expand Down Expand Up @@ -738,6 +744,38 @@ float __attribute__((overloadable)) cvt_s4_to_s32(char a) {
return convert_int_sat_rte(val);
}

#endif

#if MATH_UTILS_DECLARE_F4_E2M1

uchar __attribute__((overloadable)) cvt_f32_to_f4_e2m1(float a) {
// Rounding boundaries
const ushort8 boundaries
= {0x800, 0x3e8, 0x3f4, 0x3fa, 0x3fe, 0x402, 0x406, 0x40a};
ushort b = as_uint(a) >> 20;
ushort val = b & 0x7ff;
short cmp_mask = (val & 0x7f8) != 0x7f8;
short8 gt = (val > boundaries) & cmp_mask;
short4 eq = (val == boundaries.s0246) & cmp_mask;
short4 r0 = gt.s0123 + gt.s4567 + eq;
short2 r1 = r0.s01 + r0.s23;
uchar sign = (b >> 8) & (cmp_mask << 3);
return sign | (uchar)(r1.s0 + r1.s1);
}

float __attribute__((overloadable)) cvt_f4_e2m1_to_f32(uchar a) {
uint sign = a & 0x08;
uint em = a & 0x07;
uint exp = em >> 1;
uint mant = exp ? a & 0x01 : 0x0;
if (em) exp += 126; // No f4 values are subnormal in f32
return as_float((sign << 28) | (exp << 23) | (mant << 22));
}

#endif

#if MATH_UTILS_DECLARE_S4 || MATH_UTILS_DECLARE_U4 || MATH_UTILS_DECLARE_F4_E2M1

uchar __attribute__((overloadable)) get_half_byte(__global uchar *x, off_t y) {
uchar ret = 0;
if (y % 2) {
Expand All @@ -747,20 +785,25 @@ uchar __attribute__((overloadable)) get_half_byte(__global uchar *x, off_t y) {
}
return ret;
}

char __attribute__((overloadable)) get_half_byte(__global char *x, off_t y) {
if (y % 2) {
return (x[y / 2] & 0xf0) >> 4;
} else {
return x[y / 2] & 0x0f;
}
}

void __attribute__((overloadable))
set_double_half_byte(__global uchar *x, off_t y, uchar z) {
x[y / 2] = z;
}

void __attribute__((overloadable))
set_double_half_byte(__global char *x, off_t y, uchar z) {
x[y / 2] = z;
}

#endif

#endif
12 changes: 12 additions & 0 deletions src/gpu/intel/ocl/ocl_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,9 @@
#define SRC_TO_REF(x) convert_float(x)
#elif SRC_DT_S4
#define SRC_TO_REF(x) convert_float(cvt_s4_to_f32(x))
#elif SRC_DT_F4_E2M1
#define GET_HALF_BYTE(x, y) get_half_byte(x, y)
#define SRC_TO_REF(x) convert_float(cvt_f4_e2m1_to_f32(x))
#else
#define SRC_TO_REF(x) (x)
#define SRC_TO_REF8(x) (x)
Expand Down Expand Up @@ -1339,6 +1342,15 @@
#define TO_DST16(x) cvt_f32_to_s4(convert_float16(x))
#define DST_DATA_FMAX 7.0
#define DST_DATA_FMIN -8.0
#elif DST_DT_F4_E2M1
#define SET_DOUBLE_HALF_BYTE(x, y, z) set_double_half_byte(x, y, z)
#define TO_DST(x) cvt_f32_to_f4_e2m1(convert_float(x))
#define TO_DST2(x) cvt_f32_to_f4_e2m1(convert_float2(x))
#define TO_DST4(x) cvt_f32_to_f4_e2m1(convert_float4(x))
#define TO_DST8(x) cvt_f32_to_f4_e2m1(convert_float8(x))
#define TO_DST16(x) cvt_f32_to_f4_e2m1(convert_float16(x))
#define DST_DATA_FMAX 6.0
#define DST_DATA_FMIN -6.0
#elif DST_DT_U8
#define TO_DST(x) convert_uchar_sat_rte(x)
#define TO_DST2(x) convert_uchar2_sat_rte(x)
Expand Down
3 changes: 1 addition & 2 deletions src/gpu/intel/ocl/ref_reorder.cl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include "gpu/intel/ocl/reorder_common.h"
#include "gpu/intel/ocl/types_interop.h"

#define TO_I4 ((DST_DT_U4 || DST_DT_S4) && (!SRC_DT_U4 && !SRC_DT_S4))
#define FROM_I4 ((SRC_DT_U4 || SRC_DT_S4) && (!DST_DT_U4 && !DST_DT_S4))
#define GWS_GET_THREAD_ID(index) \
(off_t)(get_global_id(index) + offset.array[index])
Expand Down Expand Up @@ -94,7 +93,7 @@ __kernel void ref_reorder(__global SRC_DATA_T *restrict src,
#if WITH_DST_SCALE
dst_scale = dst_scales[SCALE_OFF(DST, d0, d1, d2, d3, d4, d5)];
#endif
#if FROM_I4
#if FROM_I4 || SRC_DT_F4_E2M1
SRC_DATA_T src_value = GET_HALF_BYTE(src, src_off);
#else
SRC_DATA_T src_value = src[src_off];
Expand Down
10 changes: 3 additions & 7 deletions src/gpu/intel/ocl/ref_reorder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using namespace dnnl::impl::memory_tracking::names;

status_t ref_reorder_t::pd_t::init_conf(impl::engine_t *engine) {
using namespace format_tag;
using namespace data_type;

const memory_desc_wrapper src_mdw(src_md());
const memory_desc_wrapper dst_mdw(dst_md());
Expand All @@ -52,15 +53,10 @@ status_t ref_reorder_t::pd_t::init_conf(impl::engine_t *engine) {
if (conf.nelems == 0) return status::success;

auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);

dim_t blocks[MAX_NDIMS] = {1, 1, 1, 1, 1, 1};

conf.dispatch = compute_engine->create_dispatch(dst_mdw.md_);
conf.subbyte_pack
= utils::one_of(dst_mdw.data_type(), data_type::u4, data_type::s4);

blocks[2] = blocks[3] = blocks[4] = blocks[5] = 0;
conf.subbyte_pack = utils::one_of(dst_mdw.data_type(), u4, s4, f4_e2m1);

dim_t blocks[MAX_NDIMS] = {1, 1, 0, 0, 0, 0};
for (int i = 0; i < MAX_NDIMS; ++i) {
auto dim_str = utils::format("D%d", i);
if (i < dst_mdw.ndims()) {
Expand Down
50 changes: 26 additions & 24 deletions src/gpu/intel/ocl/ref_reorder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,27 @@ struct ref_reorder_t : public gpu_primitive_t {
VERBOSE_RUNTIMEDIM_UNSUPPORTED);

using namespace data_type;
using compute::device_ext_t;
const auto sdt = src_md()->data_type;
const auto ddt = dst_md()->data_type;
VDISPATCH_REORDER(utils::one_of(sdt, f32, f16, bf16, f8_e5m2,
f8_e4m3, s32, s8, u8, s4, u4, f64),
VDISPATCH_REORDER(
utils::one_of(sdt, f32, f16, bf16, f8_e5m2, f8_e4m3,
f4_e2m1, s32, s8, u8, s4, u4, f64),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_REORDER(utils::one_of(ddt, f32, f16, bf16, f8_e5m2,
f8_e4m3, s32, s8, u8, s4, u4, f64),
VDISPATCH_REORDER(
utils::one_of(ddt, f32, f16, bf16, f8_e5m2, f8_e4m3,
f4_e2m1, s32, s8, u8, s4, u4, f64),
VERBOSE_UNSUPPORTED_DT);

VDISPATCH_REORDER(
IMPLICATION(utils::one_of(ddt, f8_e4m3, f8_e5m2),
utils::one_of(sdt, f32, f16, bf16, f64, ddt))
&& IMPLICATION(utils::one_of(sdt, f8_e4m3, f8_e5m2),
utils::one_of(
ddt, f32, f16, bf16, f64, sdt)),
IMPLICATION(utils::one_of(ddt, f8_e4m3, f8_e5m2, f4_e2m1),
utils::one_of(sdt, f64, f32, f16, bf16, f8_e5m2,
f8_e4m3, f4_e2m1, ddt)),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_REORDER(
IMPLICATION(utils::one_of(sdt, f8_e4m3, f8_e5m2, f4_e2m1),
utils::one_of(ddt, f64, f32, f16, bf16, f8_e5m2,
f8_e4m3, f4_e2m1, sdt)),
VERBOSE_UNSUPPORTED_DT);

auto *compute_engine = utils::downcast<compute::compute_engine_t *>(
Expand All @@ -83,25 +89,21 @@ struct ref_reorder_t : public gpu_primitive_t {
compute::device_ext_t::intel_subgroups),
VERBOSE_UNSUPPORTED_DEVICE_FEATURE, "subgroups");
VDISPATCH_REORDER(
IMPLICATION(
utils::one_of(data_type::f16, src_md()->data_type,
dst_md()->data_type),
compute_engine->mayiuse(
compute::device_ext_t::khr_fp16)
&& compute_engine->mayiuse(
compute::device_ext_t::
IMPLICATION(utils::one_of(data_type::f16, sdt, ddt),
compute_engine->mayiuse(device_ext_t::khr_fp16)
&& compute_engine->mayiuse(device_ext_t::
intel_subgroups_short)),
VERBOSE_UNSUPPORTED_DT_CFG);
VDISPATCH_REORDER(
IMPLICATION(utils::one_of(data_type::f64, sdt, ddt),
compute_engine->mayiuse(
compute::device_ext_t::khr_fp64)
&& attr()->post_ops_.has_default_values())
&& IMPLICATION(
(utils::one_of(data_type::u4, sdt, ddt)
|| utils::one_of(
data_type::s4, sdt, ddt)),
attr()->post_ops_.has_default_values()),
compute_engine->mayiuse(device_ext_t::khr_fp64)
&& attr()->post_ops_.has_default_values()),
VERBOSE_UNSUPPORTED_DT_CFG);
VDISPATCH_REORDER(
IMPLICATION(
(utils::one_of(data_type::u4, sdt, ddt)
|| utils::one_of(data_type::s4, sdt, ddt)),
attr()->post_ops_.has_default_values()),
VERBOSE_UNSUPPORTED_DT_CFG);

VDISPATCH_REORDER_SC(init_conf(engine), "init_conf()");
Expand Down
5 changes: 5 additions & 0 deletions src/gpu/intel/primitive_conf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,7 @@ inline void def_data_type(compute::kernel_ctx_t &kernel_ctx, data_type_t dt,
const char *bf16_name = with_punning ? "ushort" : "bf16";
const char *bf8_name = with_punning ? "uchar" : "f8_e5m2";
const char *hf8_name = with_punning ? "uchar" : "f8_e4m3";
const char *f4_e2m1_name = with_punning ? "uchar" : "f4_e2m1";
const char *e8m0_name = with_punning ? "uchar" : "e8m0";
const char *u4_name = with_punning ? "uchar" : "u4";
const char *s4_name = with_punning ? "uchar" : "s4";
Expand Down Expand Up @@ -1050,6 +1051,10 @@ inline void def_data_type(compute::kernel_ctx_t &kernel_ctx, data_type_t dt,
kernel_ctx.add_option(utils::format(
"-D%s_DATA_T=%s -D%s_DT_BF8", str, bf8_name, str));
break;
case data_type::f4_e2m1:
kernel_ctx.add_option(utils::format(
"-D%s_DATA_T=%s -D%s_DT_F4_E2M1", str, f4_e2m1_name, str));
break;
case data_type::e8m0:
kernel_ctx.add_option(utils::format(
"-D%s_DATA_T=%s -D%s_DT_E8M0", str, e8m0_name, str));
Expand Down
3 changes: 3 additions & 0 deletions tests/benchdnn/reorder/cfg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ REG(f16, -f16_max_exact, f16_max_exact);
REG(bf16, -int_max_exact, int_max_exact);
REG(f8_e5m2, -f16_max_exact, f16_max_exact);
REG(f8_e4m3, -f16_max_exact, f16_max_exact);
REG(f4_e2m1, -f16_max_exact, f16_max_exact);
// Do not exceed max float value representable in integer. Otherwise, we get
// a correctness issue caused by different computations in reference and the
// library.
Expand All @@ -59,6 +60,7 @@ dt_conf_t dt2cfg(dnnl_data_type_t dt) {
CASE(bf16);
CASE(f8_e5m2);
CASE(f8_e4m3);
CASE(f4_e2m1);
CASE(s32);
CASE(s8);
CASE(u8);
Expand All @@ -78,6 +80,7 @@ dnnl_data_type_t cfg2dt(dt_conf_t cfg) {
CASE(bf16);
CASE(f8_e5m2);
CASE(f8_e4m3);
CASE(f4_e2m1);
CASE(s32);
CASE(s8);
CASE(u8);
Expand Down

0 comments on commit 1692a13

Please sign in to comment.