Skip to content

Commit c6deada

Browse files
author
Ryan Butler
committed
Fixed flatten_parameters() warning for models/tacotron.py
1 parent 94302af commit c6deada

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

models/tacotron.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,17 +206,30 @@ def forward(self, encoder_seq_proj, query, t):
206206
class Decoder(nn.Module):
207207
def __init__(self, n_mels, decoder_dims, lstm_dims):
208208
super().__init__()
209+
210+
# List of all rnns to call `flatten_parameters()` on
211+
self._to_flatten = []
212+
209213
self.max_r = 20
210214
self.r = None
211215
self.generating = False
212216
self.n_mels = n_mels
213217
self.prenet = PreNet(n_mels)
214218
self.attn_net = LSA(decoder_dims)
219+
215220
self.attn_rnn = nn.GRUCell(decoder_dims + decoder_dims // 2, decoder_dims)
221+
self._to_flatten.append(self.attn_rnn)
222+
216223
self.rnn_input = nn.Linear(2 * decoder_dims, lstm_dims)
224+
217225
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
218226
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
227+
self._to_flatten += [self.res_rnn1, self.res_rnn2]
228+
219229
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
230+
231+
# Avoid fragmentation of RNN parameters and associated warning
232+
self._flatten_parameters()
220233

221234
def zoneout(self, prev, current, p=0.1):
222235
device = next(self.parameters()).device # Use same device as parameters
@@ -225,6 +238,11 @@ def zoneout(self, prev, current, p=0.1):
225238

226239
def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
227240
hidden_states, cell_states, context_vec, t):
241+
242+
# Although we `_flatten_parameters()` on init, when using DataParallel
243+
# the model gets replicated, making it no longer guaranteed that the
244+
# weights are contiguous in GPU memory. Hence, we must call it again
245+
self._flatten_parameters()
228246

229247
# Need this for reshaping mels
230248
batch_size = encoder_seq.size(0)

0 commit comments

Comments
 (0)