Skip to content

Commit 95b4377

Browse files
committed
SINGA-386 Implement RNN operation for autograd
- redesign some APIs to adapt to autograd
1 parent b176cb4 commit 95b4377

File tree

3 files changed

+141
-70
lines changed

3 files changed

+141
-70
lines changed

python/singa/autograd.py

Lines changed: 113 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -943,34 +943,135 @@ class _RNN(Operation):
943943
def __init__(self, handle):
944944
self.handle = handle
945945

946-
def forward(self, X, W):
946+
def forward(self, X, h0, c0, W):
947+
# X of shape (seq_len, batch, input_size)
948+
# h0_c0: (h0, c0) if lstm, else (h0,)
949+
# h0, c0 of shape (num_layers * num_directions, batch, hidden_size)
950+
if c0 is None:
951+
assert self.rnn_mode != 'lstm'
952+
c0= CTensor([]) # CTensor([]) and Tensor cx are the same?
947953

948954
if self.handle.device_id == -1:
949955
raise NotImplementedError
950956
else:
951957
if training:
952-
out, self.cache = singa.GpuRNNForwardTraining(
953-
self.handle, X, W)
958+
Y, hout, cout = singa.GpuRNNForwardTraining(
959+
self.handle, X, h0, c0, W)
960+
self.cache=(X, Y, h0, c0, W)
954961
else:
955-
out = singa.GpuRNNForwardInference(self.handle, X, W)
956-
return out
962+
Y, hout, cout = singa.GpuRNNForwardInference(
963+
self.handle, X, h0, c0, W)
964+
965+
# Y of shape (seq_len, batch, hidden_size * num_directions)
966+
# hout_cout: (hout, cout) if lstm, else (hout,)
967+
# hout, cout of shape (num_layers * num_directions, batch,
968+
# hidden_size)
969+
oututs= 1dTo3d(Y)
970+
971+
if self.rnn_mode != 'lstm':
972+
return outputs, hout
973+
else:
974+
return outputs, hout, cout
957975

958-
def backward(self, dY):
976+
def backward(self, dY, dh, dc=CTensor([])):
959977
assert training is True and hasattr(
960978
self, 'cache'), 'Please set training as True before do BP. '
961979

962-
if dY.device().id() != self.handle.device_id:
963-
dY.ToDevice(self.inputs[0].device())
980+
dY_1d= 3dTo1d(dY)
981+
982+
if dY_1d.device().id() != self.handle.device_id:
983+
dY_1d.ToDevice(self.cache[0].device())
964984

965985
if self.handle.device_id == -1:
966986
raise NotImplementedError
967987
else:
968-
dX, dW = singa.GpuRNNBackward(self.handle, dY, self.cache)
969-
return dX, dW
988+
dX_1d, dhout, dcout, dW = singa.GpuRNNBackward(
989+
self.handle, dY_1d, dh, dc, self.cache)
970990

991+
dX = 1dTo3d(dX_1d)
971992

972-
def rnn():
973-
pass
993+
if self.rnn_mode != 'lstm':
994+
return dX, dhout, dW
995+
else:
996+
return dX, dhout, dcout, dW
997+
998+
999+
def rnn(handle, x, h0, c0, W):
1000+
return _RNN(handle)(x, h0, c0, W)
9741001

9751002

