Skip to content

Commit

Permalink
Merge pull request #862 from IAHispano/formatter/main
Browse files Browse the repository at this point in the history
chore(format): run black on main
  • Loading branch information
blaisewf authored Nov 4, 2024
2 parents b5b5b73 + 6427ff3 commit eb43ec7
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions rvc/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def verify_checkpoint_shapes(checkpoint_path, model):
else:
model_state_dict = model.load_state_dict(checkpoint_state_dict)
except RuntimeError:
print("The parameters of the pretrain model such as the sample rate or architecture do not match the selected model.")
print(
"The parameters of the pretrain model such as the sample rate or architecture do not match the selected model."
)
sys.exit(1)
else:
del checkpoint
Expand Down Expand Up @@ -324,8 +326,8 @@ def run(
DistributedBucketSampler,
TextAudioCollateMultiNSFsid,
TextAudioLoaderMultiNSFsid,
)
)

train_dataset = TextAudioLoaderMultiNSFsid(config.data)
collate_fn = TextAudioCollateMultiNSFsid()
train_sampler = DistributedBucketSampler(
Expand All @@ -351,8 +353,8 @@ def run(
# Initialize models and optimizers
from rvc.lib.algorithm.discriminators import MultiPeriodDiscriminator
from rvc.lib.algorithm.discriminators import MultiPeriodDiscriminatorV2
from rvc.lib.algorithm.synthesizers import Synthesizer
from rvc.lib.algorithm.synthesizers import Synthesizer

net_g = Synthesizer(
config.data.filter_length // 2 + 1,
config.train.segment_size // config.data.hop_length,
Expand Down Expand Up @@ -468,6 +470,7 @@ def run(
scheduler_g.step()
scheduler_d.step()


def train_and_evaluate(
rank,
epoch,
Expand Down Expand Up @@ -537,17 +540,31 @@ def train_and_evaluate(
info = [tensor.cuda(rank, non_blocking=True) for tensor in info]
elif device.type != "cuda":
info = [tensor.to(device) for tensor in info]
# else iterator is going thru a cached list with a device already assigned

phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, wave_lengths, sid = info
# else iterator is going thru a cached list with a device already assigned

(
phone,
phone_lengths,
pitch,
pitchf,
spec,
spec_lengths,
wave,
wave_lengths,
sid,
) = info
pitch = pitch if pitch_guidance else None
pitchf = pitchf if pitch_guidance else None

# Forward pass
use_amp = config.train.fp16_run and device.type == "cuda"
with autocast(enabled=use_amp):
model_output = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = model_output
model_output = net_g(
phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid
)
y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = (
model_output
)
# used for tensorboard chart - all/mel
mel = spec_to_mel_torch(
spec,
Expand Down

0 comments on commit eb43ec7

Please sign in to comment.