Skip to content

Commit

Permalink
update mixtral support
Browse files Browse the repository at this point in the history
  • Loading branch information
Spico197 committed Dec 15, 2023
1 parent 9bb6ec6 commit 338ee9e
Show file tree
Hide file tree
Showing 16 changed files with 3,437 additions and 12 deletions.
9 changes: 8 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "tokenize",
"type": "python",
"request": "launch",
"module": "smoe.utils.tokenize",
"justMyCode": true
},
{
"name": "Python: Remote Attach",
"type": "python",
Expand All @@ -21,4 +28,4 @@
"justMyCode": false
}
]
}
}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ numpy==1.25.0
opencv-python==4.8.1.78
pynvml==11.5.0
PyYaml==6.0.1
pandas<2.1.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#!/usr/bin/bash

#SBATCH --job-name=mxitral_random_split_112gpus_8_2
#SBATCH --output=/mnt/petrelfs/share_data/zhutong/runs/mxitral_random_split_112gpus_8_2/%x-%j.log
#SBATCH --error=/mnt/petrelfs/share_data/zhutong/runs/mxitral_random_split_112gpus_8_2/%x-%j.log

#SBATCH --partition=MoE_T
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=64
#SBATCH --mem=0

#SBATCH --nodes=2
#SBATCH --gres=gpu:8
#SBATCH --quotatype=reserved

# reserved spot

source ~/anaconda3/bin/activate smoe

{
num_nodes=2 # should match with --nodes
num_gpu_per_node=8 # should match with --gres

# #cpu/#num_gpu_per_node
export OMP_NUM_THREADS=32
export LOGLEVEL=INFO
# export NCCL_DEBUG=INFO
# export TORCH_DISTRIBUTED_DEBUG=DETAIL
# export TORCH_SHOW_CPP_STACKTRACES=1
# export CUDA_LAUNCH_BLOCKING=1

model_type="mixtral"
comment="mistral 7B, random 2/8, sheared llama data portion"
pretrained_model=/mnt/hwfile/share_data/zhutong/models/Mixtral-8x7B-v0.1-Random-8Select2
tokenizer_path=/mnt/hwfile/share_data/zhutong/models/Mixtral-8x7B-v0.1-Random-8Select2
dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama-fluency-processed-agg
validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized

lr=2e-4
final_lr_portion=0.1
per_device_train_batch_size=8
per_device_eval_batch_size=8
gradient_accumulation_steps=4
block_size=4096
num_tokens="200*10^9"
warmup_tokens="15*10^8"
# warmup_tokens="0"
eval_tokens="2.5*10^9"
seed=1227
deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json

num_selects=2
scale_factor=4.0

max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc)
max_train_samples=$(echo "${num_tokens} / ($block_size)" | bc)
echo "max_steps: $max_steps"
echo "max_train_samples: $max_train_samples"
global_bs=$(echo "$per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node" | bc)
echo "global batch size: $global_bs"
tokens_per_batch=$(echo "$global_bs * $block_size" | bc)
echo "#tokens/batch: $tokens_per_batch"
# warmup_steps=$(echo "$warmup_tokens / ($tokens_per_batch)" | bc)
warmup_steps=100
echo "warmup tokens: $warmup_tokens, warmup steps: $warmup_steps"
# eval_steps=$(echo "$eval_tokens / ($tokens_per_batch)" | bc)
eval_steps=340
echo "eval interval (tokens): $eval_tokens, steps: $eval_steps"

data_cache=resources/cache
base_dir="/mnt/petrelfs/share_data/zhutong/runs/mxitral_random_split_112gpus_8_2"
output_dir=$base_dir/outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID
mkdir -p $output_dir
echo "output_dir: $output_dir"
scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh
git diff > $output_dir/diff.patch
env > $output_dir/env
echo -e "Job ID: ${SLURM_JOB_ID}\n\nGit commit: $(git log -1 --oneline)\n\nGit branch: $(git branch | grep "*")\n\nComment: ${comment}" > $output_dir/comment.txt
echo "$SLURM_JOB_ID" > $base_dir/latest.jobid
ln -snf $output_dir $base_dir/latest.dir
ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log

nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS))
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
echo "Node: $head_node"
echo "Node IP: $head_node_ip"
echo "Node list: $SLURM_JOB_NODELIS"

