Skip to content

Commit

Permalink
refactor: use @torch.inference_mode() instead of @torch.no_grad()
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Dec 26, 2024
1 parent 68aeb38 commit 62b657e
Show file tree
Hide file tree
Showing 28 changed files with 41 additions and 40 deletions.
2 changes: 1 addition & 1 deletion TTS/bin/compute_attention_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@

# compute attentions
file_paths = []
with torch.no_grad():
with torch.inference_mode():
for data in tqdm(loader):
# setup input data
text_input = data[0]
Expand Down
2 changes: 1 addition & 1 deletion TTS/bin/extract_tts_spectrograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def format_data(data):
)


@torch.no_grad()
@torch.inference_mode()
def inference(
model_name: str,
model: BaseTTS,
Expand Down
2 changes: 1 addition & 1 deletion TTS/bin/train_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False):
def evaluation(model, criterion, data_loader, global_step):
eval_loss = 0
for _, data in enumerate(data_loader):
with torch.no_grad():
with torch.inference_mode():
# setup input data
inputs, labels = data

Expand Down
4 changes: 2 additions & 2 deletions TTS/encoder/models/base_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def get_torch_mel_spectrogram_class(self, audio_config):
),
)

@torch.no_grad()
@torch.inference_mode()
def inference(self, x, l2_norm=True):
return self.forward(x, l2_norm)

@torch.no_grad()
@torch.inference_mode()
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
"""
Generate embeddings for a batch of utterances
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/delightful_tts/acoustic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def forward(
"spk_emb": speaker_embedding,
}

@torch.no_grad()
@torch.inference_mode()
def inference(
self,
tokens: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/xtts/hifigan_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def forward(self, latents, g=None):
o = self.waveform_decoder(z, g=g)
return o

@torch.no_grad()
@torch.inference_mode()
def inference(self, c, g):
"""
Args:
Expand Down
4 changes: 2 additions & 2 deletions TTS/tts/layers/xtts/stream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, **kwargs):


class NewGenerationMixin(GenerationMixin):
@torch.no_grad()
@torch.inference_mode()
def generate( # noqa: PLR0911
self,
inputs: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -662,7 +662,7 @@ def typeerror():
**model_kwargs,
)

@torch.no_grad()
@torch.inference_mode()
def sample_stream(
self,
input_ids: torch.LongTensor,
Expand Down
4 changes: 2 additions & 2 deletions TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels
)
return losses

@torch.no_grad()
@torch.inference_mode()
def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613
test_audios = {}
if self.config.test_sentences:
Expand Down Expand Up @@ -335,7 +335,7 @@ def on_init_end(self, trainer): # pylint: disable=W0613

WeightsFileHandler.add_pre_callback(callback_clearml_load_save)

