Skip to content

Commit

Permalink
save cpu mem by leveraging FSDP rank0 broadcasting (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
chauhang authored Aug 11, 2023
2 parents 3f1fef7 + feaa344 commit 205e5a4
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 69 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --mode

```

### Fine-tuning using FSDP on 70B Model

If you are interested in running full parameter fine-tuning on the 70B model, you can enable `low_cpu_fsdp` mode as the following command. This option will load model on rank0 only before moving model to devices to construct FSDP. This can dramatically save cpu memory when loading large models like 70B (on a 8-gpu node, this reduces cpu memory from 2+T to 280G for 70B model). This has been tested with `BF16` on 16xA100, 80GB GPUs.

```bash

torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --micro_batch_size 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned

```

### Multi GPU Multi Node:

```bash
Expand Down
3 changes: 2 additions & 1 deletion configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
@dataclass
class train_config:
model_name: str="PATH/to/LLAMA/7B"
enable_fsdp: bool= False
enable_fsdp: bool=False
low_cpu_fsdp: bool=False
run_validation: bool=True
batch_size_training: int=4
num_epochs: int=3
Expand Down
10 changes: 10 additions & 0 deletions docs/multi_gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --mode

```

### Fine-tuning using FSDP on 70B Model

If you are interested in running full parameter fine-tuning on the 70B model, you can enable `low_cpu_fsdp` mode as the following command. This option will load model on rank0 only before moving model to devices to construct FSDP. This can dramatically save cpu memory when loading large models like 70B (on a 8-gpu node, this reduces cpu memory from 2+T to 280G for 70B model). This has been tested with `BF16` on 16xA100, 80GB GPUs.

```bash

torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --micro_batch_size 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned

```

**Multi GPU multi node**:

Here we use a slurm script to schedule a job with slurm over multiple nodes.
Expand Down
137 changes: 72 additions & 65 deletions llama_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,68 +2,47 @@
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import os
import sys
from typing import List, Union

import fire
import torch
import transformers
from datasets import load_dataset
import os.path as osp
from tqdm import tqdm

# Unused imports removed
from utils import fsdp_auto_wrap_policy
import torch.distributed as dist
import torch.optim as optim
from peft import get_peft_model, prepare_model_for_int8_training
from pkg_resources import packaging
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
)
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DistributedSampler
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
LlamaConfig,
default_data_collator,
BitsAndBytesConfig
)
import torch.distributed as dist
# Unused imports removed
from utils.train_utils import (
set_tokenizer_params,
train,
evaluation,
freeze_transformer_layers,
check_frozen_layers_peft_model,
setup,
setup_environ_flags,
cleanup,
clear_gpu_cache,
get_parameter_dtypes,
print_model_size,
get_policies
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

from utils.dataset_utils import get_preprocessed_dataset
import policies
from configs import fsdp_config, train_config
from policies import AnyPrecisionAdamW

from utils import fsdp_auto_wrap_policy
from utils.config_utils import (
update_config,
generate_peft_config,
generate_dataset_config,
)
from peft import get_peft_model, TaskType, prepare_model_for_int8_training
import configs
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
from utils.dataset_utils import get_preprocessed_dataset

from utils.train_utils import (
train,
freeze_transformer_layers,
setup,
setup_environ_flags,
clear_gpu_cache,
print_model_size,
get_policies
)
from torch.utils.data import DistributedSampler
import policies
from policies import AnyPrecisionAdamW
from configs import fsdp_config, train_config
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from pkg_resources import packaging
import torch
import torch.cuda.nccl as nccl
import torch.distributed as dist
from transformers.models.llama.modeling_llama import LlamaDecoderLayer


def main(**kwargs):
Expand All @@ -82,18 +61,43 @@ def main(**kwargs):
world_size = int(os.environ["WORLD_SIZE"])

if torch.distributed.is_initialized():
torch.cuda.set_device(rank)
torch.cuda.set_device(local_rank)
clear_gpu_cache(local_rank)
setup_environ_flags(rank)

# Calculate gradient accumulation steps
gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size

# Load the pre-trained model and setup its configuration
model = LlamaForCausalLM.from_pretrained(
train_config.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
)
if train_config.enable_fsdp and train_config.low_cpu_fsdp:
"""
for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
this avoids cpu oom when loading large models like llama 70B, in which case
model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
overhead and currently requires latest nightly.
"""
v = packaging.version.parse(torch.__version__)
verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
if not verify_latest_nightly:
raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
"please install latest nightly.")
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
train_config.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
)
else:
llama_config = LlamaConfig.from_pretrained(train_config.model_name)
with torch.device("meta"):
model = LlamaForCausalLM(llama_config)

else:
model = LlamaForCausalLM.from_pretrained(
train_config.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
)
if train_config.enable_fsdp and train_config.use_fast_kernels:
"""
For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
Expand All @@ -106,11 +110,11 @@ def main(**kwargs):
except ImportError:
print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)

# Prepare the model for int8 training if quantization is enabled
if train_config.quantization:
model = prepare_model_for_int8_training(model)

# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
if train_config.enable_fsdp and fsdp_config.pure_bf16:
model.to(torch.bfloat16)
Expand All @@ -119,46 +123,49 @@ def main(**kwargs):
tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
tokenizer.add_special_tokens(
{

"pad_token": "<PAD>",
}
)
if train_config.use_peft:
peft_config = generate_peft_config(train_config, kwargs)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

#setting up FSDP if enable_fsdp is enabled
if train_config.enable_fsdp:
if not train_config.use_peft and train_config.freeze_layers:

freeze_transformer_layers(train_config.num_freeze_layers)

mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)

model = FSDP(
model,
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
sharding_strategy=fsdp_config.sharding_strategy,
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
sync_module_states=train_config.low_cpu_fsdp,
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
if train_config.low_cpu_fsdp and rank != 0 else None,
)
if fsdp_config.fsdp_activation_checkpointing:
policies.apply_fsdp_checkpointing(model)
elif not train_config.quantization and not train_config.enable_fsdp:
model.to("cuda")

dataset_config = generate_dataset_config(train_config, kwargs)

# Load and preprocess the dataset for training and validation
dataset_train = get_preprocessed_dataset(
tokenizer,
dataset_config,
split="train",
)

if not train_config.enable_fsdp or rank == 0:
print(f"--> Training Set Length = {len(dataset_train)}")

Expand All @@ -185,7 +192,7 @@ def main(**kwargs):
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
)

# Create DataLoaders for the training and validation dataset
train_dataloader = torch.utils.data.DataLoader(
dataset_train,
Expand All @@ -207,7 +214,7 @@ def main(**kwargs):
drop_last=True,
collate_fn=default_data_collator,
)

# Initialize the optimizer and learning rate scheduler
if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
optimizer = AnyPrecisionAdamW(
Expand All @@ -229,7 +236,7 @@ def main(**kwargs):
results = train(
model,
train_dataloader,
eval_dataloader,
eval_dataloader,
tokenizer,
optimizer,
scheduler,
Expand Down
4 changes: 2 additions & 2 deletions model_checkpointing/checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def load_model_sharded(model, rank, cfg):
reader = FileSystemReader(load_dir)

with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
checkpoint = model.state_dict()
checkpoint = {"model": model.state_dict()}
if rank == 0:
ck = checkpoint.keys()
print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
Expand All @@ -78,7 +78,7 @@ def load_model_sharded(model, rank, cfg):
print(f"checkpoint after load_state_dict()")
ck = checkpoint.keys()
print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
model.load_state_dict(checkpoint)
model.load_state_dict(checkpoint["model"])
if rank == 0:
print(f"Sharded state checkpoint loaded from {load_dir}")

Expand Down
2 changes: 1 addition & 1 deletion utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
lr_scheduler.step()

if train_config.run_validation:
eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)
eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
checkpoint_start_time = time.perf_counter()
if train_config.save_model and eval_epoch_loss < best_val_loss:
if train_config.enable_fsdp:
Expand Down

0 comments on commit 205e5a4

Please sign in to comment.