From 27cf9360d6ad2d3dd4ef0c0c5e08da112a308fdc Mon Sep 17 00:00:00 2001 From: zhutong Date: Sun, 25 Feb 2024 23:51:14 +0800 Subject: [PATCH] add sft contents --- README.md | 28 +- docs/LLaMA_MoE.pdf | Bin 131 -> 131 bytes docs/supervised_fine_tuning/SFT.md | 39 + scripts/sft/2_16.sh | 80 ++ scripts/sft/2_8.sh | 80 ++ scripts/sft/4_16.sh | 80 ++ smoe/entrypoint/sft/__init__.py | 0 smoe/entrypoint/sft/train_sft.py | 491 ++++++++ .../models/llama_moe/modeling_llama_moe_hf.py | 1041 +++++++++-------- smoe/utils/conversation.py | 114 ++ smoe/utils/io.py | 10 + 11 files changed, 1494 insertions(+), 469 deletions(-) create mode 100644 docs/supervised_fine_tuning/SFT.md create mode 100644 scripts/sft/2_16.sh create mode 100644 scripts/sft/2_8.sh create mode 100644 scripts/sft/4_16.sh create mode 100644 smoe/entrypoint/sft/__init__.py create mode 100644 smoe/entrypoint/sft/train_sft.py create mode 100644 smoe/utils/conversation.py diff --git a/README.md b/README.md index f9e177a..d7dc97f 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ LLaMA-MoE favicon
📢 A SMALLER AFFORDABLE MoE MODEL FOR EVERYONE!!
- 🤗 Model Weights | 🚀 Quick Start | ⚙️ Installation Guide | 🚧 Expert Construction | 🚅 Continual Pre-training | 💎 Evaluation + 🤗 Model Weights | 🚀 Quick Start | ⚙️ Installation Guide | 🚧 Expert Construction | 🚅 Continual Pre-training | 💎 Evaluation | 💬 Supervised Fine-Tuning (SFT)
📃 Technical Report @@ -84,11 +84,13 @@ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))

📊 Model Performance

-| Model | \#Activated Experts | \#Experts | \#Activated Params | Links | -| :------------------------ | :-----------------: | :-------: | :----------------: | :-----------------------------------------------------------------------: | -| **LLaMA-MoE-3.0B** | 2 | 16 | 3.0B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_0B-2_16) | -| **LLaMA-MoE-3.5B (4/16)** | 4 | 16 | 3.5B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-4_16) | -| **LLaMA-MoE-3.5B (2/8)** | 2 | 8 | 3.5B | [[🤗 HF Weights]](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8) | +| Model | \#Activated Experts | \#Experts | \#Activated Params | Foundation Model | SFT Model | +| :------------------------ | :-----------------: | :-------: | :----------------: | :---------------------------------------------------------------: | :------------------------------------------------------------------: | +| **LLaMA-MoE-3.0B** | 2 | 16 | 3.0B | [🤗 base](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_0B-2_16) | [🤗 SFT](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_0B-2_16-sft) | +| **LLaMA-MoE-3.5B (4/16)** | 4 | 16 | 3.5B | [🤗 base](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-4_16) | [🤗 SFT](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-4_16-sft) | +| **LLaMA-MoE-3.5B (2/8)** | 2 | 8 | 3.5B | [🤗 base](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8) | [🤗 SFT](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8-sft) | + +- Foundation models | Model | Average | SciQ | PIQA | WinoGrande | ARC-e | ARC-c (25) | HellaSwag (10) | LogiQA | BoolQ (32) | LAMBADA | NQ (32) | MMLU (5) | | :------------------------------------------------------------------------------------ | :------: | :------: | :------: | :--------: | :------: | :--------: | :------------: | :------: | :--------: | :------: | :------: | :------: | @@ -101,6 +103,15 @@ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)) | **LLaMA-MoE-3.5B (4/16)** | **57.7** | 87.6 | **77.9** | 65.5 | **65.6** | **44.2** | **73.3** | 29.7 | **75.0** | **69.5** | **20.3** | 26.8 | | **LLaMA-MoE-3.5B (2/8)** | 57.6 | **88.4** | 77.6 | **66.7** | 65.3 | 43.1 | **73.3** | 29.6 | 73.9 | 69.4 | 19.8 | 27.0 | +- SFT models + +| Model | MMLU | ARC-c | HellaSeag | TruthfulQA | MT-Bench | +| :------------------------------------- | :---: | :---: | :-------: | :--------: | :------: | +| Sheared LLaMA-2.7B ShareGPT | 28.41 | 41.04 | 71.21 | 47.65 | 3.79 | +| Sheared LLaMA-2.7B Deita6K (Our Impl.) | 25.24 | 43.69 | 71.70 | 49.00 | 4.06 | +| LLaMA-MoE-v1-3.0B (2/16) | 23.61 | 43.43 | 72.28 | 44.24 | 4.15 | +| LLaMA-MoE-v1-3.5B (4/16) | 26.49 | 48.29 | 75.10 | 45.91 | 4.60 | +| LLaMA-MoE-v1-3.5B (2/8) | 25.53 | 45.99 | 74.95 | 44.39 | 4.72 |

🚧 Expert Construction

