@@ -206,17 +206,30 @@ def forward(self, encoder_seq_proj, query, t):
206
206
class Decoder (nn .Module ):
207
207
def __init__ (self , n_mels , decoder_dims , lstm_dims ):
208
208
super ().__init__ ()
209
+
210
+ # List of all rnns to call `flatten_parameters()` on
211
+ self ._to_flatten = []
212
+
209
213
self .max_r = 20
210
214
self .r = None
211
215
self .generating = False
212
216
self .n_mels = n_mels
213
217
self .prenet = PreNet (n_mels )
214
218
self .attn_net = LSA (decoder_dims )
219
+
215
220
self .attn_rnn = nn .GRUCell (decoder_dims + decoder_dims // 2 , decoder_dims )
221
+ self ._to_flatten .append (self .attn_rnn )
222
+
216
223
self .rnn_input = nn .Linear (2 * decoder_dims , lstm_dims )
224
+
217
225
self .res_rnn1 = nn .LSTMCell (lstm_dims , lstm_dims )
218
226
self .res_rnn2 = nn .LSTMCell (lstm_dims , lstm_dims )
227
+ self ._to_flatten += [self .res_rnn1 , self .res_rnn2 ]
228
+
219
229
self .mel_proj = nn .Linear (lstm_dims , n_mels * self .max_r , bias = False )
230
+
231
+ # Avoid fragmentation of RNN parameters and associated warning
232
+ self ._flatten_parameters ()
220
233
221
234
def zoneout (self , prev , current , p = 0.1 ):
222
235
device = next (self .parameters ()).device # Use same device as parameters
@@ -225,6 +238,11 @@ def zoneout(self, prev, current, p=0.1):
225
238
226
239
def forward (self , encoder_seq , encoder_seq_proj , prenet_in ,
227
240
hidden_states , cell_states , context_vec , t ):
241
+
242
+ # Although we `_flatten_parameters()` on init, when using DataParallel
243
+ # the model gets replicated, making it no longer guaranteed that the
244
+ # weights are contiguous in GPU memory. Hence, we must call it again
245
+ self ._flatten_parameters ()
228
246
229
247
# Need this for reshaping mels
230
248
batch_size = encoder_seq .size (0 )
0 commit comments