Skip to content

Commit

Permalink
add load_checkpoint func to tts models
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Jan 20, 2021
1 parent 5c87753 commit 1faf565
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
8 changes: 8 additions & 0 deletions TTS/tts/models/glow_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,11 @@ def preprocess(self, y, y_lengths, y_max_length, attn=None):

def store_inverse(self):
self.decoder.store_inverse()

def load_checkpoint(self, config, checkpoint_path, eval=False):
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
if eval:
self.eval()
self.store_inverse()
assert not self.training
7 changes: 7 additions & 0 deletions TTS/tts/models/speedy_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,10 @@ def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument
y_lengths = o_dr.sum(1)
o_de, attn= self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
return o_de, attn

def load_checkpoint(self, config, checkpoint_path, eval=False):
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
if eval:
self.eval()
assert not self.training
8 changes: 8 additions & 0 deletions TTS/tts/models/tacotron_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@ def forward(self):
def inference(self):
pass

def load_checkpoint(self, config, checkpoint_path, eval=False):
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
self.decoder.set_r(state['r'])
if eval:
self.eval()
assert not self.training

#############################
# COMMON COMPUTE FUNCTIONS
#############################
Expand Down

0 comments on commit 1faf565

Please sign in to comment.