Skip to content

Commit

Permalink
[Train][Templates] Add LoRA support to Llama-2 finetuning example (#3…
Browse files Browse the repository at this point in the history
…7794)

Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com>
  • Loading branch information
ArturNiederfahrenhorst and kouroshHakha authored Oct 19, 2023
1 parent d6baf12 commit 1a08c16
Show file tree
Hide file tree
Showing 9 changed files with 388 additions and 29 deletions.
45 changes: 36 additions & 9 deletions doc/source/templates/04_finetuning_llms_with_deepspeed/README.md
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
# Finetuning Llama-2 series models with Deepspeed, Accelerate, and Ray Train TorchTrainer
# Fine-tuning Llama-2 series models with Deepspeed, Accelerate, and Ray Train TorchTrainer
| Template Specification | Description |
| ---------------------- | ----------- |
| Summary | This template, demonstrates how to perform full parameter fine-tuning for Llama-2 series models (7B, 13B, and 70B) using TorchTrainer with the DeepSpeed ZeRO-3 strategy. |
| Summary | This template, demonstrates how to perform fine-tuning (full parameter or LoRA) for Llama-2 series models (7B, 13B, and 70B) using TorchTrainer with the DeepSpeed ZeRO-3 strategy. |
| Time to Run | ~14 min. for 7B for 1 epoch on 3.5M tokens. ~26 min for 13B for 1 epoch. |
| Minimum Compute Requirements | 16xg5.4xlarge for worker nodes for 7B model, 4xg5.12xlarge nodes for 13B model, and 4xg5.48xlarge (or 2xp4de.24xlarge) nodes for 70B|
| Cluster Environment | This template uses a docker image built on top of the latest Anyscale-provided Ray image using Python 3.9: [`anyscale/ray:latest-py39-cu118`](https://docs.anyscale.com/reference/base-images/overview). |

## Getting Started

For 7B, set up a cluster on AWS with the following settings:
For a full-parameter fine-tuning of 7B models, set up a cluster on AWS with the following settings:

| | num | instance type | GPU per node | GPU Memory | CPU Memory |
|------------|-----|---------------|--------------|------------|------------|
| Head node | 1 | m5.xlarge | - | - | - |
| Worker node| 16 | g5.4xlarge | 1 x A10G | 24 GB | 64 GB |

And launch the following script:
And launch the following script to fine-tune LLaMA 2 7B:

```
./run_llama_ft.sh --size=7b [--as-test]
./run_llama_ft.sh --size=7b --as-test
```

The flag `--as-test` is for demo / testing purposes as it runs through only one forward and backward pass of the model. The model loading, and remote checkpointing would still run.
Expand Down Expand Up @@ -84,7 +84,7 @@ And the special tokens can be:
{"tokens": ["<ASSISTANT>", "</ASSISTANT>", "<USER>", "</USER>"]}
```

Depending on the dataset you want to finetune on, the tokenization and dataset pre-processing will likely need to be adjusted. The current code is configured to train on the Grade School Math 8k (GSM8K) dataset. By running the code below we create three files that are needed to launch the training script with.
Depending on the dataset you want to fine-tune on, the tokenization and dataset pre-processing will likely need to be adjusted. The current code is configured to train on the Grade School Math 8k (GSM8K) dataset. By running the code below we create three files that are needed to launch the training script with.

```
python create_dataset.py
Expand All @@ -100,7 +100,7 @@ This dataset is trained with a context length of 512 which includes excessive pa

The script is written using Ray Train + Deepspeed integration via accelerate API. The script is general enough that it can be used to fine-tune all released sizes of Llama-2 models.

The CLI for seeing all the options is:
The command for seeing all the options is:

```
python finetune_hf_llm.py --help
Expand All @@ -116,6 +116,34 @@ This script was tested across three model sizes on the following cluster configu
| 70B | `meta-llama/Llama-2-70b-hf` | 8 | 32x A10G (24G) | ~190 min. |


To launch a full fine-tuning you can use the following command:

```
./run_llama_ft.sh --size=7b
```

### Launching LoRA fine-tuning

You can utilize [LoRA](https://arxiv.org/abs/2106.09685) to achieve more resource efficient fine-tuning results than full-parameter fine-tuning, but unlocking smaller instance types and more effecient model serving.
To launch a LoRA fine-tuning, you can use the following command or similar commands for other model sizes:

```
./run_llama_ft.sh --size=7b --lora
```

Fine-tuning a model with LoRA results in a checkpoint containing only the fine-tuned weights.
As an example, the default Llama 2 LoRA configuration should yield a 42/64/202MB checkpoint for 7B/13B/70B models.
If we want to evaluate the model after training, we can merge the model weights with the original (non-fine-tuned) model.
We provide a script to merge the fine-tuned weights with the original weights to produce a full-parameter checkpoint.
The script has high CPU memory requirements because it requires us to load all parameters into memory at the same time,
13GB/24GB/152GB for 7B/13B/70B models. Downloading and loading the original weights should take ~1min/~2min/~10min each
on a p4de.24xlarge instance. You can run the script as follows:

```
python merge_lora_weights.py --model-name=7b --checkpoint=<path to your checkpoint> --output-path=<desired output path>
```

This leaves a self-contained LoRA fine-tuned model, config and tokenizer at the desired output path.

### Guideline on how to pick node instances when A100s are not available.

Expand Down Expand Up @@ -202,7 +230,6 @@ scaling_config=air.ScalingConfig(
)
```


### Submiting a production job
You can easily submit a production job using the following command:

Expand All @@ -214,4 +241,4 @@ This will create a job yaml file that you can use to submit a production job on

```
anyscale job submit job.yaml
```
```
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_linear_schedule_with_warmup,
)

from peft import LoraConfig, get_peft_model
import ray
from ray import train
import ray.util.scheduling_strategies
Expand All @@ -45,12 +46,72 @@

OPTIM_BETAS = (0.9, 0.999)
OPTIM_EPS = 1e-8
NUM_WARMUP_STEPS = 10
OPTIM_WEIGHT_DECAY = 0.0
ATTENTION_LAYER_NAME = "self_attn"


def get_expected_lora_num_parameters(
model, lora_config: LoraConfig, attn_layer_name: str = ATTENTION_LAYER_NAME
):
"""Calculate the expected number of parameters for lora finetuning."""
sum_params = 0
num_attention_layers = 0
modules = model.named_modules()
loraified_modules = 0
# We calculate the number of parameters we need for lora finetuning by calculating
# the sizes of the deecomposed weight matrices according to the paper.
for full_name, target in modules:
layer_name = full_name.split(".")[-1]

if layer_name == attn_layer_name:
# Detected another attention layer (for example, llama 2 70b should have 80
# of these)
num_attention_layers += 1
elif layer_name in lora_config.modules_to_save:
# Detect another non-lora module to save, which will also contribute to the
# number of checkpointed parameters. This will result in one set of
# trainable parameters "<layer>.original_module.weight" and another one with
# "<layer>.modules_to_save.default.weight"
# Therefore, each layer contributes 2 x the number of actual elements in
# that layer.
sum_params += 2 * target.weight.numel()
print(
"Found non-lora-layer to checkpoint: ",
layer_name,
" with num params ",
target.weight.numel(),
)
else:
for module_name in lora_config.target_modules:
if layer_name == module_name:
loraified_modules += 1
if isinstance(target, nn.Linear):
# Target is attention weight
sum_params += (
target.in_features + target.out_features
) * lora_config.r
elif isinstance(target, nn.Embedding):
# Target is linear weight
sum_params += (
target.embedding_dim + target.num_embeddings
) * lora_config.r

print(
f"Detected {num_attention_layers} attention layers, containing"
f" {loraified_modules} modules to modify according to LoRA's `target_modules`."
f" This should yield {sum_params} trainable parameters."
)

return sum_params


def get_number_of_params(model: nn.Module):
state_dict = model.state_dict()
return sum(p.numel() for p in state_dict.values())
sum = 0
for name, param in model.named_parameters():
if param.requires_grad:
sum += param.numel()
return sum


def collate_fn(batch, tokenizer, block_size, device):
Expand Down Expand Up @@ -228,7 +289,38 @@ def training_function(kwargs: dict):
use_cache=False,
)
print(f"Done loading model in {time.time() - s} seconds.")

model.resize_token_embeddings(len(tokenizer))

if config["lora"]:
# Apply LoRA
s = time.time()
lora_config = LoraConfig(**config["lora_config"])

expected_num_parameters = get_expected_lora_num_parameters(
lora_config=lora_config, model=model
)

print(f"Attempting to apply LoRA config: {lora_config}")

model.enable_input_require_grads()
model = get_peft_model(model, lora_config)

num_parameters = get_number_of_params(model)

if num_parameters != expected_num_parameters:
raise ValueError(
f"Expected {expected_num_parameters} parameters, got {num_parameters} "
f"parameters. LoRA-ification failed."
)

print(
f"LoRA-ification done in {time.time() - s} seconds. Estimated checkpoint "
f"size (fp16): {num_parameters * 2 / 1e6} MB"
)

print(f"Number of checkpointed parameters: {get_number_of_params(model)}")

print("Model initialized with pretrained weights. Training starting...")
if not args.no_grad_ckpt:
model.gradient_checkpointing_enable()
Expand All @@ -249,26 +341,29 @@ def training_function(kwargs: dict):
)

# Instantiate scheduler
# Creates Dummy Scheduler if `scheduler` was specified in the config file
# Creates Dummy Scheduler if `scheduler` was specified in the config file or
# else, creates `args.lr_scheduler_type` Scheduler
# get train and valid dataset lengths

num_steps_per_epoch = math.ceil(train_ds_len / args.batch_size_per_device)
total_training_steps = (
num_steps_per_epoch * num_epochs // gradient_accumulation_steps
)

if (
accelerator.state.deepspeed_plugin is None
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
):
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=100,
num_training_steps=(
(train_ds_len * num_epochs) // gradient_accumulation_steps
),
num_warmup_steps=NUM_WARMUP_STEPS * args.num_devices,
num_training_steps=total_training_steps * args.num_devices,
)
else:
lr_scheduler = DummyScheduler(
optimizer,
total_num_steps=(train_ds_len * num_epochs) // gradient_accumulation_steps,
warmup_num_steps=100,
warmup_num_steps=NUM_WARMUP_STEPS * args.num_devices,
total_num_steps=total_training_steps * args.num_devices,
)

# Prepare everything
Expand All @@ -284,7 +379,6 @@ def training_function(kwargs: dict):
print("Number of batches on main process", train_ds_len // batch_size)

for epoch in range(num_epochs):

fwd_time_sum, bwd_time_sum, optim_step_time_sum = 0, 0, 0
s_epoch = time.time()
model.train()
Expand Down Expand Up @@ -328,12 +422,13 @@ def training_function(kwargs: dict):
f"loss: {loss.item()} step-time: {e_opt_step - s_fwd}"
)

aggregated_loss = torch.mean(accelerator.gather(loss[None])).item()

if config["as_test"]:
break

# as long as this is not the last step report here
if step != (train_ds_len // batch_size - 1):
aggregated_loss = torch.mean(accelerator.gather(loss[None])).item()
train.report(
{
"epoch": epoch,
Expand Down Expand Up @@ -378,7 +473,7 @@ def training_function(kwargs: dict):
metrics = {
"epoch": epoch,
"iteration": step,
"train_loss_batch": loss.item(),
"train_loss_batch": aggregated_loss,
"avg_train_loss_epoch": loss_sum.item() / (step + 1),
"eval_loss": eloss,
"perplexity": perplex,
Expand Down Expand Up @@ -459,6 +554,13 @@ def training_function(kwargs: dict):
time.perf_counter() - checkpoint_save_start,
)

if perplex < args.stop_perplexity:
print(f"Perplexity reached {perplex} < {args.stop_perplexity}. Stopping.")
break

if config["as_test"]:
break


def parse_args():

Expand All @@ -481,6 +583,14 @@ def parse_args():
help="Batch size to use per device.",
)

parser.add_argument(
"--stop-perplexity",
default=0,
type=float,
help="Target perplexity to reach after which to stop training. Default is 0. "
"If 0, training will not stop on perplexity.",
)

parser.add_argument(
"--eval-batch-size-per-device",
type=int,
Expand All @@ -495,7 +605,9 @@ def parse_args():
"--grad_accum", type=int, default=1, help="Gradient accumulation steps."
)
parser.add_argument("--train_path", type=str, help="Path to training jsonl file")

parser.add_argument("--test_path", type=str, help="Path to testing jsonl file")

parser.add_argument(
"--special_token_path", type=str, help="Path to token json file"
)
Expand All @@ -505,6 +617,7 @@ def parse_args():
help="If passed, will not use gradient checkpointing.",
)
parser.add_argument("--output_dir", type=str, help="Path to output directory.")

parser.add_argument(
"--model_name", default="meta-llama/Llama-2-7b-chat-hf", type=str
)
Expand Down Expand Up @@ -539,6 +652,15 @@ def parse_args():
default="./deepspeed_configs/zero_3_llama_2_7b.json",
help="Deepspeed config json to use.",
)

parser.add_argument(
"--lora",
action="store_true",
default=False,
help="If passed, will enable parameter efficient fine-tuning with LoRA ("
"https://arxiv.org/pdf/2106.09685.pdf).",
)

args = parser.parse_args()

return args
Expand Down Expand Up @@ -566,6 +688,12 @@ def main():
}
)

# Add LoRA config if needed
if args.lora:
with open("./lora_configs/lora.json", "r") as json_file:
lora_config = json.load(json_file)
config["lora_config"] = lora_config

# Add deepspeed plugin to the config
ds_plugin = DeepSpeedPlugin(hf_ds_config=config.get("ds_config"))
config.update(ds_plugin=ds_plugin)
Expand Down Expand Up @@ -602,6 +730,10 @@ def main():
f"{artifact_storage}/{user_name}/ft_llms_with_deepspeed/{args.model_name}"
)

trial_name = f"{args.model_name}".split("/")[-1]
if args.lora:
trial_name += "-lora"

trainer = TorchTrainer(
training_function,
train_loop_config={
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"task_type": "CAUSAL_LM",
"modules_to_save": [],
"bias": "none",
"fan_in_fan_out": false,
"init_lora_weights": true
}
Loading

0 comments on commit 1a08c16

Please sign in to comment.