Skip to content

Commit

Permalink
fix problem with rnn layer
Browse files Browse the repository at this point in the history
  • Loading branch information
allnes committed Feb 13, 2024
1 parent 4f0012d commit dfb3fa6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ void ACLScheduler::set_num_threads(unsigned int num_threads) {}

void ACLScheduler::schedule_custom(ICPPKernel *kernel, const Hints &hints, const Window &window, ITensorPack &tensors) {
const Window & max_window = window;
const unsigned int num_iterations =
max_window.num_iterations(hints.split_dimension()) == 1 ? 1 : max_window.num_iterations_total();
const unsigned int num_iterations = max_window.num_iterations(hints.split_dimension());
const auto _num_threads = std::min(num_iterations, static_cast<unsigned int>(parallel_get_num_threads()));

if (num_iterations < 1) {
Expand Down
4 changes: 4 additions & 0 deletions src/plugins/intel_cpu/src/nodes/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,10 @@ void RNN::configurePortDataTypes() {
// onednn doesn't have fp16 instance
inDataTypes[xIdx] = outDataTypes[yIdx] = outDataTypes[hoIdx] = inDataTypes[hIdx] = memory::data_type::f32; // required by oneDNN.

// OneDNN unsupported fp16 precision for this layer
if (cell_type == dnnl::algorithm::vanilla_augru && inDataTypes[aIdx] == memory::data_type::f16)
inDataTypes[aIdx] = memory::data_type::f32;

if (outDataTypes[yIdx] == memory::data_type::bf16 && one_of(inDataTypes[xIdx], memory::data_type::s8, memory::data_type::u8))
outDataTypes[yIdx] = memory::data_type::f32; // oneDNN does not support bf16 output precision for quantized rnn primitive yet
}
Expand Down

0 comments on commit dfb3fa6

Please sign in to comment.