diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index 876abc870470b4..7f3e90194246bf 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -357,9 +357,14 @@ def __init__(self, config: VitsConfig, num_layers: int): self.res_skip_layers = torch.nn.ModuleList() self.dropout = nn.Dropout(config.wavenet_dropout) + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + else: + weight_norm = nn.utils.weight_norm + if config.speaker_embedding_size != 0: cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1) - self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") + self.cond_layer = weight_norm(cond_layer, name="weight") for i in range(num_layers): dilation = config.wavenet_dilation_rate**i @@ -371,7 +376,7 @@ def __init__(self, config: VitsConfig, num_layers: int): dilation=dilation, padding=padding, ) - in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") + in_layer = weight_norm(in_layer, name="weight") self.in_layers.append(in_layer) # last one is not necessary @@ -381,7 +386,7 @@ def __init__(self, config: VitsConfig, num_layers: int): res_skip_channels = config.hidden_size res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1) - res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") + res_skip_layer = weight_norm(res_skip_layer, name="weight") self.res_skip_layers.append(res_skip_layer) def forward(self, inputs, padding_mask, global_conditioning=None):