Skip to content

Commit

Permalink
PyTorch 2.1 Updates (Weight Norm and TorchAudio I/O) (#3176)
Browse files Browse the repository at this point in the history
* Replaced PyTorch weight_norm With parametrizations.weight_norm

* TorchAudio: Migrating The I/O Functions To Use The Dispatcher Mechanism

* Corrected Code Style

---------

Co-authored-by: Eren Gölge <erogol@hotmail.com>
  • Loading branch information
MattyB95 and erogol authored Nov 9, 2023
1 parent 66a1e24 commit 1b9c400
Show file tree
Hide file tree
Showing 24 changed files with 147 additions and 129 deletions.
13 changes: 7 additions & 6 deletions TTS/tts/layers/delightful_tts/conv_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn.functional as F
from torch.nn.utils import parametrize

from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor

Expand Down Expand Up @@ -73,7 +74,7 @@ def __init__(
)
nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain))
if self.use_weight_norm:
self.conv = nn.utils.weight_norm(self.conv)
self.conv = nn.utils.parametrizations.weight_norm(self.conv)

def forward(self, signal, mask=None):
conv_signal = self.conv(signal)
Expand Down Expand Up @@ -113,7 +114,7 @@ def __init__(
dilation=1,
w_init_gain="relu",
)
conv_layer = nn.utils.weight_norm(conv_layer.conv, name="weight")
conv_layer = nn.utils.parametrizations.weight_norm(conv_layer.conv, name="weight")
convolutions.append(conv_layer)

