Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save cpu mem by leveraging FSDP rank0 broadcasting #77

Merged
merged 14 commits into from
Aug 11, 2023
32 changes: 26 additions & 6 deletions llama_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
LlamaConfig,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
Expand Down Expand Up @@ -62,8 +63,10 @@
from torch.optim.lr_scheduler import StepLR
from pkg_resources import packaging
import torch
import torch.nn as nn
import torch.cuda.nccl as nccl
import torch.distributed as dist
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
lchu-ibm marked this conversation as resolved.
Show resolved Hide resolved
from transformers.models.llama.modeling_llama import LlamaDecoderLayer


Expand All @@ -90,11 +93,26 @@ def main(**kwargs):
gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size

# Load the pre-trained model and setup its configuration
model = LlamaForCausalLM.from_pretrained(
train_config.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
)
if train_config.enable_fsdp:
lchu-ibm marked this conversation as resolved.
Show resolved Hide resolved
# for FSDP, we save cpu memory by loading pretrained model on rank0 only.
# this avoids cpu oom when loading large models like llama 70B, in which case
# model alone would consume 2+TB cpu mem (70 * 4 * 8)
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we figure out why torch.device("meta") init doesn't work here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohan-varma for non-0 ranks, we are using torch.device("meta") init.

train_config.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
)
else:
llama_config = LlamaConfig.from_pretrained(train_config.model_name)
with torch.device("meta"):
model = LlamaForCausalLM(llama_config)
else:
model = LlamaForCausalLM.from_pretrained(
train_config.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
)

print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)

Expand Down Expand Up @@ -127,14 +145,16 @@ def main(**kwargs):

mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)

model = FSDP(
model,
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
sharding_strategy=fsdp_config.sharding_strategy,
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
sync_module_states=True,
lchu-ibm marked this conversation as resolved.
Show resolved Hide resolved
param_init_fn=None if rank == 0 else lambda module: module.to_empty(device=torch.device("cuda"), recurse=False),
)
if fsdp_config.fsdp_activation_checkpointing:
policies.apply_fsdp_checkpointing(model)
Expand Down