Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to support fsdp+qlora and dsz3+qlora #1550

Merged
merged 11 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add example and start docs
  • Loading branch information
pacman100 committed Mar 12, 2024
commit e54c97da8d5410b371565206bc4abd087b9ccdb7
9 changes: 7 additions & 2 deletions docs/source/accelerate/deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ Below is a table that summarizes the compatibility between PEFT's LoRA, [`bitsan
|---|---|
| Zero-1 | 🟢 |
| Zero-2 | 🟢 |
| Zero-3 | 🔴 |
| Zero-3 | 🟢 |

For using DeepSpeed Stage 3 + QLoRA, please share to the section []() below:

For confirming these observations, we ran the SFT (Supervised Fine-tuning) [offical example scripts](https://github.com/huggingface/trl/tree/main/examples) of the [Transformers Reinforcement Learning (TRL) library](https://github.com/huggingface/trl) using QLoRA + PEFT and the accelerate configs available [here](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs). We ran these experiments on a 2x NVIDIA T4 GPU.

Note DeepSpeed-Zero3 and `bitsandbytes` are currently **not** compatible.

# Use PEFT and DeepSpeed with ZeRO3 for finetuning large models on multiple machines and multiple nodes
# Use PEFT and DeepSpeed with ZeRO3 for finetuning large models on multiple devices and multiple nodes

This section of guide will help you learn how to use our DeepSpeed [training script](https://github.com/huggingface/peft/blob/main/examples/sft/train.py) for performing SFT. You'll configure the script to do SFT (supervised fine-tuning) of Llama-70B model with LoRA and ZeRO-3 on 8xH100 80GB GPUs on a single machine. You can configure it to scale to multiple machines by changing the accelerate config.

Expand Down Expand Up @@ -171,6 +173,9 @@ In the above example, the memory consumed per GPU is 64 GB (80%) as seen in the
## More resources
You can also refer this blog post [Falcon 180B Finetuning using 🤗 PEFT and DeepSpeed](https://medium.com/@sourabmangrulkar/falcon-180b-finetuning-using-peft-and-deepspeed-b92643091d99) on how to finetune 180B Falcon model on 16 A100 GPUs on 2 machines.


# Use PEFT QLoRA and DeepSpeed with ZeRO3 for finetuning large models on a single GPU

# Use PEFT and DeepSpeed with ZeRO3 and CPU Offloading for finetuning large models on a single GPU
This section of guide will help you learn how to use our DeepSpeed [training script](https://github.com/huggingface/peft/blob/main/examples/conditional_generation/peft_lora_seq2seq_accelerate_ds_zero3_offload.py). You'll configure the script to train a large model for conditional generation with ZeRO-3 and CPU Offload.

Expand Down
4 changes: 2 additions & 2 deletions examples/sft/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ Note:
1. At present, `use_reentrant` needs to be `False` when using gradient checkpointing with Multi-GPU QLoRA else it will lead to errors. However, this leads to huge GPU memory consumption.

## Multi-GPU SFT with LoRA and DeepSpeed
When you have access to multiple GPUs, it would be better to use normal LoRA with DeepSpeed/FSDP. TO use LoRA with DeepSpeed, refer the docs at [PEFT with DeepSpeed](https://huggingface.co/docs/peft/accelerate/deepspeed).
When you have access to multiple GPUs, it would be better to use normal LoRA with DeepSpeed/FSDP. To use LoRA with DeepSpeed, refer the docs at [PEFT with DeepSpeed](https://huggingface.co/docs/peft/accelerate/deepspeed).


## Multi-GPU SFT with LoRA and FSDP
When you have access to multiple GPUs, it would be better to use normal LoRA with DeepSpeed/FSDP. TO use LoRA with DeepSpeed, refer the docs at [PEFT with FSDP](https://huggingface.co/docs/peft/accelerate/fsdp).
When you have access to multiple GPUs, it would be better to use normal LoRA with DeepSpeed/FSDP. To use LoRA with DeepSpeed, refer the docs at [PEFT with FSDP](https://huggingface.co/docs/peft/accelerate/fsdp).


22 changes: 22 additions & 0 deletions examples/sft/configs/deepspeed_config_z3_qlora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
25 changes: 25 additions & 0 deletions examples/sft/configs/fsdp_config_qlora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
41 changes: 41 additions & 0 deletions examples/sft/run_peft_qlora_deepspeed_stage3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
accelerate launch --config_file "configs/deepspeed_config_z3_qlora.yaml" train.py \
--seed 100 \
--model_name_or_path "meta-llama/Llama-2-70b-hf" \
--dataset_name "smangrul/ultrachat-10k-chatml" \
--chat_template_format "chatml" \
--add_special_tokens False \
--append_concat_token False \
--splits "train,test" \
--max_seq_len 2048 \
--num_train_epochs 1 \
--logging_steps 5 \
--log_level "info" \
--logging_strategy "steps" \
--evaluation_strategy "epoch" \
--save_strategy "epoch" \
--push_to_hub \
--hub_private_repo True \
--hub_strategy "every_save" \
--bf16 True \
--packing True \
--learning_rate 1e-4 \
--lr_scheduler_type "cosine" \
--weight_decay 1e-4 \
--warmup_ratio 0.0 \
--max_grad_norm 1.0 \
--output_dir "llama-sft-qlora-dsz3" \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 2 \
--gradient_checkpointing True \
--use_reentrant True \
--dataset_text_field "content" \
--use_flash_attn True \
--use_peft_lora True \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.1 \
--lora_target_modules "all-linear" \
--use_4bit_quantization True \
--use_nested_quant True \
--bnb_4bit_compute_dtype "bfloat16"
42 changes: 42 additions & 0 deletions examples/sft/run_peft_qlora_fsdp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
accelerate launch --config_file "configs/fsdp_config_qlora.yaml" train.py \
--seed 100 \
--model_name_or_path "meta-llama/Llama-2-70b-hf" \
--dataset_name "smangrul/ultrachat-10k-chatml" \
--chat_template_format "chatml" \
--add_special_tokens False \
--append_concat_token False \
--splits "train,test" \
--max_seq_len 2048 \
--num_train_epochs 1 \
--logging_steps 5 \
--log_level "info" \
--logging_strategy "steps" \
--evaluation_strategy "epoch" \
--save_strategy "epoch" \
--push_to_hub \
--hub_private_repo True \
--hub_strategy "every_save" \
--bf16 True \
--packing True \
--learning_rate 1e-4 \
--lr_scheduler_type "cosine" \
--weight_decay 1e-4 \
--warmup_ratio 0.0 \
--max_grad_norm 1.0 \
--output_dir "llama-sft-qlora-fsdp" \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 2 \
--gradient_checkpointing True \
--use_reentrant True \
--dataset_text_field "content" \
--use_flash_attn True \
--use_peft_lora True \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.1 \
--lora_target_modules "all-linear" \
--use_4bit_quantization True \
--use_nested_quant True \
--bnb_4bit_compute_dtype "bfloat16" \
--bnb_4bit_quant_storage_dtype "bfloat16"
53 changes: 35 additions & 18 deletions examples/sft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ class ModelArguments:
"""

model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
metadata={
"help": "Path to pretrained model or model identifier from huggingface.co/models"
}
)
chat_template_format: Optional[str] = field(
default="none",
Expand All @@ -29,7 +31,9 @@ class ModelArguments:
lora_r: Optional[int] = field(default=64)
lora_target_modules: Optional[str] = field(
default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
metadata={"help": "comma separated list of target modules to apply LoRA layers to"},
metadata={
"help": "comma separated list of target modules to apply LoRA layers to"
},
)
use_nested_quant: Optional[bool] = field(
default=False,
Expand All @@ -39,6 +43,10 @@ class ModelArguments:
default="float16",
metadata={"help": "Compute dtype for 4bit base models"},
)
bnb_4bit_quant_storage_dtype: Optional[str] = field(
default="float32",
metadata={"help": "Quantization storage dtype for 4bit base models"},
)
bnb_4bit_quant_type: Optional[str] = field(
default="nf4",
metadata={"help": "Quantization type fp4 or nf4"},
Expand Down Expand Up @@ -79,15 +87,21 @@ class DataTrainingArguments:
default=False,
metadata={"help": "Use packing dataset creating."},
)
dataset_text_field: str = field(default="text", metadata={"help": "Dataset field to use as input text."})
dataset_text_field: str = field(
default="text", metadata={"help": "Dataset field to use as input text."}
)
max_seq_length: Optional[int] = field(default=512)
append_concat_token: Optional[bool] = field(
default=False,
metadata={"help": "If True, appends `eos_token_id` at the end of each sample being packed."},
metadata={
"help": "If True, appends `eos_token_id` at the end of each sample being packed."
},
)
add_special_tokens: Optional[bool] = field(
default=False,
metadata={"help": "If True, tokenizers adds special tokens to each sample being packed."},
metadata={
"help": "If True, tokenizers adds special tokens to each sample being packed."
},
)
splits: Optional[str] = field(
default="train,test",
Expand All @@ -100,13 +114,19 @@ def main(model_args, data_args, training_args):
set_seed(training_args.seed)

# model
model, peft_config, tokenizer = create_and_prepare_model(model_args, data_args, training_args)
model, peft_config, tokenizer = create_and_prepare_model(
model_args, data_args, training_args
)

# gradient ckpt
model.config.use_cache = not training_args.gradient_checkpointing
training_args.gradient_checkpointing = training_args.gradient_checkpointing and not model_args.use_unsloth
training_args.gradient_checkpointing = (
training_args.gradient_checkpointing and not model_args.use_unsloth
)
if training_args.gradient_checkpointing:
training_args.gradient_checkpointing_kwargs = {"use_reentrant": model_args.use_reentrant}
training_args.gradient_checkpointing_kwargs = {
"use_reentrant": model_args.use_reentrant
}

# datasets
train_dataset, eval_dataset = create_datasets(
Expand All @@ -133,14 +153,7 @@ def main(model_args, data_args, training_args):
max_seq_length=data_args.max_seq_length,
)
trainer.accelerator.print(f"{trainer.model}")
if model_args.use_peft_lora:
# handle PEFT+FSDP case
trainer.model.print_trainable_parameters()
if getattr(trainer.accelerator.state, "fsdp_plugin", None):
from peft.utils.other import fsdp_auto_wrap_policy

fsdp_plugin = trainer.accelerator.state.fsdp_plugin
fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(trainer.model)
trainer.model.print_trainable_parameters()

# train
checkpoint = None
Expand All @@ -155,11 +168,15 @@ def main(model_args, data_args, training_args):


if __name__ == "__main__":
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, TrainingArguments)
)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
model_args, data_args, training_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1])
)
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
main(model_args, data_args, training_args)
29 changes: 16 additions & 13 deletions examples/sft/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def preprocess(samples):
elif "test" in split:
raw_datasets["test"] = dataset
else:
raise ValueError(f"Split type {split} not recognized as one of test or train.")
raise ValueError(
f"Split type {split} not recognized as one of test or train."
)

if apply_chat_template:
raw_datasets = raw_datasets.map(
Expand All @@ -75,7 +77,9 @@ def preprocess(samples):

train_data = raw_datasets["train"]
valid_data = raw_datasets["test"]
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
print(
f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}"
)
print(f"A sample of train dataset: {train_data[0]}")

return train_data, valid_data
Expand All @@ -84,8 +88,8 @@ def preprocess(samples):
def create_and_prepare_model(args, data_args, training_args):
if args.use_unsloth:
from unsloth import FastLanguageModel
device_map = None
bnb_config = None
quant_storage_stype = None

if (
torch.distributed.is_available()
Expand All @@ -97,30 +101,27 @@ def create_and_prepare_model(args, data_args, training_args):

if args.use_4bit_quantization:
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)
quant_storage_stype = getattr(torch, args.bnb_4bit_quant_storage_dtype)

bnb_config = BitsAndBytesConfig(
load_in_4bit=args.use_4bit_quantization,
bnb_4bit_quant_type=args.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=args.use_nested_quant,
bnb_4bit_quant_storage=quant_storage_stype,
)

if compute_dtype == torch.float16 and args.use_4bit_quantization:
major, _ = torch.cuda.get_device_capability()
if major >= 8:
print("=" * 80)
print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16")
print(
"Your GPU supports bfloat16, you can accelerate training with the argument --bf16"
)
print("=" * 80)
elif args.use_8bit_quantization:
bnb_config = BitsAndBytesConfig(load_in_8bit=args.use_8bit_quantization)

if args.use_4bit_quantization or args.use_8bit_quantization:
device_map = (
int(os.environ.get("LOCAL_RANK", -1))
if torch.distributed.is_available() and torch.distributed.is_initialized()
else "auto"
) # {"": 0}

if args.use_unsloth:
# Load model
model, _ = FastLanguageModel.from_pretrained(
Expand All @@ -133,9 +134,9 @@ def create_and_prepare_model(args, data_args, training_args):
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
quantization_config=bnb_config,
device_map=device_map,
trust_remote_code=True,
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
torch_dtype=quant_storage_stype or torch.float32,
)

peft_config = None
Expand Down Expand Up @@ -174,7 +175,9 @@ def create_and_prepare_model(args, data_args, training_args):
# make embedding resizing configurable?
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path, trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token

if args.use_unsloth:
Expand Down
Loading