diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 8c8ad3dd30..fbae32162d 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -410,7 +410,7 @@ def tts( # run vocoder model # [1, T, C] waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device)) - if waveform.device != torch.device("cpu") and not use_gl: + if torch.is_tensor(waveform) and waveform.device != torch.device("cpu") and not use_gl: waveform = waveform.cpu() if not use_gl: waveform = waveform.numpy() @@ -474,7 +474,7 @@ def tts( # run vocoder model # [1, T, C] waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device)) - if waveform.device != torch.device("cpu"): + if torch.is_tensor(waveform) and waveform.device != torch.device("cpu"): waveform = waveform.cpu() if not use_gl: waveform = waveform.numpy()