Skip to content

Commit

Permalink
Update fsdp mmlu
Browse files Browse the repository at this point in the history
  • Loading branch information
mzio committed Sep 20, 2024
1 parent 1189bbb commit f102c8a
Show file tree
Hide file tree
Showing 4 changed files with 745 additions and 0 deletions.
19 changes: 19 additions & 0 deletions configs/experiment/eval_mmlu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
dataset:
name: eval_mmlu
dataset_config:
lm_evaluation_harness_path: '/juice2/scr2/mzhang/projects/lm-evaluation-harness'
tasks: null # all of them
num_fewshot: 5
limit: null
cache_dir: 'data/mmlu' # Change this to where you want to save
pretrained_model_config:
pretrained_model_name_or_path: 'mistralai/Mistral-7B-v0.1' # will be updated based on model_config
cache_dir: '/scr-ssd/mzhang/models/mistral-v0.1'
preprocess_config: null

dataloader:
batch_size: 4
num_workers: 2
drop_last: false
pin_memory: true

378 changes: 378 additions & 0 deletions llama_recipes/distill_llama_eval_mmlu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,378 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

"""
Evaluate MMLU
torchrun
"""
import os
from os.path import join
# import sys
# sys.path.append('/workspace/lolcats') # needed for vast-ai instances
import dataclasses
import random
import argparse # ours
from pkg_resources import packaging

import torch
import torch.optim as optim

from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
StateDictType
)

from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload

from llama_recipes.configs import fsdp_config as FSDP_CONFIG
# from llama_recipes.configs import train_config as TRAIN_CONFIG
from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing

from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy
from llama_recipes.utils.config_utils import (
update_config,
# generate_peft_config,
# generate_dataset_config,
# get_dataloader_kwargs,
)
from llama_recipes.utils.fsdp_utils import (
hsdp_device_mesh as get_hsdp_device_mesh
)
from llama_recipes.trainer_finetune import (
train as _train_normal,
setup,
setup_environ_flags,
clear_gpu_cache,
print_model_size,
get_policies,
)
from llama_recipes.model_checkpointing.distill_checkpoint_handler import (
load_model_sharded,
load_sharded_model_single_gpu,
)
# from llama_recipes.trainer_finetune_chunked import train as train_chunked

from accelerate.utils import is_xpu_available

# -------------
# Our arguments
# -------------
from omegaconf import OmegaConf

from src.utils.setup import (
update_config_from_args,
update_model_config_from_args
)
from src.utils.logging import print_header, print_config
# from src.dataloaders import load_data
from src.trainer import get_scheduler

from src.finetune import prepare_finetune_configs # get_finetuner

from src.model.pretrained import get_pretrained_loader
from src.model.load_model import (
load_and_convert_attns,
load_and_convert_finetune
)
from distill_llama import (
setup_wandb, get_args, # get_run_name_from_checkpoint
setup_fsdp_config
)

from src.dataloaders.eval_mmlu import load_data


def main():
# ---------
# 1. SET UP
# ---------
args = get_args()
args.checkpoint_dir = join(args.checkpoint_dir, args.model_config)
if not os.path.isdir(args.checkpoint_dir):
os.makedirs(args.checkpoint_dir)

kwargs = vars(args)

# if 'finetune_long' in args.finetune_config:
# train = train_chunked
# else:
# train = _train_normal
train = _train_normal

# Load distillation + attention configs
distill_config_path = join('./configs/experiment', f'{args.distill_config}.yaml')
distill_config = OmegaConf.load(distill_config_path)
distill_config = update_config_from_args(distill_config, args)

model_config_path = join('./configs/model', f'{args.model_config}.yaml')
model_config = OmegaConf.load(model_config_path)
model_config = update_model_config_from_args(model_config, args)
if args.enable_fsdp:
if getattr(model_config.model, 'load_in_4bit', False):
model_config.model.device_map = 'auto'
elif getattr(model_config.model, 'load_in_8bit', False):
model_config.model.device_map = 'auto'
else:
model_config.model.device_map = None # FSDP will complain about device placement o.w.

# Update dataset pretrained model config
for k in distill_config.dataset.pretrained_model_config:
distill_config.dataset.pretrained_model_config[k] = getattr(model_config.model, k)

args.run_name = args.run_name.replace('True', '1').replace('False', '0') # concise hacks

