Skip to content

Commit

Permalink
Merge pull request #861 from AznamirWoW/training_fix
Browse files Browse the repository at this point in the history
minor bugfixes and code cleanup
  • Loading branch information
blaisewf authored Nov 4, 2024
2 parents eec1a83 + 730daf5 commit b5b5b73
Showing 1 changed file with 48 additions and 122 deletions.
170 changes: 48 additions & 122 deletions rvc/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@
load_wav_to_torch,
)

from data_utils import (
DistributedBucketSampler,
TextAudioCollateMultiNSFsid,
TextAudioLoaderMultiNSFsid,
)

from losses import (
discriminator_loss,
feature_loss,
Expand All @@ -54,9 +48,6 @@
from rvc.train.process.extract_model import extract_model

from rvc.lib.algorithm import commons
from rvc.lib.algorithm.discriminators import MultiPeriodDiscriminator
from rvc.lib.algorithm.discriminators import MultiPeriodDiscriminatorV2
from rvc.lib.algorithm.synthesizers import Synthesizer

# Parse command line arguments
model_name = sys.argv[1]
Expand Down Expand Up @@ -134,7 +125,7 @@ def verify_checkpoint_shapes(checkpoint_path, model):
else:
model_state_dict = model.load_state_dict(checkpoint_state_dict)
except RuntimeError:
print("The sample rate of the pretrain doesn't match the selected one")
print("The parameters of the pretrain model such as the sample rate or architecture do not match the selected model.")
sys.exit(1)
else:
del checkpoint
Expand Down Expand Up @@ -313,6 +304,8 @@ def run(
if rank == 0:
writer = SummaryWriter(log_dir=experiment_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(experiment_dir, "eval"))
else:
writer, writer_eval = None, None

dist.init_process_group(
backend="gloo",
Expand All @@ -327,6 +320,12 @@ def run(
torch.cuda.set_device(rank)

# Create datasets and dataloaders
from data_utils import (
DistributedBucketSampler,
TextAudioCollateMultiNSFsid,
TextAudioLoaderMultiNSFsid,
)

train_dataset = TextAudioLoaderMultiNSFsid(config.data)
collate_fn = TextAudioCollateMultiNSFsid()
train_sampler = DistributedBucketSampler(
Expand All @@ -350,6 +349,10 @@ def run(
)

# Initialize models and optimizers
from rvc.lib.algorithm.discriminators import MultiPeriodDiscriminator
from rvc.lib.algorithm.discriminators import MultiPeriodDiscriminatorV2
from rvc.lib.algorithm.synthesizers import Synthesizer

net_g = Synthesizer(
config.data.filter_length // 2 + 1,
config.train.segment_size // config.data.hop_length,
Expand Down Expand Up @@ -430,9 +433,6 @@ def run(
optim_d, gamma=config.train.lr_decay, last_epoch=epoch_str - 2
)

optim_g.step()
optim_d.step()

scaler = GradScaler(enabled=config.train.fp16_run and device.type == "cuda")

cache = []
Expand All @@ -449,42 +449,25 @@ def run(
break

for epoch in range(epoch_str, total_epoch + 1):
if rank == 0:
train_and_evaluate(
rank,
epoch,
config,
[net_g, net_d],
[optim_g, optim_d],
scaler,
[train_loader, None],
[writer, writer_eval],
cache,
custom_save_every_weights,
custom_total_epoch,
device,
reference,
)
else:
train_and_evaluate(
rank,
epoch,
config,
[net_g, net_d],
[optim_g, optim_d],
scaler,
[train_loader, None],
None,
cache,
custom_save_every_weights,
custom_total_epoch,
device,
reference,
)
train_and_evaluate(
rank,
epoch,
config,
[net_g, net_d],
[optim_g, optim_d],
scaler,
[train_loader, None],
[writer, writer_eval],
cache,
custom_save_every_weights,
custom_total_epoch,
device,
reference,
)

scheduler_g.step()
scheduler_d.step()


def train_and_evaluate(
rank,
epoch,
Expand Down Expand Up @@ -539,41 +522,9 @@ def train_and_evaluate(
data_iterator = cache
if cache == []:
for batch_idx, info in enumerate(train_loader):
(
phone,
phone_lengths,
pitch,
pitchf,
spec,
spec_lengths,
wave,
wave_lengths,
sid,
) = info
cache.append(
(
batch_idx,
(
phone.cuda(rank, non_blocking=True),
phone_lengths.cuda(rank, non_blocking=True),
(
pitch.cuda(rank, non_blocking=True)
if pitch_guidance
else None
),
(
pitchf.cuda(rank, non_blocking=True)
if pitch_guidance
else None
),
spec.cuda(rank, non_blocking=True),
spec_lengths.cuda(rank, non_blocking=True),
wave.cuda(rank, non_blocking=True),
wave_lengths.cuda(rank, non_blocking=True),
sid.cuda(rank, non_blocking=True),
),
)
)
# phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, wave_lengths, sid
info = [tensor.cuda(rank, non_blocking=True) for tensor in info]
cache.append((batch_idx, info))
else:
shuffle(cache)
else:
Expand All @@ -582,50 +533,22 @@ def train_and_evaluate(
epoch_recorder = EpochRecorder()
with tqdm(total=len(train_loader), leave=False) as pbar:
for batch_idx, info in data_iterator:
(
phone,
phone_lengths,
pitch,
pitchf,
spec,
spec_lengths,
wave,
wave_lengths,
sid,
) = info
if device.type == "cuda" and not cache_data_in_gpu:
phone = phone.cuda(rank, non_blocking=True)
phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
pitch = pitch.cuda(rank, non_blocking=True) if pitch_guidance else None
pitchf = (
pitchf.cuda(rank, non_blocking=True) if pitch_guidance else None
)
sid = sid.cuda(rank, non_blocking=True)
spec = spec.cuda(rank, non_blocking=True)
spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
wave = wave.cuda(rank, non_blocking=True)
wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
else:
phone = phone.to(device)
phone_lengths = phone_lengths.to(device)
pitch = pitch.to(device) if pitch_guidance else None
pitchf = pitchf.to(device) if pitch_guidance else None
sid = sid.to(device)
spec = spec.to(device)
spec_lengths = spec_lengths.to(device)
wave = wave.to(device)
wave_lengths = wave_lengths.to(device)
info = [tensor.cuda(rank, non_blocking=True) for tensor in info]
elif device.type != "cuda":
info = [tensor.to(device) for tensor in info]
# else iterator is going thru a cached list with a device already assigned

phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, wave_lengths, sid = info
pitch = pitch if pitch_guidance else None
pitchf = pitchf if pitch_guidance else None

# Forward pass
use_amp = config.train.fp16_run and device.type == "cuda"
with autocast(enabled=use_amp):
(
y_hat,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
model_output = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = model_output
# used for tensorboard chart - all/mel
mel = spec_to_mel_torch(
spec,
config.data.filter_length,
Expand All @@ -634,12 +557,14 @@ def train_and_evaluate(
config.data.mel_fmin,
config.data.mel_fmax,
)
# used for tensorboard chart - slice/mel_org
y_mel = commons.slice_segments(
mel,
ids_slice,
config.train.segment_size // config.data.hop_length,
dim=3,
)
# used for tensorboard chart - slice/mel_gen
with autocast(enabled=False):
y_hat_mel = mel_spectrogram_torch(
y_hat.float().squeeze(1),
Expand All @@ -653,6 +578,7 @@ def train_and_evaluate(
)
if use_amp:
y_hat_mel = y_hat_mel.half()
# slice of the original waveform to match a generate slice
wave = commons.slice_segments(
wave,
ids_slice * config.data.hop_length,
Expand Down Expand Up @@ -714,8 +640,8 @@ def train_and_evaluate(
"loss/g/total": loss_gen_all,
"loss/d/total": loss_disc,
"learning_rate": lr,
"grad_norm_d": grad_norm_d,
"grad_norm_g": grad_norm_g,
"grad/norm_d": grad_norm_d,
"grad/norm_g": grad_norm_g,
"loss/g/fm": loss_fm,
"loss/g/mel": loss_mel,
"loss/g/kl": loss_kl,
Expand Down

0 comments on commit b5b5b73

Please sign in to comment.