@@ -263,24 +263,21 @@ vector<Tensor> SplitOutput(size_t num, size_t dim,
263263 return outputs;
264264};
265265
266- std::vector<std::vector< Tensor>> GpuRNNForwardTraining (const CudnnRNNHandle &crh, const vector< Tensor> &inputs , const Tensor &W) {
267- DataType dtype = inputs. at ( 0 ) .data_type ();
268- auto dev = inputs .at (0 ).device ();
266+ std::vector<Tensor> GpuRNNForwardTraining (const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx , const Tensor &W) {
267+ DataType dtype = input .data_type ();
268+ auto dev = input .at (0 ).device ();
269269
270- CHECK_GT (inputs.size (), 1u + crh.has_cell_ );
271- size_t num_x = inputs.size () - crh.has_cell_ - 1 ;
272- Tensor input = MergeInputs (num_x, inputs);
273270
274271 Shape outshape{input.Size () * crh.hidden_size_ / crh.input_size_ * crh.num_directions_ };
275272 Tensor output (outshape, dev, dtype);
276273 // LOG(INFO) << "output size " << output.Size();
277- Tensor hx = inputs. at (num_x);
274+
278275 Shape state_shape{crh.num_stacks_ * crh.num_directions_ , crh.batch_size_ , crh.hidden_size_ };
276+ CHECK_EQ (hx.shape (), state_shape);
279277 Tensor hy (state_shape, dev, dtype);
280278
281- Tensor cy, cx ;
279+ Tensor cy;
282280 if (crh.has_cell_ ) {
283- cx = inputs.at (num_x + 1 );
284281 cy.ResetLike (hy);
285282 }
286283
@@ -330,39 +327,23 @@ std::vector<std::vector<Tensor>> GpuRNNForwardTraining(const CudnnRNNHandle &crh
330327 },
331328 {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace});
332329
333- auto outputs =
334- SplitOutput (num_x, crh.hidden_size_ * crh.num_directions_ , inputs, output);
335- outputs.push_back (hy);
336- if (crh.has_cell_ ) outputs.push_back (cy);
337-
338- std::vector<Tensor> cache;
339- cache.push_back (input);
340- cache.push_back (output);
341- cache.push_back (hx);
342- cache.push_back (cx);
343- cache.push_back (W);
344-
345- return {outputs, cache};
330+ return {output, hy, cy};
346331};
347332
348- std::vector<Tensor> GpuRNNForwardInference (const CudnnRNNHandle &crh, const vector<Tensor> &inputs, const Tensor &W) {
349- DataType dtype = inputs.at (0 ).data_type ();
350- auto dev = inputs.at (0 ).device ();
351-
352- CHECK_GT (inputs.size (), 1u + crh.has_cell_ );
353- size_t num_x = inputs.size () - crh.has_cell_ - 1 ;
354- Tensor input = MergeInputs (num_x, inputs);
333+ std::vector<Tensor> GpuRNNForwardInference (const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) {
334+ DataType dtype = input.data_type ();
335+ auto dev = input.device ();
355336
356337 Shape outshape{input.Size () * crh.hidden_size_ / crh.input_size_ * crh.num_directions_ };
357338 Tensor output (outshape, dev, dtype);
358339 // LOG(INFO) << "output size " << output.Size();
359- Tensor hx = inputs. at (num_x);
340+
360341 Shape state_shape{crh.num_stacks_ * crh.num_directions_ , crh.batch_size_ , crh.hidden_size_ };
342+ CHECK_EQ (hx.shape (), state_shape);
361343 Tensor hy (state_shape, dev, dtype);
362344
363- Tensor cy, cx ;
345+ Tensor cy;
364346 if (crh.has_cell_ ) {
365- cx = inputs.at (num_x + 1 );
366347 cy.ResetLike (hy);
367348 }
368349
@@ -405,15 +386,10 @@ std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const vect
405386 // clang-format on
406387 }, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace});
407388
408- auto outputs =
409- SplitOutput (num_x, crh.hidden_size_ * crh.num_directions_ , inputs, output);
410- outputs.push_back (hy);
411- if (crh.has_cell_ ) outputs.push_back (cy);
412-
413- return outputs;
389+ return {output, hy, cy};
414390};
415391
416- std::pair< vector<Tensor>, Tensor> GpuRNNBackward (const CudnnRNNHandle &crh, const vector<Tensor> &grads , const vector<Tensor> &cache) {
392+ std::vector<Tensor> GpuRNNBackward (const CudnnRNNHandle &crh, const vector<Tensor> &dY, const Tensor &dh, const Tensor &dc , const vector<Tensor> &cache) {
417393 const Tensor x = cache[0 ];
418394 const Tensor y = cache[1 ];
419395 const Tensor hx = cache[2 ];
@@ -423,24 +399,24 @@ std::pair<vector<Tensor>, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, cons
423399 auto dev = y.device ();
424400 auto dtype = y.data_type ();
425401
426- CHECK_GT (grads.size (), 1u + crh.has_cell_ );
427- size_t num_dy = grads.size () - crh.has_cell_ - 1 ;
428- CHECK_EQ (num_dy, crh.seq_length_ );
429- const Tensor dy = MergeInputs (num_dy, grads);
430- CHECK_EQ (dy.Size (), y.Size ());
431- const Tensor dhy = grads.at (num_dy);
432- Tensor dcy;
433- if (crh.has_cell_ )
434- dcy = grads.at (num_dy + 1 );
402+
403+ CHECK_EQ (dY.Size (), y.Size ());
404+
435405
436406 Shape xshape{y.Size () * crh.input_size_ / crh.hidden_size_ / crh.num_directions_ };
407+ CHECK_EQ (x.shape (), xshape)
437408 Tensor dx (xshape, dev, dtype);
409+
438410 Tensor dw (W.shape (), dev, dtype);
411+
439412 Shape state_shape{crh.num_stacks_ * crh.num_directions_ , crh.batch_size_ , crh.hidden_size_ };
413+ CHECK_EQ (hx.shape (), state_shape)
440414 Tensor dhx (state_shape, dev, dtype);
415+
441416 Tensor dcx;
442417 if (crh.has_cell_ )
443418 dcx.ResetLike (dhx);
419+
444420 dw.SetValue (0 .0f );
445421 Block *yb = y.block (), *dyb = dy.block (), *dhyb = dhy.block (),
446422 *dcyb = dcy.block (), *xb = x.block (), *cxb = cx.block (),
@@ -483,12 +459,7 @@ std::pair<vector<Tensor>, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, cons
483459 {yb, dyb, dhyb, dcyb, xb, wb, wspace, rspace},
484460 {dxb, dwb, dhxb, dcxb, wspace, rspace});
485461
486- auto data_grads = SplitOutput (num_dy, crh.input_size_ , grads, dx);
487- data_grads.push_back (dhx);
488- if (crh.has_cell_ )
489- data_grads.push_back (dcx);
490-
491- return std::make_pair (data_grads, dw);
462+ return {dx, dhx, dcx, dw};
492463};
493464
494465#endif // USE_CUDNN
0 commit comments