diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index b55ba1b1e..c978e4fa5 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -223,3 +223,11 @@ def preprocess(self, y, y_lengths, y_max_length, attn=None): def store_inverse(self): self.decoder.store_inverse() + + def load_checkpoint(self, config, checkpoint_path, eval=False): + state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + self.load_state_dict(state['model']) + if eval: + self.eval() + self.store_inverse() + assert not self.training diff --git a/TTS/tts/models/speedy_speech.py b/TTS/tts/models/speedy_speech.py index 2e7d0a5f7..7f5c660e9 100644 --- a/TTS/tts/models/speedy_speech.py +++ b/TTS/tts/models/speedy_speech.py @@ -190,3 +190,10 @@ def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument y_lengths = o_dr.sum(1) o_de, attn= self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) return o_de, attn + + def load_checkpoint(self, config, checkpoint_path, eval=False): + state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + self.load_state_dict(state['model']) + if eval: + self.eval() + assert not self.training diff --git a/TTS/tts/models/tacotron_abstract.py b/TTS/tts/models/tacotron_abstract.py index 54c46be2d..0a63b8717 100644 --- a/TTS/tts/models/tacotron_abstract.py +++ b/TTS/tts/models/tacotron_abstract.py @@ -121,6 +121,14 @@ def forward(self): def inference(self): pass + def load_checkpoint(self, config, checkpoint_path, eval=False): + state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + self.load_state_dict(state['model']) + self.decoder.set_r(state['r']) + if eval: + self.eval() + assert not self.training + ############################# # COMMON COMPUTE FUNCTIONS #############################