Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support instruction tuning #15

Merged
merged 3 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Support instruction tuning with SFTTrainer
  • Loading branch information
jubgjf committed Feb 8, 2024
commit 76e5e3196442ec6a9a9cc1788faf52b60c34a14e
36 changes: 33 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,10 @@ TRAIN_DATASETS=(
如果您使用SLURM集群管理系统,可以通过`sbatch`进行提交:

```shell
$ sbatch scripts/train.sh
$ sbatch scripts/train-pt.sh
```

如果没有SLURM或希望通过命令行启动训练,您可以直接提取`scripts/train.sh`中的`torchrun`开始训练。
如果没有SLURM或希望通过命令行启动训练,您可以直接提取`scripts/train-pt.sh`中的`torchrun`开始训练。

</details>

Expand All @@ -247,7 +247,37 @@ $ sbatch scripts/train.sh

</summary>

本项目发布的Chinese-Mixtral-8x7B为基座模型,没有经过微调。如果您希望使用Chinese-Mixtral-8x7B进行下游任务微调或SFT,可以参考HuggingFace给出Mixtral-8x7B的QLoRA微调脚本进行训练:[HuggingFace的官方示例代码](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py)。
#### 数据集准备

微调需要的数据集格式与预训练类似,数据集文件需要为jsonl格式:每行一个json,其中需要包含`"text"`字段,将instruction、input和output全部按照您需要的模板进行拼接。

然后需要将数据集名称和路径注册到`data/datasets.toml`中:

```toml
[ShareGPT-Chinese] # 数据集名称
splits = ["train"] # 数据集train/valid集合
root = "{DATA_DIR}/sft/{name}" # 数据集根目录
doc = "{name}-{split}" # 数据集文件名
```

#### 开始训练

训练启动脚本为`scripts/train-sft.sh`。可以通过修改其中的`TRAIN_DATASETS`修改训练数据集和数据集比例:

```shell
TRAIN_DATASETS=(
1.0:ShareGPT-Chinese # 使用全量ShareGPT-Chinese
0.5:ShareGPT-English # 使用ShareGPT-English的50%数据
)
```

如果您使用SLURM集群管理系统,可以通过`sbatch`进行提交:

```shell
$ sbatch scripts/train-sft.sh
```

如果没有SLURM或希望通过命令行启动训练,您可以直接提取`scripts/train-sft.sh`中的`torchrun`开始训练。

</details>

Expand Down
Empty file added data/sft/.gitkeep
Empty file.
16 changes: 0 additions & 16 deletions scripts/preprocess_datasets.sh

This file was deleted.

File renamed without changes.
68 changes: 68 additions & 0 deletions scripts/train-sft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/bash

#SBATCH -J mixtral
#SBATCH -o logs/%j.log
#SBATCH -e logs/%j.err
#SBATCH -p gpu
#SBATCH -N 4
#SBATCH --ntasks-per-node=1
#SBATCH -c 64
#SBATCH --mem=800G
#SBATCH --gres=gpu:8

. "$HOME"/.miniconda/etc/profile.d/conda.sh
conda activate vocab-ext

TRAIN_DATASETS=(
1:your-dataset-1
1:your-dataset-2
)

VALID_DATASETS=(
your-val-dataset-1
your-val-dataset-2
)

TRAIN_PARAMS=""
TRAIN_PARAMS+=" --mode sft"
TRAIN_PARAMS+=" --enable_lora"
TRAIN_PARAMS+=" --lora_alpha 128"
TRAIN_PARAMS+=" --lora_dropout 0.05"
TRAIN_PARAMS+=" --lora_rank 64"
TRAIN_PARAMS+=" --lora_target_modules q_proj v_proj k_proj o_proj w1 w2 w3"
TRAIN_PARAMS+=" --lora_modules_to_save embed_tokens lm_head"
TRAIN_PARAMS+=" --model_name_or_path models/sft-model"
TRAIN_PARAMS+=" --tokenizer_name_or_path tokenizer/sft-tokenizer"
TRAIN_PARAMS+=" --neftune_noise_alpha 5"
TRAIN_PARAMS+=" --train_datasets ${TRAIN_DATASETS[*]}"
TRAIN_PARAMS+=" --valid_datasets ${VALID_DATASETS[*]}"
TRAIN_PARAMS+=" --dataloader_drop_last"
TRAIN_PARAMS+=" --cache_dir hf-cache"
TRAIN_PARAMS+=" --output_dir outputs/$SLURM_JOB_ID"
TRAIN_PARAMS+=" --num_train_epochs 1"
TRAIN_PARAMS+=" --model_max_length 32768"
TRAIN_PARAMS+=" --per_device_train_batch_size 1"
TRAIN_PARAMS+=" --gradient_accumulation_steps 20"
TRAIN_PARAMS+=" --optim adamw_torch_fused"
TRAIN_PARAMS+=" --per_device_eval_batch_size 1"
TRAIN_PARAMS+=" --evaluation_strategy steps"
TRAIN_PARAMS+=" --eval_steps 500"
TRAIN_PARAMS+=" --save_strategy steps"
TRAIN_PARAMS+=" --save_steps 1000"
TRAIN_PARAMS+=" --learning_rate 1e-5"
TRAIN_PARAMS+=" --warmup_ratio 0.05"
TRAIN_PARAMS+=" --logging_dir logs/tb/$SLURM_JOB_ID"
TRAIN_PARAMS+=" --logging_strategy steps"
TRAIN_PARAMS+=" --logging_steps 1"
TRAIN_PARAMS+=" --lr_scheduler_type cosine"
TRAIN_PARAMS+=" --report_to tensorboard"
TRAIN_PARAMS+=" --gradient_checkpointing"
TRAIN_PARAMS+=" --bf16"
TRAIN_PARAMS+=" --deepspeed ds-config/config.json"

TORCHRUN_PARAMS='--nproc_per_node 8 --node_rank=$SLURM_NODEID --nnodes=$SLURM_JOB_NUM_NODES --rdzv_id=0 --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT'
srun --label --export=ALL bash -c "
MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1);
MASTER_PORT=55511;
torchrun $TORCHRUN_PARAMS train.py $TRAIN_PARAMS
"
27 changes: 27 additions & 0 deletions tokenizer/add_chatml_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import fire
from transformers import AutoTokenizer


def main(
tokenizer_name_or_path: str,
tokenizer_save_path: str,
additional_tokens: str = "<|beginofutterance|> <|endofutterance|>",
sequence_length: int = 2048,
cache_dir: str = "./hf-cache",
):
additional_tokens = additional_tokens.split(" ")

tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path,
cache_dir=cache_dir,
model_max_length=sequence_length,
padding_side="right",
use_fast=True,
)
tokenizer.add_special_tokens(special_tokens_dict={"additional_special_tokens": additional_tokens})

tokenizer.save_pretrained(tokenizer_save_path)


if __name__ == "__main__":
fire.Fire(main)