diff --git a/src/common/memory_storage.hpp b/src/common/memory_storage.hpp index 822cce0391f..8b4bb39f8ea 100644 --- a/src/common/memory_storage.hpp +++ b/src/common/memory_storage.hpp @@ -75,6 +75,13 @@ struct memory_storage_t : public c_compatible { /** returns shallow copy */ virtual std::unique_ptr clone() const = 0; + /** returns shallow copy with a offset for accessor pointer for buffers + * to prevent use of sub-buffers where possible*/ + virtual std::unique_ptr clone_ptr_off(size_t offset) const { + assert(!"not expected"); + return nullptr; + } + /** returns true if the pointer associated with the storage is NULL */ bool is_null() const { void *ptr; diff --git a/src/gpu/generic/sycl/README.md b/src/gpu/generic/sycl/README.md index a8bd3238b9a..6cb63a42fb5 100644 --- a/src/gpu/generic/sycl/README.md +++ b/src/gpu/generic/sycl/README.md @@ -179,3 +179,11 @@ The implementation supports both forward and backward propagations. * Supported formats: plain formats with up to 6 dimensions * Supported data types: `f32`, `bf16`, `f16`, `s8`, `u8` + +## RNN + +The implementation supports forward propagation and vanilla RNN cell kind. + +* Supported formats: `ldigo`, `ldgoi` +* Supported data types: `f32`, `bf16`, `f16`, `s8`, `u8` +* Supported direction: `left2right` diff --git a/src/gpu/generic/sycl/rnn/cell_common.cpp b/src/gpu/generic/sycl/rnn/cell_common.cpp new file mode 100644 index 00000000000..65ad0b85b10 --- /dev/null +++ b/src/gpu/generic/sycl/rnn/cell_common.cpp @@ -0,0 +1,64 @@ +/******************************************************************************* +* Copyright 2019-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +// Common for RNN and LSTM cell execution + +#include "gpu/generic/sycl/rnn/ref_rnn.hpp" + +namespace dnnl { +namespace impl { +namespace gpu { +namespace generic { +namespace sycl { + +using namespace dnnl::impl::utils; +using namespace rnn_utils; + +status_t _ref_rnn_common_t::cell_execution(const cell_ctx_t &cell_struct) { + + auto cell_layer = cell_struct.workspace.states_range(cell_struct.lay - 1, + cell_struct.lay - 1, cell_struct.dir, cell_struct.dir, + cell_struct.iter - 1, cell_struct.iter); + + auto cell_iter = cell_struct.workspace.states_range(cell_struct.lay, + cell_struct.lay, cell_struct.dir, cell_struct.dir, + cell_struct.iter - 2, cell_struct.iter - 1); + + auto scratch_gates = cell_struct.scratch.gates(0); + + auto wei_layer + = cell_struct.user_data.wei_layer(cell_struct.lay, cell_struct.dir); + auto wei_iter + = cell_struct.user_data.wei_iter(cell_struct.lay, cell_struct.dir); + + CHECK(gemm_primitive(cell_struct.engine, cell_struct.ctx, wei_layer, + cell_layer, scratch_gates, gemm_layer_fwd)); + + CHECK(gemm_primitive(cell_struct.engine, cell_struct.ctx, wei_iter, + cell_iter, scratch_gates, gemm_iter_fwd)); + + CHECK(rnn_bias(cell_struct.ctx, cell_struct.rnn.mb, cell_struct.rnn.dhc, + cell_struct.iter, cell_struct.lay, cell_struct.dir, + cell_struct.workspace, cell_struct.scratch, cell_struct.user_data)); + + return status::success; +} + +} // namespace sycl +} // namespace generic +} // namespace gpu +} // namespace impl +} // namespace dnnl diff --git a/src/gpu/generic/sycl/rnn/ref_rnn.cpp b/src/gpu/generic/sycl/rnn/ref_rnn.cpp new file mode 100644 index 00000000000..105178adcd1 --- /dev/null +++ b/src/gpu/generic/sycl/rnn/ref_rnn.cpp @@ -0,0 +1,702 @@ +/******************************************************************************* +* Copyright 2019-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +// General architecture +// +// for diff states, we have n_states + 1 as we have n_states diff +// to propagate to the previous iteration and 1 states to propagate +// to the previous layer +// index 0 is dh for cell(t-1, l) to consume +// index 1 is dc for cell(t-1, l) to consume +// index 2 is dh for cell(t, l-1) to consume +// this indexing enables to have the same indexing for states in elemwise +// function +// only the cell execution function should be impacted + +#include "gpu/generic/sycl/rnn/ref_rnn.hpp" +#include "common/primitive.hpp" +#include "common/primitive_desc.hpp" + +#include "common/matmul_pd.hpp" +#include "common/stream.hpp" +#include "common/type_helpers.hpp" +#include "gpu/generic/sycl/rnn/rnn_kernels.hpp" + +#include + +#define DPRINT(fmt, ...) \ + do { \ + if (get_verbose_dev_mode(verbose_t::debuginfo) >= 2) { \ + printf(fmt, __VA_ARGS__); \ + fflush(nullptr); \ + } \ + } while (0) + +namespace dnnl { +namespace impl { +namespace gpu { +namespace generic { +namespace sycl { + +using namespace dnnl::impl::utils; +using namespace dnnl::impl::math; +using namespace prop_kind; +using namespace alg_kind; +using namespace rnn_utils; +using namespace dnnl::impl::memory_tracking::names; + +status_t _ref_rnn_common_t::pd_t::set_default_params() { + using namespace format_tag; + if (src_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_layer_md_, tnc)); + if (dst_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc)); + + // Optional parameters + if ((!types::is_zero_md(&src_iter_md_)) + && (src_iter_md_.format_kind == format_kind::any)) + CHECK(memory_desc_init_by_tag(src_iter_md_, ldnc)); + if ((!types::is_zero_md(&bias_md_)) + && (bias_md_.format_kind == format_kind::any)) + CHECK(memory_desc_init_by_tag(bias_md_, ldgo)); + if ((!types::is_zero_md(&dst_iter_md_)) + && (dst_iter_md_.format_kind == format_kind::any)) + CHECK(memory_desc_init_by_tag(dst_iter_md_, ldnc)); + + return status::success; +} + +status_t _ref_rnn_common_t::pd_t::init(impl::engine_t *engine) { + using namespace prop_kind; + using namespace utils; + using namespace rnn_utils; + using namespace format_tag; + + assert(engine->kind() == engine_kind::gpu); + + const alg_kind_t cell_kind = this->desc()->cell_kind; + + data_type_t src_layer_dt = this->desc()->src_layer_desc.data_type; + data_type_t weights_iter_dt = this->desc()->weights_iter_desc.data_type; + data_type_t weights_layer_dt = this->desc()->weights_layer_desc.data_type; + data_type_t bias_dt = this->desc()->bias_desc.data_type; + + acc_data_t = data_type::f32; + + src_type = src_layer_dt; + weights_type = weights_layer_dt; + + VDISPATCH_RNN( + one_of(cell_kind, alg_kind::vanilla_rnn), VERBOSE_BAD_ALGORITHM); + VDISPATCH_RNN(weights_iter_dt == weights_layer_dt, VERBOSE_UNSUPPORTED_DT); + VDISPATCH_RNN_SC(this->set_default_params(), VERBOSE_UNSUPPORTED_TAG); + VDISPATCH_RNN(this->with_bias(), VERBOSE_UNSUPPORTED_BIAS_CFG); + VDISPATCH_RNN(this->desc()->prop_kind == forward_inference + || this->desc()->prop_kind == forward_training, + VERBOSE_UNSUPPORTED_DT_CFG); + + init_rnn_conf(rnn_conf, this, acc_data_t); + + // Check that only supported attr have been passed. + primitive_attr_t::skip_mask_t attr_mask + = primitive_attr_t::skip_mask_t::rnn_tparams; + if (weights_layer_dt == data_type::s8) { + attr_mask = attr_mask | primitive_attr_t::skip_mask_t::rnn_data_qparams + | primitive_attr_t::skip_mask_t::rnn_weights_qparams + | primitive_attr_t::skip_mask_t::fpmath_mode; + } + VDISPATCH_RNN(this->attr()->has_default_values(attr_mask), + VERBOSE_UNSUPPORTED_ATTR); + + // Set weights descriptors to desired format + VDISPATCH_RNN_SC(set_weights_desc(this->weights_layer_md_, rnn_conf), + "unsupported weights layer memory descriptor"); + VDISPATCH_RNN_SC(set_weights_desc(this->weights_iter_md_, rnn_conf), + "unsupported weights iter memory descriptor"); + + // Currently only run L2R + VDISPATCH_RNN(this->direction() == dnnl_unidirectional_left2right, + VERBOSE_BAD_ALGORITHM); + // Check dimensions consistency + VDISPATCH_RNN((this->SIC() == this->DHC() || (this->T() == 1)), + VERBOSE_INCONSISTENT_DIM, "SIC", (int)this->SIC(), "DHC", + (int)this->DHC()); + + set_rnn_conf(rnn_conf, *this->desc()); + + dim_t workspace_size = get_workspace_size(rnn_conf); + + // initialize the workspace_pd if needed + if (rnn_conf.use_workspace) { + dims_t ws_dims = {workspace_size}; + VDISPATCH_RNN_SC(memory_desc_init_by_tag( + this->ws_md_, 1, ws_dims, data_type::u8, x), + "memory_desc_init_by_tag()"); + } + + memory_desc_t state_md; + dims_t state_dims = {rnn_conf.n_layer, rnn_conf.n_dir, rnn_conf.n_iter + 1, + rnn_conf.mb, rnn_conf.states_ws_ld}; + + CHECK(memory_desc_init_by_tag(state_md, 5, state_dims, + rnn_conf.src_data_type, format_tag::abcde)); + + copy_init_layer_conf_ = sycl_rnn_copy_conf_t { + xpu::sycl::md_t(this->src_md(0)), xpu::sycl::md_t(&state_md), + rnn_conf.slc, rnn_conf.n_dir, rnn_conf.n_layer, rnn_conf.n_iter, + rnn_conf.mb, rnn_conf.states_ws_ld, true, true}; + + xpu::sycl::md_t src_iter_md = this->src_md(1)->data_type == data_type::undef + ? copy_init_iter_conf_.src_md = xpu::sycl::md_t() + : copy_init_iter_conf_.src_md = xpu::sycl::md_t(this->src_md(1)); + + copy_init_iter_conf_ = sycl_rnn_copy_conf_t {src_iter_md, + xpu::sycl::md_t(&state_md), rnn_conf.sic, rnn_conf.n_dir, + rnn_conf.n_layer, rnn_conf.n_iter, rnn_conf.mb, + rnn_conf.states_ws_ld, false, true}; + + copy_res_layer_conf_ = sycl_rnn_copy_conf_t {xpu::sycl::md_t(&state_md), + xpu::sycl::md_t(this->dst_md(0)), rnn_conf.dhc, rnn_conf.n_dir, + rnn_conf.n_layer, rnn_conf.n_iter, rnn_conf.mb, + rnn_conf.states_ws_ld, true, false}; + + xpu::sycl::md_t dst_iter_md = this->dst_md(1)->data_type == data_type::undef + ? copy_init_iter_conf_.src_md = xpu::sycl::md_t() + : copy_init_iter_conf_.src_md = xpu::sycl::md_t(this->dst_md(1)); + + copy_res_iter_conf_ = sycl_rnn_copy_conf_t {xpu::sycl::md_t(&state_md), + dst_iter_md, rnn_conf.dhc, rnn_conf.n_dir, rnn_conf.n_layer, + rnn_conf.n_iter, rnn_conf.mb, rnn_conf.states_ws_ld, false, false}; + + sycl_rnn_bias_conf_t_ = sycl_rnn_bias_conf_t(); + sycl_rnn_bias_conf_t_.dst_md = xpu::sycl::md_t(this->dst_md(0)); + sycl_rnn_bias_conf_t_.bias_type = bias_dt; + sycl_rnn_bias_conf_t_.batch = rnn_conf.mb; + sycl_rnn_bias_conf_t_.dhc = rnn_conf.dhc; + sycl_rnn_bias_conf_t_.gates_ws_ld = rnn_conf.gates_ws_ld; + sycl_rnn_bias_conf_t_.states_ws_ld = rnn_conf.states_ws_ld; + sycl_rnn_bias_conf_t_.activation_kind = this->activation_kind(); + sycl_rnn_bias_conf_t_.alpha = this->desc()->alpha; + + auto fpmath_mode = this->attr()->fpmath_.mode_; + + // The inputs of create_gemm_pd describe a gemm in column major. + // Below, we have to transpose the a and b descriptor to describe + // the GEMM as a row major problem. + auto create_gemm_pd = + [&](std::shared_ptr &gemm_pd, dim_t m, dim_t n, + dim_t k, strides_t<2> a_strides, strides_t<2> b_strides, + strides_t<2> c_strides, data_type_t a_dt, data_type_t b_dt, + data_type_t c_dt, float beta) -> status_t { + memory_desc_t a_md, b_md, c_md, bias_md; + + dims_t a_dims = {n, k}, b_dims = {k, m}, c_dims = {n, m}; + + dims_t b_strides_md = {b_strides[0], b_strides[1]}; + CHECK(memory_desc_init_by_strides( + b_md, 2, b_dims, rnn_conf.wei_layer_type, b_strides_md)); + dims_t a_strides_md = {a_strides[0], a_strides[1]}; + CHECK(memory_desc_init_by_strides( + a_md, 2, a_dims, rnn_conf.src_data_type, a_strides_md)); + dims_t c_strides_md = {c_strides[0], c_strides[1]}; + CHECK(memory_desc_init_by_strides( + c_md, 2, c_dims, rnn_conf.dst_data_type, c_strides_md)); + + primitive_attr_t attr; + CHECK(attr.post_ops_.append_sum(beta)); + CHECK(attr.set_fpmath_mode(fpmath_mode)); + attr.deterministic_ = this->attr()->deterministic_; + + matmul_desc_t matmul_desc; + dnnl::impl::matmul_desc_init( + &matmul_desc, &a_md, &b_md, &bias_md, &c_md); + + primitive_desc_iterator_t it(engine, + reinterpret_cast(&matmul_desc), &attr, nullptr); + + while (++it != it.end()) { + if (*it) { + gemm_pd = *it; + return status::success; + break; + } + } + return status::unimplemented; + }; + + float gemm_iter_fwd_beta = this->is_lbr() ? 0.0f : 1.0f; + + // Setup gemm PDs + + dim_t batch = rnn_conf.mb; + dim_t n_gates = rnn_conf.n_gates; + dim_t slc = rnn_conf.slc; + dim_t sic = rnn_conf.sic; + dim_t dhc = rnn_conf.dhc; + + strides_t<5> wei_layer_strides = get_outer_strides(this->weights_md(0)); + strides_t<5> wei_iter_strides = get_outer_strides(this->weights_md(1)); + + VDISPATCH_RNN_SC(create_gemm_pd(gemm_layer_fwd_pd_, n_gates * dhc, batch, + slc, {rnn_conf.states_ws_ld, 1}, + {wei_layer_strides[2], wei_layer_strides[4]}, + {rnn_conf.scratch_gates_ld, 1}, weights_type, + src_type, rnn_conf.acc_data_type, 0.0), + "create_gemm_pd(gemm_layer_fwd_pd_)"); + + VDISPATCH_RNN_SC(create_gemm_pd(gemm_iter_fwd_pd_, n_gates * dhc, batch, + sic, {rnn_conf.states_ws_ld, 1}, + {wei_iter_strides[2], wei_iter_strides[4]}, + {rnn_conf.gates_ws_ld, 1}, weights_type, src_type, + rnn_conf.acc_data_type, gemm_iter_fwd_beta), + "create_gemm_pd(gemm_iter_fwd_pd_)"); + + init_scratchpad(rnn_conf.use_workspace ? 0 : workspace_size); + return status::success; +} + +status_t _ref_rnn_common_t::init(impl::engine_t *engine) { + using namespace rnn_utils; + + switch (pd()->cell_kind()) { + case dnnl_vanilla_rnn: + cell_func = [this](const cell_ctx_t &cell_struct) -> status_t { + return this->cell_execution(cell_struct); + }; + break; + default: break; + } + grid_func = [this](const grid_ctx_t &grid_struct) -> status_t { + return this->linear_execution(grid_struct); + }; + + const conf_t &rnn = pd()->rnn_conf; + rnn_utils::set_workspace_offsets(rnn, ws_gates_offset_, ws_states_offset_); + + // IMPORTANT SYCL STUFF + const auto copy_kid = ::sycl::get_kernel_id(); + this->create_kernel(engine, copy_kid, ©_kernel_); + const auto bias_kid = ::sycl::get_kernel_id(); + this->create_kernel(engine, bias_kid, &bias_kernel_); + + bool gemm_ok = true; + auto create_nested_gemm = + [&](const std::shared_ptr &prim_desc, + std::shared_ptr &prim) { + std::pair, cache_state_t> + pair; + bool gemm_ok = prim_desc->create_primitive_nested(pair, engine) + == status::success; + prim = pair.first; + return gemm_ok; + }; + + gemm_ok = gemm_ok + && create_nested_gemm(pd()->gemm_layer_fwd_pd_, gemm_layer_fwd_); + gemm_ok = gemm_ok + && create_nested_gemm(pd()->gemm_iter_fwd_pd_, gemm_iter_fwd_); + + if (!gemm_ok) return status::runtime_error; + + return status::success; +} // namespace sycl + +status_t _ref_rnn_common_t::gemm_primitive(impl::engine_t *engine, + const exec_ctx_t &ctx, std::unique_ptr &a, + std::unique_ptr &b, + std::unique_ptr &c, gemm_kind_t gemm_kind) const { + std::unique_ptr arg1, arg2, arg3; + exec_args_t gemm_args; + std::shared_ptr gemm_pd; + + switch (gemm_kind) { + case gemm_iter_fwd: gemm_pd = pd()->gemm_iter_fwd_pd_; break; + case gemm_layer_fwd: gemm_pd = pd()->gemm_layer_fwd_pd_; break; + } + + CHECK(safe_ptr_assign(arg2, + new memory_t( + ctx.stream()->engine(), gemm_pd->src_md(0), a->clone()))); + CHECK(safe_ptr_assign(arg1, + new memory_t(ctx.stream()->engine(), gemm_pd->weights_md(0), + b->clone()))); + CHECK(safe_ptr_assign(arg3, + new memory_t( + ctx.stream()->engine(), gemm_pd->dst_md(0), c->clone()))); + + gemm_args[DNNL_ARG_SRC] = memory_arg_t {arg1.get(), true}; + gemm_args[DNNL_ARG_WEIGHTS] = memory_arg_t {arg2.get(), true}; + gemm_args[DNNL_ARG_DST] = memory_arg_t {arg3.get(), false}; + + exec_ctx_t gemm_ctx(ctx, std::move(gemm_args)); + + std::unique_ptr ns; + const auto init_gemm_nested_scratchpad + = [&](const std::shared_ptr &gemm, int key) { + ns = utils::make_unique(ctx, key, gemm); + gemm_ctx.set_scratchpad_grantor(ns->grantor()); + }; + + switch (gemm_kind) { + case gemm_iter_fwd: + init_gemm_nested_scratchpad( + gemm_iter_fwd_, rnn_utils::scratch_t::key_gemm_iter_fwd); + CHECK(gemm_iter_fwd_->execute(gemm_ctx)); + break; + case gemm_layer_fwd: + init_gemm_nested_scratchpad( + gemm_layer_fwd_, rnn_utils::scratch_t::key_gemm_layer_fwd); + CHECK(gemm_layer_fwd_->execute(gemm_ctx)); + break; + + default: assert(!"unknown gemm_kind"); return status::runtime_error; + } + + return status::success; +} + +//*************** Grid computations strategy: linear ***************// +status_t _ref_rnn_common_t::linear_execution(const grid_ctx_t &grid_struct) { + + dim_t n_layer = grid_struct.rnn.n_layer; + dim_t n_dir = grid_struct.rnn.n_dir; + dim_t n_iter = grid_struct.rnn.n_iter; + + for (dim_t dir = 0; dir < n_dir; dir++) { + for (dim_t j = 0; j < n_layer; j++) { + dim_t lay = j; + for (dim_t i = 0; i < n_iter; i += grid_struct.rnn.iter_loop) { + dim_t iter = i; + const cell_ctx_t c_struct + = {grid_struct.engine, grid_struct.ctx, dir, lay, iter, + grid_struct.user_data, grid_struct.workspace, + grid_struct.scratch, grid_struct.rnn}; + CHECK(cell_func(c_struct)); + } + } + } + return status::success; +} +//********* GRID computations strategy: utility functions **********// + +status_t _ref_rnn_common_t::copy_init_layer(const exec_ctx_t &ctx, dim_t batch, + dim_t dhc, dim_t slc, dim_t n_iter, dim_t n_layer, dim_t n_dir, + dim_t n_states, dim_t states_ws_ld, const rnn_utils::workspace_t &ws, + const memory_storage_t &input) const { + + auto max_wg_size_per_dim = calc_local_range(ctx); + + parallel_for(ctx, copy_kernel_, [&](::sycl::handler &cgh) { + auto src_mem_arg + = utils::downcast( + &input) + ->get_in_memory_arg(ctx.stream(), cgh); + auto dst_mem_arg + = utils::downcast( + &ws.states()) + ->get_out_memory_arg(ctx.stream(), cgh); + + ref_rnn_copy_t copy_kernel( + pd()->copy_init_layer_conf_, src_mem_arg, dst_mem_arg); + size_t local_batch = max_wg_size_per_dim; + size_t local_iter = max_wg_size_per_dim; + size_t local_channel = max_wg_size_per_dim; + size_t global_batch = calc_global_range( + static_cast(local_batch), static_cast(batch)); + size_t global_iter = calc_global_range( + static_cast(local_iter), static_cast(n_iter)); + size_t global_channels = calc_global_range( + static_cast(local_channel), static_cast(slc)); + cgh.parallel_for( + ::sycl::nd_range<3>(::sycl::range<3>(global_iter, global_batch, + global_channels), + ::sycl::range<3>( + local_iter, local_batch, local_channel)), + copy_kernel); + }); + + return status::success; +} + +status_t _ref_rnn_common_t::copy_init_iter(const exec_ctx_t &ctx, dim_t batch, + dim_t dhc, dim_t sic, dim_t n_iter, dim_t n_layer, dim_t n_dir, + dim_t n_states, dim_t states_ws_ld, const rnn_utils::workspace_t &ws, + const memory_storage_t &firstit_states) const { + + auto max_wg_size_per_dim = calc_local_range(ctx); + + parallel_for(ctx, copy_kernel_, [&](::sycl::handler &cgh) { + auto src_iter_mem_arg = firstit_states + ? utils::downcast( + &firstit_states) + ->get_in_memory_arg(ctx.stream(), cgh) + : xpu::sycl::memory_storage_base_t::empty_in_memory_arg( + ctx.stream(), cgh); + auto ws_mem_arg + = utils::downcast( + &ws.states()) + ->get_out_memory_arg(ctx.stream(), cgh); + + ref_rnn_copy_t copy_kernel( + pd()->copy_init_iter_conf_, src_iter_mem_arg, ws_mem_arg); + size_t local_batch = max_wg_size_per_dim; + size_t local_channel = max_wg_size_per_dim; + size_t local_lay_dir = max_wg_size_per_dim; + size_t global_batch + = calc_global_range(static_cast(max_wg_size_per_dim), + static_cast(batch)); + size_t global_channels = calc_global_range( + static_cast(max_wg_size_per_dim), + std::max(static_cast(sic), static_cast(dhc))); + size_t global_lay_dir + = calc_global_range(static_cast(max_wg_size_per_dim), + static_cast(n_layer * n_dir)); + cgh.parallel_for( + ::sycl::nd_range<3>(::sycl::range<3>(global_lay_dir, + global_batch, global_channels), + ::sycl::range<3>( + local_lay_dir, local_batch, local_channel)), + copy_kernel); + }); + return status::success; +} + +status_t _ref_rnn_common_t::copy_res_layer(const exec_ctx_t &ctx, dim_t batch, + dim_t dhc, dim_t slc, dim_t n_iter, dim_t n_layer, dim_t n_dir, + dim_t n_states, dim_t states_ws_ld, + const memory_storage_t &dst_last_layer, + const rnn_utils::workspace_t &ws) const { + + auto max_wg_size_per_dim = calc_local_range(ctx); + + parallel_for(ctx, copy_kernel_, [&](::sycl::handler &cgh) { + auto ws_mem_arg + = utils::downcast( + &ws.states()) + ->get_in_memory_arg(ctx.stream(), cgh); + auto dst_mem_arg + = utils::downcast( + &dst_last_layer) + ->get_out_memory_arg(ctx.stream(), cgh); + + ref_rnn_copy_t copy_kernel( + pd()->copy_res_layer_conf_, ws_mem_arg, dst_mem_arg); + size_t local_batch = max_wg_size_per_dim; + size_t local_iter = max_wg_size_per_dim; + size_t local_channel = max_wg_size_per_dim; + size_t global_batch + = calc_global_range(static_cast(max_wg_size_per_dim), + static_cast(batch)); + size_t global_iter + = calc_global_range(static_cast(max_wg_size_per_dim), + static_cast(n_iter)); + size_t global_channels + = calc_global_range(static_cast(max_wg_size_per_dim), + static_cast(n_states * dhc)); + cgh.parallel_for( + ::sycl::nd_range<3>(::sycl::range<3>(global_iter, global_batch, + global_channels), + ::sycl::range<3>( + local_iter, local_batch, local_channel)), + copy_kernel); + }); + return status::success; +} + +status_t _ref_rnn_common_t::copy_res_iter(const exec_ctx_t &ctx, dim_t batch, + dim_t dhc, dim_t sic, dim_t n_iter, dim_t n_layer, dim_t n_dir, + dim_t n_states, dim_t states_ws_ld, + const memory_storage_t &dst_last_iter, + const rnn_utils::workspace_t &ws) const { + + auto max_wg_size_per_dim = calc_local_range(ctx); + + parallel_for(ctx, copy_kernel_, [&](::sycl::handler &cgh) { + auto src_iter + = utils::downcast( + &ws.states()) + ->get_in_memory_arg(ctx.stream(), cgh); + auto dst_iter = dst_last_iter + ? utils::downcast( + &dst_last_iter) + ->get_out_memory_arg(ctx.stream(), cgh) + : xpu::sycl::memory_storage_base_t::empty_out_memory_arg( + ctx.stream(), cgh); + ref_rnn_copy_t copy_kernel( + pd()->copy_res_iter_conf_, src_iter, dst_iter); + + size_t local_batch = max_wg_size_per_dim; + size_t local_channel = max_wg_size_per_dim; + size_t local_lay_dir = max_wg_size_per_dim; + size_t global_batch + = calc_global_range(static_cast(max_wg_size_per_dim), + static_cast(batch)); + size_t global_channels + = calc_global_range(static_cast(max_wg_size_per_dim), + static_cast(dhc)); + size_t global_lay_dir + = calc_global_range(static_cast(max_wg_size_per_dim), + static_cast(n_layer * n_dir)); + cgh.parallel_for( + ::sycl::nd_range<3>(::sycl::range<3>(global_lay_dir, + global_batch, global_channels), + ::sycl::range<3>( + local_lay_dir, local_batch, local_channel)), + copy_kernel); + }); + + return status::success; +} + +status_t _ref_rnn_common_t::rnn_bias(const exec_ctx_t &ctx, dim_t batch, + dim_t dhc, dim_t iter, dim_t lay, dim_t dir, + const rnn_utils::workspace_t &ws, const rnn_utils::scratch_t &scratch, + const rnn_utils ::user_data_t &user_data) const { + + auto max_wg_size_per_dim = calc_local_range(ctx); + + parallel_for(ctx, bias_kernel_, [&](::sycl::handler &cgh) { + auto src_mem_arg + = utils::downcast( + scratch.gates(0).get()) + ->get_inout_memory_arg(ctx.stream(), cgh); + auto bias_mem_arg + = utils::downcast( + user_data.bias(lay, dir).get()) + ->get_in_memory_arg(ctx.stream(), cgh); + + auto dst_mem_arg + = utils::downcast( + ws.states(lay, dir, iter - 1).get()) + ->get_out_memory_arg(ctx.stream(), cgh); + ref_rnn_bias bias_kernel(pd()->sycl_rnn_bias_conf_t_, src_mem_arg, + bias_mem_arg, dst_mem_arg); + + size_t local_batch = max_wg_size_per_dim; + size_t local_channel = max_wg_size_per_dim; + size_t global_batch + = calc_global_range(static_cast(max_wg_size_per_dim), + static_cast(batch)); + size_t global_channels + = calc_global_range(static_cast(max_wg_size_per_dim), + static_cast(dhc)); + cgh.parallel_for( + ::sycl::nd_range<3>( + ::sycl::range<3>(global_channels, global_batch, 1), + ::sycl::range<3>(local_channel, local_batch, 1)), + bias_kernel); + }); + + return status::success; +} + +// //********************* Execution function *********************// + +status_t _ref_rnn_common_t::execute_(const exec_ctx_t &ctx) const { + + impl::engine_t *engine = ctx.stream()->engine(); + + auto rnn_pd = this->pd(); + + const conf_t &rnn = this->pd()->rnn_conf; + + dim_t n_layer = rnn.n_layer; + dim_t n_dir = rnn.n_dir; + dim_t n_states = rnn.n_states; + dim_t n_iter = rnn.n_iter; + dim_t n_gates = rnn.n_gates; + dim_t n_bias = rnn.n_bias; + dim_t batch = rnn.mb; + dim_t slc = rnn.slc; + dim_t sic = rnn.sic; + dim_t dhc = rnn.dhc; + dim_t dlc = rnn.dlc; + + auto &src_layer_native_ = CTX_IN_STORAGE(DNNL_ARG_SRC_LAYER); + auto &src_iter_native_ = CTX_IN_STORAGE(DNNL_ARG_SRC_ITER); + auto &wei_layer_native_ = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS_LAYER); + auto &wei_iter_native_ = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS_ITER); + auto &bias_native_ = CTX_IN_STORAGE(DNNL_ARG_BIAS); + + auto &dst_last_layer_native_ = CTX_OUT_STORAGE(DNNL_ARG_DST_LAYER); + auto &dst_last_iter_native_ = CTX_OUT_STORAGE(DNNL_ARG_DST_ITER); + + auto scratch_workspace + = ctx.get_scratchpad_grantor().get_memory_storage(key_rnn_space); + auto &workspace_ = rnn.is_training ? CTX_OUT_STORAGE(DNNL_ARG_WORKSPACE) + : *scratch_workspace; + const auto &workspace = rnn_utils::workspace_t(workspace_, rnn); + + const auto scratch + = rnn_utils::scratch_t(rnn, ctx.get_scratchpad_grantor()); + + //const rnn_utils::user_data_t user_data(src_layer_native_, wei_layer_native_, + // wei_iter_native_, bias_native_, rnn, pd()->off); + const rnn_utils::user_data_t user_data(wei_layer_native_, + {pd()->weights_md(0)}, wei_iter_native_, {pd()->weights_md(1)}, + bias_native_, {pd()->weights_md(2)}, rnn); + + DPRINT("\n%s\n", "+++++++++++++++"); + DPRINT("%s\n", "+++++++++++++++"); + DPRINT(" n_layer = %lld\n", static_cast(n_layer)); + DPRINT(" n_dir = %lld\n", static_cast(n_dir)); + DPRINT(" n_iter = %lld\n", static_cast(n_iter)); + DPRINT(" n_gates = %lld\n", static_cast(n_gates)); + DPRINT(" n_bias = %lld\n", static_cast(n_bias)); + DPRINT(" n_states = %lld\n", static_cast(n_states)); + DPRINT(" n_weights_layer = %lld\n", static_cast(rnn_pd->SLC())); + DPRINT(" n_weights_iter = %lld\n", static_cast(rnn_pd->SIC())); + DPRINT(" batch = %lld\n", static_cast(batch)); + DPRINT(" slc = %lld\n", static_cast(slc)); + DPRINT(" sic = %lld\n", static_cast(sic)); + DPRINT(" dhc = %lld\n", static_cast(dhc)); + DPRINT(" dlc = %lld\n", static_cast(dlc)); + DPRINT("%s\n", "+++++++++++++++"); + DPRINT(" use_workspace = %s\n", rnn.use_workspace ? "yes" : "no"); + DPRINT("%s\n", "+++++++++++++++"); + DPRINT(" with_bias = %s\n", rnn_pd->with_bias() ? "yes" : "no"); + DPRINT(" with_dst_iter = %s\n", rnn_pd->with_dst_iter() ? "yes" : "no"); + DPRINT("%s\n", "+++++++++++++++"); + + CHECK(copy_init_layer(ctx, batch, dhc, slc, n_iter, n_layer, n_dir, + n_states, rnn.states_ws_ld, workspace, src_layer_native_)); + + CHECK(copy_init_iter(ctx, batch, dhc, sic, n_iter, n_layer, n_dir, n_states, + rnn.states_ws_ld, workspace, src_iter_native_)); + + // run the execution on the grid + const grid_ctx_t &grid_struct { + engine, ctx, user_data, workspace, scratch, pd()->rnn_conf}; + CHECK(this->grid_func(grid_struct)); + + // Finally we copy the results to the result buffers + + CHECK(copy_res_layer(ctx, batch, dhc, slc, n_iter, n_layer, n_dir, n_states, + rnn.states_ws_ld, dst_last_layer_native_, workspace)); + + CHECK(copy_res_iter(ctx, batch, dhc, sic, n_iter, n_layer, n_dir, n_states, + rnn.states_ws_ld, dst_last_iter_native_, workspace)); + + return status::success; +}; + +struct _ref_rnn_common_t; + +} // namespace sycl +} // namespace generic +} // namespace gpu +} // namespace impl +} // namespace dnnl diff --git a/src/gpu/generic/sycl/rnn/ref_rnn.hpp b/src/gpu/generic/sycl/rnn/ref_rnn.hpp new file mode 100644 index 00000000000..8ec720d6f7c --- /dev/null +++ b/src/gpu/generic/sycl/rnn/ref_rnn.hpp @@ -0,0 +1,177 @@ +/******************************************************************************* +* Copyright 2019-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GPU_GENERIC_SYCL_RNN_REF_RNN_HPP +#define GPU_GENERIC_SYCL_RNN_REF_RNN_HPP + +#include + +#include "common/c_types_map.hpp" +#include "common/primitive.hpp" +#include "common/primitive_desc_iterator.hpp" +#include "common/utils.hpp" +#include "gpu/generic/sycl/rnn/rnn_utils.hpp" +#include "gpu/generic/sycl/sycl_gpu_primitive.hpp" +#include "gpu/gpu_rnn_pd.hpp" + +#include "gpu/generic/sycl/sycl_gpu_primitive.hpp" +#include "gpu/generic/sycl/sycl_primitive_conf.hpp" + +#include "gpu/generic/sycl/sycl_gpu_kernel.hpp" + +namespace dnnl { +namespace impl { +namespace gpu { +namespace generic { +namespace sycl { + +enum gemm_kind_t { gemm_iter_fwd, gemm_layer_fwd }; + +struct _ref_rnn_common_t : public primitive_t { + using primitive_t::primitive_t; + + using base_pd_t = gpu_rnn_fwd_pd_t; + + struct cell_ctx_t { + impl::engine_t *engine; + const exec_ctx_t &ctx; + dim_t dir; + dim_t lay; + dim_t iter; + const rnn_utils::user_data_t &user_data; + const rnn_utils::workspace_t &workspace; + const rnn_utils::scratch_t &scratch; + rnn_utils::conf_t rnn; + }; + + struct grid_ctx_t { + impl::engine_t *engine; + const exec_ctx_t &ctx; + const rnn_utils::user_data_t &user_data; + const rnn_utils::workspace_t &workspace; + const rnn_utils::scratch_t &scratch; + rnn_utils::conf_t rnn; + }; + + struct pd_t : public base_pd_t { + + using base_pd_t::base_pd_t; + + pd_t(const pd_t &other) = default; + + DECLARE_COMMON_PD_T("ref:any", _ref_rnn_common_t); + + status_t init(impl::engine_t *engine); + + status_t set_default_params(); + + rnn_utils::conf_t rnn_conf = {}; + data_type_t acc_data_t = data_type::undef; + data_type_t src_type = data_type::undef; + data_type_t weights_type = data_type::undef; + + std::shared_ptr vanilla_cell_act_pd_; + std::shared_ptr gemm_iter_fwd_pd_; + std::shared_ptr gemm_layer_fwd_pd_; + + sycl_rnn_copy_conf_t copy_init_layer_conf_; + sycl_rnn_copy_conf_t copy_init_iter_conf_; + sycl_rnn_copy_conf_t copy_res_layer_conf_; + sycl_rnn_copy_conf_t copy_res_iter_conf_; + sycl_rnn_bias_conf_t sycl_rnn_bias_conf_t_; + + private: + void init_scratchpad(dim_t workspace_size) { + using namespace memory_tracking::names; + auto scratchpad = this->scratchpad_registry().registrar(); + scratchpad.book(key_rnn_space, workspace_size, 1); + rnn_utils::scratch_t::book(scratchpad, rnn_conf, + {gemm_iter_fwd_pd_.get(), gemm_layer_fwd_pd_.get()}); + } + }; + + status_t init(impl::engine_t *engine) override; + + status_t execute(const exec_ctx_t &ctx) const override { + return execute_(ctx); + } + +private: + status_t execute_(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + + status_t linear_execution(const grid_ctx_t &grid_struct); + + status_t cell_execution(const cell_ctx_t &cell_struct); + + status_t gemm_primitive(impl::engine_t *engine, const exec_ctx_t &ctx, + std::unique_ptr &a, + std::unique_ptr &b, + std::unique_ptr &c, gemm_kind_t gemm_kind) const; + + status_t copy_init_layer(const exec_ctx_t &ctx, dim_t n_iter, dim_t batch, + dim_t slc, dim_t dhc, dim_t n_layer, dim_t n_dir, dim_t n_states, + dim_t states_ws_ld, const rnn_utils::workspace_t &ws, + const memory_storage_t &input) const; + status_t copy_init_iter(const exec_ctx_t &ctx, dim_t n_layer, dim_t n_dir, + dim_t batch, dim_t sic, dim_t dhc, dim_t n_iter, dim_t n_states, + dim_t states_ws_ld, const rnn_utils::workspace_t &ws, + const memory_storage_t &firstit_states) const; + status_t copy_res_layer(const exec_ctx_t &ctx, dim_t n_iter, dim_t batch, + dim_t slc, dim_t dhc, dim_t n_layer, dim_t n_dir, dim_t n_states, + dim_t states_ws_ld, const memory_storage_t &dst_last_layer, + const rnn_utils::workspace_t &ws) const; + status_t copy_res_iter(const exec_ctx_t &ctx, dim_t n_layer, dim_t n_dir, + dim_t batch, dim_t sic, dim_t dhc, dim_t n_iter, dim_t n_states, + dim_t states_ws_ld, const memory_storage_t &dst_last_iter, + const rnn_utils::workspace_t &ws) const; + status_t rnn_bias(const exec_ctx_t &ctx, dim_t batch, dim_t dhc, dim_t iter, + dim_t lay, dim_t dir, const rnn_utils::workspace_t &ws, + const rnn_utils::scratch_t &scratch, + const rnn_utils ::user_data_t &user_data) const; + + // ptrs to GEMM primitives + std::shared_ptr gemm_layer_fwd_; + std::shared_ptr gemm_iter_fwd_; + + // offset variables set in workspace and used in offset calculations for + // grid & cell execution and fwd & bwd kernel macros + dim_t ws_gates_offset_ = 0; + dim_t ws_states_offset_ = 0; + dim_t ws_c_states_offset_ = 0; + dim_t ws_grid_comp_offset_ = 0; + dim_t ws_bias_offset_ = 0; + + // ptrs for storing weight offsets which are pre-calculated in + // in grid execution as weights_*_assing_func + std::vector wei_layer_offsets; + std::vector wei_iter_offsets; + + std::function cell_func; + std::function grid_func; + + kernel_t copy_kernel_; + kernel_t bias_kernel_; +}; + +using ref_rnn_fwd_t = _ref_rnn_common_t; + +} // namespace sycl +} // namespace generic +} // namespace gpu +} // namespace impl +} // namespace dnnl +#endif diff --git a/src/gpu/generic/sycl/rnn/rnn_kernels.hpp b/src/gpu/generic/sycl/rnn/rnn_kernels.hpp new file mode 100644 index 00000000000..3cba4ec7d78 --- /dev/null +++ b/src/gpu/generic/sycl/rnn/rnn_kernels.hpp @@ -0,0 +1,177 @@ +/******************************************************************************* +* Copyright 2023-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SRC_GPU_GENERIC_SYCL_RNN_RNN_KERNELS_HPP +#define SRC_GPU_GENERIC_SYCL_RNN_RNN_KERNELS_HPP + +#include "common/c_types_map.hpp" +#include "gpu/generic/sycl/sycl_io_helper.hpp" +#include "gpu/generic/sycl/sycl_math_utils.hpp" +#include "gpu/generic/sycl/sycl_primitive_conf.hpp" +#include "xpu/sycl/types.hpp" + +namespace dnnl { +namespace impl { +namespace gpu { +namespace generic { +namespace sycl { + +inline int off_ker_bias(int dhc, int i0, int i1, int n_gates) { + return i0 * dhc + i1; +} + +inline int cell_ws_state(int states_ws_ld, int i, int j) { + return i * states_ws_ld + j; +} + +inline int cell_scratch_mem( + int scratch_gates_ld, int dhc, int i, int n, int j) { + return i * scratch_gates_ld + n * dhc + j; +} + +struct ref_rnn_copy_t { + ref_rnn_copy_t(const sycl_rnn_copy_conf_t &conf, + const xpu::sycl::in_memory_arg_t &src, + xpu::sycl::out_memory_arg_t &dst) + : src_ {src}, dst_ {dst}, conf_ {conf} {} + + void operator()(::sycl::nd_item<3> item) const { + const dim_t tl = item.get_global_id(0) / conf_.n_dir; // timestep/layer + const dim_t dir = item.get_global_id(0) % conf_.n_dir; // direction + const dim_t n = item.get_global_id(1); // batch + const dim_t c = item.get_global_id(2); // channel + + if (dir >= conf_.n_dir || n >= conf_.batch || c >= conf_.range) return; + + dim_t src_offset = 0; + dim_t dst_offset = 0; + if (conf_.layer) { // layer + if (tl >= conf_.n_iter) return; + if (conf_.to_state) { // init + src_offset = conf_.src_md.off(tl, n, c); + dst_offset = conf_.dst_md.off(0, dir, tl, n, c); + } else { // res + src_offset = conf_.src_md.off(conf_.n_layer, dir, tl, n, c); + dst_offset = conf_.dst_md.off(tl, n, dir * conf_.range + c); + } + } else { // iter + if (tl >= conf_.n_layer) return; + if (conf_.to_state) { // init + src_offset = conf_.src_md.off(tl, dir, n, c); + dst_offset = conf_.dst_md.off(tl, dir, conf_.n_iter, n, c); + } else { // res + src_offset + = conf_.src_md.off(tl + 1, dir, conf_.n_iter - 1, n, c); + dst_offset = conf_.dst_md.off(tl, dir, n, c); + } + } + if (src_ptr()) { + auto src = load_float_value( + src_md().data_type(), src_ptr(), src_offset); + if (dst_ptr()) { + store_float_value( + src_md().data_type(), src, dst_ptr(), dst_offset); + } + } else { + if (dst_ptr()) { + store_float_value( + src_md().data_type(), 0.0f, dst_ptr(), dst_offset); + } + } + } + + xpu::sycl::in_memory_arg_t src_; + xpu::sycl::out_memory_arg_t dst_; + sycl_rnn_copy_conf_t conf_; + + const xpu::sycl::md_t &src_md() const { return conf_.src_md; } + void *src_ptr() const { return src_.get_pointer(); } + void *dst_ptr() const { return dst_.get_pointer(); } +}; + +struct ref_rnn_bias { + ref_rnn_bias(const sycl_rnn_bias_conf_t &conf, + const xpu::sycl::inout_memory_arg_t &src_base, + const xpu::sycl::in_memory_arg_t &bias, + const xpu::sycl::out_memory_arg_t &dst_base) + : src_ {src_base}, bias_ {bias}, dst_ {dst_base}, conf_ {conf} {} + void operator()(::sycl::nd_item<3> item) const { + + const int b = item.get_global_id(1); + const int c = item.get_global_id(0); + + if (b >= conf_.batch || c >= conf_.dhc) return; + + auto src = src_ptr(); + auto bias = bias_ptr(); + auto dst = dst_ptr(); + + auto src_offset = src_data_offset(b, c); + auto bias_offset = bias_data_offset(b, c); + auto dst_offset = dst_data_offset(b, c); + + auto src_val + = load_float_value(conf_.dst_md.data_type(), src, src_offset); + auto bias_val = load_float_value(conf_.bias_type, bias, bias_offset); + + auto g = compute_gates(src_val, bias_val); + + store_float_value(conf_.dst_md.data_type(), g, dst, dst_offset); + store_float_value(conf_.dst_md.data_type(), g, src, src_offset); + } + + inline dim_t src_data_offset(int b, int c) const { + return cell_scratch_mem(conf_.gates_ws_ld, conf_.dhc, b, 0, c); + } + + inline dim_t bias_data_offset(int b, int c) const { + return off_ker_bias(conf_.dhc, 0, c, 0); + } + + inline dim_t dst_data_offset(int b, int c) const { + return cell_ws_state(conf_.states_ws_ld, b, c); + } + + float compute_gates(float in_val, float bias_val) const { + switch (conf_.activation_kind) { + case alg_kind::eltwise_relu: + return (float)(math::relu_fwd( + (float)(in_val + bias_val), conf_.alpha)); + case alg_kind::eltwise_tanh: + return (float)(math::tanh_fwd((float)(in_val + bias_val))); + case alg_kind::eltwise_logistic: + return (float)(math::logistic_fwd((float)(in_val + bias_val))); + default: return 0; + } + } + + void *src_ptr() const { return src_.get_pointer(); } + void *dst_ptr() const { return dst_.get_pointer(); } + void *bias_ptr() const { return bias_.get_pointer(); } + + xpu::sycl::inout_memory_arg_t src_; + xpu::sycl::in_memory_arg_t bias_; + xpu::sycl::out_memory_arg_t dst_; + sycl_rnn_bias_conf_t conf_; +}; + +} // namespace sycl +} // namespace generic +} // namespace gpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/gpu/generic/sycl/rnn/rnn_utils.cpp b/src/gpu/generic/sycl/rnn/rnn_utils.cpp new file mode 100644 index 00000000000..b6663f22465 --- /dev/null +++ b/src/gpu/generic/sycl/rnn/rnn_utils.cpp @@ -0,0 +1,202 @@ +/******************************************************************************* +* Copyright 2019-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "gpu/generic/sycl/rnn/rnn_utils.hpp" + +#include "common/c_types_map.hpp" +#include "gpu/intel/utils.hpp" + +namespace dnnl { +namespace impl { +namespace gpu { +namespace generic { +namespace sycl { + +using namespace dnnl::impl::utils; +using namespace prop_kind; +using namespace data_type; + +void rnn_utils::init_rnn_conf( + conf_t &rnn, const rnn_pd_t *rnn_pd, data_type_t acc_data_t) { + + rnn = utils::zero(); + rnn.is_fwd = utils::one_of(rnn_pd->desc()->prop_kind, + prop_kind::forward_training, prop_kind::forward_inference); + rnn.is_training = utils::one_of(rnn_pd->desc()->prop_kind, + prop_kind::forward_training, prop_kind::backward); + + rnn.aux_data_type + = acc_data_t == data_type::f16 ? data_type::f16 : data_type::f32; + + rnn.acc_data_type = acc_data_t; + + rnn.wei_layer_type = rnn_pd->weights_md(0)->data_type; + rnn.wei_iter_type = rnn_pd->weights_md(1)->data_type; + + rnn.n_layer = rnn_pd->weights_md(0)->dims[0]; + rnn.n_iter = rnn_pd->src_md(0)->dims[0]; + rnn.n_dir = rnn_pd->weights_md(0)->dims[1]; + rnn.n_gates = rnn_pd->weights_md(0)->dims[3]; + rnn.n_states = rnn_pd->desc()->cell_kind == dnnl_vanilla_lstm ? 2 : 1; + rnn.n_bias = rnn.n_gates + 1; + rnn.mb = rnn_pd->src_md(0)->dims[1]; + rnn.sic = rnn_pd->weights_md(1)->dims[2]; + rnn.slc = rnn_pd->weights_md(0)->dims[2]; + rnn.dhc = rnn_pd->weights_md(0)->dims[4]; + rnn.dlc = rnn_pd->dst_md(0)->dims[2]; + + rnn.gates_ld = rnn.dhc * rnn.n_gates; + + rnn.n_parts_bias = 1; + rnn.parts_bias[0] = rnn.n_bias; + rnn.parts_bias[1] = 0; + rnn.iter_loop = 1; + + rnn.use_workspace = rnn.is_training; + + rnn.src_data_type = rnn_pd->src_md(0)->data_type; + rnn.input_data_type = rnn_pd->src_md(1)->data_type; + rnn.bias_data_type = rnn_pd->weights_md(2)->data_type; + rnn.dst_data_type = rnn_pd->dst_md(0)->data_type; + rnn.output_data_type = rnn_pd->dst_md(1)->data_type; + + // Assign types for optional parameters for improved kernel reuse. + if (rnn.input_data_type == data_type::undef) + rnn.input_data_type = rnn.src_data_type; + if (rnn.output_data_type == data_type::undef) + rnn.output_data_type = rnn.dst_data_type; +} + +void rnn_utils::set_rnn_conf(conf_t &rnn, const rnn_desc_t &rd) { + + const bool is_fwd = rnn.is_fwd; + + dim_t aux_elsz + = static_cast(types::data_type_size(rnn.aux_data_type)); + rnn.ws_states_elsz = types::data_type_size(rnn.src_data_type); + + rnn.scratch_gates_elsz = types::data_type_size(rnn.acc_data_type); + + // Set workspace sizes to store: + // states to compute a pass + // intermediate results from the gates + rnn.states_ws_ld = nstl::max(rnn.slc, nstl::max(rnn.sic, rnn.dhc)); + rnn.gates_ws_ld = rnn.gates_ld; + rnn.scratch_gates_ld = rnn.gates_ld; + + rnn.ws_states_cell_size = rnn.mb * rnn.states_ws_ld * rnn.ws_states_elsz; + rnn.ws_states_size = (rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) + * rnn.ws_states_cell_size; + + rnn.ws_gates_cell_size = rnn.mb * rnn.gates_ws_ld * aux_elsz; + rnn.ws_gates_size = rnn.ws_gates_cell_size; + rnn.scratch_gates_size + = rnn.mb * rnn.scratch_gates_ld * rnn.scratch_gates_elsz; + + rnn.ws_bias_size + = rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dhc * aux_elsz; + + // For intermediate step in post-gemm fwd lbr gru + rnn.scratch_cell_size = [&]() { + if (is_fwd) { + return rnn.mb * rnn.scratch_gates_ld * rnn.scratch_gates_elsz; + } else { + return static_cast(0); + } + }(); + + // Used for storing the intermediate value from fwd pass in training lbr gru + rnn.ws_per_cell = rnn.mb * rnn.dhc * aux_elsz; + + set_workspace_offsets(rnn, rnn.ws_gates_offset, rnn.ws_states_offset); +} + +dim_t rnn_utils::set_workspace_offsets( + const conf_t &rnn, dim_t &ws_gates_offset, dim_t &ws_states_offset) { + + const dim_t page_size = 4096; + dim_t current_offset = 0; + +#define register_space(a) \ + do { \ + current_offset = utils::rnd_up(current_offset, page_size); \ + CONCAT2(a, _offset) = current_offset; \ + current_offset += rnn.CONCAT2(a, _size); \ + } while (false) + + // Mandatory workspaces: go to workspace if use_workspace, scratchpad + // otherwise assumes the workspace base pointer is page aligned + register_space(ws_states); + register_space(ws_gates); + + return current_offset; +} + +dim_t rnn_utils::get_workspace_size(const conf_t &rnn) { + dim_t ws_gates_offset, ws_states_offset; + return set_workspace_offsets(rnn, ws_gates_offset, ws_states_offset); +} + +status_t rnn_utils::set_good_strides( + memory_desc_t &weights_md, format_tag_t tag) { + auto &strides = weights_md.format_desc.blocking.strides; + auto dims = weights_md.dims; + using namespace format_tag; + + if (tag == ldigo) { + strides[1] = dims[2] * strides[2]; + strides[0] = dims[1] * strides[1]; + } else if (tag == ldgoi) { + strides[3] = dims[4] * strides[4]; + strides[1] = dims[3] * strides[3]; + strides[0] = dims[1] * strides[1]; + } else + return status::unimplemented; + + return status::success; +} + +status_t rnn_utils::set_weights_desc( + memory_desc_t &weights_md, const conf_t &rnn) { + using namespace format_tag; + if (weights_md.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(weights_md, rnn.is_fwd ? ldigo : ldgoi)); + + // Adjust strides for good leading dimension in GEMM + CHECK(set_good_strides(weights_md, rnn.is_fwd ? ldigo : ldgoi)); + + return status::success; + } else if (weights_md.format_kind != format_kind::blocked) { + // This implementation only supports blocked memory + return status::unimplemented; + } + return status::success; +} + +const memory_storage_t &rnn_utils::get_storage( + const memory_storage_t *storage) { + return storage ? *storage : memory_storage_t::empty_storage(); +} +const memory_storage_t &rnn_utils::get_storage( + const std::unique_ptr &storage) { + return rnn_utils::get_storage(storage.get()); +} + +} // namespace sycl +} // namespace generic +} // namespace gpu +} // namespace impl +} // namespace dnnl diff --git a/src/gpu/generic/sycl/rnn/rnn_utils.hpp b/src/gpu/generic/sycl/rnn/rnn_utils.hpp new file mode 100644 index 00000000000..9adb62c4ddc --- /dev/null +++ b/src/gpu/generic/sycl/rnn/rnn_utils.hpp @@ -0,0 +1,392 @@ +/******************************************************************************* +* Copyright 2019-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GPU_GENERIC_SYCL_RNN_RNN_REF_UTILS_HPP +#define GPU_GENERIC_SYCL_RNN_RNN_REF_UTILS_HPP + +#include "common/c_types_map.hpp" +#include "common/memory_storage.hpp" +#include "common/memory_tracking.hpp" +#include "common/primitive_desc.hpp" +#include "common/stream.hpp" +#include "gpu/generic/sycl/sycl_gpu_primitive.hpp" +#include "gpu/gpu_rnn_pd.hpp" + +inline int calc_4d_off(int i0, int i1, int d1, int i2, int d2, int i3, int d3, + int i4, int d4) { + return ((((i0) * (d1) + (i1)) * (d2) + (i2)) * (d3) + (i3)) * (d4) + (i4); +} + +namespace dnnl { +namespace impl { +namespace gpu { +namespace generic { +namespace sycl { + +template +using strides_t = std::array; + +namespace rnn_utils { + +enum ws_part_t { gates, states, cell, grid, bias }; + +namespace kernel_id { +constexpr size_t copy_init_layer = 0; +constexpr size_t copy_init_iter = 1; +constexpr size_t copy_res_layer = 2; +constexpr size_t copy_res_iter = 3; +constexpr size_t bias_fwd = 4; +constexpr size_t cell_fwd = 5; +} // namespace kernel_id + +struct conf_t { + dim_t n_layer, n_iter, n_dir, n_gates, n_states; + dim_t mb; + dim_t slc, sic, dhc, dlc; + + dim_t gates_ld, gates_ws_ld; + + dim_t n_bias, n_parts_bias, parts_bias[DNNL_RNN_MAX_N_PARTS]; + + dim_t iter_loop; + + dim_t states_ws_ld; + bool is_fwd, is_training; + bool use_workspace; + + // Size of workspace for each tensor in bytes + dim_t ws_states_cell_size, ws_gates_cell_size; + dim_t ws_gates_size, ws_states_size, scratch_cell_size, ws_per_cell, + ws_bias_size; + + dim_t ws_gates_offset; + dim_t ws_states_offset; + dim_t ws_bias_offset; + + // Element size of each workspace part in bytes + dim_t ws_gates_elsz, ws_states_elsz, ws_bias_elsz; + + dim_t n_iter_scratch_gates; + dim_t scratch_gates_size, scratch_gates_elsz, scratch_gates_ld; + dims_t local_ranges; + + data_type_t acc_data_type; + data_type_t aux_data_type; + data_type_t input_data_type; + data_type_t output_data_type; + data_type_t src_data_type; + data_type_t dst_data_type; + data_type_t wei_layer_type; + data_type_t wei_iter_type; + data_type_t bias_data_type; +}; + +dim_t get_good_ld( + dim_t arch_ld, dim_t dim, dim_t sizeof_dt, bool ignore_assoc = false); +void init_rnn_conf( + conf_t &rnn, const rnn_pd_t *rnn_pd, data_type_t acc_data_type); +void set_rnn_conf(conf_t &rnn, const rnn_desc_t &rd); +dim_t set_workspace_offsets( + const conf_t &rnn, dim_t &ws_gates_offset, dim_t &ws_h_state_offset); +dim_t get_workspace_size(const conf_t &rnn); +status_t set_weights_desc(memory_desc_t &weights_md, const conf_t &rnn); +status_t set_good_strides(memory_desc_t &weights_md, format_tag_t tag); +const memory_storage_t &get_storage(const memory_storage_t *storage); +const memory_storage_t &get_storage( + const std::unique_ptr &storage); + +struct data_helper_t { + static dim_t type_size(data_type_t d) { + return static_cast(types::data_type_size(d)); + } +}; + +struct user_data_t : public data_helper_t { + using mst = memory_storage_t; + user_data_t(const mst &wei_layer, const memory_desc_wrapper &wei_layer_wrap, + const mst &wei_iter, const memory_desc_wrapper &wei_iter_wrap, + const mst &bias, const memory_desc_wrapper &bias_wrap, + const conf_t &conf) + : wei_layer_(wei_layer) + , wei_layer_wrap_(wei_layer_wrap) + , wei_iter_(wei_iter) + , wei_iter_wrap_(wei_iter_wrap) + , bias_(bias) + , bias_wrap_(bias_wrap) + , conf_(conf) {} + + const mst &wei_layer() const { return wei_layer_; } + std::unique_ptr wei_layer(dim_t lay, dim_t dir) const { + + dim_t t = type_size(conf_.wei_layer_type); + // wei_layer dimension order: layer, dir, src c, gate, dst c + dim_t offset = wei_layer_wrap_.off(lay, dir, 0, 0, 0) * t; + + return wei_layer_.clone_ptr_off(offset); + } + + const mst &wei_iter() const { return wei_iter_; } + std::unique_ptr wei_iter(dim_t lay, dim_t dir) const { + dim_t t = type_size(conf_.wei_iter_type); + // wei_iter dimension order: layer, dir, src c, gate, dst c + dim_t offset = wei_iter_wrap_.off(lay, dir, 0, 0, 0) * t; + + return wei_iter_.clone_ptr_off(offset); + } + + const mst &bias() const { return bias_; } + + std::unique_ptr bias(dim_t lay, dim_t dir) const { + if (bias().data_handle() == nullptr) return {}; + auto t = type_size(conf_.bias_data_type); + // bia dimension order: lay, dir, gates, dhc + auto offset = bias_wrap_.off(lay, dir, 0, 0) * t; + + return bias_.clone_ptr_off(offset); + } + + const mst &wei_layer_; + const memory_desc_wrapper &wei_layer_wrap_; + const mst &wei_iter_; + const memory_desc_wrapper &wei_iter_wrap_; + const mst &bias_; + const memory_desc_wrapper &bias_wrap_; + const conf_t &conf_; +}; + +struct workspace_t : public data_helper_t { + using mst = memory_storage_t; + workspace_t(const mst &ws, const conf_t &conf) + : ws_(ws) + , conf_(conf) + , gates_(conf.ws_gates_size > 0 ? ws.clone() : nullptr) + , gates_strides_ {0} + , states_(conf.ws_states_size > 0 ? ws.clone() : nullptr) + , states_strides_ {0} + , bias_(conf.ws_bias_size > 0 ? ws.clone() : nullptr) { + if (gates_) { + gates_->set_offset(gates_->offset() + conf.ws_gates_offset); + const int n_b = conf_.mb; + const int n_tb = conf_.n_iter * n_b; + const int n_dtb = conf_.n_dir * n_tb; + gates_strides_ + = {n_dtb * conf_.gates_ws_ld, n_tb * conf_.gates_ws_ld, + n_b * conf_.gates_ws_ld, conf_.gates_ws_ld}; + } + if (states_) { + states_->set_offset(states_->offset() + conf.ws_states_offset); + const int n_b = conf_.mb; + const int n_tb = (conf_.n_iter + 1) * n_b; + const int n_dtb = conf_.n_dir * n_tb; + states_strides_ = {n_dtb * conf_.states_ws_ld, + n_tb * conf_.states_ws_ld, n_b * conf_.states_ws_ld, 1}; + } + bias_->set_offset(bias_->offset() + conf.ws_bias_offset); + } + + template + static dim_t get_offset(const strides_t &strides, + const std::array &dims) { + dim_t offset = 0; + for (size_t i = 0; i < ndims; i++) { + offset += strides[i] * dims[i]; + } + return offset; + } + + dim_t calc_off_ws_state( + dim_t i0_, dim_t i1, dim_t i2_, dim_t i3, dim_t i4) const { + //lay,dir,time + // Logical index into workspace grid + auto i0 = i0_ + 1; + auto i2 = i2_ + 1; + + assert(i0 >= 0); + + return calc_4d_off(i0, i1, conf_.n_dir, i2, conf_.n_iter + 1, i3, + conf_.mb, i4, conf_.states_ws_ld); + } + + dim_t calc_off_ws_c_state( + dim_t i0_, dim_t i1, dim_t i2_, dim_t i3, dim_t i4) const { + // Logical index into workspace grid + auto i0 = i0_; + auto i2 = i2_ + 1; + + assert(i0 >= 0); + + return calc_4d_off(i0, i1, conf_.n_dir, i2, conf_.n_iter + 1, i3, + conf_.mb, i4, conf_.states_ws_ld); + } + + dim_t calc_off_ws_grid_offset( + dim_t i0, dim_t i1, dim_t i2, dim_t i3, dim_t i4) const { + return calc_4d_off(i0, i1, conf_.n_dir, i2, conf_.n_iter, i3, conf_.mb, + i4, conf_.dhc); + } + + const mst &ws() const { return ws_; } + const mst &gates() const { return get_storage(gates_); } + const mst &states() const { return get_storage(states_); } + + std::unique_ptr states(dim_t layer, dim_t dir, dim_t time) const { + if (!states_) return {}; + + auto i0 = layer + 1; + auto i2 = time + 1; + auto off_ = get_offset(states_strides(), {i0, dir, i2, 0}) + * conf_.ws_states_elsz; + return states().clone_ptr_off(off_); + } + + const strides_t<4> &states_strides() const { return states_strides_; } + + std::unique_ptr states_range(dim_t layer_start, dim_t layer_end, + dim_t dir_start, dim_t dir_end, dim_t time_start, + dim_t time_end) const { + auto off_start + = calc_off_ws_state(layer_start, dir_start, time_start, 0, 0) + * conf_.ws_states_elsz; + return states().clone_ptr_off(off_start); + } + + std::unique_ptr gates( + dim_t layer, dim_t dir, dim_t time, dim_t mb = 0) const { + auto off = get_offset(gates_strides(), {layer, dir, time, mb}) + * type_size(conf_.aux_data_type); + return gates().clone_ptr_off(off); + } + const strides_t<4> &gates_strides() const { return gates_strides_; } + + std::unique_ptr grid_comp(dim_t layer, dim_t dir, dim_t time) const { + if (!grid_comp_) return {}; + + auto off = calc_off_ws_grid_offset(layer, dir, time, 0, 0) + * type_size(conf_.aux_data_type); + + return grid_comp().clone_ptr_off(off); + } + + const mst &c_states() const { return get_storage(c_states_); } + const mst &bias() const { return get_storage(bias_); } + const mst &grid_comp() const { return get_storage(grid_comp_); } + +private: + const mst &ws_; + const conf_t &conf_; + std::unique_ptr gates_; + strides_t<4> gates_strides_; + std::unique_ptr states_; + strides_t<4> states_strides_; + std::unique_ptr c_states_; + std::unique_ptr bias_; + std::unique_ptr grid_comp_; +}; + +struct scratch_t : public data_helper_t { + using mst = memory_storage_t; + + enum { + key_gemm_iter_fwd = memory_tracking::names::key_nested_multiple, + key_gemm_layer_fwd, + }; + + scratch_t(const conf_t &conf, const memory_tracking::grantor_t &scratchpad) + : conf_(conf) { + using namespace memory_tracking::names; + gates_ = scratchpad.get_memory_storage(key_rnn_gates); + cell_ = scratchpad.get_memory_storage(key_rnn_cell); + } + + struct gemm_pds { + const primitive_desc_t *iter_fwd_pd; + const primitive_desc_t *layer_fwd_pd; + }; + + static void book(memory_tracking::registrar_t &scratchpad, + const conf_t &rnn_conf, const gemm_pds &gemms) { + using namespace memory_tracking::names; + if (rnn_conf.scratch_gates_size > 0) + scratchpad.book(key_rnn_gates, rnn_conf.scratch_gates_size, 1); + scratchpad.book(key_rnn_cell, rnn_conf.scratch_cell_size, 1); + // book scratchpad for nested primitives + if (gemms.layer_fwd_pd) { + scratchpad.book(key_gemm_layer_fwd, + gemms.layer_fwd_pd->scratchpad_registry()); + } + if (gemms.iter_fwd_pd) { + scratchpad.book(key_gemm_iter_fwd, + gemms.iter_fwd_pd->scratchpad_registry()); + } + } + + dim_t calc_off_gates(dim_t iter) const { + return conf_.n_iter_scratch_gates != 1 + ? iter * conf_.mb * conf_.scratch_gates_ld + : 0; + }; + + const mst *gates() const { + assert(gates_); + return (conf_.is_fwd) ? (gates_ ? gates_.get() : diff_gates_.get()) + : nullptr; + } + std::unique_ptr gates(dim_t iter) const { + auto g = gates(); + if (g == nullptr) return {}; + + auto off = calc_off_gates(iter) * conf_.scratch_gates_elsz; + return g->clone_ptr_off(off); + } + + const mst *cell() const { return cell_.get(); } + + const mst *diff_ht() const { return diff_ht_.get(); } + +private: + const conf_t &conf_; + + std::unique_ptr gates_; + std::unique_ptr diff_gates_; + std::unique_ptr cell_; + std::unique_ptr diff_states_; + std::unique_ptr diff_ht_; +}; + +inline size_t calc_global_range(const size_t lc_range, size_t gl_range) { + return ((gl_range + (lc_range - 1)) / lc_range) * lc_range; +} + +inline size_t calc_local_range(const exec_ctx_t &ctx) { + // Check the device for the supported max worgroup size + // TODO: 256 is an arbitrary ceiling to ensure we do not use too + // many registers, can be improved in future. + return std::floor(std::cbrt(std::min(256, + static_cast(ctx.stream()->impl()) + ->queue() + ->get_device() + .get_info<::sycl::info::device::max_work_group_size>()))); +} + +} // namespace rnn_utils + +} // namespace sycl +} // namespace generic +} // namespace gpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/gpu/generic/sycl/sycl_post_ops.hpp b/src/gpu/generic/sycl/sycl_post_ops.hpp index 5e95cdbfb50..c11b5147073 100644 --- a/src/gpu/generic/sycl/sycl_post_ops.hpp +++ b/src/gpu/generic/sycl/sycl_post_ops.hpp @@ -19,6 +19,7 @@ #include "common/c_types_map.hpp" #include "common/primitive_attr.hpp" +#include "common/primitive_exec_types.hpp" #include "common/utils.hpp" #include "gpu/generic/sycl/sycl_io_helper.hpp" #include "gpu/generic/sycl/sycl_math_utils.hpp" diff --git a/src/gpu/generic/sycl/sycl_primitive_conf.hpp b/src/gpu/generic/sycl/sycl_primitive_conf.hpp index ff9a921cbe8..fe2264cd0b7 100644 --- a/src/gpu/generic/sycl/sycl_primitive_conf.hpp +++ b/src/gpu/generic/sycl/sycl_primitive_conf.hpp @@ -460,6 +460,61 @@ struct sycl_reduction_conf_t { static constexpr int local_col_wg = 8; }; +struct sycl_rnn_copy_conf_t { + xpu::sycl::md_t src_md; + xpu::sycl::md_t dst_md; + dim_t range; + dim_t n_dir; + dim_t n_layer; + dim_t n_iter; + dim_t batch; + dim_t states_ws_ld; + bool layer; + bool to_state; +}; + +struct sycl_rnn_bias_conf_t { + xpu::sycl::md_t dst_md; + data_type_t bias_type; + dim_t batch; + dim_t dhc; + dim_t gates_ws_ld; + dim_t states_ws_ld; + dnnl_alg_kind_t activation_kind; + float alpha; +}; + +template +using strides_t = std::array; +struct outer_strides_getter_t { + template + operator strides_t() const { + strides_t ret; + assert(static_cast(ndims) >= md.ndims()); + for (int d = ndims - 1; d >= 0; d--) { + // Assumes size 1 dimensions are dense with respect to the neighboring + // dimension so they can be used for size calculations in some layouts + ret[d] = [&]() { + if (d >= md.ndims()) + return static_cast(0); + else if (md.padded_dims()[d] > 1) + return md.strides()[d]; + else if (d == md.ndims() - 1) + return static_cast(1); + else + return ret[d + 1] * md.padded_dims()[d + 1]; + }(); + } + return ret; + } + + const memory_desc_wrapper &md; +}; + +inline outer_strides_getter_t get_outer_strides(const memory_desc_wrapper &md) { + return {md}; +} + CHECK_SYCL_KERNEL_ARG_TYPE(sycl_binary_conf_t); CHECK_SYCL_KERNEL_ARG_TYPE(sycl_prelu_conf_t); CHECK_SYCL_KERNEL_ARG_TYPE(sycl_shuffle_conf_t); @@ -478,7 +533,8 @@ CHECK_SYCL_KERNEL_ARG_TYPE(sycl_convolution_bwd_data_conf_t); CHECK_SYCL_KERNEL_ARG_TYPE(sycl_convolution_bwd_weights_conf_t); CHECK_SYCL_KERNEL_ARG_TYPE(sycl_simple_reduction_conf_t); CHECK_SYCL_KERNEL_ARG_TYPE(sycl_reduction_conf_t); - +CHECK_SYCL_KERNEL_ARG_TYPE(sycl_rnn_copy_conf_t); +CHECK_SYCL_KERNEL_ARG_TYPE(sycl_rnn_bias_conf_t); } // namespace sycl } // namespace generic } // namespace gpu diff --git a/src/gpu/gpu_rnn_list.cpp b/src/gpu/gpu_rnn_list.cpp index 7e9526abf61..9055273ff9e 100644 --- a/src/gpu/gpu_rnn_list.cpp +++ b/src/gpu/gpu_rnn_list.cpp @@ -20,6 +20,10 @@ #include "gpu/intel/ocl/rnn/rnn_grid.hpp" #endif +#ifdef GENERIC_SYCL_KERNELS_ENABLED +#include "gpu/generic/sycl/rnn/ref_rnn.hpp" +#endif + namespace dnnl { namespace impl { namespace gpu { @@ -32,6 +36,7 @@ const std::map> impl_list_map REG_RNN_P({ {{forward}, { GPU_INSTANCE_INTEL(intel::ocl::simple_rnn_fwd_t) + GPU_INSTANCE_GENERIC_SYCL(generic::sycl::ref_rnn_fwd_t) nullptr, }}, {{backward}, REG_BWD_PK({ diff --git a/src/xpu/sycl/buffer_memory_storage.cpp b/src/xpu/sycl/buffer_memory_storage.cpp index e10cd90b751..3a2978f8f17 100644 --- a/src/xpu/sycl/buffer_memory_storage.cpp +++ b/src/xpu/sycl/buffer_memory_storage.cpp @@ -39,7 +39,10 @@ memory_arg_t get_memory_arg(const buffer_memory_storage_t *storage, = utils::downcast(stream->impl()); return {sycl_stream_impl->get_dummy_accessor(cgh)}; } - return {storage->buffer().get_access(cgh)}; + ::sycl::id<1> offset(storage->offset()); + ::sycl::range<1> range(storage->buffer().size() - storage->offset()); + + return {storage->buffer().get_access(cgh, range, offset)}; } } // namespace @@ -124,6 +127,16 @@ std::unique_ptr buffer_memory_storage_t::clone() const { storage->buffer_ = buffer_; storage->base_offset_ = base_offset_; + storage->set_offset(offset()); + + return storage; +} + +std::unique_ptr buffer_memory_storage_t::clone_ptr_off( + size_t offset) const { + auto storage = clone(); + storage->set_offset(offset + this->offset()); + return storage; } diff --git a/src/xpu/sycl/buffer_memory_storage.hpp b/src/xpu/sycl/buffer_memory_storage.hpp index 9453cf4a24f..927662b7ab2 100644 --- a/src/xpu/sycl/buffer_memory_storage.hpp +++ b/src/xpu/sycl/buffer_memory_storage.hpp @@ -68,6 +68,8 @@ class buffer_memory_storage_t : public memory_storage_base_t { std::unique_ptr clone() const override; + std::unique_ptr clone_ptr_off(size_t offset) const override; + in_memory_arg_t get_in_memory_arg( stream_t *stream, ::sycl::handler &cgh) const override; out_memory_arg_t get_out_memory_arg( diff --git a/src/xpu/sycl/memory_storage_helper.hpp b/src/xpu/sycl/memory_storage_helper.hpp index 2c3d50f8825..88f735ef5fe 100644 --- a/src/xpu/sycl/memory_storage_helper.hpp +++ b/src/xpu/sycl/memory_storage_helper.hpp @@ -63,17 +63,19 @@ class interop_memory_arg_t { interop_memory_arg_t(memory_storage_t *raw_mem, ::sycl::handler &cgh) { if (!raw_mem || raw_mem->is_null()) { return; } auto *mem = static_cast(raw_mem); + dim_t offset = mem->offset(); switch (mem->memory_kind()) { case sycl::memory_kind::buffer: { auto *buffer_storage = utils::downcast(mem); acc_.emplace(buffer_storage->buffer(), cgh); - offset_ = buffer_storage->base_offset(); + offset_ = buffer_storage->base_offset() + offset; break; } case sycl::memory_kind::usm: { raw_ptr_ = utils::downcast(mem) ->usm_ptr(); + offset_ = offset; break; } default: assert(!"unexpected memory kind"); @@ -107,17 +109,20 @@ class interop_memory_arg_t { ih.get_native_mem(acc_.value())) + offset_); } else { - raw_ptr = raw_ptr_; + raw_ptr = reinterpret_cast( + reinterpret_cast(raw_ptr_) + offset_); } return reinterpret_cast(raw_ptr); } - bool empty() const { return !raw_ptr_ && !acc_.has_value(); } + bool empty() const { + return !raw_ptr_ && !acc_.has_value(); + } private: void *raw_ptr_ = nullptr; std::optional<::sycl::accessor> acc_; - size_t offset_; + size_t offset_ = 0; }; } // namespace sycl diff --git a/src/xpu/sycl/types.hpp b/src/xpu/sycl/types.hpp index 81e6d387aef..6caaeec8376 100644 --- a/src/xpu/sycl/types.hpp +++ b/src/xpu/sycl/types.hpp @@ -101,7 +101,8 @@ struct memory_arg_t { if (usm_) return usm_; return const_cast( acc_.template get_multi_ptr<::sycl::access::decorated::no>() - .get()); + .get() + + acc_.get_offset()); } bool empty() const { return empty_; } diff --git a/src/xpu/sycl/usm_memory_storage.hpp b/src/xpu/sycl/usm_memory_storage.hpp index 33da2f432b0..b7c7e9e20b1 100644 --- a/src/xpu/sycl/usm_memory_storage.hpp +++ b/src/xpu/sycl/usm_memory_storage.hpp @@ -111,10 +111,15 @@ class usm_memory_storage_t : public memory_storage_base_t { storage->usm_ptr_ = decltype(usm_ptr_)(usm_ptr_.get(), [](void *) {}); storage->usm_kind_ = usm_kind_; + storage->set_offset(offset()); return storage; } + std::unique_ptr clone_ptr_off(size_t offset) const override { + return get_sub_storage(offset, 0); + } + in_memory_arg_t get_in_memory_arg( stream_t *stream, ::sycl::handler &cgh) const override; out_memory_arg_t get_out_memory_arg( diff --git a/tests/gtests/test_rnn_forward.cpp b/tests/gtests/test_rnn_forward.cpp index f885df10c74..7017712b0d4 100644 --- a/tests/gtests/test_rnn_forward.cpp +++ b/tests/gtests/test_rnn_forward.cpp @@ -169,6 +169,19 @@ class rnn_forward_test_t : public ::testing::TestWithParam { void SetUp() override { auto p = ::testing::TestWithParam::GetParam(); + SKIP_IF_GENERIC(!is_vanilla_rnn, "Unsupported cell type"); + SKIP_IF_GENERIC( + !(p.direction == rnn_direction::unidirectional_left2right), + "Unsupported direction"); + SKIP_IF_GENERIC( + is_lstm || is_gru || is_lbr_gru || is_augru || is_lbr_augru, + "Unsupported cell type"); + SKIP_IF_CUDA(!is_vanilla_rnn, "Unsupported cell type"); + SKIP_IF_CUDA(!(p.direction == rnn_direction::unidirectional_left2right), + "Unsupported direction cuda"); + SKIP_IF_CUDA( + is_lstm || is_gru || is_lbr_gru || is_augru || is_lbr_augru, + "Unsupported cell type"); catch_expected_failures( [&]() { Test(); }, p.expect_to_fail, p.expected_status, false); } @@ -732,6 +745,65 @@ CPU_INSTANTIATE_TEST_SUITE_P(TestRnn, rnn_forward_test_f32, fmt::undef}, test_rnn_sizes_t {3, 1, 5, 1, 4, 4, 4, 4}})); +TEST_P(rnn_forward_test_f32, TestsRnnGPU) {} +GPU_INSTANTIATE_TEST_SUITE_P(TestRnn, rnn_forward_test_f32, + ::testing::Values( + cfg_f32 {PLAIN_RNN(alg::eltwise_tanh), + prop_kind::forward_inference, + dir::unidirectional_left2right, + {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, + fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, + fmt::ldnc}, + test_rnn_sizes_t {1, 1, 10, 16, 100, 100, 100, 100}}, + /* Check for invalid parameters: unsupported unrolling */ + cfg_f32 {PLAIN_RNN(alg::eltwise_tanh), + prop_kind::forward_inference, + dir::unidirectional_left2right, + {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, + fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, + fmt::ldnc}, + test_rnn_sizes_t {2, 1, 10, 16, 200, 100, 100, 100}, + true, dnnl_invalid_arguments}, + cfg_f32 {PLAIN_RNN(alg::eltwise_tanh), + prop_kind::forward_inference, + dir::unidirectional_left2right, + {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, + fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, + fmt::ldnc}, + test_rnn_sizes_t {2, 1, 10, 16, 100, 200, 100, 100}, + true, dnnl_invalid_arguments}, + /* Check for invalid parameters: inconsistent dimensions */ + cfg_f32 {PLAIN_RNN(alg::eltwise_tanh), + prop_kind::forward_inference, + dir::unidirectional_left2right, + {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, + fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, + fmt::ldnc}, + test_rnn_sizes_t {2, 1, 10, 16, 100, 100, 50, 100}, + true, dnnl_invalid_arguments}, + /* Check if passing {src,dst}_iter impacts results */ + cfg_f32 {PLAIN_RNN(alg::eltwise_tanh), + prop_kind::forward_inference, + dir::unidirectional_left2right, + {fmt::tnc, fmt::undef, fmt::ldigo, fmt::ldigo, + fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, + fmt::ldnc}, + test_rnn_sizes_t {3, 1, 5, 1, 4, 4, 4, 4}}, + cfg_f32 {PLAIN_RNN(alg::eltwise_tanh), + prop_kind::forward_inference, + dir::unidirectional_left2right, + {fmt::tnc, fmt::ldnc, fmt::ldigo, fmt::ldigo, + fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, + fmt::undef}, + test_rnn_sizes_t {3, 1, 5, 1, 4, 4, 4, 4}}, + cfg_f32 {PLAIN_RNN(alg::eltwise_tanh), + prop_kind::forward_inference, + dir::unidirectional_left2right, + {fmt::tnc, fmt::undef, fmt::ldigo, fmt::ldigo, + fmt::undef, fmt::undef, fmt::ldgo, fmt::tnc, + fmt::undef}, + test_rnn_sizes_t {3, 1, 5, 1, 4, 4, 4, 4}})); + TEST_P(lstm_forward_test_f32, TestsLSTM) {} CPU_INSTANTIATE_TEST_SUITE_P(TestLSTM, lstm_forward_test_f32, ::testing::Values(