@@ -54,6 +54,9 @@ class CBHG(nn.Module):
54
54
def __init__ (self , K , in_channels , channels , proj_channels , num_highways ):
55
55
super ().__init__ ()
56
56
57
+ # List of all rnns to call `flatten_parameters()` on
58
+ self ._to_flatten = []
59
+
57
60
self .bank_kernels = [i for i in range (1 , K + 1 )]
58
61
self .conv1d_bank = nn .ModuleList ()
59
62
for k in self .bank_kernels :
@@ -78,8 +81,16 @@ def __init__(self, K, in_channels, channels, proj_channels, num_highways):
78
81
self .highways .append (hn )
79
82
80
83
self .rnn = nn .GRU (channels , channels , batch_first = True , bidirectional = True )
84
+ self ._to_flatten .append (self .rnn )
85
+
86
+ # Avoid fragmentation of RNN parameters and associated warning
87
+ self ._flatten_parameters ()
81
88
82
89
def forward (self , x ):
90
+ # Although we `_flatten_parameters()` on init, when using DataParallel
91
+ # the model gets replicated, making it no longer guaranteed that the
92
+ # weights are contiguous in GPU memory. Hence, we must call it again
93
+ self ._flatten_parameters ()
83
94
84
95
# Save these for later
85
96
residual = x
@@ -114,6 +125,10 @@ def forward(self, x):
114
125
x , _ = self .rnn (x )
115
126
return x
116
127
128
+ def _flatten_parameters (self ):
129
+ """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
130
+ to improve efficiency and avoid PyTorch yelling at us."""
131
+ [m .flatten_parameters () for m in self ._to_flatten ]
117
132
118
133
class PreNet (nn .Module ):
119
134
def __init__ (self , in_dims , fc1_dims = 256 , fc2_dims = 128 , dropout = 0.5 ):
@@ -189,10 +204,12 @@ def forward(self, encoder_seq_proj, query, t):
189
204
190
205
191
206
class Decoder (nn .Module ):
207
+ # Class variable because its value doesn't change between classes
208
+ # yet ought to be scoped by class because its a property of a Decoder
209
+ max_r = 20
192
210
def __init__ (self , n_mels , decoder_dims , lstm_dims ):
193
211
super ().__init__ ()
194
- self .max_r = 20
195
- self .r = None
212
+ self .register_buffer ('r' , torch .tensor (1 , dtype = torch .int ))
196
213
self .generating = False
197
214
self .n_mels = n_mels
198
215
self .prenet = PreNet (n_mels )
@@ -204,8 +221,7 @@ def __init__(self, n_mels, decoder_dims, lstm_dims):
204
221
self .mel_proj = nn .Linear (lstm_dims , n_mels * self .max_r , bias = False )
205
222
206
223
def zoneout (self , prev , current , p = 0.1 ):
207
- device = prev .device
208
- assert prev .device == current .device
224
+ device = next (self .parameters ()).device # Use same device as parameters
209
225
mask = torch .zeros (prev .size (), device = device ).bernoulli_ (p )
210
226
return prev * mask + current * (1 - mask )
211
227
@@ -279,17 +295,15 @@ def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels, ff
279
295
self .init_model ()
280
296
self .num_params ()
281
297
282
- # Unfortunately I have to put these settings into params in order to save
283
- # if anyone knows a better way of doing this please open an issue in the repo
284
- self .step = nn .Parameter (torch .zeros (1 ).long (), requires_grad = False )
285
- self .r = nn .Parameter (torch .tensor (0 ).long (), requires_grad = False )
286
-
287
- def set_r (self , r ):
288
- self .r .data = torch .tensor (r )
289
- self .decoder .r = r
298
+ self .register_buffer ('step' , torch .zeros (1 , dtype = torch .long ))
299
+
300
+ @property
301
+ def r (self ):
302
+ return self .decoder .r .item ()
290
303
291
- def get_r (self ):
292
- return self .r .item ()
304
+ @r .setter
305
+ def r (self , value ):
306
+ self .decoder .r = self .decoder .r .new_tensor (value , requires_grad = False )
293
307
294
308
def forward (self , x , m , generate_gta = False ):
295
309
device = next (self .parameters ()).device # use same device as parameters
@@ -351,7 +365,7 @@ def forward(self, x, m, generate_gta=False):
351
365
352
366
# For easy visualisation
353
367
attn_scores = torch .cat (attn_scores , 1 )
354
- attn_scores = attn_scores .cpu ().data .numpy ()
368
+ # attn_scores = attn_scores.cpu().data.numpy()
355
369
356
370
return mel_outputs , linear , attn_scores
357
371
@@ -430,11 +444,17 @@ def get_step(self):
430
444
return self .step .data .item ()
431
445
432
446
def reset_step (self ):
433
- self .step = nn .Parameter (torch .zeros (1 ).long (), requires_grad = False )
447
+ assert self .step is not None
448
+ device = next (self .parameters ()).device # use same device as parameters
449
+ # assignment to parameters or buffers is overloaded, updates internal dict entry
450
+ self .step = torch .zeros (1 , dtype = torch .long , device = device )
434
451
435
- def checkpoint (self , path ):
452
+ def checkpoint (self , path , optimizer ):
453
+ # Optimizer can be given as an argument because checkpoint function is
454
+ # only useful in context of already existing training process.
436
455
k_steps = self .get_step () // 1000
437
456
self .save (f'{ path } /checkpoint_{ k_steps } k_steps.pyt' )
457
+ torch .save (optimizer .get_state (), f'{ path } /checkpoint_{ k_steps } k_steps_optim.pyt' )
438
458
439
459
def log (self , path , msg ):
440
460
with open (path , 'a' ) as f :
@@ -447,17 +467,21 @@ def restore(self, path):
447
467
else :
448
468
print (f'\n Loading Weights: "{ path } "\n ' )
449
469
self .load (path )
450
- self .decoder .r = self .r .item ()
451
470
452
- def load (self , path , device = 'cpu' ):
453
- # because PyTorch places on CPU by default, we follow those semantics by using CPU as default.
471
+ def load (self , path ):
472
+ # Use device of model params as location for loaded state
473
+ device = next (self .parameters ()).device
454
474
self .load_state_dict (torch .load (path , map_location = device ), strict = False )
455
475
456
476
def save (self , path ):
477
+ # No optimizer argument because saving a model should not include data
478
+ # only relevant in the training process - it should only be properties
479
+ # of the model itself. Let caller take care of saving optimzier state.
457
480
torch .save (self .state_dict (), path )
458
481
459
482
def num_params (self , print_out = True ):
460
483
parameters = filter (lambda p : p .requires_grad , self .parameters ())
461
484
parameters = sum ([np .prod (p .size ()) for p in parameters ]) / 1_000_000
462
485
if print_out :
463
486
print ('Trainable Parameters: %.3fM' % parameters )
487
+ return parameters
0 commit comments