srun torchrun \
--nnodes ${num_nodes} \
--nproc_per_node ${num_gpu_per_node} \
--node_rank $SLURM_NODEID \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node:29518 \
smoe/entrypoint/cpt/cpt_fpt.py \
--prob_map "sheared_llama" \
--num_selects ${num_selects} \
--moe_calculator_score_scale_factor ${scale_factor} \
--deepspeed ${deepspeed_config_file} \
--model_name_or_path ${pretrained_model} \
--model_type ${model_type} \
--tokenizer_name_or_path ${tokenizer_path} \
--dataset_dir ${dataset_dir} \
--data_cache_dir ${data_cache} \
--validation_dir ${validation_dir} \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--do_train \
--evaluation_strategy steps \
--eval_steps ${eval_steps} \
--seed ${seed} \
--bf16 \
--num_train_epochs 1 \
--final_lr_portion ${final_lr_portion} \
--optim adamw_torch \
--adam_beta1 0.9 \
--adam_beta2 0.95 \
--learning_rate ${lr} \
--weight_decay 0.1 \
--max_grad_norm 1.0 \
--warmup_steps ${warmup_steps} \
--max_steps ${max_steps} \
--max_train_samples ${max_train_samples} \
--save_strategy steps \
--save_total_limit 1 \
--save_steps ${eval_steps} \
--dataloader_num_workers 0 \
--dataloader_pin_memory True \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--block_size ${block_size} \
--output_dir ${output_dir} \
--overwrite_output_dir \
--ddp_timeout 3600 \
--ddp_find_unused_parameters False \
--torch_dtype bfloat16 \
--gradient_checkpointing \
--logging_first_step True \
--logging_strategy steps \
--logging_steps 5 \
--log_level info \
--log_level_replica warning \
--log_on_each_node False \
--report_to none \
--gate_type "TopKBalancedNoisyGate" \
--calculator_type "UniversalCalculator"
}
52 changes: 52 additions & 0 deletions scripts/tokenize/slimpajama_convert.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/usr/bin/bash

# set -vx

content_column=input_ids
src_tokenizer_dir=/mnt/petrelfs/share_data/zhutong/models/llama2_7B
tokenizer_dir=/mnt/petrelfs/share_data/zhutong/models/Mistral-7B-v0.1

data_dir=/mnt/petrelfs/share_data/zhutong/data/slimpajama_fluency_llama_middle_parts
out_dir=/mnt/petrelfs/share_data/zhutong/data/slimpajama_fluency_mistral_middle_parts
# data_dir=/mnt/petrelfs/share_data/zhutong/data/llama1_7B_val_set_tokenized
# out_dir=/mnt/petrelfs/share_data/zhutong/data/mixtral_val_set_tokenized


logs_dir=logs

mkdir -p $logs_dir

# for loop in: en_arxiv, en_book, en_c4, en_cc, en_stack, en_wikipedia, github
# for data_type in $(ls $data_dir)
for data_type in "en_arxiv" "en_book" "en_c4" "en_stack" "en_wikipedia" "github"
do
# get all parts from source data dir
for part in $(ls $data_dir/$data_type)
do
echo "tokenizing $data_dir/$data_type/$part - $(ls $data_dir/$data_type/$part | wc -l)"
log_path=logs/tokenize-$data_type-$part.log
nohup srun -p MoE_T -N1 -n1 --cpus-per-task=32 \
python -m smoe.utils.tokenize \
-f jsonl \
-c $content_column \
-s $src_tokenizer_dir \
-t $tokenizer_dir \
-i $data_dir/$data_type/$part \
-o $out_dir/$data_type/$part \
1>$log_path 2>&1 &
# echo "$data_type/$part > $log_path"
sleep 3
done

# log_path=logs/tokenize_$data_type.log
# nohup srun -p MoE_T -N1 -n1 --cpus-per-task=32 \
# python -m smoe.utils.tokenize \
# -f jsonl \
# -s $src_tokenizer_dir \
# -c $content_column \
# -t $tokenizer_dir \
# -i $data_dir/$data_type \
# -o $out_dir/$data_type \
# 1>$logs_dir/tokenize_$data_type.log 2>&1 &
# echo "$data_type > $logs_dir/tokenize_$data_type.log"
done
5 changes: 5 additions & 0 deletions smoe/entrypoint/cpt/cpt_fpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
LlamaMoEResidualConfig,
LlamaMoEResidualForCausalLM,
)
from smoe.models.mixtral.configuration_mixtral import MixtralConfig
from smoe.models.mixtral.modeling_mixtral import MixtralForCausalLM
from smoe.modules.flash_attn import replace_xformers
from smoe.trainer.llama_lr_scheduling import LlamaLrSchedulingTrainer
from smoe.utils.config import (
Expand All @@ -51,13 +53,15 @@
"llama": LlamaForCausalLM,
"llama_moe": LlamaMoEForCausalLM,
"llama_moe_residual": LlamaMoEResidualForCausalLM,
"mixtral": MixtralForCausalLM,
}

CONFIG_MAPPING.update(
{
"llama": LlamaConfig,
"llama_moe": LlamaMoEConfig,
"llama_moe_residual": LlamaMoEResidualConfig,
"mixtral": MixtralConfig,
}
)

Expand Down Expand Up @@ -276,6 +280,7 @@ def main():
# model.half()
# model.to(torch_dtype)

# TODO (tzhu): add flash-attn for mixtral
model: LlamaForCausalLM | LlamaMoEForCausalLM | LlamaMoEResidualForCausalLM = (
ModelClass.from_pretrained(
model_args.model_name_or_path,
Expand Down
62 changes: 62 additions & 0 deletions smoe/models/mixtral/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2023 Mixtral AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from transformers.utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)


_import_structure = {
"configuration_mixtral": ["MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MixtralConfig"],
}


try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_mixtral"] = [
"MixtralForCausalLM",
"MixtralModel",
"MixtralPreTrainedModel",
"MixtralForSequenceClassification",
]


if TYPE_CHECKING:
from .configuration_mixtral import MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MixtralConfig

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_mixtral import (
MixtralForCausalLM,
MixtralForSequenceClassification,
MixtralModel,
MixtralPreTrainedModel,
)


else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Loading

0 comments on commit 338ee9e

Please sign in to comment.