Skip to content

Commit

Permalink
x64: brgemm: add flag to brgattr to indicate is f16 b matrix has vnni…
Browse files Browse the repository at this point in the history
… layout
  • Loading branch information
ankalinin committed Sep 24, 2024
1 parent 0185923 commit 0aa315e
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 31 deletions.
3 changes: 2 additions & 1 deletion src/cpu/x64/brgemm/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ status_t brgemm_desc_set_attr(
|| brgattr.hint_ld_block != 0 || brgattr.hint_ld_block2 != 0
|| brgattr.hint_load_nt_A != brgemm_hint_nt_undef
|| brgattr.hint_load_nt_B != brgemm_hint_nt_undef
|| brgattr.hint_bs_group > 1);
|| brgattr.hint_bs_group > 1 || brgattr.b_is_vnni);
if (brgattr.use_uker || brg->is_bf16_tmm || hint_blocking_set
|| brgattr.bd_mask_level
|| brgattr.fpmath_mode != fpmath_mode::strict || max_vpad > 0) {
Expand Down Expand Up @@ -768,6 +768,7 @@ int brgemm_cmp(const brgemm_desc_t &lhs, const brgemm_desc_t &rhs) {
CMP_BRGEMM_FIELD(brgattr.bd_mask_level);
CMP_BRGEMM_FIELD(brgattr.use_uker);
CMP_BRGEMM_FIELD(brgattr.use_interleave_stores);
CMP_BRGEMM_FIELD(brgattr.b_is_vnni);
CMP_BRGEMM_FIELD(brgattr.fpmath_mode);
CMP_BRGEMM_FIELD(brgattr.LDA2);
CMP_BRGEMM_FIELD(brgattr.LDB2);
Expand Down
10 changes: 9 additions & 1 deletion src/cpu/x64/brgemm/brgemm_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ struct DNNL_API brgemm_attr_t {
// interleave stores or not
bool use_interleave_stores;
impl::fpmath_mode_t fpmath_mode = fpmath_mode::strict;
bool b_is_vnni {false};
// Second level leading dimension describing distance between 16-line
// blocks in case of blocked layout. Used to calculate address of next
// bd block. By default are equal to regular leading dimension parameters
Expand Down Expand Up @@ -393,14 +394,21 @@ struct brgemm_desc_t {
switch (dt_b) {
case f32: return false;
// Note: `dt_a == f32` means implicit up-conversion of B to f32.
case f16: return (isa_impl != avx512_core_fp16) && (dt_a != f32);
case f16:
return brgattr.b_is_vnni
|| ((isa_impl != avx512_core_fp16) && (dt_a != f32));
// Note: `dt_a == f32` means implicit up-conversion of B to f32.
case bf16: return dt_a != f32;
default: return true;
}
}
bool is_xf16() const noexcept { return is_bf16 || is_f16; }

bool is_f16_b_non_amx_vnni() const {
return dt_b == data_type::f16 && brgattr.b_is_vnni
&& !is_superset(isa_impl, avx512_core_amx_fp16);
}

bool operator==(const brgemm_desc_t &rhs) const;
bool operator<(const brgemm_desc_t &rhs) const;

Expand Down
26 changes: 16 additions & 10 deletions src/cpu/x64/brgemm/brgemm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ int calculate_ldb_params(brgemm_desc_t *brg, const int try_ld_block2) {

int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) {

// TODO: Calculating the number of available registers should be re-factored
// to use one code here and in brgemm kernel generator on
// "max_effective_vregs" calculation
constexpr int max_bcst_regs = 1;
const bool req_compensation = brg->req_s8s8_compensation
|| brg->zp_type_a != brgemm_broadcast_t::none;
Expand All @@ -190,12 +193,15 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) {
|| brg->brgattr.max_bottom_vpad > 0)
&& brg->zp_type_a != brgemm_broadcast_t::none;
const int beta_regs = !one_of(brg->beta, 1.f, 0.f);
// To support the f16 vnni B matrix on non-AMX we need to use two Vmm
// registers for permutation in brgemm kernel
const int b_vnni_regs = brg->is_f16_b_non_amx_vnni() ? 2 : 0;

const int max_isa_regs = isa_num_vregs(brg->isa_impl);
// note: the 'adj_ld_block2' already removes the necessary registers
// for 'embd_bcst'
auto max_reg_count = max_isa_regs - max_bcst_regs - beta_regs
- req_compensation - req_zp_a_comp_pads;
- req_compensation - req_zp_a_comp_pads - b_vnni_regs;
if (req_zp_a_comp_pads)
max_reg_count
= nstl::min(max_reg_count, max_isa_regs - max_bcst_regs - 5);
Expand Down Expand Up @@ -224,6 +230,15 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) {
}

status_t brgemm_blocking(brgemm_desc_t *brg) {
const data_type_t ld_step_compute_dt
= get_mac_emu_data_type(brg->dt_b, brg->isa_impl,
brg->isa_impl != avx2_vnni_2 && !brg->is_fp8_via_convert());
brg->ld_step = brg->is_f16_b_non_amx_vnni()
? 2
: data_type_vnni_granularity(ld_step_compute_dt);
const data_type_t rd_step_compute_dt = get_mac_emu_data_type(
brg->dt_b, brg->isa_impl, !brg->is_fp8_via_convert());
brg->rd_step = data_type_vnni_granularity(rd_step_compute_dt);

set_isa_impl(brg);
if (brg->isa_impl == isa_undef) return status::unimplemented;
Expand Down Expand Up @@ -875,15 +890,6 @@ void init_brgemm_conf(brgemm_desc_t *brg, cpu_isa_t isa,
brg->bd_block2 = 0;
brg->bdb2 = 0;
brg->bdb2_tail = 0;

const data_type_t ld_step_compute_dt
= get_mac_emu_data_type(brg->dt_b, brg->isa_impl,
brg->isa_impl != avx2_vnni_2 && !brg->is_fp8_via_convert());
brg->ld_step = data_type_vnni_granularity(ld_step_compute_dt);

const data_type_t rd_step_compute_dt = get_mac_emu_data_type(
brg->dt_b, brg->isa_impl, !brg->is_fp8_via_convert());
brg->rd_step = data_type_vnni_granularity(rd_step_compute_dt);
}

void init_brdgmm_conf(brgemm_desc_t *brg, cpu_isa_t isa,
Expand Down
98 changes: 79 additions & 19 deletions src/cpu/x64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@ struct jit_brgemm_kernel_t : public jit_generator {
: jit_generator(jit_name(), abrg.isa_impl)
, brg(abrg)
, postops_injector_(nullptr)
, max_effective_vregs(isa_num_vregs(brg.isa_impl)
- (brg.is_int8 && !brg.has_int8_vnni
? 2
: (brg.is_fp8_via_convert() ? 5 : 0))) {
, max_effective_vregs(get_max_effective_vregs(brg)) {

// The implementation uses is_superset(), is_subset() utilities.
// So avoid isa_all, isa_undef in these comparisions.
Expand Down Expand Up @@ -149,6 +146,8 @@ struct jit_brgemm_kernel_t : public jit_generator {

Xbyak::Label avx_tail_mask_;
Xbyak::Label sum_zp_scale_data_;
Xbyak::Label f16_perm_even_table_;
Xbyak::Label f16_perm_odd_table_;
using reg64_t = const Xbyak::Reg64;

// Register decomposition
Expand Down Expand Up @@ -276,6 +275,17 @@ struct jit_brgemm_kernel_t : public jit_generator {
Xbyak::Opmask fp8_col_mask = Xbyak::Opmask(4);
Xbyak::Opmask kmask_fp8_aux = Xbyak::Opmask(5);

static int get_max_effective_vregs(const brgemm_desc_t &brg) {
auto used_vregs = 0;
if (brg.is_int8 && !brg.has_int8_vnni)
used_vregs = 2;
else if (brg.is_fp8_via_convert())
used_vregs = 5;
else if (brg.is_f16_b_non_amx_vnni())
used_vregs = 2;
return isa_num_vregs(brg.isa_impl) - used_vregs;
}

Vmm accm(int ld_block, int bd, int ld) {
return Vmm(max_effective_vregs - 1 - (bd * ld_block + ld));
}
Expand Down Expand Up @@ -336,6 +346,9 @@ struct jit_brgemm_kernel_t : public jit_generator {
return Vmm(isa_num_vregs(brg.isa_impl) - 2);
}

Zmm f16_perm_even_vreg_ = Zmm(isa_num_vregs(brg.isa_impl) - 1);
Zmm f16_perm_odd_vreg_ = Zmm(isa_num_vregs(brg.isa_impl) - 2);

Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store,
Xbyak::Opmask ktail_mask) const;
Vmm_lower_t vmm_lower_mask(const Vmm_lower_t vmm_lower_in, bool mask_flag,
Expand Down Expand Up @@ -448,14 +461,13 @@ int jit_brgemm_kernel_t<Wmm>::B_offset(
if (is_amx) {
return brg.typesize_B * (brg.rd_step * ld * brg.ld_block);
} else {
const int data_vnni_granularity = brg.ld_step;
const int rdb0 = rd / data_vnni_granularity;
const int rdb0 = rd / brg.ld_step;
// Note: Offsets for elements within vnni_granularity are expected to be
// handled within gemm_microkernel (for ex: odd-even converts).
// hence no `rd % data_vnni_granularity`
// hence no `rd % brg.ld_step`
return brg.typesize_B
* (rdb0 * data_vnni_granularity * brg.LDB
+ data_vnni_granularity * ld * brg.ld_block);
* (rdb0 * brg.ld_step * brg.LDB
+ brg.ld_step * ld * brg.ld_block);
}
}

Expand Down Expand Up @@ -766,7 +778,7 @@ void jit_brgemm_kernel_t<Wmm>::ldb_regs_shift(int ld_block2, bool is_tail) {
add(reg_aux_D, D_offset);

add(reg_b_offset,
(is_tail) ? ldb_B_offset(1, true) : ldb_B_offset(ld_block2));
(is_tail) ? ldb_B_offset(0, true) : ldb_B_offset(ld_block2));

if (brg.with_bias) {
mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]);
Expand Down Expand Up @@ -2090,8 +2102,7 @@ void jit_brgemm_kernel_t<Wmm>::compute_int8_compensation(int rd_loop, int bd_b,
auto vmm_store = vmm_mask(load(), is_tail, false, ld_tail_mask);
uni_vmovups(vmm_store, addr);
} else {
load_bytes(
load(), addr, brg.typesize_B * brg.ldb_tail * brg.ld_step);
load_bytes(load(), addr, ldb_B_offset(0, true));
}

if (brg.req_cal_comp_pads) {
Expand Down Expand Up @@ -2194,8 +2205,20 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel(int bd_block2, bool is_bdb_tail,
vcvtneeph2ps(vmm_load, addr);
else
vcvtneoph2ps(vmm_load, addr);
} else
vcvtph2psx(vmm_load, addr);
} else {
if (brg.is_f16_b_non_amx_vnni()) {
const auto vnni_addr = ptr[reg_aux_B
+ B_offset(ld, utils::rnd_dn(rd, 2))];
vmovups(vmm_load, vnni_addr);
if (rd % 2 == 0)
vpermw(vmm_load, f16_perm_even_vreg_, vmm_load);
else
vpermw(vmm_load, f16_perm_odd_vreg_, vmm_load);
vcvtph2psx(
vmm_load, Vmm_lower_t(vmm_load.getIdx()));
} else
vcvtph2psx(vmm_load, addr);
}
} else if (brg.dt_b == data_type::bf16
&& brg.isa_impl == avx2_vnni_2) {
if (rd % 2 == 0)
Expand All @@ -2206,8 +2229,7 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel(int bd_block2, bool is_bdb_tail,
if (is_superset(brg.isa_impl, avx512_core)) {
uni_vmovups(vmm_load, addr);
} else {
load_bytes(vmm_load, addr,
brg.typesize_B * brg.ldb_tail * brg.ld_step);
load_bytes(vmm_load, addr, ldb_B_offset(0, true));
}
} else {
uni_vmovups(vmm_load, addr);
Expand Down Expand Up @@ -2239,7 +2261,20 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel(int bd_block2, bool is_bdb_tail,
else
vcvtneoph2ps(vmm_load, addr);
} else {
vcvtph2psx(vmm_load, addr);
if (brg.is_f16_b_non_amx_vnni()) {
const auto actual_B_offset
= B_offset(ld, utils::rnd_dn(rd, 2));
const auto vnni_addr
= ptr[reg_aux_B + actual_B_offset];
vmovups(vmm_load, vnni_addr);
if (rd % 2 == 0)
vpermw(vmm_load, f16_perm_even_vreg_, vmm_load);
else
vpermw(vmm_load, f16_perm_odd_vreg_, vmm_load);
vcvtph2psx(
vmm_load, Vmm_lower_t(vmm_load.getIdx()));
} else
vcvtph2psx(vmm_load, addr);
}
} else if (brg.dt_b == data_type::bf16
&& brg.isa_impl == avx2_vnni_2) {
Expand All @@ -2251,8 +2286,7 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel(int bd_block2, bool is_bdb_tail,
if (is_superset(brg.isa_impl, avx512_core)) {
uni_vmovups(vmm_load, addr);
} else {
load_bytes(vmm_load, addr,
brg.typesize_B * brg.ldb_tail * brg.ld_step);
load_bytes(vmm_load, addr, ldb_B_offset(0, true));
}
} else {
uni_vmovups(vmm_load, addr);
Expand Down Expand Up @@ -2739,6 +2773,13 @@ void jit_brgemm_kernel_t<Wmm>::generate() {
vpbroadcastw(int8_ones_words(), reg_tmp_gpr.cvt16());
}

if (brg.is_f16_b_non_amx_vnni()) {
mov(reg_tmp_gpr, f16_perm_even_table_);
vmovups(f16_perm_even_vreg_, ptr[reg_tmp_gpr]);
mov(reg_tmp_gpr, f16_perm_odd_table_);
vmovups(f16_perm_odd_vreg_, ptr[reg_tmp_gpr]);
}

read_params();

bdb_loop();
Expand Down Expand Up @@ -2771,6 +2812,25 @@ void jit_brgemm_kernel_t<Wmm>::generate() {

if (brg.with_eltwise)
postops_injector_->prepare_table(/* generate = */ true);

if (brg.is_f16_b_non_amx_vnni()) {
// convert interleaved vnni data with holes to packed.
align(64);
L(f16_perm_even_table_);
for (int i = 0; i < 32; ++i) {
if (i < 16)
dw(uint16_t(2 * i));
else
dw(uint16_t(0));
}
align(64);
L(f16_perm_odd_table_);
for (int i = 0; i < 32; ++i)
if (i < 16)
dw(uint16_t(2 * i + 1));
else
dw(uint16_t(0));
}
}

brgemm_attr_t::brgemm_attr_t()
Expand Down
1 change: 1 addition & 0 deletions src/cpu/x64/rnn/rnn_brgemm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ status_t init_brgemm_kernel(x64::brgemm_desc_t *desc, x64::cpu_isa_t isa,
brgattr.max_bs = max_bs;
brgattr.max_top_vpad = 0;
brgattr.max_bottom_vpad = 0;
brgattr.b_is_vnni = true;
CHECK(brgemm_desc_set_attr(desc, brgattr));

x64::brgemm_kernel_t *_t_ptr;
Expand Down

0 comments on commit 0aa315e

Please sign in to comment.