Skip to content

Commit

Permalink
generic: sycl: Vanilla RNN FWD l2r
Browse files Browse the repository at this point in the history
  • Loading branch information
ShanoToni committed Jan 30, 2025
1 parent 2fba675 commit c2392c4
Show file tree
Hide file tree
Showing 17 changed files with 1,896 additions and 7 deletions.
7 changes: 7 additions & 0 deletions src/common/memory_storage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ struct memory_storage_t : public c_compatible {
/** returns shallow copy */
virtual std::unique_ptr<memory_storage_t> clone() const = 0;

/** returns shallow copy with a offset for accessor pointer for buffers
* to prevent use of sub-buffers where possible*/
virtual std::unique_ptr<memory_storage_t> clone_ptr_off(size_t offset) const {
assert(!"not expected");
return nullptr;
}

/** returns true if the pointer associated with the storage is NULL */
bool is_null() const {
void *ptr;
Expand Down
8 changes: 8 additions & 0 deletions src/gpu/generic/sycl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,11 @@ The implementation supports both forward and backward propagations.

* Supported formats: plain formats with up to 6 dimensions
* Supported data types: `f32`, `bf16`, `f16`, `s8`, `u8`

## RNN

The implementation supports forward propagation and vanilla RNN cell kind.

* Supported formats: `ldigo`, `ldgoi`
* Supported data types: `f32`, `bf16`, `f16`, `s8`, `u8`
* Supported direction: `left2right`
64 changes: 64 additions & 0 deletions src/gpu/generic/sycl/rnn/cell_common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*******************************************************************************
* Copyright 2019-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

// Common for RNN and LSTM cell execution

#include "gpu/generic/sycl/rnn/ref_rnn.hpp"

namespace dnnl {
namespace impl {
namespace gpu {
namespace generic {
namespace sycl {

using namespace dnnl::impl::utils;
using namespace rnn_utils;

status_t _ref_rnn_common_t::cell_execution(const cell_ctx_t &cell_struct) {

auto cell_layer = cell_struct.workspace.states_range(cell_struct.lay - 1,
cell_struct.lay - 1, cell_struct.dir, cell_struct.dir,
cell_struct.iter - 1, cell_struct.iter);

auto cell_iter = cell_struct.workspace.states_range(cell_struct.lay,
cell_struct.lay, cell_struct.dir, cell_struct.dir,
cell_struct.iter - 2, cell_struct.iter - 1);

auto scratch_gates = cell_struct.scratch.gates(0);

auto wei_layer
= cell_struct.user_data.wei_layer(cell_struct.lay, cell_struct.dir);
auto wei_iter
= cell_struct.user_data.wei_iter(cell_struct.lay, cell_struct.dir);

CHECK(gemm_primitive(cell_struct.engine, cell_struct.ctx, wei_layer,
cell_layer, scratch_gates, gemm_layer_fwd));

CHECK(gemm_primitive(cell_struct.engine, cell_struct.ctx, wei_iter,
cell_iter, scratch_gates, gemm_iter_fwd));

CHECK(rnn_bias(cell_struct.ctx, cell_struct.rnn.mb, cell_struct.rnn.dhc,
cell_struct.iter, cell_struct.lay, cell_struct.dir,
cell_struct.workspace, cell_struct.scratch, cell_struct.user_data));

return status::success;
}

} // namespace sycl
} // namespace generic
} // namespace gpu
} // namespace impl
} // namespace dnnl
Loading

0 comments on commit c2392c4

Please sign in to comment.