Skip to content

Commit

Permalink
switched optimizer to RAdam, added some preliminary changes for non-r…
Browse files Browse the repository at this point in the history
…andomized finetuning
  • Loading branch information
AznamirWoW committed Jan 1, 2025
1 parent bf2e6fd commit 6f0b733
Showing 1 changed file with 27 additions and 15 deletions.
42 changes: 27 additions & 15 deletions rvc/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
cleanup = strtobool(sys.argv[15])
vocoder = sys.argv[16]
checkpointing = strtobool(sys.argv[17])
randomized=True
optimizer="RAdam" # "AdamW"

current_dir = os.getcwd()
experiment_dir = os.path.join(current_dir, "logs", model_name)
Expand Down Expand Up @@ -316,7 +318,7 @@ def run(
config (object): Configuration object containing training parameters.
device (torch.device): The device to use for training (CPU or GPU).
"""
global global_step, smoothed_value_gen, smoothed_value_disc
global global_step, smoothed_value_gen, smoothed_value_disc, optimizer

smoothed_value_gen = 0
smoothed_value_disc = 0
Expand Down Expand Up @@ -380,6 +382,7 @@ def run(
sr=sample_rate,
vocoder=vocoder,
checkpointing=checkpointing,
randomized=randomized,
)

net_d = MultiPeriodDiscriminator(
Expand All @@ -393,13 +396,18 @@ def run(
net_g.to(device)
net_d.to(device)

optim_g = torch.optim.AdamW(
if optimizer == "AdamW":
optimizer = torch.optim.AdamW
elif optimizer == "RAdam":
optimizer = torch.optim.RAdam

optim_g = optimizer(
net_g.parameters(),
config.train.learning_rate,
betas=config.train.betas,
eps=config.train.eps,
)
optim_d = torch.optim.AdamW(
optim_d = optimizer(
net_d.parameters(),
config.train.learning_rate,
betas=config.train.betas,
Expand Down Expand Up @@ -631,12 +639,13 @@ def train_and_evaluate(
model_output
)
# slice of the original waveform to match a generate slice
wave = commons.slice_segments(
wave,
ids_slice * config.data.hop_length,
config.train.segment_size,
dim=3,
)
if randomized:
wave = commons.slice_segments(
wave,
ids_slice * config.data.hop_length,
config.train.segment_size,
dim=3,
)
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
with autocast(enabled=False):
# if vocoder == "HiFi-GAN":
Expand Down Expand Up @@ -751,12 +760,15 @@ def train_and_evaluate(
config.data.mel_fmax,
)
# used for tensorboard chart - slice/mel_org
y_mel = commons.slice_segments(
mel,
ids_slice,
config.train.segment_size // config.data.hop_length,
dim=3,
)
if randomized:
y_mel = commons.slice_segments(
mel,
ids_slice,
config.train.segment_size // config.data.hop_length,
dim=3,
)
else:
y_mel = mel
# used for tensorboard chart - slice/mel_gen
with autocast(enabled=False):
y_hat_mel = mel_spectrogram_torch(
Expand Down

0 comments on commit 6f0b733

Please sign in to comment.