Skip to content

Commit

Permalink
x64: conv: move post_ops_ok after tags init in bf16 gemm conv
Browse files Browse the repository at this point in the history
  • Loading branch information
tczeszun authored and vpirogov committed Aug 1, 2024
1 parent d53e3ce commit d6c216a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 20 deletions.
20 changes: 19 additions & 1 deletion src/cpu/gemm_convolution_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,8 @@ void col2im(const conv_gemm_conf_t &jcp, const float *col, float *im,
status_t init_conf(conv_gemm_conf_t &jcp,
memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd,
memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md,
memory_desc_t &bias_md, primitive_attr_t &attr, int max_threads) {
memory_desc_t &bias_md, primitive_attr_t &attr, int max_threads,
bool check_postops) {
const memory_desc_wrapper src_d(&src_md);
const memory_desc_wrapper weights_d(&weights_md);
const memory_desc_wrapper dst_d(&dst_md);
Expand Down Expand Up @@ -1154,6 +1155,23 @@ status_t init_conf(conv_gemm_conf_t &jcp,

CHECK(attr.set_default_formats(&dst_md));

#if DNNL_X64
// for x64 we need to check post-ops after tags init
if (check_postops) {
using namespace x64::injector;
static constexpr bool sum_at_pos_0_only = true;
static constexpr bool sum_requires_scale_one = true;
static constexpr bool sum_requires_zp_zero = true;

VDISPATCH_CONV_IC(
post_ops_ok(post_ops_ok_args_t(x64::avx512_core,
{binary, eltwise, sum}, attr.post_ops_, &dst_d,
sum_at_pos_0_only, sum_requires_scale_one,
sum_requires_zp_zero)),
VERBOSE_UNSUPPORTED_POSTOP);
}
#endif

jcp.post_ops = attr.post_ops_;

const int eltwise_ind = jcp.post_ops.find(primitive_kind::eltwise);
Expand Down
5 changes: 3 additions & 2 deletions src/cpu/gemm_convolution_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2016-2022 Intel Corporation
* Copyright 2016-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -117,7 +117,8 @@ void col2im(const conv_gemm_conf_t &jcp, const float *col, float *im,
status_t init_conf(conv_gemm_conf_t &jcp,
memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd,
memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md,
memory_desc_t &bias_md, primitive_attr_t &attr, int max_threads);
memory_desc_t &bias_md, primitive_attr_t &attr, int max_threads,
bool check_postops = false);

void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g,
int &nthr_g, int &ithr_mb, int &nthr_mb);
Expand Down
19 changes: 2 additions & 17 deletions src/cpu/x64/gemm_bf16_convolution.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2023 Intel Corporation
* Copyright 2019-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -70,25 +70,10 @@ struct gemm_bf16_convolution_fwd_t : public primitive_t {
dst_data_type),
VERBOSE_UNSUPPORTED_ATTR);

{
using namespace x64::injector;
static constexpr bool sum_at_pos_0_only = true;
static constexpr bool sum_requires_scale_one = true;
static constexpr bool sum_requires_zp_zero = true;
const auto dst_md = memory_desc_wrapper(dst_md_);

VDISPATCH_CONV(
post_ops_ok(post_ops_ok_args_t(avx512_core,
{binary, eltwise, sum}, attr()->post_ops_,
&dst_md, sum_at_pos_0_only,
sum_requires_scale_one, sum_requires_zp_zero)),
VERBOSE_UNSUPPORTED_POSTOP);
}

auto scratchpad = scratchpad_registry().registrar();
return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
*desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_,
dnnl_get_max_threads());
dnnl_get_max_threads(), true /* check_postops */);
}

bool is_postprocess_required() const {
Expand Down

0 comments on commit d6c216a

Please sign in to comment.