From 0aa315efb21a2646c58550504357c4c9d5954499 Mon Sep 17 00:00:00 2001 From: Andrey Kalinin Date: Wed, 18 Sep 2024 10:51:40 -0700 Subject: [PATCH] x64: brgemm: add flag to brgattr to indicate is f16 b matrix has vnni layout --- src/cpu/x64/brgemm/brgemm.cpp | 3 +- src/cpu/x64/brgemm/brgemm_types.hpp | 10 ++- src/cpu/x64/brgemm/brgemm_utils.cpp | 26 ++++--- src/cpu/x64/brgemm/jit_brgemm_kernel.cpp | 98 +++++++++++++++++++----- src/cpu/x64/rnn/rnn_brgemm_utils.cpp | 1 + 5 files changed, 107 insertions(+), 31 deletions(-) diff --git a/src/cpu/x64/brgemm/brgemm.cpp b/src/cpu/x64/brgemm/brgemm.cpp index d022f76554b..ba6a4c27d1d 100644 --- a/src/cpu/x64/brgemm/brgemm.cpp +++ b/src/cpu/x64/brgemm/brgemm.cpp @@ -513,7 +513,7 @@ status_t brgemm_desc_set_attr( || brgattr.hint_ld_block != 0 || brgattr.hint_ld_block2 != 0 || brgattr.hint_load_nt_A != brgemm_hint_nt_undef || brgattr.hint_load_nt_B != brgemm_hint_nt_undef - || brgattr.hint_bs_group > 1); + || brgattr.hint_bs_group > 1 || brgattr.b_is_vnni); if (brgattr.use_uker || brg->is_bf16_tmm || hint_blocking_set || brgattr.bd_mask_level || brgattr.fpmath_mode != fpmath_mode::strict || max_vpad > 0) { @@ -768,6 +768,7 @@ int brgemm_cmp(const brgemm_desc_t &lhs, const brgemm_desc_t &rhs) { CMP_BRGEMM_FIELD(brgattr.bd_mask_level); CMP_BRGEMM_FIELD(brgattr.use_uker); CMP_BRGEMM_FIELD(brgattr.use_interleave_stores); + CMP_BRGEMM_FIELD(brgattr.b_is_vnni); CMP_BRGEMM_FIELD(brgattr.fpmath_mode); CMP_BRGEMM_FIELD(brgattr.LDA2); CMP_BRGEMM_FIELD(brgattr.LDB2); diff --git a/src/cpu/x64/brgemm/brgemm_types.hpp b/src/cpu/x64/brgemm/brgemm_types.hpp index 9054081576e..465a6b82a5c 100644 --- a/src/cpu/x64/brgemm/brgemm_types.hpp +++ b/src/cpu/x64/brgemm/brgemm_types.hpp @@ -152,6 +152,7 @@ struct DNNL_API brgemm_attr_t { // interleave stores or not bool use_interleave_stores; impl::fpmath_mode_t fpmath_mode = fpmath_mode::strict; + bool b_is_vnni {false}; // Second level leading dimension describing distance between 16-line // blocks in case of blocked layout. Used to calculate address of next // bd block. By default are equal to regular leading dimension parameters @@ -393,7 +394,9 @@ struct brgemm_desc_t { switch (dt_b) { case f32: return false; // Note: `dt_a == f32` means implicit up-conversion of B to f32. - case f16: return (isa_impl != avx512_core_fp16) && (dt_a != f32); + case f16: + return brgattr.b_is_vnni + || ((isa_impl != avx512_core_fp16) && (dt_a != f32)); // Note: `dt_a == f32` means implicit up-conversion of B to f32. case bf16: return dt_a != f32; default: return true; @@ -401,6 +404,11 @@ struct brgemm_desc_t { } bool is_xf16() const noexcept { return is_bf16 || is_f16; } + bool is_f16_b_non_amx_vnni() const { + return dt_b == data_type::f16 && brgattr.b_is_vnni + && !is_superset(isa_impl, avx512_core_amx_fp16); + } + bool operator==(const brgemm_desc_t &rhs) const; bool operator<(const brgemm_desc_t &rhs) const; diff --git a/src/cpu/x64/brgemm/brgemm_utils.cpp b/src/cpu/x64/brgemm/brgemm_utils.cpp index d62d908f101..2f4e64191b4 100644 --- a/src/cpu/x64/brgemm/brgemm_utils.cpp +++ b/src/cpu/x64/brgemm/brgemm_utils.cpp @@ -182,6 +182,9 @@ int calculate_ldb_params(brgemm_desc_t *brg, const int try_ld_block2) { int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) { + // TODO: Calculating the number of available registers should be re-factored + // to use one code here and in brgemm kernel generator on + // "max_effective_vregs" calculation constexpr int max_bcst_regs = 1; const bool req_compensation = brg->req_s8s8_compensation || brg->zp_type_a != brgemm_broadcast_t::none; @@ -190,12 +193,15 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) { || brg->brgattr.max_bottom_vpad > 0) && brg->zp_type_a != brgemm_broadcast_t::none; const int beta_regs = !one_of(brg->beta, 1.f, 0.f); + // To support the f16 vnni B matrix on non-AMX we need to use two Vmm + // registers for permutation in brgemm kernel + const int b_vnni_regs = brg->is_f16_b_non_amx_vnni() ? 2 : 0; const int max_isa_regs = isa_num_vregs(brg->isa_impl); // note: the 'adj_ld_block2' already removes the necessary registers // for 'embd_bcst' auto max_reg_count = max_isa_regs - max_bcst_regs - beta_regs - - req_compensation - req_zp_a_comp_pads; + - req_compensation - req_zp_a_comp_pads - b_vnni_regs; if (req_zp_a_comp_pads) max_reg_count = nstl::min(max_reg_count, max_isa_regs - max_bcst_regs - 5); @@ -224,6 +230,15 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) { } status_t brgemm_blocking(brgemm_desc_t *brg) { + const data_type_t ld_step_compute_dt + = get_mac_emu_data_type(brg->dt_b, brg->isa_impl, + brg->isa_impl != avx2_vnni_2 && !brg->is_fp8_via_convert()); + brg->ld_step = brg->is_f16_b_non_amx_vnni() + ? 2 + : data_type_vnni_granularity(ld_step_compute_dt); + const data_type_t rd_step_compute_dt = get_mac_emu_data_type( + brg->dt_b, brg->isa_impl, !brg->is_fp8_via_convert()); + brg->rd_step = data_type_vnni_granularity(rd_step_compute_dt); set_isa_impl(brg); if (brg->isa_impl == isa_undef) return status::unimplemented; @@ -875,15 +890,6 @@ void init_brgemm_conf(brgemm_desc_t *brg, cpu_isa_t isa, brg->bd_block2 = 0; brg->bdb2 = 0; brg->bdb2_tail = 0; - - const data_type_t ld_step_compute_dt - = get_mac_emu_data_type(brg->dt_b, brg->isa_impl, - brg->isa_impl != avx2_vnni_2 && !brg->is_fp8_via_convert()); - brg->ld_step = data_type_vnni_granularity(ld_step_compute_dt); - - const data_type_t rd_step_compute_dt = get_mac_emu_data_type( - brg->dt_b, brg->isa_impl, !brg->is_fp8_via_convert()); - brg->rd_step = data_type_vnni_granularity(rd_step_compute_dt); } void init_brdgmm_conf(brgemm_desc_t *brg, cpu_isa_t isa, diff --git a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp index 0d26602aafd..764b75ac230 100644 --- a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp +++ b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp @@ -45,10 +45,7 @@ struct jit_brgemm_kernel_t : public jit_generator { : jit_generator(jit_name(), abrg.isa_impl) , brg(abrg) , postops_injector_(nullptr) - , max_effective_vregs(isa_num_vregs(brg.isa_impl) - - (brg.is_int8 && !brg.has_int8_vnni - ? 2 - : (brg.is_fp8_via_convert() ? 5 : 0))) { + , max_effective_vregs(get_max_effective_vregs(brg)) { // The implementation uses is_superset(), is_subset() utilities. // So avoid isa_all, isa_undef in these comparisions. @@ -149,6 +146,8 @@ struct jit_brgemm_kernel_t : public jit_generator { Xbyak::Label avx_tail_mask_; Xbyak::Label sum_zp_scale_data_; + Xbyak::Label f16_perm_even_table_; + Xbyak::Label f16_perm_odd_table_; using reg64_t = const Xbyak::Reg64; // Register decomposition @@ -276,6 +275,17 @@ struct jit_brgemm_kernel_t : public jit_generator { Xbyak::Opmask fp8_col_mask = Xbyak::Opmask(4); Xbyak::Opmask kmask_fp8_aux = Xbyak::Opmask(5); + static int get_max_effective_vregs(const brgemm_desc_t &brg) { + auto used_vregs = 0; + if (brg.is_int8 && !brg.has_int8_vnni) + used_vregs = 2; + else if (brg.is_fp8_via_convert()) + used_vregs = 5; + else if (brg.is_f16_b_non_amx_vnni()) + used_vregs = 2; + return isa_num_vregs(brg.isa_impl) - used_vregs; + } + Vmm accm(int ld_block, int bd, int ld) { return Vmm(max_effective_vregs - 1 - (bd * ld_block + ld)); } @@ -336,6 +346,9 @@ struct jit_brgemm_kernel_t : public jit_generator { return Vmm(isa_num_vregs(brg.isa_impl) - 2); } + Zmm f16_perm_even_vreg_ = Zmm(isa_num_vregs(brg.isa_impl) - 1); + Zmm f16_perm_odd_vreg_ = Zmm(isa_num_vregs(brg.isa_impl) - 2); + Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store, Xbyak::Opmask ktail_mask) const; Vmm_lower_t vmm_lower_mask(const Vmm_lower_t vmm_lower_in, bool mask_flag, @@ -448,14 +461,13 @@ int jit_brgemm_kernel_t::B_offset( if (is_amx) { return brg.typesize_B * (brg.rd_step * ld * brg.ld_block); } else { - const int data_vnni_granularity = brg.ld_step; - const int rdb0 = rd / data_vnni_granularity; + const int rdb0 = rd / brg.ld_step; // Note: Offsets for elements within vnni_granularity are expected to be // handled within gemm_microkernel (for ex: odd-even converts). - // hence no `rd % data_vnni_granularity` + // hence no `rd % brg.ld_step` return brg.typesize_B - * (rdb0 * data_vnni_granularity * brg.LDB - + data_vnni_granularity * ld * brg.ld_block); + * (rdb0 * brg.ld_step * brg.LDB + + brg.ld_step * ld * brg.ld_block); } } @@ -766,7 +778,7 @@ void jit_brgemm_kernel_t::ldb_regs_shift(int ld_block2, bool is_tail) { add(reg_aux_D, D_offset); add(reg_b_offset, - (is_tail) ? ldb_B_offset(1, true) : ldb_B_offset(ld_block2)); + (is_tail) ? ldb_B_offset(0, true) : ldb_B_offset(ld_block2)); if (brg.with_bias) { mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]); @@ -2090,8 +2102,7 @@ void jit_brgemm_kernel_t::compute_int8_compensation(int rd_loop, int bd_b, auto vmm_store = vmm_mask(load(), is_tail, false, ld_tail_mask); uni_vmovups(vmm_store, addr); } else { - load_bytes( - load(), addr, brg.typesize_B * brg.ldb_tail * brg.ld_step); + load_bytes(load(), addr, ldb_B_offset(0, true)); } if (brg.req_cal_comp_pads) { @@ -2194,8 +2205,20 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, vcvtneeph2ps(vmm_load, addr); else vcvtneoph2ps(vmm_load, addr); - } else - vcvtph2psx(vmm_load, addr); + } else { + if (brg.is_f16_b_non_amx_vnni()) { + const auto vnni_addr = ptr[reg_aux_B + + B_offset(ld, utils::rnd_dn(rd, 2))]; + vmovups(vmm_load, vnni_addr); + if (rd % 2 == 0) + vpermw(vmm_load, f16_perm_even_vreg_, vmm_load); + else + vpermw(vmm_load, f16_perm_odd_vreg_, vmm_load); + vcvtph2psx( + vmm_load, Vmm_lower_t(vmm_load.getIdx())); + } else + vcvtph2psx(vmm_load, addr); + } } else if (brg.dt_b == data_type::bf16 && brg.isa_impl == avx2_vnni_2) { if (rd % 2 == 0) @@ -2206,8 +2229,7 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, if (is_superset(brg.isa_impl, avx512_core)) { uni_vmovups(vmm_load, addr); } else { - load_bytes(vmm_load, addr, - brg.typesize_B * brg.ldb_tail * brg.ld_step); + load_bytes(vmm_load, addr, ldb_B_offset(0, true)); } } else { uni_vmovups(vmm_load, addr); @@ -2239,7 +2261,20 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, else vcvtneoph2ps(vmm_load, addr); } else { - vcvtph2psx(vmm_load, addr); + if (brg.is_f16_b_non_amx_vnni()) { + const auto actual_B_offset + = B_offset(ld, utils::rnd_dn(rd, 2)); + const auto vnni_addr + = ptr[reg_aux_B + actual_B_offset]; + vmovups(vmm_load, vnni_addr); + if (rd % 2 == 0) + vpermw(vmm_load, f16_perm_even_vreg_, vmm_load); + else + vpermw(vmm_load, f16_perm_odd_vreg_, vmm_load); + vcvtph2psx( + vmm_load, Vmm_lower_t(vmm_load.getIdx())); + } else + vcvtph2psx(vmm_load, addr); } } else if (brg.dt_b == data_type::bf16 && brg.isa_impl == avx2_vnni_2) { @@ -2251,8 +2286,7 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, if (is_superset(brg.isa_impl, avx512_core)) { uni_vmovups(vmm_load, addr); } else { - load_bytes(vmm_load, addr, - brg.typesize_B * brg.ldb_tail * brg.ld_step); + load_bytes(vmm_load, addr, ldb_B_offset(0, true)); } } else { uni_vmovups(vmm_load, addr); @@ -2739,6 +2773,13 @@ void jit_brgemm_kernel_t::generate() { vpbroadcastw(int8_ones_words(), reg_tmp_gpr.cvt16()); } + if (brg.is_f16_b_non_amx_vnni()) { + mov(reg_tmp_gpr, f16_perm_even_table_); + vmovups(f16_perm_even_vreg_, ptr[reg_tmp_gpr]); + mov(reg_tmp_gpr, f16_perm_odd_table_); + vmovups(f16_perm_odd_vreg_, ptr[reg_tmp_gpr]); + } + read_params(); bdb_loop(); @@ -2771,6 +2812,25 @@ void jit_brgemm_kernel_t::generate() { if (brg.with_eltwise) postops_injector_->prepare_table(/* generate = */ true); + + if (brg.is_f16_b_non_amx_vnni()) { + // convert interleaved vnni data with holes to packed. + align(64); + L(f16_perm_even_table_); + for (int i = 0; i < 32; ++i) { + if (i < 16) + dw(uint16_t(2 * i)); + else + dw(uint16_t(0)); + } + align(64); + L(f16_perm_odd_table_); + for (int i = 0; i < 32; ++i) + if (i < 16) + dw(uint16_t(2 * i + 1)); + else + dw(uint16_t(0)); + } } brgemm_attr_t::brgemm_attr_t() diff --git a/src/cpu/x64/rnn/rnn_brgemm_utils.cpp b/src/cpu/x64/rnn/rnn_brgemm_utils.cpp index 0cc7d3bf74a..68080d6d320 100644 --- a/src/cpu/x64/rnn/rnn_brgemm_utils.cpp +++ b/src/cpu/x64/rnn/rnn_brgemm_utils.cpp @@ -564,6 +564,7 @@ status_t init_brgemm_kernel(x64::brgemm_desc_t *desc, x64::cpu_isa_t isa, brgattr.max_bs = max_bs; brgattr.max_top_vpad = 0; brgattr.max_bottom_vpad = 0; + brgattr.b_is_vnni = true; CHECK(brgemm_desc_set_attr(desc, brgattr)); x64::brgemm_kernel_t *_t_ptr;