@@ -152,6 +163,11 @@ python -m smoe.utils.tokenize \ - For evalution on Natural Questions (NQ), please refer to [opencompass](https://github.com/Spico197/opencompass/tree/main). - For other tasks, please refer to [lm-eval-harness](https://github.com/spico197/smoe-eval). +

💬 Supervised Fine-Tuning (SFT)

+ +We provide simple examples of SFT to build chatbots. +Please refer to [SFT docs](/mnt/petrelfs/zhutong/smoe/docs/supervised_fine_tuning/SFT.md) and `/mnt/petrelfs/zhutong/smoe/scripts/sft` for more details. +

📑 Citation

```bibtex diff --git a/docs/LLaMA_MoE.pdf b/docs/LLaMA_MoE.pdf index d6ffa8c64bfd93e5a26c6c2773be60f4dab3be53..91531c21c92bd456bf0e1cc790cbc35c9f1b5b6f 100644 GIT binary patch delta 84 zcmWN_u@S%^2mrvdb&8B2f}lfW2n4xHoi42lWaQ+_-)X0pPc6lxDG`~Ys4N{6wHK}? fxXPa)vdm{Qo<4ywngn#ORm~*weLd?KKnR#W3RoB1 delta 84 zcmWN`u@S%^2mrvdb&8AtqKHFe2vE37oi42lWaQ-A-)X0pPeSX3xHMfIs?AuXfd)1b fbuor+e-2e2jAjT $output_dir/diff.patch + env > $output_dir/env + echo -e "Job ID: ${SLURM_JOB_ID}\n\nLog: logs/llama_moe_2_16_deita-$SLURM_JOB_ID.log\n\nGit commit: $(git log -1 --oneline)\n\nGit branch: $(git branch | grep "*")\n\nComment: ${comment}" > $output_dir/comment.txt + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $output_dir/log.log + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + echo "Node: $head_node" + + torchrun \ + --nnodes 1 \ + --nproc_per_node $num_gpus \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29522 \ + -m smoe.entrypoint.sft.train_sft \ + --do_train \ + --freeze_gate True \ + --evaluation_strategy no \ + --run_name $task_name \ + --model_type $model_type \ + --model_name_or_path $model_name_or_path \ + --dataset_dir_or_path $dataset_dir_or_path \ + --output_dir $output_dir \ + --deepspeed conf/ds_bf16_zero1.json \ + --seed 12306 \ + --bf16 True \ + --tf32 True \ + --torch_dtype bfloat16 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --num_train_epochs 2 \ + --save_strategy steps \ + --save_steps 9999999999999 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type cosine \ + --logging_steps 1 \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --report_to wandb + +} diff --git a/scripts/sft/2_8.sh b/scripts/sft/2_8.sh new file mode 100644 index 0000000..113d658 --- /dev/null +++ b/scripts/sft/2_8.sh @@ -0,0 +1,80 @@ +#!/usr/bin/bash + +#SBATCH --job-name=llama_moe_2_8_deita +#SBATCH --output=logs/%x-%j.log +#SBATCH --error=logs/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=64G + +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --quotatype=auto + +export WANDB_PROJECT="llama_moe_sft" +num_gpus=4 + +{ + task_name="llama_moe_2_8_deita" + model_type="auto" + model_name_or_path="/mnt/petrelfs/zhutong/llama-moe-models/LLaMA-MoE-v1-3_5B-2_8-new" + dataset_dir_or_path="data/deita/deita_6k.jsonl" + + comment="llama-moe 2/8, deita, w/ balance loss, w/ freeze gate, w/ gate noise" + base_dir="outputs/llama_moe_sft" + output_dir="${base_dir}/${task_name}/$SLURM_JOB_NAME-$SLURM_JOB_ID" + mkdir -p $output_dir + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo -e "Job ID: ${SLURM_JOB_ID}\n\nLog: logs/llama_moe_2_8_deita-$SLURM_JOB_ID.log\n\nGit commit: $(git log -1 --oneline)\n\nGit branch: $(git branch | grep "*")\n\nComment: ${comment}" > $output_dir/comment.txt + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $output_dir/log.log + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + echo "Node: $head_node" + + torchrun \ + --nnodes 1 \ + --nproc_per_node $num_gpus \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29522 \ + -m smoe.entrypoint.sft.train_sft \ + --do_train \ + --freeze_gate True \ + --evaluation_strategy no \ + --run_name $task_name \ + --model_type $model_type \ + --model_name_or_path $model_name_or_path \ + --dataset_dir_or_path $dataset_dir_or_path \ + --output_dir $output_dir \ + --deepspeed conf/deepspeed/bf16_zero1.json \ + --seed 12306 \ + --bf16 True \ + --tf32 True \ + --torch_dtype bfloat16 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --num_train_epochs 2 \ + --save_strategy steps \ + --save_steps 9999999999999 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type cosine \ + --logging_steps 1 \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --report_to wandb + +} diff --git a/scripts/sft/4_16.sh b/scripts/sft/4_16.sh new file mode 100644 index 0000000..b6a040f --- /dev/null +++ b/scripts/sft/4_16.sh @@ -0,0 +1,80 @@ +#!/usr/bin/bash + +#SBATCH --job-name=llama_moe_4_16_deita +#SBATCH --output=logs/%x-%j.log +#SBATCH --error=logs/%x-%j.log + +#SBATCH --partition=MoE +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=64G + +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --quotatype=auto + +export WANDB_PROJECT="llama_moe_sft" +num_gpus=4 + +{ + task_name="llama_moe_4_16_deita" + model_type="auto" + model_name_or_path="/mnt/petrelfs/zhutong/llama-moe-models/LLaMA-MoE-v1-3_5B-4_16-new" + dataset_dir_or_path="data/deita/deita_6k.jsonl" + + comment="llama-moe 4/16, deita, w/ balance loss, w/ freeze gate, w/ gate noise" + base_dir="outputs/llama_moe_sft" + output_dir="${base_dir}/${task_name}/$SLURM_JOB_NAME-$SLURM_JOB_ID" + mkdir -p $output_dir + scontrol write batch_script $SLURM_JOBID $output_dir/sbatch.sh + git diff > $output_dir/diff.patch + env > $output_dir/env + echo -e "Job ID: ${SLURM_JOB_ID}\n\nLog: logs/llama_moe_4_16_deita-$SLURM_JOB_ID.log\n\nGit commit: $(git log -1 --oneline)\n\nGit branch: $(git branch | grep "*")\n\nComment: ${comment}" > $output_dir/comment.txt + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $output_dir/log.log + echo "$SLURM_JOB_ID" > $base_dir/latest.jobid + ln -snf $output_dir $base_dir/latest.dir + ln -snf $(scontrol show job $SLURM_JOB_ID | grep "StdOut=" | cut -d '=' -f 2) $base_dir/latest.log + + nodes=($(scontrol show hostnames $SLURM_JOB_NODELIS)) + nodes_array=($nodes) + head_node=${nodes_array[0]} + echo "Node: $head_node" + + torchrun \ + --nnodes 1 \ + --nproc_per_node $num_gpus \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29522 \ + -m smoe.entrypoint.sft.train_sft \ + --do_train \ + --freeze_gate True \ + --evaluation_strategy no \ + --run_name $task_name \ + --model_type $model_type \ + --model_name_or_path $model_name_or_path \ + --dataset_dir_or_path $dataset_dir_or_path \ + --output_dir $output_dir \ + --deepspeed conf/ds_bf16_zero1.json \ + --seed 12306 \ + --bf16 True \ + --tf32 True \ + --torch_dtype bfloat16 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --num_train_epochs 2 \ + --save_strategy steps \ + --save_steps 9999999999999 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type cosine \ + --logging_steps 1 \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --report_to wandb + +} diff --git a/smoe/entrypoint/sft/__init__.py b/smoe/entrypoint/sft/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/smoe/entrypoint/sft/train_sft.py b/smoe/entrypoint/sft/train_sft.py new file mode 100644 index 0000000..4a20ae0 --- /dev/null +++ b/smoe/entrypoint/sft/train_sft.py @@ -0,0 +1,491 @@ +import math +import pathlib +import random +from dataclasses import dataclass, field +from typing import Any, Dict, Mapping, Optional + +import numpy as np +import torch +import transformers +from loguru import logger +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer, Trainer +from transformers.trainer_pt_utils import LabelSmoother + +from smoe.utils.conversation import Conversation +from smoe.utils.io import load_json, load_jsonlines + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default=None) + tokenizer_name_or_path: Optional[str] = field(default=None) + trust_remote_code: bool = field( + default=True, + metadata={ + "help": "Whether or not to allow for custom models defined on the Hub in their own modeling files" + }, + ) + padding_side: str = field( + default="right", metadata={"help": "The padding side in tokenizer"} + ) + model_type: str = field( + default="auto", metadata={"help": "Model type: `moe` or `mixtral` or `auto`"} + ) + torch_dtype: str = field( + default="auto", + metadata={"help": "Torch dtype: `float32` or `bfloat16`"}, + ) + additional_config: str = field( + default=None, + metadata={"help": "Additional config file (in json) to load"}, + ) + attn_impl: str = field( + default="flash_attention_2", + metadata={ + "help": "attention implementation, choice from [eager, flash_attention_2, sdpa] (default: `flash_attention_2`)" + }, + ) + + def __post_init__(self): + if hasattr(torch, self.torch_dtype): + self.torch_dtype = getattr(torch, self.torch_dtype) + if self.additional_config is not None: + if not pathlib.Path(self.additional_config).exists(): + raise ValueError( + f"Additional config file {self.additional_config} not found" + ) + self.additional_config = load_json(self.additional_config) + + +@dataclass +class DataArguments: + eval_data_dir: str = field( + default=None, metadata={"help": "Path to the evaluation data folder."} + ) + dataset_dir_or_path: str = field( + default="data/merged", + metadata={"help": "Path to dataset directory or a single jsonl file"}, + ) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=2048, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + freeze_gate: bool = field( + default=False, + metadata={"help": "Whether to freeze the gate during training."}, + ) + save_final_ckpt: bool = field( + default=True, + metadata={"help": "Whether to save final checkpoint."}, + ) + + +def trainer_save_model_safe(trainer): + from torch.distributed.fsdp import FullStateDictConfig + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import StateDictType + + save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type( + trainer.model, StateDictType.FULL_STATE_DICT, save_policy + ): + trainer.save_model() + + +def preprocess( + instances, + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + tokenizer_legacy = getattr(tokenizer, "legacy", True) + conv = Conversation() + conv.sep2 = tokenizer.eos_token + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, ins in enumerate(instances): + if roles[ins["conversations"][0]["from"]] != roles["human"]: + # Skip the first one if it is not from human + ins["conversations"] = ins["conversations"][1:] + + conv.clear_msg() + sys_msg = ins.get("system_prompt") + if sys_msg is not None: + conv.set_system_message(sys_msg) + else: + conv.set_system_message("") + for j, turn in enumerate(ins["conversations"]): + role = roles[turn["from"]] + assert role == conv.roles[j % 2], f"{i}/{j}" + conv.append_message(role, turn["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + res = tokenizer( + conversations, + return_tensors="pt", + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + ) + input_ids = res["input_ids"] + attention_masks = res["attention_mask"] + targets = input_ids.clone() + + # Mask targets. Only compute loss on the assistant outputs. + sep = conv.sep + conv.roles[1] + ": " + # attention_masks = torch.ones_like(input_ids) + for conversation, target, attention_mask in zip( + conversations, targets, attention_masks + ): + turns = conversation.split(conv.sep2) + # the eos token is included in `total_len`, llama2 will add bos token + # total_len = int(target.ne(tokenizer.pad_token_id).sum()) + len(turns) * int(tokenizer.pad_token == tokenizer.eos_token) + # attention_mask[total_len:] = 0 + total_len = attention_mask.sum() + + cur_len = 0 + has_bos = False + if target[0] == tokenizer.bos_token_id: + cur_len = 1 + target[:cur_len] = IGNORE_TOKEN_ID # bos token + has_bos = True + for i, turn in enumerate(turns): + if turn == "": + break + # +1: add sep2 token + turn_len = len(tokenizer(turn).input_ids) - int(has_bos) + 1 + + # sep: " ASSISTANT: " + parts = turn.split(sep) + if len(parts) != 2: + break + parts[0] += sep + # "-2" is hardcoded for the Llama tokenizer to make the offset correct: bos and the last space token + # -1 means remove extra suffix space in sep + instruction_len = len(tokenizer(parts[0]).input_ids) - int(has_bos) - 1 + + if i != 0 and not tokenizer_legacy: + # The legacy and non-legacy modes handle special tokens differently + instruction_len -= 1 + + # Ignore the user instructions + target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID + cur_len += turn_len + # if i < len(turns) - 1: + # # plus one for sep2 token (eos) + # cur_len += 1 + + if i != 0 and not tokenizer_legacy: + # The legacy and non-legacy modes handle special tokens differently + cur_len -= 1 + + target[cur_len:] = IGNORE_TOKEN_ID + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + logger.info( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" #turn = {len(turns) - 1}. (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=attention_masks, + ) + + +def fault_tolerance_data_collator(features: list) -> dict[str, Any]: + if not isinstance(features[0], Mapping): + try: + features = [vars(f) for f in features] + except TypeError: + print(len(features), type(features[0]), features[0]) + first = features[0] + batch = {} + + # Special handling for labels. + # Ensure that tensor is created with the correct type + # (it should be automatically the case, but let's make sure of it.) + if "label" in first and first["label"] is not None: + label = ( + first["label"].item() + if isinstance(first["label"], torch.Tensor) + else first["label"] + ) + dtype = torch.long if isinstance(label, int) else torch.float + batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype) + elif "label_ids" in first and first["label_ids"] is not None: + if isinstance(first["label_ids"], torch.Tensor): + batch["labels"] = torch.stack([f["label_ids"] for f in features]) + else: + dtype = ( + torch.long if isinstance(first["label_ids"][0], int) else torch.float + ) + batch["labels"] = torch.tensor( + [f["label_ids"] for f in features], dtype=dtype + ) + + # Handling of all other possible keys. + # Again, we will use the first element to figure out which key/values are not None for this model. + + try: + for k, v in first.items(): + if ( + k not in ("label", "label_ids") + and v is not None + and not isinstance(v, str) + ): + if isinstance(v, torch.Tensor): + batch[k] = torch.stack([f[k] for f in features]) + elif isinstance(v, np.ndarray): + batch[k] = torch.tensor(np.stack([f[k] for f in features])) + else: + batch[k] = torch.tensor([f[k] for f in features]) + except ValueError: # quick fix by simply take the first example + for k, v in first.items(): + if ( + k not in ("label", "label_ids") + and v is not None + and not isinstance(v, str) + ): + if isinstance(v, torch.Tensor): + batch[k] = torch.stack([features[0][k]] * len(features)) + elif isinstance(v, np.ndarray): + batch[k] = torch.tensor(np.stack([features[0][k]] * len(features))) + else: + batch[k] = torch.tensor([features[0][k]] * len(features)) + + return batch + + +class CachedJsonlDataset(Dataset): + def __init__( + self, + datapath: str, + tokenizer: PreTrainedTokenizer, + seed: int = 1227, + ) -> None: + super().__init__() + self.datapath = datapath + self.rng = random.Random(seed) + self.tokenizer = tokenizer + self.data = load_jsonlines(datapath) + self.rng.shuffle(self.data) + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, index): + ins = self.data[index] + processed = preprocess([ins], self.tokenizer) + ins = {} + for key in processed: + ins[key] = processed[key][0] + return ins + + def state_dict(self): + return { + "datapath": self.datapath, + "seed": self.seed, + "rng": self.rng.getstate(), + } + + +def get_tokenizer( + model_name_or_path, + cache_dir: str = None, + model_max_length: int = 2048, + padding_side: str = "right", + use_fast: bool = False, + trust_remote_code: bool = False, +): + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_name_or_path, + cache_dir=cache_dir, + model_max_length=model_max_length, + padding_side=padding_side, + use_fast=use_fast, + trust_remote_code=trust_remote_code, + ) + if tokenizer.pad_token is None: + if tokenizer.unk_token is not None: + tokenizer.pad_token = tokenizer.unk_token + else: + tokenizer.pad_token = tokenizer.eos_token + logger.info(f"tokenizer ready, pad_token: {tokenizer.pad_token}") + return tokenizer + + +def get_model( + model_type: str, + model_name_or_path: str, + torch_dtype: str = "auto", + model_max_length: int = 2048, + attn_impl: str = "flash_attention_2", + cache_dir: str = None, + trust_remote_code: bool = False, + additional_config: dict = None, +): + logger.info(f"Model type: {model_type}") + if model_type == "auto": + ConfigClass = transformers.AutoConfig + ModelClass = transformers.AutoModelForCausalLM + else: + raise ValueError(f"Unknown model type: {model_type}") + + # Set RoPE scaling factor + config = ConfigClass.from_pretrained( + model_name_or_path, + cache_dir=cache_dir, + trust_remote_code=trust_remote_code, + ) + orig_ctx_len = getattr(config, "max_position_embeddings", None) + if orig_ctx_len and model_max_length > orig_ctx_len: + scaling_factor = float(math.ceil(model_max_length / orig_ctx_len)) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + config.use_cache = False + if additional_config is not None: + config.update(additional_config) + logger.info("Config ready") + + # Load model and tokenizer + model = ModelClass.from_pretrained( + model_name_or_path, + config=config, + cache_dir=cache_dir, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + attn_implementation=attn_impl, + ) + logger.info("model ready") + + return model + + +def get_model_and_tokenizer( + model_type: str, + model_name_or_path: str, + tokenizer_path: str = None, + torch_dtype: str = "auto", + model_max_length: int = 2048, + attn_impl: str = "flash_attention_2", + cache_dir: str = None, + trust_remote_code: bool = False, + padding_side: str = "right", + additional_config: dict = None, + use_fast: bool = False, +) -> tuple: + if tokenizer_path is None: + tokenizer_path = model_name_or_path + tokenizer = get_tokenizer( + tokenizer_path, + cache_dir=cache_dir, + model_max_length=model_max_length, + padding_side=padding_side, + use_fast=use_fast, + trust_remote_code=trust_remote_code, + ) + model = get_model( + model_type, + model_name_or_path, + torch_dtype=torch_dtype, + model_max_length=model_max_length, + attn_impl=attn_impl, + cache_dir=cache_dir, + trust_remote_code=trust_remote_code, + additional_config=additional_config, + ) + + return model, tokenizer + + +def train(): + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: ModelArguments + data_args: DataArguments + training_args: TrainingArguments + logger.info(f"model_args: {model_args}") + logger.info(f"data_args: {data_args}") + logger.info(f"training_args: {training_args}") + + model, tokenizer = get_model_and_tokenizer( + model_args.model_type, + model_args.model_name_or_path, + tokenizer_path=model_args.tokenizer_name_or_path, + trust_remote_code=model_args.trust_remote_code, + padding_side=model_args.padding_side, + torch_dtype=model_args.torch_dtype, + additional_config=model_args.additional_config, + attn_impl=model_args.attn_impl, + model_max_length=training_args.model_max_length, + cache_dir=training_args.cache_dir, + ) + if training_args.freeze_gate: + for name, param in model.named_parameters(): + if "gate" in name: + param.requires_grad = False + + train_dataset = None + datapath = pathlib.Path(data_args.dataset_dir_or_path) + if not datapath.exists(): + raise ValueError(f"Dataset path {datapath} not found") + elif datapath.is_file(): + logger.info(f"CachedJsonlDataset: {datapath}") + train_dataset = CachedJsonlDataset( + data_args.dataset_dir_or_path, + tokenizer, + seed=training_args.seed, + ) + else: + raise ValueError(f"Unknown dataset path type: {datapath}") + logger.info("train dataset ready") + + trainer = Trainer( + model=model, + tokenizer=tokenizer, + args=training_args, + train_dataset=train_dataset, + data_collator=fault_tolerance_data_collator, + ) + logger.info("trainer ready") + + if training_args.do_train: + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + logger.info("resume training from ckpt") + trainer.train(resume_from_checkpoint=True) + else: + logger.info("start training") + trainer.train() + + # Save model + if training_args.save_final_ckpt: + logger.info("training finished, dumping model") + model.config.use_cache = True + trainer.save_state() + if trainer.is_deepspeed_enabled: + trainer.save_model() + else: + trainer_save_model_safe(trainer) + + logger.info("🎉 All done~") + + +if __name__ == "__main__": + train() diff --git a/smoe/models/llama_moe/modeling_llama_moe_hf.py b/smoe/models/llama_moe/modeling_llama_moe_hf.py index da79503..af7de67 100644 --- a/smoe/models/llama_moe/modeling_llama_moe_hf.py +++ b/smoe/models/llama_moe/modeling_llama_moe_hf.py @@ -9,12 +9,43 @@ import torch.utils.checkpoint from torch.distributions.normal import Normal from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ModelOutput, logging +from transformers.utils import ( + ModelOutput, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, +) from .configuration_llama_moe import LlamaMoEConfig +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "LlamaMoEConfig" @@ -37,18 +68,20 @@ class BaseMoEModelOutputWithPast(ModelOutput): past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None - balance_loss: Optional[float] = None + balance_loss: Optional[torch.FloatTensor] = None num_dropped_tokens: Optional[Tuple[torch.Tensor]] = None - gate_load: Optional[Tuple[list]] = None - gate_importance: Optional[Tuple[list]] = None + gate_load: Optional[torch.LongTensor] = None + gate_importance: Optional[torch.FloatTensor] = None + expert2tokens: Optional[dict] = None @dataclass class MoECausalLMOutputWithPast(CausalLMOutputWithPast): - balance_loss: Optional[float] = None + balance_loss: Optional[torch.FloatTensor] = None num_dropped_tokens: Optional[Tuple[int]] = None - gate_load: Optional[Tuple[list[torch.Tensor]]] = None - gate_importance: Optional[Tuple[list[torch.Tensor]]] = None + gate_load: Optional[Tuple[torch.LongTensor]] = None + gate_importance: Optional[Tuple[torch.FloatTensor]] = None + expert2tokens: Optional[dict] = None @dataclass @@ -56,8 +89,9 @@ class MoEMlpOutput(ModelOutput): hidden_states: Optional[torch.FloatTensor] = None balance_loss: Optional[torch.FloatTensor] = None num_dropped_tokens: Optional[int] = None - gate_load: Optional[list] = None - gate_importance: Optional[list] = None + gate_load: Optional[torch.LongTensor] = None + gate_importance: Optional[torch.FloatTensor] = None + expert2tokens: Optional[dict] = None def _make_causal_mask( @@ -262,45 +296,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): return q_embed, k_embed -class LlamaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.pretraining_tp = config.pretraining_tp - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - if self.pretraining_tp > 1: - slice = self.intermediate_size // self.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat( - [F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], - dim=-1, - ) - up_proj = torch.cat( - [F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], - dim=-1, - ) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [ - F.linear(intermediate_states[i], down_proj_slices[i]) - for i in range(self.pretraining_tp) - ] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - return down_proj - - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -318,40 +313,57 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: LlamaMoEConfig): + def __init__(self, config: LlamaMoEConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.pretraining_tp = config.pretraining_tp self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) + self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False + self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, ) self.v_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, ) self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False + self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias ) self._init_rope() def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb = LlamaRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, ) else: scaling_type = self.config.rope_scaling["type"] @@ -361,12 +373,14 @@ def _init_rope(self): self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, + base=self.rope_theta, ) elif scaling_type == "dynamic": self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, + base=self.rope_theta, ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") @@ -383,37 +397,43 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() - if self.pretraining_tp > 1: + if self.config.pretraining_tp > 1: key_value_slicing = ( self.num_key_value_heads * self.head_dim - ) // self.pretraining_tp + ) // self.config.pretraining_tp query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 ) key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) query_states = [ F.linear(hidden_states, query_slices[i]) - for i in range(self.pretraining_tp) + for i in range(self.config.pretraining_tp) ] query_states = torch.cat(query_states, dim=-1) key_states = [ F.linear(hidden_states, key_slices[i]) - for i in range(self.pretraining_tp) + for i in range(self.config.pretraining_tp) ] key_states = torch.cat(key_states, dim=-1) value_states = [ F.linear(hidden_states, value_slices[i]) - for i in range(self.pretraining_tp) + for i in range(self.config.pretraining_tp) ] value_states = torch.cat(value_states, dim=-1) @@ -434,20 +454,24 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) - # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -472,6 +496,9 @@ def forward( attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): @@ -481,19 +508,20 @@ def forward( ) attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - if self.pretraining_tp > 1: + if self.config.pretraining_tp > 1: attn_output = attn_output.split( - self.hidden_size // self.pretraining_tp, dim=2 + self.hidden_size // self.config.pretraining_tp, dim=2 ) o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.pretraining_tp, dim=1 + self.hidden_size // self.config.pretraining_tp, dim=1 ) attn_output = sum( [ F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.pretraining_tp) + for i in range(self.config.pretraining_tp) ] ) else: @@ -505,6 +533,356 @@ def forward( return attn_output, attn_weights, past_key_value +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # LlamaFlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input( + attn_output_unpad, indices_q, batch_size, query_length + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output + + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + class TopKBalancedNoisyGate(nn.Module): def __init__( self, @@ -940,12 +1318,38 @@ def _create_calculator(self, experts, **kwargs): else: raise NotImplementedError - def forward(self, x) -> MoEMlpOutput: + def forward(self, x, attention_mask=None) -> MoEMlpOutput: original_shape = x.shape[:-1] x = x.reshape(-1, self.input_size) + flattened_mask = None + if attention_mask is not None and len(attention_mask.shape) == 2: + flattened_mask = attention_mask.flatten() + flattened_shape = flattened_mask.shape + x = x[flattened_mask.bool()] + gate_outputs: dict = self.gate(x) calc_outs: CalculatorOutput = self.calculator(x, **gate_outputs) + y = calc_outs.hidden_states + expert2tokens = None + if flattened_mask is not None: + y = torch.zeros( + flattened_shape + (self.output_size,), dtype=x.dtype, device=x.device + ) # (batch_size*seq_len, output_size) + y[ + flattened_mask.bool() + ] = calc_outs.hidden_states # (non_padding_num, output_size) + + # # for stats only + # expert2tokens = {e: set() for e in range(self.gate.num_experts)} + # i = 0 + # for t, m in enumerate(flattened_mask.tolist()): + # if m: + # selected_experts = gate_outputs["topK_indices"][i].tolist() + # for e in selected_experts: + # expert2tokens[e].add(t) + # i += 1 + y = y.reshape(original_shape + (self.output_size,)) return MoEMlpOutput( @@ -954,115 +1358,9 @@ def forward(self, x) -> MoEMlpOutput: num_dropped_tokens=calc_outs.num_dropped_tokens, gate_load=gate_outputs.get("load", torch.tensor(-1)), gate_importance=gate_outputs.get("importance", torch.tensor(-1)), + expert2tokens=expert2tokens, ) - def set_num_selects(self, num_selects): - if "num_selects" not in vars(self.gate): - raise KeyError(f'{self.gate_type} does not have a key named "num_selects".') - elif num_selects > self.gate.num_experts: - raise ValueError( - 'The value of "num_selects" must satisfy "num_selects <= num_experts"!' - ) - elif self.gate_type in ("SwitchBalancedGate",): - raise ValueError( - f"{self.gate_type} doesn't support manually setting num_selects." - ) - else: - self.num_selects = num_selects - self.gate.num_selects = num_selects - - def set_gate_use_softmax(self, use_softmax): - if "use_softmax" not in vars(self.gate): - raise KeyError(f'{self.gate_type} does not have a key named "use_softmax".') - else: - self.gate.use_softmax = use_softmax - - def set_gate_use_balance(self, use_balance): - if "use_balance" not in vars(self.gate): - raise KeyError(f'{self.gate_type} does not have a key named "use_balance".') - else: - self.gate.use_balance = use_balance - - def set_gate_balance_loss_weight(self, balance_loss_weight): - if "balance_loss_weight" not in vars(self.gate): - raise KeyError( - f'{self.gate_type} does not have a key named "balance_loss_weight".' - ) - else: - self.gate.balance_loss_weight = balance_loss_weight - - def set_gate_add_noise(self, add_noise): - if "add_noise" not in vars(self.gate): - raise KeyError(f'{self.gate_type} does not have a key named "add_noise".') - else: - self.gate.add_noise = add_noise - - def set_gate_noise_epsilon(self, noise_epsilon): - if "noise_epsilon" not in vars(self.gate): - raise KeyError( - f'{self.gate_type} does not have a key named "noise_epsilon".' - ) - else: - self.gate.noise_epsilon = noise_epsilon - - def set_calculator_multiply_gate_scores(self, multiply_gate_scores): - if "multiply_gate_scores" not in vars(self.calculator): - raise KeyError( - f'{self.gate_type} does not have a key named "multiply_gate_scores".' - ) - else: - self.calculator.multiply_gate_scores = multiply_gate_scores - - def set_calculator_score_scale_factor(self, score_scale_factor): - if "score_scale_factor" not in vars(self.calculator): - raise KeyError( - f'{self.gate_type} does not have a key named "score_scale_factor".' - ) - else: - self.calculator.score_scale_factor = score_scale_factor - - def set_calculator_drop_tokens(self, drop_tokens): - if "drop_tokens" not in vars(self.calculator): - raise KeyError(f'{self.gate_type} does not have a key named "drop_tokens".') - elif ( - drop_tokens - and self.calculator.dropped_padding != "zero" - and self.input_size != self.output_size - ): - warnings.warn( - 'Setting "drop_tokens=True" without zero dropped padding when "input_size != output_size" will cause error!' - ) - else: - self.calculator.drop_tokens = drop_tokens - - def set_calculator_dropped_padding(self, dropped_padding): - if "dropped_padding" not in vars(self.calculator): - raise KeyError( - f'{self.gate_type} does not have a key named "dropped_padding".' - ) - elif dropped_padding not in self.calculator.available_dropped_padding_choices: - raise ValueError( - f"'dropped_padding' type not available! (available choices: {self.calculator.available_dropped_padding_choices})" - ) - elif ( - self.calculator.drop_tokens - and dropped_padding != "zero" - and self.input_size != self.output_size - ): - warnings.warn( - f'Setting "dropped_padding={dropped_padding}" with "drop_tokens=True" when "input_size != output_size" will cause error!' - ) - else: - self.calculator.dropped_padding = dropped_padding - - def set_calculator_capacity_factor(self, capacity_factor): - if "capacity_factor" not in vars(self.calculator): - raise KeyError( - f'{self.gate_type} does not have a key named "capacity_factor".' - ) - else: - self.calculator.capacity_factor = capacity_factor - def reset_gate_network(self): self.gate.reset_gate_network() @@ -1113,8 +1411,11 @@ def __init__(self, config: LlamaMoEConfig, layer_index): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP(config) + # self.self_attn = LlamaAttention(config=config) + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_index + ) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -1190,7 +1491,7 @@ def forward( # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - mlp_outs: MoEMlpOutput = self.mlp(hidden_states) + mlp_outs: MoEMlpOutput = self.mlp(hidden_states, attention_mask=attention_mask) hidden_states = residual + mlp_outs.hidden_states outputs = ( @@ -1199,6 +1500,7 @@ def forward( mlp_outs.num_dropped_tokens, mlp_outs.gate_load, mlp_outs.gate_importance, + mlp_outs.expert2tokens, ) if output_attentions: outputs += (self_attn_weights,) @@ -1207,45 +1509,6 @@ def forward( return outputs - def set_moe_num_selects(self, num_selects): - self.mlp.set_num_selects(num_selects) - - def set_moe_gate_use_softmax(self, use_softmax): - self.mlp.set_gate_use_softmax(use_softmax) - - def set_moe_gate_use_balance(self, use_balance): - self.mlp.set_gate_use_balance(use_balance) - - def set_moe_gate_balance_loss_weight(self, balance_loss_weight): - self.mlp.set_gate_balance_loss_weight(balance_loss_weight) - - def set_moe_gate_add_noise(self, add_noise): - self.mlp.set_gate_add_noise(add_noise) - - def set_moe_gate_noise_epsilon(self, noise_epsilon): - self.mlp.set_gate_noise_epsilon(noise_epsilon) - - def set_moe_calculator_multiply_gate_scores(self, multiply_gate_scores): - self.mlp.set_calculator_multiply_gate_scores(multiply_gate_scores) - - def set_moe_calculator_score_scale_factor(self, score_scale_factor): - self.mlp.set_calculator_score_scale_factor(score_scale_factor) - - def set_moe_calculator_drop_tokens(self, drop_tokens): - self.mlp.set_calculator_drop_tokens(drop_tokens) - - def set_moe_calculator_dropped_padding(self, dropped_padding): - self.mlp.set_calculator_dropped_padding(dropped_padding) - - def set_moe_calculator_capacity_factor(self, capacity_factor): - self.mlp.set_calculator_capacity_factor(capacity_factor) - - def reset_gate_network(self): - self.mlp.reset_gate_network() - - def reset_experts(self): - self.mlp.reset_experts() - class LlamaMoEPreTrainedModel(PreTrainedModel): config_class = LlamaMoEConfig @@ -1253,6 +1516,7 @@ class LlamaMoEPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlamaMoEDecoderLayer"] _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True def _init_weights(self, module): std = self.config.initializer_range @@ -1265,10 +1529,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaMoEModel): - module.gradient_checkpointing = value - class LlamaMoEModel(LlamaMoEPreTrainedModel): def __init__(self, config: LlamaMoEConfig): @@ -1282,6 +1542,8 @@ def __init__(self, config: LlamaMoEConfig): self.layers = nn.ModuleList( [LlamaMoEDecoderLayer(config, i) for i in range(config.num_hidden_layers)] ) + self._use_sdpa = config._attn_implementation == "sdpa" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False self.post_init() @@ -1292,34 +1554,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length - ): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ).to(inputs_embeds.device) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids=None, @@ -1363,12 +1597,19 @@ def forward( "You have to specify either decoder_input_ids or decoder_inputs_embeds" ) - seq_length_with_past = seq_length - past_key_values_length = 0 + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -1378,75 +1619,68 @@ def forward( dtype=torch.long, device=device, ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + position_ids = position_ids.unsqueeze(0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = ( + attention_mask + if (attention_mask is not None and 0 in attention_mask) + else None + ) + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) hidden_states = inputs_embeds balance_loss = 0.0 - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing." - " Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None num_dropped_tokens = () gate_load = () gate_importance = () + expert2tokens = () for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = ( - past_key_values[idx] if past_key_values is not None else None - ) - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs: tuple = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, position_ids, - None, + past_key_values, + output_attentions, + use_cache, ) else: - layer_outputs: tuple = decoder_layer( + layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) @@ -1456,7 +1690,7 @@ def custom_forward(*inputs): balance_loss += layer_outputs[1] if use_cache: - next_decoder_cache += (layer_outputs[6 if output_attentions else 5],) + next_decoder_cache = layer_outputs[6 if output_attentions else 5] if output_attentions: all_self_attns += (layer_outputs[5],) @@ -1464,6 +1698,7 @@ def custom_forward(*inputs): num_dropped_tokens += (layer_outputs[2],) gate_load += (layer_outputs[3],) gate_importance += (layer_outputs[4],) + expert2tokens += (layer_outputs[-1],) hidden_states = self.norm(hidden_states) @@ -1471,7 +1706,13 @@ def custom_forward(*inputs): if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) if not return_dict: return tuple( v @@ -1487,125 +1728,9 @@ def custom_forward(*inputs): num_dropped_tokens=num_dropped_tokens, gate_load=gate_load, gate_importance=gate_importance, + expert2tokens=expert2tokens, ) - def update_config(self): - self.config.vocab_size = self.config.vocab_size - self.config.max_position_embeddings = self.config.max_position_embeddings - # ↓↓↓↓↓↓↓↓↓↓↓↓ changed here ↓↓↓↓↓↓↓↓↓↓↓↓ # - self.config.hidden_size = self.layers[0].mlp.input_size - self.config.intermediate_size = self.layers[0].mlp.hidden_size - self.config.num_hidden_layers = len(self.layers) - self.config.num_attention_heads = self.layers[0].self_attn.num_heads - self.config.hidden_act = self.layers[0].mlp.hidden_act - # ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑ # - self.config.initializer_range = self.config.initializer_range - self.config.rms_norm_eps = self.config.rms_norm_eps - self.config.pretraining_tp = self.config.pretraining_tp - self.config.use_cache = self.config.use_cache - self.config.rope_scaling = self.config.rope_scaling - self.config._rope_scaling_validation() - - self.config.num_experts = self.layers[0].mlp.num_experts - self.config.num_selects = self.layers[0].mlp.num_selects - self.config.size_experts = [ - self.layers[i].mlp.calculator.experts.size_experts - for i in range(self.config.num_hidden_layers) - ] - - self.config.gate_type = vars(self.layers[0].mlp).get( - "gate_type", "TopKBalancedNoisyGate" - ) - self.config.gate_network = vars(self.layers[0].mlp.gate).get( - "gate_network_type", "mlp" - ) - self.config.gate_use_softmax = vars(self.layers[0].mlp.gate).get( - "use_softmax", True - ) - self.config.gate_use_balance = vars(self.layers[0].mlp.gate).get( - "use_balance", True - ) - self.config.gate_balance_loss_weight = vars(self.layers[0].mlp.gate).get( - "balance_loss_weight", 1e-2 - ) - self.config.gate_add_noise = vars(self.layers[0].mlp.gate).get( - "add_noise", True - ) - self.config.gate_noise_epsilon = vars(self.layers[0].mlp.gate).get( - "noise_epsilon", 1e-2 - ) - - self.config.calculator_type = vars(self.layers[0].mlp).get( - "calculator_type", "UniversalCalculator" - ) - self.config.multiply_gate_scores = vars(self.layers[0].mlp.calculator).get( - "multiply_gate_scores", True - ) - self.config.score_scale_factor = [ - vars(self.layers[i].mlp.calculator).get("score_scale_factor", 1.0) - for i in range(self.config.num_hidden_layers) - ] - self.config.drop_tokens = vars(self.layers[0].mlp.calculator).get( - "drop_tokens", True - ) - self.config.dropped_padding = vars(self.layers[0].mlp.calculator).get( - "dropped_padding", "zero" - ) - self.config.capacity_factor = vars(self.layers[0].mlp.calculator).get( - "capacity_factor", 1.25 - ) - - def set_moe_num_selects(self, num_selects): - for idx, decoder_layer in enumerate(self.layers): - decoder_layer.set_moe_num_selects(num_selects) - - def set_moe_gate_use_softmax(self, use_softmax): - for idx, decoder_layer in enumerate(self.layers): - decoder_layer.set_moe_gate_use_softmax(use_softmax) - - def set_moe_gate_use_balance(self, use_balance): - for idx, decoder_layer in enumerate(self.layers): - decoder_layer.set_moe_gate_use_balance(use_balance) - - def set_moe_gate_balance_loss_weight(self, balance_loss_weight): - for idx, decoder_layer in enumerate(self.layers): - decoder_layer.set_moe_gate_balance_loss_weight(balance_loss_weight) - - def set_moe_gate_add_noise(self, add_noise): - for idx, decoder_layer in enumerate(self.layers): - decoder_layer.set_moe_gate_add_noise(add_noise) - - def set_moe_gate_noise_epsilon(self, noise_epsilon): - for idx, decoder_layer in enumerate(self.layers): - decoder_layer.set_moe_gate_noise_epsilon(noise_epsilon) - - def set_moe_calculator_multiply_gate_scores(self, multiply_gate_scores): - for idx, decoder_layer in enumerate(self.layers): - decoder_layer.set_moe_calculator_multiply_gate_scores(multiply_gate_scores) - - def set_moe_calculator_score_scale_factor( - self, score_scale_factor, layer_index=None - ): - if layer_index is None: - for idx, decoder_layer in enumerate(self.layers): - decoder_layer.set_moe_calculator_score_scale_factor(score_scale_factor) - else: - self.layers[layer_index].set_moe_calculator_score_scale_factor( - score_scale_factor - ) - - def set_moe_calculator_drop_tokens(self, drop_tokens): - for idx, decoder_layer in enumerate(self.layers): - decoder_layer.set_moe_calculator_drop_tokens(drop_tokens) - - def set_moe_calculator_dropped_padding(self, dropped_padding): - for idx, decoder_layer in enumerate(self.layers): - decoder_layer.set_moe_calculator_dropped_padding(dropped_padding) - - def set_moe_calculator_capacity_factor(self, capacity_factor): - for idx, decoder_layer in enumerate(self.layers): - decoder_layer.set_moe_calculator_capacity_factor(capacity_factor) - def reset_gate_network(self): for idx, decoder_layer in enumerate(self.layers): decoder_layer.reset_gate_network() @@ -1719,6 +1844,7 @@ def forward( balance_loss=outputs.balance_loss, gate_load=outputs.gate_load, gate_importance=outputs.gate_importance, + expert2tokens=outputs.expert2tokens, ) def prepare_inputs_for_generation( @@ -1729,8 +1855,37 @@ def prepare_inputs_for_generation( inputs_embeds=None, **kwargs, ): - if past_key_values: - input_ids = input_ids[:, -1:] + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -1738,7 +1893,7 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: @@ -1768,46 +1923,6 @@ def _reorder_cache(past_key_values, beam_idx): ) return reordered_past - def update_config(self): - self.model.update_config() - - def set_moe_num_selects(self, num_selects): - self.model.set_moe_num_selects(num_selects) - - def set_moe_gate_use_softmax(self, use_softmax): - self.model.set_moe_gate_use_softmax(use_softmax) - - def set_moe_gate_use_balance(self, use_balance): - self.model.set_moe_gate_use_balance(use_balance) - - def set_moe_gate_balance_loss_weight(self, balance_loss_weight): - self.model.set_moe_gate_balance_loss_weight(balance_loss_weight) - - def set_moe_gate_add_noise(self, add_noise): - self.model.set_moe_gate_add_noise(add_noise) - - def set_moe_gate_noise_epsilon(self, noise_epsilon): - self.model.set_moe_gate_noise_epsilon(noise_epsilon) - - def set_moe_calculator_multiply_gate_scores(self, multiply_gate_scores): - self.model.set_moe_calculator_multiply_gate_scores(multiply_gate_scores) - - def set_moe_calculator_score_scale_factor( - self, score_scale_factor, layer_index=None - ): - self.model.set_moe_calculator_score_scale_factor( - score_scale_factor, layer_index=layer_index - ) - - def set_moe_calculator_drop_tokens(self, drop_tokens): - self.model.set_moe_calculator_drop_tokens(drop_tokens) - - def set_moe_calculator_dropped_padding(self, dropped_padding): - self.model.set_moe_calculator_dropped_padding(dropped_padding) - - def set_moe_calculator_capacity_factor(self, capacity_factor): - self.model.set_moe_calculator_capacity_factor(capacity_factor) - def reset_gate_network(self): self.model.reset_gate_network() diff --git a/smoe/utils/conversation.py b/smoe/utils/conversation.py new file mode 100644 index 0000000..c1d610c --- /dev/null +++ b/smoe/utils/conversation.py @@ -0,0 +1,114 @@ +from typing import List, Tuple, Union + + +class Conversation: + """A class that manages prompt templates and keeps all conversation history.""" + + def __init__(self): + # The name of this template + self.name: str = "vicuna_v1.1" + # The template of the system prompt + self.system_template: str = "{system_message}" + # The system message + self.system_message: str = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." + # The names of two roles + self.roles: Tuple[str] = ("USER", "ASSISTANT") + # All messages. Each item is (role, message). + self.messages: List[List[str]] = [] + # The number of few shot examples + self.offset: int = 0 + self.sep: str = " " + self.sep2: str = "" + # Stop criteria (the default one is EOS token) + self.stop_str: Union[str, List[str]] = None + # Stops generation if meeting any token in this list + self.stop_token_ids: List[int] = None + + def clear_msg(self): + self.messages.clear() + + def get_prompt(self) -> str: + """Get the prompt for generation.""" + system_prompt = self.system_template.format(system_message=self.system_message) + seps = [self.sep, self.sep2] + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + + @classmethod + def parse(cls, instance: dict) -> str: + conv = cls() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + sys_msg = instance.get("system_prompt") + if sys_msg: + conv.set_system_message(sys_msg) + for j, turn in enumerate(instance["conversations"]): + role = roles[turn["from"]] + assert role == conv.roles[j % 2] + conv.append_message(role, turn["value"]) + return conv.get_prompt() + + def set_system_message(self, system_message: str): + """Set the system message.""" + self.system_message = system_message + + def append_message(self, role: str, message: str): + """Append a new message.""" + self.messages.append([role, message]) + + def update_last_message(self, message: str): + """Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + """ + self.messages[-1][1] = message + + def to_gradio_chatbot(self): + """Convert the conversation to gradio chatbot format.""" + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def to_openai_api_messages(self): + """Convert the conversation to OpenAI chat completion format.""" + ret = [{"role": "system", "content": self.system_message}] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append({"role": "user", "content": msg}) + else: + if msg is not None: + ret.append({"role": "assistant", "content": msg}) + return ret + + def copy(self): + return Conversation( + name=self.name, + system_template=self.system_template, + system_message=self.system_message, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep=self.sep, + sep2=self.sep2, + stop_str=self.stop_str, + stop_token_ids=self.stop_token_ids, + ) + + def dict(self): + return { + "template_name": self.name, + "system_message": self.system_message, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + } diff --git a/smoe/utils/io.py b/smoe/utils/io.py index 230448a..297800a 100644 --- a/smoe/utils/io.py +++ b/smoe/utils/io.py @@ -77,6 +77,16 @@ def __iter__(self): self.fin.close() +def load_json(filepath): + with open(filepath, "r", encoding="utf8") as fin: + return json.load(fin) + + +def dump_json(obj, filepath, **kwargs): + with open(filepath, "w", encoding="utf8") as fout: + json.dump(obj, fout, ensure_ascii=False, **kwargs) + + def load_jsonlines(filepath): data = [] with open(filepath, "r", encoding="utf8") as fin: