Skip to content

Commit

Permalink
use spk_embedding when sft
Browse files Browse the repository at this point in the history
  • Loading branch information
aluminumbox committed Jul 10, 2024
1 parent a723ea3 commit 0fd15bb
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 2 deletions.
2 changes: 1 addition & 1 deletion cosyvoice/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def forward(
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['utt_embedding'].to(device)
embedding = batch['embedding'].to(device)

# xvec projection
embedding = F.normalize(embedding, dim=1)
Expand Down
2 changes: 1 addition & 1 deletion cosyvoice/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def forward(
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
embedding = batch['utt_embedding'].to(device)
embedding = batch['embedding'].to(device)

# 1. prepare llm_target
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))]
Expand Down
4 changes: 4 additions & 0 deletions cosyvoice/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data
info_dict["batch_idx"] = batch_idx
if cosyvoice_join(group_join, info_dict):
break
if info_dict["use_spk_embedding"] is True:
batch_dict["embedding"] = batch_dict["spk_embedding"]
else:
batch_dict["embedding"] = batch_dict["utt_embedding"]

# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ train_conf:
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
use_spk_embedding: False # change to True during sft
max_epoch: 200
grad_clip: 5
accum_grad: 2
Expand Down
1 change: 1 addition & 0 deletions examples/libritts/cosyvoice/conf/cosyvoice.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ train_conf:
scheduler: warmuplr # change to constantlr during sft
scheduler_conf:
warmup_steps: 2500
use_spk_embedding: False # change to True during sft
max_epoch: 200
grad_clip: 5
accum_grad: 2
Expand Down

0 comments on commit 0fd15bb

Please sign in to comment.