# Update the configuration for the training and sharding process
distill_config = setup_fsdp_config(distill_config, args, 'distill') # patch llama-recipes args
fsdp_config = FSDP_CONFIG()
update_config((fsdp_config), **vars(args))
# Set the seeds for reproducibility
if is_xpu_available():
torch.xpu.manual_seed(args.seed)
torch.manual_seed(args.seed)
random.seed(args.seed)

if args.enable_fsdp:
setup()
# torchrun specific
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])

if rank == 0 or not args.enable_fsdp:
print_header('Distillation Config')
print_config(distill_config)
print_header('Model Config')
print_config(model_config)
print_header('FSDP Config')
print_config(dataclasses.asdict(fsdp_config))

if torch.distributed.is_initialized():
if is_xpu_available():
torch.xpu.set_device(local_rank)
elif torch.cuda.is_available():
torch.cuda.set_device(local_rank)
clear_gpu_cache(local_rank)
setup_environ_flags(rank)

wandb_run = None

if not args.no_wandb:
if not args.enable_fsdp or rank==0:
wandb_run = setup_wandb(distill_config, fsdp_config, **kwargs)

# ------------------------
# 2. LOAD PRETRAINED MODEL
# ------------------------
# Load the pre-trained model and setup its configuration
# Initialize tokenizer and model loader
model_loader = get_pretrained_loader(**model_config.model,
huggingface_token=args.huggingface_token)
tokenizer = model_loader.load_tokenizer()
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'

use_cache = False if args.enable_fsdp else None

if 'llama' in model_config.model.pretrained_model_name_or_path:
from transformers import LlamaConfig as ModelConfig
from transformers.models.llama.modeling_llama import LlamaDecoderLayer as DecoderLayer
from src.model.modeling_llama import LolcatsLlamaForCausalLM as ModelClass
model_type = 'llama'

# Convert model
try:
args.attention_type = model_config['attention']['attention_type']
except AttributeError:
args.attention_type = 'lolcats_llama'

if args.enable_fsdp and args.low_cpu_fsdp:
# for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
# this avoids cpu oom when loading large models like llama 70B, in which case
# model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
# overhead and currently requires latest nightly.
v = packaging.version.parse(torch.__version__)
verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
if not verify_latest_nightly and rank == 0:
print(f'-> Pytorch version is {v} ({v.dev})')
print(f' - Llama-recipes says "latest pytorch nightly build is required to run with low_cpu_fsdp config"')
print(f" - But who knows maybe this will work. We're just trying stuff.")
print(f" - (Also if PyTorch was installed after July 1, 2023 we should be good)")
# raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
# "please install latest nightly.")
model = model_loader.load(args.attention_type)
model.state_chunk_len = model_config['attention']['state_chunk_len']
else:
model = model_loader.load(args.attention_type)
model.state_chunk_len = model_config['attention']['state_chunk_len']

if rank == 0 or not args.enable_fsdp:
print_header('Pretrained Model')

model_config.model_name = model_config.model.pretrained_model_name_or_path
print_model_size(model, model_config, rank if args.enable_fsdp else 0)

# Prepare the model for int8 training if quantization is enabled
# -> But we only use this script for FSDP without quantization
# if train_config.quantization:
# model = prepare_model_for_int8_training(model)

# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
if args.enable_fsdp and fsdp_config.pure_bf16:
model.to(torch.bfloat16)

# -------------------------------
# 3. CONVERT DISTILLED ATTENTIONS
# -------------------------------
model, distill_peft_config = load_and_convert_attns(model, model_config,
attention_type=args.attention_type,
checkpoint_path=None, # args.load_distill_checkpoint,
print_model=args.verbose,
merge_loras=False,
peft_gradient_checkpointing=not args.no_peft_grad_ckpt,
train_attention=False,
rank=rank)
if rank == 0:
print_header('** Sanity check model weights **')
for n, p in model.named_parameters():
if ('layers.0.' in n and ('feature_map' in n or 'lora' in n)):
print(f'-> {n}:\n', p)

if distill_config.trainer.name is not None:
if args.load_distill_checkpoint is not None:
model = load_sharded_model_single_gpu(model, model_path=args.load_distill_checkpoint, cfg=distill_config, rank=rank)
else:
model = load_sharded_model_single_gpu(model, model_path=None, cfg=distill_config, rank=rank)
else:
print(" -> Proceeding without learned linear attentions")

