Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add declarations of explicit specializations and instantiations. #430

Merged
merged 1 commit into from
Mar 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/cpu/cpu_reducer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ struct cpu_reducer_t {
void reduce_nolock(int ithr, data_t *dst);
};

// Explicit instantiations in cpu_reducer.cpp.
extern template struct cpu_reducer_t<data_type::f32>;
extern template struct cpu_reducer_t<data_type::s32>;

template <impl::data_type_t data_type>
struct cpu_reducer_2d_t {
typedef typename prec_traits<data_type>::type data_t;
Expand Down Expand Up @@ -254,6 +258,10 @@ struct cpu_reducer_2d_t {
void reduce_nolock(int ithr, data_t *dst);
};

// Explicit instantiations in cpu_reducer.cpp.
extern template struct cpu_reducer_2d_t<data_type::f32>;
extern template struct cpu_reducer_2d_t<data_type::s32>;

/** simple 1d accumulator: y[:] += x[:] */
template <impl::data_type_t data_type>
struct cpu_accumulator_1d_t {
Expand All @@ -266,6 +274,10 @@ struct cpu_accumulator_1d_t {
reducer_2d_driver_t<data_type> *drv_;
};

// Explicit instantiations in cpu_reducer.cpp.
extern template struct cpu_accumulator_1d_t<data_type::f32>;
extern template struct cpu_accumulator_1d_t<data_type::s32>;

}
}
}
Expand Down
12 changes: 12 additions & 0 deletions src/cpu/gemm/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ void ref_gemm(const char *transa, const char *transb, const int *M,
const int *N, const int *K, const data_t *alpha, const data_t *A,
const int *lda, const data_t *B, const int *ldb, const data_t *beta,
data_t *C, const int *ldc, const data_t *bias);

// Explicit instantiations in ref_gemm.cpp.
extern template
void ref_gemm<float>(const char *transa_, const char *transb_,
const int *M_, const int *N_, const int *K_, const float *alpha_,
const float *A, const int *lda_, const float *B, const int *ldb_,
const float *beta_, float *C, const int *ldc_, const float *bias);
extern template
void ref_gemm<double>(const char *transa_, const char *transb_,
const int *M_, const int *N_, const int *K_, const double *alpha_,
const double *A, const int *lda_, const double *B, const int *ldb_,
const double *beta_, double *C, const int *ldc_, const double *bias);
#ifdef USE_CBLAS
#define GEMM_IMPL_STR "gemm:blas"
#else
Expand Down
8 changes: 8 additions & 0 deletions src/cpu/gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ template <typename data_type>
void sum_two_matrices(
int m, int n, data_type *p_src, int ld_src, data_type *p_dst, int ld_dst);

// Explicit instantiations are provided in gemm_utils.cpp.
extern template
void sum_two_matrices<float>(
int m, int n, float *p_src, int ld_src, float *p_dst, int ld_dst);
extern template
void sum_two_matrices<double>(
int m, int n, double *p_src, int ld_src, double *p_dst, int ld_dst);

void calc_nthr_nocopy_avx512_common(int m,
int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k,
int *BM, int *BN, int *BK);
Expand Down
14 changes: 14 additions & 0 deletions src/cpu/jit_avx512_common_convolution_winograd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,20 @@ struct _jit_avx512_common_convolution_winograd_t {
const primitive_attr_t *attr_;
};

// Explicit instantiations in jit_avx512_common_convolution_winograd.cpp.
extern template void
_jit_avx512_common_convolution_winograd_t<true>::_execute_data_W_S_G_D(
float *, float *, float *, float *);
extern template void
_jit_avx512_common_convolution_winograd_t<false>::_execute_data_W_S_G_D(
float *, float *, float *, float *);
extern template void
_jit_avx512_common_convolution_winograd_t<true>::_execute_data_W_SGD(
float *, float *, float *, float *);
extern template void
_jit_avx512_common_convolution_winograd_t<false>::_execute_data_W_SGD(
float *, float *, float *, float *);

