Skip to content

Commit

Permalink
add sft contents
Browse files Browse the repository at this point in the history
  • Loading branch information
Spico197 committed Feb 25, 2024
1 parent c6ba9b9 commit 27cf936
Show file tree
Hide file tree
Showing 11 changed files with 1,494 additions and 469 deletions.
28 changes: 22 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<img src="docs/imgs/title-favicon.png" width="200" alt="LLaMA-MoE favicon" style="border-radius: 5%;"><br />
<span style="color:red">📢 <strong><i>A SMALLER AFFORDABLE MoE MODEL FOR EVERYONE!!</i></strong></span>
<div>
<a href="https://huggingface.co/llama-moe" target="_blank">🤗 Model Weights</a> | <a href="#quick-start">🚀 Quick Start</a> | <a href="#installation">⚙️ Installation Guide</a> | <a href="#expert-construction">🚧 Expert Construction</a> | <a href="#continual-pretraining">🚅 Continual Pre-training</a> | <a href="#evaluation">💎 Evaluation</a>
<a href="https://huggingface.co/llama-moe" target="_blank">🤗 Model Weights</a> | <a href="#quick-start">🚀 Quick Start</a> | <a href="#installation">⚙️ Installation Guide</a> | <a href="#expert-construction">🚧 Expert Construction</a> | <a href="#continual-pretraining">🚅 Continual Pre-training</a> | <a href="#evaluation">💎 Evaluation</a> | <a href="#sft">💬 Supervised Fine-Tuning (SFT)</a>
</div>
<a href="docs/LLaMA_MoE.pdf" target="_blank"><strong>📃 Technical Report</strong></a>
</div>
Expand Down Expand Up @@ -84,11 +84,13 @@ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
<h2 id="performance">📊 Model Performance</h2>
| Model | \#Activated Experts | \#Experts | \#Activated Params | Links |
| :------------------------ | :-----------------: | :-------: | :----------------: | :-----------------------------------------------------------------------: |
| **LLaMA-MoE-3.0B** | 2 | 16 | 3.0B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_0B-2_16) |
| **LLaMA-MoE-3.5B (4/16)** | 4 | 16 | 3.5B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-4_16) |
| **LLaMA-MoE-3.5B (2/8)** | 2 | 8 | 3.5B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8) |
| Model | \#Activated Experts | \#Experts | \#Activated Params | Foundation Model | SFT Model |
| :------------------------ | :-----------------: | :-------: | :----------------: | :---------------------------------------------------------------: | :------------------------------------------------------------------: |
| **LLaMA-MoE-3.0B** | 2 | 16 | 3.0B | [🤗 base](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_0B-2_16) | [🤗 SFT](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_0B-2_16-sft) |
| **LLaMA-MoE-3.5B (4/16)** | 4 | 16 | 3.5B | [🤗 base](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-4_16) | [🤗 SFT](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-4_16-sft) |
| **LLaMA-MoE-3.5B (2/8)** | 2 | 8 | 3.5B | [🤗 base](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8) | [🤗 SFT](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8-sft) |
- Foundation models
| Model | Average | SciQ | PIQA | WinoGrande | ARC-e | ARC-c (25) | HellaSwag (10) | LogiQA | BoolQ (32) | LAMBADA | NQ (32) | MMLU (5) |
| :------------------------------------------------------------------------------------ | :------: | :------: | :------: | :--------: | :------: | :--------: | :------------: | :------: | :--------: | :------: | :------: | :------: |
Expand All @@ -101,6 +103,15 @@ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
| **LLaMA-MoE-3.5B (4/16)** | **57.7** | 87.6 | **77.9** | 65.5 | **65.6** | **44.2** | **73.3** | 29.7 | **75.0** | **69.5** | **20.3** | 26.8 |
| **LLaMA-MoE-3.5B (2/8)** | 57.6 | **88.4** | 77.6 | **66.7** | 65.3 | 43.1 | **73.3** | 29.6 | 73.9 | 69.4 | 19.8 | 27.0 |
- SFT models
| Model | MMLU | ARC-c | HellaSeag | TruthfulQA | MT-Bench |
| :------------------------------------- | :---: | :---: | :-------: | :--------: | :------: |
| Sheared LLaMA-2.7B ShareGPT | 28.41 | 41.04 | 71.21 | 47.65 | 3.79 |
| Sheared LLaMA-2.7B Deita6K (Our Impl.) | 25.24 | 43.69 | 71.70 | 49.00 | 4.06 |
| LLaMA-MoE-v1-3.0B (2/16) | 23.61 | 43.43 | 72.28 | 44.24 | 4.15 |
| LLaMA-MoE-v1-3.5B (4/16) | 26.49 | 48.29 | 75.10 | 45.91 | 4.60 |
| LLaMA-MoE-v1-3.5B (2/8) | 25.53 | 45.99 | 74.95 | 44.39 | 4.72 |
<h2 id="expert-construction">🚧 Expert Construction</h2>
Expand Down Expand Up @@ -152,6 +163,11 @@ python -m smoe.utils.tokenize \
- For evalution on Natural Questions (NQ), please refer to [opencompass](https://github.com/Spico197/opencompass/tree/main).
- For other tasks, please refer to [lm-eval-harness](https://github.com/spico197/smoe-eval).
<h2 id="sft">💬 Supervised Fine-Tuning (SFT)</h2>
We provide simple examples of SFT to build chatbots.
Please refer to [SFT docs](/mnt/petrelfs/zhutong/smoe/docs/supervised_fine_tuning/SFT.md) and `/mnt/petrelfs/zhutong/smoe/scripts/sft` for more details.
<h2 id="citation">📑 Citation</h2>
```bibtex
Expand Down
Binary file modified docs/LLaMA_MoE.pdf
Binary file not shown.
39 changes: 39 additions & 0 deletions docs/supervised_fine_tuning/SFT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Supervised Fine-Tuning (SFT)

## Data Preparation

Download [Deita 6K](https://huggingface.co/datasets/hkust-nlp/deita-6k-v0) to `data/deita/deita_6k.jsonl`.

## Training

Start training in Slurm clusters: `sbatch scripts/sft/2_8.sh`.

## Inference

```python
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

from src.utils.conversation import Conversation

conv = Conversation()
conv.append_message("human", "Give me a three-day plan in Suzhou.")
conv.append_message("gpt", None)
prompt = conv.get_prompt()
print(prompt)
print(prompt[-1] == " ")

model_dir = "llama-moe/LLaMA-MoE-v1-3_5B-2_8-sft"

tok = AutoTokenizer.from_pretrained(model_dir)
m = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True)
m.eval()
m.cuda()

inputs = tok(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].cuda()

output = m.generate(input_ids, max_length=100, temperature=1.0, do_sample=True, use_cache=True)
response = tok.decode(output[0], skip_special_tokens=True)
print(response)
```
80 changes: 80 additions & 0 deletions scripts/sft/2_16.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/bash

#SBATCH --job-name=llama_moe_2_16_deita
#SBATCH --output=logs/%x-%j.log
#SBATCH --error=logs/%x-%j.log

#SBATCH --partition=MoE
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=16
#SBATCH --mem=64G

#SBATCH --nodes=1
#SBATCH --gres=gpu:4
#SBATCH --quotatype=auto

export WANDB_PROJECT="llama_moe_sft"
num_gpus=4

{
task_name="llama_moe_2_16_deita"
model_type="auto"
model_name_or_path="/mnt/petrelfs/zhutong/llama-moe-models/LLaMA-MoE-v1-3_0B-2_16"
dataset_dir_or_path="data/deita/deita_6k.jsonl"

comment="llama-moe 2/16, deita, w/ balance loss, w/ freeze gate, w/ gate noise"
base_dir="outputs/llama_moe_sft"
output_dir="${base_dir}/${task_name}/$SLURM_JOB_NAME-$SLURM_JOB_ID"
mkdir -p $output_dir
scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh
git diff > $output_dir/diff.patch
env > $output_dir/env
echo -e "Job ID: ${SLURM_JOB_ID}\n\nLog: logs/llama_moe_2_16_deita-$SLURM_JOB_ID.log\n\nGit commit: $(git log -1 --oneline)\n\nGit branch: $(git branch | grep "*")\n\nComment: ${comment}" > $output_dir/comment.txt
ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $output_dir/log.log
echo "$SLURM_JOB_ID" > $base_dir/latest.jobid
ln -snf $output_dir $base_dir/latest.dir
ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log

nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS))
nodes_array=($nodes)
head_node=${nodes_array[0]}
echo "Node: $head_node"

torchrun \
--nnodes 1 \
--nproc_per_node $num_gpus \
--node_rank $SLURM_NODEID \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node:29522 \
-m smoe.entrypoint.sft.train_sft \
--do_train \
--freeze_gate True \
--evaluation_strategy no \
--run_name $task_name \
--model_type $model_type \
--model_name_or_path $model_name_or_path \
--dataset_dir_or_path $dataset_dir_or_path \
--output_dir $output_dir \
--deepspeed conf/ds_bf16_zero1.json \
--seed 12306 \
--bf16 True \
--tf32 True \
--torch_dtype bfloat16 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 8 \
--num_train_epochs 2 \
--save_strategy steps \
--save_steps 9999999999999 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type cosine \
--logging_steps 1 \
--model_max_length 2048 \
--gradient_checkpointing True \
--report_to wandb

}
80 changes: 80 additions & 0 deletions scripts/sft/2_8.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/bash

#SBATCH --job-name=llama_moe_2_8_deita
#SBATCH --output=logs/%x-%j.log
#SBATCH --error=logs/%x-%j.log

#SBATCH --partition=MoE
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=16
#SBATCH --mem=64G

#SBATCH --nodes=1
#SBATCH --gres=gpu:4
#SBATCH --quotatype=auto

export WANDB_PROJECT="llama_moe_sft"
num_gpus=4

{
task_name="llama_moe_2_8_deita"
model_type="auto"
model_name_or_path="/mnt/petrelfs/zhutong/llama-moe-models/LLaMA-MoE-v1-3_5B-2_8-new"
dataset_dir_or_path="data/deita/deita_6k.jsonl"

comment="llama-moe 2/8, deita, w/ balance loss, w/ freeze gate, w/ gate noise"
base_dir="outputs/llama_moe_sft"
output_dir="${base_dir}/${task_name}/$SLURM_JOB_NAME-$SLURM_JOB_ID"
mkdir -p $output_dir
scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh
git diff > $output_dir/diff.patch
env > $output_dir/env
echo -e "Job ID: ${SLURM_JOB_ID}\n\nLog: logs/llama_moe_2_8_deita-$SLURM_JOB_ID.log\n\nGit commit: $(git log -1 --oneline)\n\nGit branch: $(git branch | grep "*")\n\nComment: ${comment}" > $output_dir/comment.txt
ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $output_dir/log.log
echo "$SLURM_JOB_ID" > $base_dir/latest.jobid
ln -snf $output_dir $base_dir/latest.dir
ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log

nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS))
nodes_array=($nodes)
head_node=${nodes_array[0]}
echo "Node: $head_node"

torchrun \
--nnodes 1 \
--nproc_per_node $num_gpus \
--node_rank $SLURM_NODEID \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node:29522 \
-m smoe.entrypoint.sft.train_sft \
--do_train \
--freeze_gate True \
--evaluation_strategy no \
--run_name $task_name \
--model_type $model_type \
--model_name_or_path $model_name_or_path \
--dataset_dir_or_path $dataset_dir_or_path \
--output_dir $output_dir \
--deepspeed conf/deepspeed/bf16_zero1.json \
--seed 12306 \
--bf16 True \
--tf32 True \
--torch_dtype bfloat16 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 8 \
--num_train_epochs 2 \
--save_strategy steps \
--save_steps 9999999999999 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type cosine \
--logging_steps 1 \
--model_max_length 2048 \
--gradient_checkpointing True \
--report_to wandb

}
80 changes: 80 additions & 0 deletions scripts/sft/4_16.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/bash

#SBATCH --job-name=llama_moe_4_16_deita
#SBATCH --output=logs/%x-%j.log
#SBATCH --error=logs/%x-%j.log

#SBATCH --partition=MoE
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=16
#SBATCH --mem=64G

#SBATCH --nodes=1
#SBATCH --gres=gpu:4
#SBATCH --quotatype=auto

export WANDB_PROJECT="llama_moe_sft"
num_gpus=4

{
task_name="llama_moe_4_16_deita"
model_type="auto"
model_name_or_path="/mnt/petrelfs/zhutong/llama-moe-models/LLaMA-MoE-v1-3_5B-4_16-new"
dataset_dir_or_path="data/deita/deita_6k.jsonl"

comment="llama-moe 4/16, deita, w/ balance loss, w/ freeze gate, w/ gate noise"
base_dir="outputs/llama_moe_sft"
output_dir="${base_dir}/${task_name}/$SLURM_JOB_NAME-$SLURM_JOB_ID"
mkdir -p $output_dir
scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh
git diff > $output_dir/diff.patch
env > $output_dir/env
echo -e "Job ID: ${SLURM_JOB_ID}\n\nLog: logs/llama_moe_4_16_deita-$SLURM_JOB_ID.log\n\nGit commit: $(git log -1 --oneline)\n\nGit branch: $(git branch | grep "*")\n\nComment: ${comment}" > $output_dir/comment.txt
ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $output_dir/log.log
echo "$SLURM_JOB_ID" > $base_dir/latest.jobid
ln -snf $output_dir $base_dir/latest.dir
ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log

nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS))
nodes_array=($nodes)
head_node=${nodes_array[0]}
echo "Node: $head_node"

torchrun \
--nnodes 1 \
--nproc_per_node $num_gpus \
--node_rank $SLURM_NODEID \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node:29522 \
-m smoe.entrypoint.sft.train_sft \
--do_train \
--freeze_gate True \
--evaluation_strategy no \
--run_name $task_name \
--model_type $model_type \
--model_name_or_path $model_name_or_path \
--dataset_dir_or_path $dataset_dir_or_path \
--output_dir $output_dir \
--deepspeed conf/ds_bf16_zero1.json \
--seed 12306 \
--bf16 True \
--tf32 True \
--torch_dtype bfloat16 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 8 \
--num_train_epochs 2 \
--save_strategy steps \
--save_steps 9999999999999 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type cosine \
--logging_steps 1 \
--model_max_length 2048 \
--gradient_checkpointing True \
--report_to wandb

}
Empty file added smoe/entrypoint/sft/__init__.py
Empty file.
Loading

0 comments on commit 27cf936

Please sign in to comment.