self.convolutions = nn.ModuleList(convolutions)
Expand Down Expand Up @@ -567,7 +568,7 @@ def __init__( # pylint: disable=dangerous-default-value

self.convt_pre = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.ConvTranspose1d(
in_channels,
in_channels,
Expand All @@ -584,7 +585,7 @@ def __init__( # pylint: disable=dangerous-default-value
self.conv_blocks.append(
nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
in_channels,
in_channels,
Expand Down Expand Up @@ -665,6 +666,6 @@ def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=25

def remove_weight_norm(self):
self.kernel_predictor.remove_weight_norm()
nn.utils.remove_weight_norm(self.convt_pre[1])
parametrize.remove_parametrizations(self.convt_pre[1], "weight")
for block in self.conv_blocks:
nn.utils.remove_weight_norm(block[1])
parametrize.remove_parametrizations(block[1], "weight")
23 changes: 13 additions & 10 deletions TTS/tts/layers/delightful_tts/kernel_predictor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch.nn as nn # pylint: disable=consider-using-from-import
from torch.nn.utils import parametrize


class KernelPredictor(nn.Module):
Expand Down Expand Up @@ -36,7 +37,9 @@ def __init__( # pylint: disable=dangerous-default-value
kpnet_bias_channels = conv_out_channels * conv_layers # l_b

self.input_conv = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
nn.utils.parametrizations.weight_norm(
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)

Expand All @@ -46,7 +49,7 @@ def __init__( # pylint: disable=dangerous-default-value
self.residual_convs.append(
nn.Sequential(
nn.Dropout(kpnet_dropout),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
Expand All @@ -56,7 +59,7 @@ def __init__( # pylint: disable=dangerous-default-value
)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
Expand All @@ -68,7 +71,7 @@ def __init__( # pylint: disable=dangerous-default-value
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
)
self.kernel_conv = nn.utils.weight_norm(
self.kernel_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_kernel_channels,
Expand All @@ -77,7 +80,7 @@ def __init__( # pylint: disable=dangerous-default-value
bias=True,
)
)
self.bias_conv = nn.utils.weight_norm(
self.bias_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_bias_channels,
Expand Down Expand Up @@ -117,9 +120,9 @@ def forward(self, c):
return kernels, bias

def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input_conv[0])
nn.utils.remove_weight_norm(self.kernel_conv)
nn.utils.remove_weight_norm(self.bias_conv)
parametrize.remove_parametrizations(self.input_conv[0], "weight")
parametrize.remove_parametrizations(self.kernel_conv, "weight")
parametrize.remove_parametrizations(self.bias_conv, "weight")
for block in self.residual_convs:
nn.utils.remove_weight_norm(block[1])
nn.utils.remove_weight_norm(block[3])
parametrize.remove_parametrizations(block[1], "weight")
parametrize.remove_parametrizations(block[3], "weight")
13 changes: 7 additions & 6 deletions TTS/tts/layers/generic/wavenet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torch import nn
from torch.nn.utils import parametrize


@torch.jit.script
Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(
# init conditioning layer
if c_in_channels > 0:
cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
# intermediate layers
for i in range(num_layers):
dilation = dilation_rate**i
Expand All @@ -75,7 +76,7 @@ def __init__(
in_layer = torch.nn.Conv1d(
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)

if i < num_layers - 1:
Expand All @@ -84,7 +85,7 @@ def __init__(
res_skip_channels = hidden_channels

res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
# setup weight norm
if not weight_norm:
Expand Down Expand Up @@ -115,11 +116,11 @@ def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-a

def remove_weight_norm(self):
if self.c_in_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
parametrize.remove_parametrizations(self.cond_layer, "weight")
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
parametrize.remove_parametrizations(l, "weight")
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
parametrize.remove_parametrizations(l, "weight")


class WNBlocks(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/glow_tts/glow.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __init__(
self.sigmoid_scale = sigmoid_scale
# input layer
start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
start = torch.nn.utils.weight_norm(start)
start = torch.nn.utils.parametrizations.weight_norm(start)
self.start = start
# output layer
# Initializing last layer to 0 makes the affine coupling layers
Expand Down
42 changes: 23 additions & 19 deletions TTS/tts/layers/tortoise/vocoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize

MAX_WAV_VALUE = 32768.0

Expand Down Expand Up @@ -44,7 +44,9 @@ def __init__(
kpnet_bias_channels = conv_out_channels * conv_layers # l_b

self.input_conv = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
nn.utils.parametrizations.weight_norm(
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)

Expand All @@ -54,7 +56,7 @@ def __init__(
self.residual_convs.append(
nn.Sequential(
nn.Dropout(kpnet_dropout),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
Expand All @@ -64,7 +66,7 @@ def __init__(
)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
Expand All @@ -76,7 +78,7 @@ def __init__(
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
)
self.kernel_conv = nn.utils.weight_norm(
self.kernel_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_kernel_channels,
Expand All @@ -85,7 +87,7 @@ def __init__(
bias=True,
)
)
self.bias_conv = nn.utils.weight_norm(
self.bias_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_bias_channels,
Expand Down Expand Up @@ -125,12 +127,12 @@ def forward(self, c):
return kernels, bias

def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input_conv[0])
nn.utils.remove_weight_norm(self.kernel_conv)
nn.utils.remove_weight_norm(self.bias_conv)
parametrize.remove_parametrizations(self.input_conv[0], "weight")
parametrize.remove_parametrizations(self.kernel_conv, "weight")
parametrize.remove_parametrizations(self.bias_conv)
for block in self.residual_convs:
nn.utils.remove_weight_norm(block[1])
nn.utils.remove_weight_norm(block[3])
parametrize.remove_parametrizations(block[1], "weight")
parametrize.remove_parametrizations(block[3], "weight")


class LVCBlock(torch.nn.Module):
Expand Down Expand Up @@ -169,7 +171,7 @@ def __init__(

self.convt_pre = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.ConvTranspose1d(
in_channels,
in_channels,
Expand All @@ -186,7 +188,7 @@ def __init__(
self.conv_blocks.append(
nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
in_channels,
in_channels,
Expand Down Expand Up @@ -267,9 +269,9 @@ def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=25

def remove_weight_norm(self):
self.kernel_predictor.remove_weight_norm()
nn.utils.remove_weight_norm(self.convt_pre[1])
parametrize.remove_parametrizations(self.convt_pre[1], "weight")
for block in self.conv_blocks:
nn.utils.remove_weight_norm(block[1])
parametrize.remove_parametrizations(block[1], "weight")


class UnivNetGenerator(nn.Module):
Expand Down Expand Up @@ -314,11 +316,13 @@ def __init__(
)
)

self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect"))
self.conv_pre = nn.utils.parametrizations.weight_norm(
nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")
)

self.conv_post = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
nn.utils.parametrizations.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
nn.Tanh(),
)

Expand Down Expand Up @@ -346,11 +350,11 @@ def eval(self, inference=False):
self.remove_weight_norm()

def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv_pre)
parametrize.remove_parametrizations(self.conv_pre, "weight")

for layer in self.conv_post:
if len(layer.state_dict()) != 0:
nn.utils.remove_weight_norm(layer)
parametrize.remove_parametrizations(layer, "weight")

for res_block in self.res_stack:
res_block.remove_weight_norm()
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/vits/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class DiscriminatorS(torch.nn.Module):

def __init__(self, use_spectral_norm=False):
super().__init__()
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
Expand Down
19 changes: 10 additions & 9 deletions TTS/tts/layers/xtts/hifigan_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations

from TTS.utils.io import load_fsspec

Expand Down Expand Up @@ -120,9 +121,9 @@ def forward(self, x):

def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.convs2:
remove_weight_norm(l)
remove_parametrizations(l, "weight")


class ResBlock2(torch.nn.Module):
Expand Down Expand Up @@ -176,7 +177,7 @@ def forward(self, x):

def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
remove_parametrizations(l, "weight")


class HifiganGenerator(torch.nn.Module):
Expand Down Expand Up @@ -251,10 +252,10 @@ def __init__(
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)

if not conv_pre_weight_norm:
remove_weight_norm(self.conv_pre)
remove_parametrizations(self.conv_pre, "weight")

if not conv_post_weight_norm:
remove_weight_norm(self.conv_post)
remove_parametrizations(self.conv_post, "weight")

if self.cond_in_each_up_layer:
self.conds = nn.ModuleList()
Expand Down Expand Up @@ -317,11 +318,11 @@ def inference(self, c):
def remove_weight_norm(self):
print("Removing weight norm...")
for l in self.ups:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
remove_parametrizations(self.conv_pre, "weight")
remove_parametrizations(self.conv_post, "weight")

def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
Expand Down
3 changes: 1 addition & 2 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from contextlib import contextmanager
from dataclasses import dataclass

import librosa
Expand All @@ -8,7 +7,7 @@
import torchaudio
from coqpit import Coqpit

from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
from TTS.tts.layers.tortoise.audio_utils import wav_to_univnet_mel
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
Expand Down
Loading

0 comments on commit 1b9c400

Please sign in to comment.