9761003
class RNN(Layer):
1004+
1005+
def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False, rnn_mode='tanh'):
1006+
self.input_size = input_size
1007+
self.hidden_size = hidden_size
1008+
self.num_layers = num_layers
1009+
self.bias = bias
1010+
self.batch_first = batch_first
1011+
self.dropout = dropout
1012+
self.bidirectional = bidirectional
1013+
self.rnn_mode = rnn_mode
1014+
1015+
if bias is not True or batch_first is not False:
1016+
raise NotImplementedError
1017+
1018+
mult = 1
1019+
if self.rnn_mode == 'tanh' or self.rnn_mode == 'relu':
1020+
mult *= 1
1021+
elif self.rnn_mode == 'lstm':
1022+
mult *= 4
1023+
elif self.rnn_mode == 'gru':
1024+
mult *= 3
1025+
else:
1026+
raise ValueError
1027+
1028+
if self.bidirectional:
1029+
mult *= 2
1030+
1031+
for k in range(num_layers):
1032+
if k == 1:
1033+
w_size = self.hidden_size * \
1034+
(self.input_size + self.hidden_size + 2)
1035+
else:
1036+
w_size = self.hidden_size * \
1037+
(self.hidden_size + self.hidden_size + 2)
1038+
W_Size *= mult * w_size
1039+
1040+
self.W_Size = W_Size
1041+
self.W = Tensor(shape=(W_Size,), requires_grad=True, stores_grad=True)
1042+
self.W.uniform(0.0, 1.0)
1043+
1044+
def __call__(self, inputs, h0, c0=None):
1045+
# inputs of shape (seq_len, batch, input_size)
1046+
# h0_c0: (h0, c0) if lstm, else (h0,)
1047+
# h0, c0 of shape (num_layers * num_directions, batch, hidden_size)
1048+
1049+
self.device_check(inputs, h0, self.W)
1050+
1051+
if self.rnn_mode == 'lstm':
1052+
assert c0 is not None, 'Please input c0.'
1053+
self.device_check(h0, c0)
1054+
1055+
self.handle = signa.CudnnRNNHandle(inputs.data, *SOME_PARAMETERS*)
1056+
self.handle.device_id = inputs.device.id()
1057+
1058+
X= 3dTo1d(inputs)
1059+
outputs = rnn(self.handle, X, h0, c0, self.W)
1060+
return outputs
1061+
1062+
def 3dTo1d(self, inputs):
1063+
pass
1064+
1065+
def 1dTo3d(self, *args):
1066+
pass
1067+
1068+
class LSTM(RNN):
1069+
1070+
def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False):
1071+
super(LSTM, self).__init__(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectionalrnn_mode='lstm')
1072+
1073+
1074+
class GRU(RNN):
1075+
1076+
def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False):
1077+
super(GRU, self).__init__(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectionalrnn_mode='gru')

src/model/operation/rnn.cc

100644100755
Lines changed: 25 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -263,24 +263,21 @@ vector<Tensor> SplitOutput(size_t num, size_t dim,
263263
return outputs;
264264
};
265265

266-
std::vector<std::vector<Tensor>> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const vector<Tensor> &inputs, const Tensor &W) {
267-
DataType dtype = inputs.at(0).data_type();
268-
auto dev = inputs.at(0).device();
266+
std::vector<Tensor> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) {
267+
DataType dtype = input.data_type();
268+
auto dev = input.at(0).device();
269269

270-
CHECK_GT(inputs.size(), 1u + crh.has_cell_);
271-
size_t num_x = inputs.size() - crh.has_cell_ - 1;
272-
Tensor input = MergeInputs(num_x, inputs);
273270

274271
Shape outshape{input.Size() * crh.hidden_size_ / crh.input_size_ * crh.num_directions_};
275272
Tensor output(outshape, dev, dtype);
276273
// LOG(INFO) << "output size " << output.Size();
277-
Tensor hx = inputs.at(num_x);
274+
278275
Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_};
276+
CHECK_EQ(hx.shape(), state_shape);
279277
Tensor hy(state_shape, dev, dtype);
280278

281-
Tensor cy, cx;
279+
Tensor cy;
282280
if (crh.has_cell_) {
283-
cx = inputs.at(num_x + 1);
284281
cy.ResetLike(hy);
285282
}
286283

@@ -330,39 +327,23 @@ std::vector<std::vector<Tensor>> GpuRNNForwardTraining(const CudnnRNNHandle &crh
330327
},
331328
{inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace});
332329

333-
auto outputs =
334-
SplitOutput(num_x, crh.hidden_size_ * crh.num_directions_, inputs, output);
335-
outputs.push_back(hy);
336-
if (crh.has_cell_) outputs.push_back(cy);
337-
338-
std::vector<Tensor> cache;
339-
cache.push_back(input);
340-
cache.push_back(output);
341-
cache.push_back(hx);
342-
cache.push_back(cx);
343-
cache.push_back(W);
344-
345-
return {outputs, cache};
330+
return {output, hy, cy};
346331
};
347332

348-
std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const vector<Tensor> &inputs, const Tensor &W) {
349-
DataType dtype = inputs.at(0).data_type();
350-
auto dev = inputs.at(0).device();
351-
352-
CHECK_GT(inputs.size(), 1u + crh.has_cell_);
353-
size_t num_x = inputs.size() - crh.has_cell_ - 1;
354-
Tensor input = MergeInputs(num_x, inputs);
333+
std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) {
334+
DataType dtype = input.data_type();
335+
auto dev = input.device();
355336

356337
Shape outshape{input.Size() * crh.hidden_size_ / crh.input_size_ * crh.num_directions_};
357338
Tensor output(outshape, dev, dtype);
358339
// LOG(INFO) << "output size " << output.Size();
359-
Tensor hx = inputs.at(num_x);
340+
360341
Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_};
342+
CHECK_EQ(hx.shape(), state_shape);
361343
Tensor hy(state_shape, dev, dtype);
362344

