Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-miotk committed Aug 1, 2024
1 parent 54ee912 commit dfdd052
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 14 deletions.
6 changes: 0 additions & 6 deletions src/plugins/intel_gpu/include/intel_gpu/primitives/lstm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,6 @@ struct lstm_seq : public primitive_base<lstm_seq> {
protected:
std::vector<input_info> get_dependencies() const override {
std::vector<input_info> ret;
/*
if (!cell.empty())
ret.push_back(cell);
*/
//ret.push_back(out1_prim_id);
//ret.push_back(out2_prim_id);
return ret;
}
};
Expand Down
8 changes: 4 additions & 4 deletions src/plugins/intel_gpu/src/graph/lstm_seq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ std::vector<layout> lstm_seq_inst::calc_output_layouts(lstm_seq_node const& node
if (impl_param.desc->output_data_types.size() > 0) {
OPENVINO_ASSERT(static_cast<bool>(impl_param.desc->output_data_types[0]) == false, "Output data type forcing is not supported for lstm_seq_node!");
}
OPENVINO_ASSERT(input_pshape_x.rank().get_length() == 4, "input_layout rank should be 4 on dynamic shape.");

if (input_pshape_x.is_static()) {
OPENVINO_ASSERT(input_pshape_x.rank().get_length() == 4, "input_layout rank should be 4 on static shape.");
}
int lstm_batch_size, lstm_seq_length, lstm_hidden_size;
if (input_pshape_x[input_pshape_x.size() - 3].is_static()) {
lstm_batch_size = input_pshape_x[0].get_length();
Expand Down Expand Up @@ -95,7 +96,6 @@ lstm_seq_inst::typed_primitive_inst(network& network, lstm_seq_node const& node)
"input format",
input_size.format.value,
"expected format",
format::bfyx,
format::fyxb);
format::bfyx);
}
} // namespace cldnn
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ KERNEL(lstm_seq)(
cur_history_idx = real_seq_length - 1 - i ;
}
hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] = (OUTPUT_TYPE)(gate_output[3]*ACTIVATION_H(cell_state[OUTPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)], ACTIVATION_PARAMS_H));
barrier(CLK_LOCAL_MEM_FENCE);
hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, cur_history_idx, hidden_idx)] = hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)];
} barrier(CLK_LOCAL_MEM_FENCE);
barrier(CLK_LOCAL_MEM_FENCE);
}

//printf("R is %p B is %p ; hidden history %p cell state %p batch %d\n", &R[0], &B[0], &hidden_history[0], &cell_state[0], b);
for(int i=0;i<real_seq_length;i++){
Expand Down
4 changes: 1 addition & 3 deletions src/plugins/intel_gpu/src/plugin/ops/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,7 @@ static void CreateLSTMSequenceOp(ProgramBuilder& p, const std::shared_ptr<ov::op
cldnn::input_info recurrent = inputs[5];
cldnn::input_info bias = inputs[6];

if (op->get_input_shape(0).size() != 3 ||
op->get_input_shape(1).size() != 3 ||
op->get_input_shape(2).size() != 3 || op->get_input_shape(3).size() != 1 \
if (op->get_input_shape(2).size() != 3 || op->get_input_shape(3).size() != 1 \
|| op->get_input_shape(4).size() != 3 || op->get_input_shape(5).size() != 3 || op->get_input_shape(6).size() != 2)
OPENVINO_THROW("Wrong input shapes for LSTMSequence op ", op->get_friendly_name());

Expand Down

0 comments on commit dfdd052

Please sign in to comment.