Skip to content

Commit

Permalink
add llm train
Browse files Browse the repository at this point in the history
  • Loading branch information
aluminumbox committed Feb 7, 2025
1 parent 2a3e033 commit 79b7dff
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 3 deletions.
65 changes: 65 additions & 0 deletions cosyvoice/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import Dict, Optional, Callable, List, Generator
import torch
from torch import nn
Expand All @@ -21,6 +22,7 @@
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
from cosyvoice.utils.common import th_accuracy
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.mask import make_pad_mask


class TransformerLM(torch.nn.Module):
Expand Down Expand Up @@ -226,6 +228,17 @@ def __init__(self, pretrain_path):
super().__init__()
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)

def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T)
outs = self.model(
inputs_embeds=xs,
attention_mask=masks,
output_hidden_states=True,
return_dict=True,
)
return outs.hidden_states[-1], masks.unsqueeze(1)

def forward_one_step(self, xs, masks, cache=None):
input_masks = masks[:, -1, :]
outs = self.model(
Expand Down Expand Up @@ -280,6 +293,58 @@ def __init__(
self.sampling = sampling
self.mix_ratio = mix_ratio

def pad_unpad_sequence(self, sos_eos_emb, text_token, text_token_len, task_id_emb, speech_token, speech_token_len, bistream):
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
for i in range(len(text_token))]
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
return lm_input, lm_input_len

def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
"""
Args:
text: (B, L, D)
text_lengths: (B,)
audio: (B, T, N) or (B, T)
audio_lengths: (B,)
"""
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)

# 1. prepare llm_target
bistream = True if random.random() < 0.5 else False
lm_target = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
[self.speech_token_size]) for i in range(text_token.size(0))]
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)

# 1. encode text_token
text_token = self.llm.model.model.embed_tokens(text_token)

# 3. eos and task_id
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)

# 4. encode speech_token
speech_token = self.speech_embedding(speech_token)

# 5. unpad and pad
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, text_token, text_token_len, task_id_emb, speech_token, speech_token_len, bistream)

# 6. run lm forward
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
logits = self.llm_decoder(lm_output)
loss = self.criterion_ce(logits, lm_target)
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
return {'loss': loss, 'acc': acc}

@torch.inference_mode()
def inference(
self,
Expand Down
2 changes: 1 addition & 1 deletion examples/libritts/cosyvoice2/conf/cosyvoice2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ sort: !name:cosyvoice.dataset.processor.sort
sort_size: 500 # sort_size should be less than shuffle_size
batch: !name:cosyvoice.dataset.processor.batch
batch_type: 'dynamic'
max_frames_in_batch: 2500
max_frames_in_batch: 2000
padding: !name:cosyvoice.dataset.processor.padding
use_spk_embedding: False # change to True during sft

Expand Down
4 changes: 2 additions & 2 deletions examples/libritts/cosyvoice2/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ stop_stage=3

data_url=www.openslr.org/resources/60
data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
pretrained_model_dir=/mnt/lyuxiang.lx/data/tts/models/IIC/CosyVoice2-0.5B/
pretrained_model_dir=../../../pretrained_models/CosyVoice2-0.5B

if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "Data Download"
Expand Down Expand Up @@ -86,7 +86,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
# NOTE will update llm/hift training later
for model in flow; do
for model in llm flow; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
cosyvoice/bin/train.py \
Expand Down

0 comments on commit 79b7dff

Please sign in to comment.