363-
Tensor cy, cx;
345+
Tensor cy;
364346
if (crh.has_cell_) {
365-
cx = inputs.at(num_x + 1);
366347
cy.ResetLike(hy);
367348
}
368349

@@ -405,15 +386,10 @@ std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const vect
405386
// clang-format on
406387
}, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace});
407388

408-
auto outputs =
409-
SplitOutput(num_x, crh.hidden_size_ * crh.num_directions_, inputs, output);
410-
outputs.push_back(hy);
411-
if (crh.has_cell_) outputs.push_back(cy);
412-
413-
return outputs;
389+
return {output, hy, cy};
414390
};
415391

416-
std::pair<vector<Tensor>, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector<Tensor> &grads, const vector<Tensor> &cache) {
392+
std::vector<Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector<Tensor> &dY, const Tensor &dh, const Tensor &dc, const vector<Tensor> &cache) {
417393
const Tensor x = cache[0];
418394
const Tensor y = cache[1];
419395
const Tensor hx = cache[2];
@@ -423,24 +399,24 @@ std::pair<vector<Tensor>, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, cons
423399
auto dev = y.device();
424400
auto dtype = y.data_type();
425401

426-
CHECK_GT(grads.size(), 1u + crh.has_cell_);
427-
size_t num_dy = grads.size() - crh.has_cell_ - 1;
428-
CHECK_EQ(num_dy, crh.seq_length_);
429-
const Tensor dy = MergeInputs(num_dy, grads);
430-
CHECK_EQ(dy.Size(), y.Size());
431-
const Tensor dhy = grads.at(num_dy);
432-
Tensor dcy;
433-
if (crh.has_cell_)
434-
dcy = grads.at(num_dy + 1);
402+
403+
CHECK_EQ(dY.Size(), y.Size());
404+
435405

436406
Shape xshape{y.Size() * crh.input_size_ / crh.hidden_size_ / crh.num_directions_};
407+
CHECK_EQ(x.shape(), xshape)
437408
Tensor dx(xshape, dev, dtype);
409+
438410
Tensor dw(W.shape(), dev, dtype);
411+
439412
Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_};
413+
CHECK_EQ(hx.shape(), state_shape)
440414
Tensor dhx(state_shape, dev, dtype);
415+
441416
Tensor dcx;
442417
if (crh.has_cell_)
443418
dcx.ResetLike(dhx);
419+
444420
dw.SetValue(0.0f);
445421
Block *yb = y.block(), *dyb = dy.block(), *dhyb = dhy.block(),
446422
*dcyb = dcy.block(), *xb = x.block(), *cxb = cx.block(),
@@ -483,12 +459,7 @@ std::pair<vector<Tensor>, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, cons
483459
{yb, dyb, dhyb, dcyb, xb, wb, wspace, rspace},
484460
{dxb, dwb, dhxb, dcxb, wspace, rspace});
485461

486-
auto data_grads = SplitOutput(num_dy, crh.input_size_, grads, dx);
487-
data_grads.push_back(dhx);
488-
if (crh.has_cell_)
489-
data_grads.push_back(dcx);
490-
491-
return std::make_pair(data_grads, dw);
462+
return {dx, dhx, dcx, dw};
492463
};
493464

494465
#endif // USE_CUDNN

src/model/operation/rnn.h

100644100755
Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,18 +69,17 @@ class CudnnRNNHandle: public RNNHandle {
6969
Tensor reserve_space_;
7070
Tensor dropout_state_;
7171
};
72-
7372
Tensor MergeInputs(size_t num, const vector<Tensor> &in);
7473

7574
vector<Tensor> SplitOutput(size_t num, size_t dim,
7675
const vector<Tensor> &in,
7776
const Tensor output);
7877

79-
std::vector<std::vector<Tensor>> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const vector<Tensor> &inputs, const Tensor &W);
78+
std::vector<Tensor> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) ;
8079

81-
std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const vector<Tensor> &inputs, const Tensor &W);
80+
std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W);
8281

83-
std::pair<vector<Tensor>, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector<Tensor> &grads, const vector<Tensor> &cache);
82+
std::vector<Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector<Tensor> &dY, const Tensor &dh, const Tensor &dc, const vector<Tensor> &cache);
8483

8584
#endif // USE_CUDNN
8685

0 commit comments

Comments
 (0)