From a97f021e890065858f28dbacb409fecd6dde5f4a Mon Sep 17 00:00:00 2001 From: "Xuxin, Zeng" Date: Mon, 20 May 2024 23:27:46 -0700 Subject: [PATCH] cpu: x64: fix assertion in bf16 conv for relo on AMX --- src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp | 3 ++- src/cpu/x64/jit_brgemm_conv.cpp | 1 + src/cpu/x64/jit_primitive_conf.hpp | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp index 5cf314ad117..f19de8e57ca 100644 --- a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp @@ -696,7 +696,8 @@ void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row(int icb) { void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row_reduced_lowering() { assert(jcp.nb_ic_int == 1); - assert(jcp.ic_block_int * jcp.typesize_in == 64); + assert((jcp.is_bf32 ? jcp.ic_block : jcp.ic_block_int) * jcp.typesize_in + == 64); assert(jcp.is_nspc); auto load_mask = [this](int tail, Opmask kmask) { diff --git a/src/cpu/x64/jit_brgemm_conv.cpp b/src/cpu/x64/jit_brgemm_conv.cpp index f6d8598e752..b8b11bf096e 100644 --- a/src/cpu/x64/jit_brgemm_conv.cpp +++ b/src/cpu/x64/jit_brgemm_conv.cpp @@ -977,6 +977,7 @@ status_t brgemm_convolution_fwd_t::init(engine_t *engine) { ajcp.is_relo = true; ajcp.nb_ic_int = 1; ajcp.is_nspc = true; + ajcp.is_bf32 = jcp.is_bf32; ajcp.typesize_in = jcp.src_dsz; ajcp.ic_block_int = jcp.amx_w; diff --git a/src/cpu/x64/jit_primitive_conf.hpp b/src/cpu/x64/jit_primitive_conf.hpp index 0361cf9d69b..c3beb63cfed 100644 --- a/src/cpu/x64/jit_primitive_conf.hpp +++ b/src/cpu/x64/jit_primitive_conf.hpp @@ -166,6 +166,7 @@ struct jit_conv_conf_t { data_type_t ddst_dt; data_type_t dsrc_dt; data_type_t dwei_dt; + bool is_bf32 {false}; bool expl_bcast; bool large_spatial, large_w_filter; int is_ic_scale, is_oc_scale;