@torch.no_grad()
@torch.inference_mode()
def inference(
self,
x,
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/align_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def forward(
}
return outputs

@torch.no_grad()
@torch.inference_mode()
def inference(self, x, aux_input={"d_vectors": None}): # pylint: disable=unused-argument
"""
Shapes:
Expand Down
6 changes: 3 additions & 3 deletions TTS/tts/models/delightful_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def forward(
model_outputs["slice_ids"] = slice_ids
return model_outputs

@torch.no_grad()
@torch.inference_mode()
def inference(
self, x, aux_input={"d_vectors": None, "speaker_ids": None}, pitch_transform=None, energy_transform=None
):
Expand All @@ -646,7 +646,7 @@ def inference(
model_outputs["model_outputs"] = vocoder_output
return model_outputs

@torch.no_grad()
@torch.inference_mode()
def inference_spec_decoder(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
encoder_outputs = self.acoustic_model.inference(
tokens=x,
Expand Down Expand Up @@ -1018,7 +1018,7 @@ def synthesize_with_gl(self, text: str, speaker_id, d_vector):
}
return return_dict

@torch.no_grad()
@torch.inference_mode()
def test_run(self, assets) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`.
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/forward_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def forward(
}
return outputs

@torch.no_grad()
@torch.inference_mode()
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
"""Model's inference pass.
Expand Down
10 changes: 5 additions & 5 deletions TTS/tts/models/glow_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def forward(
}
return outputs

@torch.no_grad()
@torch.inference_mode()
def inference_with_MAS(
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
Expand Down Expand Up @@ -318,7 +318,7 @@ def inference_with_MAS(
}
return outputs

@torch.no_grad()
@torch.inference_mode()
def decoder_inference(
self, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
Expand All @@ -341,7 +341,7 @@ def decoder_inference(
outputs["logdet"] = logdet
return outputs

@torch.no_grad()
@torch.inference_mode()
def inference(
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
Expand Down Expand Up @@ -464,7 +464,7 @@ def train_log(
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, self.ap.sample_rate)

@torch.no_grad()
@torch.inference_mode()
def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion)

Expand All @@ -473,7 +473,7 @@ def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, s
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)

@torch.no_grad()
@torch.inference_mode()
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`.
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/neuralhmm_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def _format_aux_input(self, aux_input: Dict, default_input_dict):
return format_aux_input(default_input_dict, aux_input)
return default_input_dict

@torch.no_grad()
@torch.inference_mode()
def inference(
self,
text: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/overflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _format_aux_input(self, aux_input: Dict, default_input_dict):
return format_aux_input(default_input_dict, aux_input)
return default_input_dict

@torch.no_grad()
@torch.inference_mode()
def inference(
self,
text: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def forward( # pylint: disable=dangerous-default-value
)
return outputs

@torch.no_grad()
@torch.inference_mode()
def inference(self, text_input, aux_input=None):
aux_input = self._format_aux_input(aux_input)
inputs = self.embedding(text_input)
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/tacotron2.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def forward( # pylint: disable=dangerous-default-value
)
return outputs

@torch.no_grad()
@torch.inference_mode()
def inference(self, text, aux_input=None):
"""Forward pass for inference with no Teacher-Forcing.
Expand Down
8 changes: 4 additions & 4 deletions TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def _set_x_lengths(x, aux_input):
return aux_input["x_lengths"]
return torch.tensor(x.shape[1:2]).to(x.device)

@torch.no_grad()
@torch.inference_mode()
def inference(
self,
x,
Expand Down Expand Up @@ -1014,7 +1014,7 @@ def inference(
}
return outputs

@torch.no_grad()
@torch.inference_mode()
def inference_voice_conversion(
self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None
):
Expand Down Expand Up @@ -1209,7 +1209,7 @@ def train_log(
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, self.ap.sample_rate)

@torch.no_grad()
@torch.inference_mode()
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
return self.train_step(batch, criterion, optimizer_idx)

Expand Down Expand Up @@ -1266,7 +1266,7 @@ def get_aux_input_from_test_sentences(self, sentence_info):
"language_name": language_name,
}

@torch.no_grad()
@torch.inference_mode()
def test_run(self, assets) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`.
Expand Down
1 change: 1 addition & 0 deletions TTS/tts/utils/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def init_encoder(
)
self.encoder_ap = AudioProcessor(**self.encoder_config.audio)

@torch.inference_mode()
def compute_embedding_from_clip(
self, wav_file: Union[Union[str, os.PathLike[Any]], List[Union[str, os.PathLike[Any]]]]
) -> list:
Expand Down
2 changes: 1 addition & 1 deletion TTS/vc/models/freevc.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def forward(

return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q)

@torch.no_grad()
@torch.inference_mode()
def inference(self, c, g=None, mel=None, c_lengths=None):
"""
Inference pass of the model
Expand Down
2 changes: 1 addition & 1 deletion TTS/vc/models/openvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def _set_x_lengths(x: torch.Tensor, aux_input: Mapping[str, Optional[torch.Tenso
return aux_input["x_lengths"]
return torch.tensor(x.shape[1:2]).to(x.device)

@torch.no_grad()
@torch.inference_mode()
def inference(
self,
x: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion TTS/vocoder/models/fullband_melgan_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
num_res_blocks=num_res_blocks,
)

@torch.no_grad()
@torch.inference_mode()
def inference(self, cond_features):
cond_features = cond_features.to(self.layers[1].weight.device)
cond_features = torch.nn.functional.pad(
Expand Down
2 changes: 1 addition & 1 deletion TTS/vocoder/models/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def train_log(
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)

@torch.no_grad()
@torch.inference_mode()
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Call `train_step()` with `no_grad()`"""
self.train_disc = True # Avoid a bug in the Training with the missing discriminator loss
Expand Down
2 changes: 1 addition & 1 deletion TTS/vocoder/models/hifigan_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def forward(self, x, g=None):
o = torch.tanh(o)
return o

@torch.no_grad()
@torch.inference_mode()
def inference(self, c):
"""
Args:
Expand Down
2 changes: 1 addition & 1 deletion TTS/vocoder/models/multiband_melgan_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def pqmf_analysis(self, x):
def pqmf_synthesis(self, x):
return self.pqmf_layer.synthesis(x)

@torch.no_grad()
@torch.inference_mode()
def inference(self, cond_features):
cond_features = cond_features.to(self.layers[1].weight.device)
cond_features = torch.nn.functional.pad(
Expand Down
2 changes: 1 addition & 1 deletion TTS/vocoder/models/parallel_wavegan_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def forward(self, c):

return x

@torch.no_grad()
@torch.inference_mode()
def inference(self, c):
c = c.to(self.first_conv.weight.device)
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
Expand Down
2 changes: 1 addition & 1 deletion TTS/vocoder/models/univnet_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def receptive_field_size(self):
"""Return receptive field size."""
return _get_receptive_field_size(self.layers, self.stacks, self.kernel_size)

@torch.no_grad()
@torch.inference_mode()
def inference(self, c):
"""Perform inference.
Args:
Expand Down
4 changes: 2 additions & 2 deletions TTS/vocoder/models/wavegrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def load_noise_schedule(self, path):
beta = np.load(path, allow_pickle=True).item()["beta"] # pylint: disable=unexpected-keyword-arg
self.compute_noise_level(beta)

@torch.no_grad()
@torch.inference_mode()
def inference(self, x, y_n=None):
"""
Shapes:
Expand Down Expand Up @@ -262,7 +262,7 @@ def train_log( # pylint: disable=no-self-use
) -> Tuple[Dict, np.ndarray]:
pass

@torch.no_grad()
@torch.inference_mode()
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
return self.train_step(batch, criterion)

Expand Down
2 changes: 1 addition & 1 deletion TTS/vocoder/models/wavernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def inference(self, mels, batched=None, target=None, overlap=None):
rnn1 = self.get_gru_cell(self.rnn1)
rnn2 = self.get_gru_cell(self.rnn2)

with torch.no_grad():
with torch.inference_mode():
if isinstance(mels, np.ndarray):
mels = torch.FloatTensor(mels).to(str(next(self.parameters()).device))

Expand Down

0 comments on commit 62b657e

Please sign in to comment.