if wandb_run and distill_peft_config is not None:
wandb_run.config.update(distill_peft_config)

# ----------------------------
# 4. ADD FINETUNING PARAMETERS
# ----------------------------
finetune_config, args = prepare_finetune_configs(args, model_config,
args.finetune_config)
# finetune_config = update_config_from_args(finetune_config, args)
finetune_config = setup_fsdp_config(finetune_config, args, 'finetune')
if args.finetune_lr is not None:
finetune_config.model_name += f'=flr={args.finetune_lr}'

# model, ft_peft_config
model, _ = load_and_convert_finetune(model, finetune_config,
checkpoint_path=None,
print_model=args.verbose,
merge_loras=False,
peft_gradient_checkpointing=not args.no_peft_grad_ckpt,
rank=rank)

if args.load_finetune_checkpoint is not None:
model = load_sharded_model_single_gpu(model, model_path=args.load_finetune_checkpoint, cfg=finetune_config, rank=rank)
else:
print(" -> Proceeding without finetuned parameters")


# ------------------------------------------------------
# 5. SETUP FSDP AND LOAD DISTILLED ATTENTION CHECKPOINTS
# ------------------------------------------------------
if args.enable_fsdp:

mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank, model=model_type)
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, DecoderLayer)

device_id = 0
if is_xpu_available():
device_id = torch.xpu.current_device()
elif torch.cuda.is_available():
device_id = torch.cuda.current_device()
print('-> device_id:', device_id)

model = FSDP(
model,
auto_wrap_policy=my_auto_wrapping_policy, # if train_config.use_peft else wrapping_policy,
cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
sharding_strategy=fsdp_config.sharding_strategy,
# device_mesh=hsdp_device_mesh,
device_id=device_id,
limit_all_gathers=True,
sync_module_states=args.low_cpu_fsdp, # train_config.low_cpu_fsdp
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
if args.low_cpu_fsdp and rank != 0 else None,
)
if fsdp_config.fsdp_activation_checkpointing:
apply_fsdp_checkpointing(model)

# Load distilled checkpoints
if args.verbose and rank == 0:
print_header('*** FSDP Model ***')
print(model)
print('Loading checkpoints from:', distill_config.model_name)

# load_model_sharded(model, rank, distill_config, model_path=args.load_distill_checkpoint)

if rank == 0 or not args.enable_fsdp: # debugging
print_header('** Sanity check model weights **')
for n, p in model.named_parameters():
if ('layers.0.' in n and 'base_attn' not in n and
'.0.mlp.' not in n and '.block_sparse_moe' not in n):
print(f'-> {n}:\n', p)


else: # if not model_config.model.quantization and not args.enable_fsdp:
if is_xpu_available():
model.to("xpu:0")
elif torch.cuda.is_available():
model.to("cuda")

if args.verbose and (rank == 0 or not args.enable_fsdp):
print_header('*** FSDP MODEL ***')
print(model)
print_header('*** Trainable Parameters ***')
for n, p in model.named_parameters():
if p.requires_grad:
print(f'├── {n} (dtype = {p.dtype})')
# print_header('*** model.state_dict() ***')
# for k in model.state_dict().keys():
# print(f'├── {k}')

# Get data
eval_config = './configs/experiment/eval_mmlu'
eval_config = OmegaConf.load(eval_config)
for k in eval_config.dataset.pretrained_model_config:
eval_config.dataset.pretrained_model_config[k] = getattr(model_config.model, k, None)

eval_dataloader = load_data(**eval_config['dataset'], **eval_config['dataloader'])

if not args.enable_fsdp or rank == 0:
print(f"--> Validation Set Length = {len(eval_dataloader.dataset)}")

# -----------
# 5. EVALUATE
# -----------
from llama_recipes.trainer_eval_mmlu import evaluate_mmlu

if rank == 0 or not args.enable_fsdp:
print_header('*** Evaluating ***')
if args.verbose:
print_config(eval_config)

results = evaluate_mmlu(
model,
train_config,
eval_dataloader,
local_rank,
tokenizer,
wandb_run,
epoch=0,
rank=rank,
)

if not args.enable_fsdp or rank==0:
for k,v in results.items():
print(f'{k}:, {v}')

if __name__ == "__main__":
main()
Loading

0 comments on commit f102c8a

Please sign in to comment.