Skip to content

Commit

Permalink
cpu: x64: enable bf16 matmul with weights decompression on avx512_bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
xuxinzen committed Mar 1, 2024
1 parent bf8b4b4 commit 4d6bd3c
Show file tree
Hide file tree
Showing 4 changed files with 313 additions and 46 deletions.
21 changes: 18 additions & 3 deletions src/cpu/x64/matmul/brgemm_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
= everyone_is(bf16, src_dt, wei_dt) && one_of(dst_dt, bf16, f32);
const bool is_f16
= everyone_is(f16, src_dt, wei_dt) && one_of(dst_dt, f16, f32);
const bool is_bf16_with_int_wei = src_dt == bf16 && one_of(wei_dt, s8, u8)
&& one_of(dst_dt, bf16, f32);

auto check_bias = [&]() -> bool {
const auto bia_dt = weights_md(1)->data_type;
Expand Down Expand Up @@ -84,7 +86,8 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {

auto check_attr_zero_points
= [&]() -> bool { return attr()->zero_points_.common(); };
const bool problem_dt_correct = is_int8 || is_bf16 || is_f32 || is_f16;
const bool problem_dt_correct = one_of(
true, is_int8, is_bf16, is_f32, is_f16, is_bf16_with_int_wei);

auto src_d = memory_desc_wrapper(src_md_);
auto weights_d = memory_desc_wrapper(weights_md_);
Expand Down Expand Up @@ -1693,7 +1696,13 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
int get_M_tail_block_idx(int m_block_idx) const {
const int tail_idx = m_block_idx - M_tail_block_start_;
if (!bgmmc_.is_runtime_M) return tail_idx;
return tail_idx < (int)m_tail_processing_.size() ? tail_idx : -1;
const bool is_index_within_range
= tail_idx < (int)m_tail_processing_.size();
if (!is_index_within_range) {
assert(!"Error in M_tail_block index, not within range.");
return 0;
}
return tail_idx;
}
bool is_M_tail_processing(int m_block_idx) const {
return get_M_tail_block_idx(m_block_idx) >= 0;
Expand All @@ -1713,7 +1722,13 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
int get_N_tail_block_idx(int n_block_idx) const {
const int tail_idx = n_block_idx - N_tail_block_start_;
if (!bgmmc_.is_runtime_N) return tail_idx;
return tail_idx < (int)n_tail_processing_.size() ? tail_idx : -1;
const bool is_index_within_range
= tail_idx < (int)n_tail_processing_.size();
if (!is_index_within_range) {
assert(!"Error in N_tail_block index, not within range.");
return 0;
}
return tail_idx;
}
bool is_N_tail_processing(int n_block_idx) const {
return get_N_tail_block_idx(n_block_idx) >= 0;
Expand Down
Loading

0 comments on commit 4d6bd3c

Please sign in to comment.