template <bool with_relu>
struct _jit_avx512_common_convolution_winograd_fwd_t
: _jit_avx512_common_convolution_winograd_t<true>
Expand Down
6 changes: 6 additions & 0 deletions src/cpu/jit_uni_dw_conv_kernel_f32.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,12 @@ struct jit_uni_dw_conv_bwd_weights_kernel_f32 : public jit_generator {

void generate();
};

// Explicit instantiations in jit_uni_dw_conv_kernel_f32.cpp.
extern template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx512_common>;
extern template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx2>;
extern template struct jit_uni_dw_conv_bwd_weights_kernel_f32<sse42>;

}
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/cpu/jit_uni_eltwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ namespace mkldnn {
namespace impl {
namespace cpu {

// Explicit instantiations are in jit_uni_eltwise.cpp.
extern template struct jit_uni_eltwise_injector_f32<avx512_common>;
extern template struct jit_uni_eltwise_injector_f32<avx2>;
extern template struct jit_uni_eltwise_injector_f32<sse42>;

struct jit_uni_eltwise_kernel_f32;

template <cpu_isa_t isa>
Expand Down
54 changes: 54 additions & 0 deletions src/cpu/jit_uni_lrn_kernel_f32.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,60 @@ struct jit_uni_lrn_bwd_kernel_f32 : public jit_generator {
void(*ker)(jit_args_bwd_t *);
};

// Explicit specializations and instantiations in jit_uni_lrn_kernel_f32.cpp.
template <>
jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32(
const struct nchw8c_across &J, float A, float K, prop_kind_t pk,
void *code_ptr, size_t code_size);
template <>
jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32(
const struct nchw8c_across &J, float A, float K, prop_kind_t pk,
void *code_ptr, size_t code_size);
template <>
jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32(
const struct nhwc_across &J, float A, float K, prop_kind_t pk,
void *code_ptr, size_t code_size);
template <>
jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32(
const struct nhwc_across &J, float A, float K, prop_kind_t pk,
void *code_ptr, size_t code_size);
template <>
void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_body(
int tail, int HW, prop_kind_t pk, Xbyak::Ymm ymask, Xbyak::Ymm ya,
Xbyak::Ymm yb, Xbyak::Ymm yc, Xbyak::Ymm yd, Xbyak::Ymm ye,
Xbyak::Ymm ysum);
template <>
void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_body(
int tail, int HW, prop_kind_t pk, Xbyak::Ymm ymask, Xbyak::Ymm ya,
Xbyak::Ymm yb, Xbyak::Ymm yc, Xbyak::Ymm yd, Xbyak::Ymm ye,
Xbyak::Ymm ysum);
template <>
void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_tail_sse42(
int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi);
template <>
void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_tail_sse42(
int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi);
template <>
void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_body_sse42(
int tail, int HW, prop_kind_t pk, Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi,
Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi);
template <>
void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_body_sse42(
int tail, int HW, prop_kind_t pk, Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi,
Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi);
template <>
jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32(
struct nchw_across J, float A, float K, prop_kind_t pk, void *code_ptr,
size_t code_size);
template <>
jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32(
struct nchw_across J, float A, float K, prop_kind_t pk, void *code_ptr,
size_t code_size);

extern template struct jit_uni_lrn_fwd_kernel_f32<sse42>;
extern template struct jit_uni_lrn_fwd_kernel_f32<avx2>;
extern template struct jit_uni_lrn_bwd_kernel_f32<avx2>;

}
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/cpu/jit_uni_pool_kernel_f32.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ struct jit_uni_pool_kernel_f32: public jit_generator {
}
};

// Explicit instantiations in jit_uni_pool_kernel_f32.cpp.
extern template struct jit_uni_pool_kernel_f32<sse42>;
extern template struct jit_uni_pool_kernel_f32<avx>;
extern template struct jit_uni_pool_kernel_f32<avx512_common>;

}
}
}
Expand Down