From 1a08c168136e08ab917c4ab12d0508099dc37f96 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 18 Oct 2023 20:13:00 -0700 Subject: [PATCH] [Train][Templates] Add LoRA support to Llama-2 finetuning example (#37794) Signed-off-by: Artur Niederfahrenhorst Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> --- .../README.md | 45 ++++- .../finetune_hf_llm.py | 156 +++++++++++++++-- .../lora_configs/lora.json | 11 ++ .../merge_lora_weights.py | 157 ++++++++++++++++++ .../run_llama_ft.sh | 14 +- .../Dockerfile | 4 +- .../requirements.txt | 3 +- .../ray_release/byod/byod_finetune_llvms.sh | 3 +- release/release_tests.yaml | 24 +++ 9 files changed, 388 insertions(+), 29 deletions(-) create mode 100644 doc/source/templates/04_finetuning_llms_with_deepspeed/lora_configs/lora.json create mode 100644 doc/source/templates/04_finetuning_llms_with_deepspeed/merge_lora_weights.py diff --git a/doc/source/templates/04_finetuning_llms_with_deepspeed/README.md b/doc/source/templates/04_finetuning_llms_with_deepspeed/README.md index 734004fef205b..5f7e0f133f7cd 100644 --- a/doc/source/templates/04_finetuning_llms_with_deepspeed/README.md +++ b/doc/source/templates/04_finetuning_llms_with_deepspeed/README.md @@ -1,24 +1,24 @@ -# Finetuning Llama-2 series models with Deepspeed, Accelerate, and Ray Train TorchTrainer +# Fine-tuning Llama-2 series models with Deepspeed, Accelerate, and Ray Train TorchTrainer | Template Specification | Description | | ---------------------- | ----------- | -| Summary | This template, demonstrates how to perform full parameter fine-tuning for Llama-2 series models (7B, 13B, and 70B) using TorchTrainer with the DeepSpeed ZeRO-3 strategy. | +| Summary | This template, demonstrates how to perform fine-tuning (full parameter or LoRA) for Llama-2 series models (7B, 13B, and 70B) using TorchTrainer with the DeepSpeed ZeRO-3 strategy. | | Time to Run | ~14 min. for 7B for 1 epoch on 3.5M tokens. ~26 min for 13B for 1 epoch. | | Minimum Compute Requirements | 16xg5.4xlarge for worker nodes for 7B model, 4xg5.12xlarge nodes for 13B model, and 4xg5.48xlarge (or 2xp4de.24xlarge) nodes for 70B| | Cluster Environment | This template uses a docker image built on top of the latest Anyscale-provided Ray image using Python 3.9: [`anyscale/ray:latest-py39-cu118`](https://docs.anyscale.com/reference/base-images/overview). | ## Getting Started -For 7B, set up a cluster on AWS with the following settings: +For a full-parameter fine-tuning of 7B models, set up a cluster on AWS with the following settings: | | num | instance type | GPU per node | GPU Memory | CPU Memory | |------------|-----|---------------|--------------|------------|------------| | Head node | 1 | m5.xlarge | - | - | - | | Worker node| 16 | g5.4xlarge | 1 x A10G | 24 GB | 64 GB | -And launch the following script: +And launch the following script to fine-tune LLaMA 2 7B: ``` -./run_llama_ft.sh --size=7b [--as-test] +./run_llama_ft.sh --size=7b --as-test ``` The flag `--as-test` is for demo / testing purposes as it runs through only one forward and backward pass of the model. The model loading, and remote checkpointing would still run. @@ -84,7 +84,7 @@ And the special tokens can be: {"tokens": ["", "", "", ""]} ``` -Depending on the dataset you want to finetune on, the tokenization and dataset pre-processing will likely need to be adjusted. The current code is configured to train on the Grade School Math 8k (GSM8K) dataset. By running the code below we create three files that are needed to launch the training script with. +Depending on the dataset you want to fine-tune on, the tokenization and dataset pre-processing will likely need to be adjusted. The current code is configured to train on the Grade School Math 8k (GSM8K) dataset. By running the code below we create three files that are needed to launch the training script with. ``` python create_dataset.py @@ -100,7 +100,7 @@ This dataset is trained with a context length of 512 which includes excessive pa The script is written using Ray Train + Deepspeed integration via accelerate API. The script is general enough that it can be used to fine-tune all released sizes of Llama-2 models. -The CLI for seeing all the options is: +The command for seeing all the options is: ``` python finetune_hf_llm.py --help @@ -116,6 +116,34 @@ This script was tested across three model sizes on the following cluster configu | 70B | `meta-llama/Llama-2-70b-hf` | 8 | 32x A10G (24G) | ~190 min. | +To launch a full fine-tuning you can use the following command: + +``` +./run_llama_ft.sh --size=7b +``` + +### Launching LoRA fine-tuning + +You can utilize [LoRA](https://arxiv.org/abs/2106.09685) to achieve more resource efficient fine-tuning results than full-parameter fine-tuning, but unlocking smaller instance types and more effecient model serving. +To launch a LoRA fine-tuning, you can use the following command or similar commands for other model sizes: + +``` +./run_llama_ft.sh --size=7b --lora +``` + +Fine-tuning a model with LoRA results in a checkpoint containing only the fine-tuned weights. +As an example, the default Llama 2 LoRA configuration should yield a 42/64/202MB checkpoint for 7B/13B/70B models. +If we want to evaluate the model after training, we can merge the model weights with the original (non-fine-tuned) model. +We provide a script to merge the fine-tuned weights with the original weights to produce a full-parameter checkpoint. +The script has high CPU memory requirements because it requires us to load all parameters into memory at the same time, +13GB/24GB/152GB for 7B/13B/70B models. Downloading and loading the original weights should take ~1min/~2min/~10min each +on a p4de.24xlarge instance. You can run the script as follows: + +``` +python merge_lora_weights.py --model-name=7b --checkpoint= --output-path= +``` + +This leaves a self-contained LoRA fine-tuned model, config and tokenizer at the desired output path. ### Guideline on how to pick node instances when A100s are not available. @@ -202,7 +230,6 @@ scaling_config=air.ScalingConfig( ) ``` - ### Submiting a production job You can easily submit a production job using the following command: @@ -214,4 +241,4 @@ This will create a job yaml file that you can use to submit a production job on ``` anyscale job submit job.yaml -``` \ No newline at end of file +``` diff --git a/doc/source/templates/04_finetuning_llms_with_deepspeed/finetune_hf_llm.py b/doc/source/templates/04_finetuning_llms_with_deepspeed/finetune_hf_llm.py index 2cb1a7e73055c..a95cc742e11b1 100644 --- a/doc/source/templates/04_finetuning_llms_with_deepspeed/finetune_hf_llm.py +++ b/doc/source/templates/04_finetuning_llms_with_deepspeed/finetune_hf_llm.py @@ -29,6 +29,7 @@ get_linear_schedule_with_warmup, ) +from peft import LoraConfig, get_peft_model import ray from ray import train import ray.util.scheduling_strategies @@ -45,12 +46,72 @@ OPTIM_BETAS = (0.9, 0.999) OPTIM_EPS = 1e-8 +NUM_WARMUP_STEPS = 10 OPTIM_WEIGHT_DECAY = 0.0 +ATTENTION_LAYER_NAME = "self_attn" + + +def get_expected_lora_num_parameters( + model, lora_config: LoraConfig, attn_layer_name: str = ATTENTION_LAYER_NAME +): + """Calculate the expected number of parameters for lora finetuning.""" + sum_params = 0 + num_attention_layers = 0 + modules = model.named_modules() + loraified_modules = 0 + # We calculate the number of parameters we need for lora finetuning by calculating + # the sizes of the deecomposed weight matrices according to the paper. + for full_name, target in modules: + layer_name = full_name.split(".")[-1] + + if layer_name == attn_layer_name: + # Detected another attention layer (for example, llama 2 70b should have 80 + # of these) + num_attention_layers += 1 + elif layer_name in lora_config.modules_to_save: + # Detect another non-lora module to save, which will also contribute to the + # number of checkpointed parameters. This will result in one set of + # trainable parameters ".original_module.weight" and another one with + # ".modules_to_save.default.weight" + # Therefore, each layer contributes 2 x the number of actual elements in + # that layer. + sum_params += 2 * target.weight.numel() + print( + "Found non-lora-layer to checkpoint: ", + layer_name, + " with num params ", + target.weight.numel(), + ) + else: + for module_name in lora_config.target_modules: + if layer_name == module_name: + loraified_modules += 1 + if isinstance(target, nn.Linear): + # Target is attention weight + sum_params += ( + target.in_features + target.out_features + ) * lora_config.r + elif isinstance(target, nn.Embedding): + # Target is linear weight + sum_params += ( + target.embedding_dim + target.num_embeddings + ) * lora_config.r + + print( + f"Detected {num_attention_layers} attention layers, containing" + f" {loraified_modules} modules to modify according to LoRA's `target_modules`." + f" This should yield {sum_params} trainable parameters." + ) + + return sum_params def get_number_of_params(model: nn.Module): - state_dict = model.state_dict() - return sum(p.numel() for p in state_dict.values()) + sum = 0 + for name, param in model.named_parameters(): + if param.requires_grad: + sum += param.numel() + return sum def collate_fn(batch, tokenizer, block_size, device): @@ -228,7 +289,38 @@ def training_function(kwargs: dict): use_cache=False, ) print(f"Done loading model in {time.time() - s} seconds.") + model.resize_token_embeddings(len(tokenizer)) + + if config["lora"]: + # Apply LoRA + s = time.time() + lora_config = LoraConfig(**config["lora_config"]) + + expected_num_parameters = get_expected_lora_num_parameters( + lora_config=lora_config, model=model + ) + + print(f"Attempting to apply LoRA config: {lora_config}") + + model.enable_input_require_grads() + model = get_peft_model(model, lora_config) + + num_parameters = get_number_of_params(model) + + if num_parameters != expected_num_parameters: + raise ValueError( + f"Expected {expected_num_parameters} parameters, got {num_parameters} " + f"parameters. LoRA-ification failed." + ) + + print( + f"LoRA-ification done in {time.time() - s} seconds. Estimated checkpoint " + f"size (fp16): {num_parameters * 2 / 1e6} MB" + ) + + print(f"Number of checkpointed parameters: {get_number_of_params(model)}") + print("Model initialized with pretrained weights. Training starting...") if not args.no_grad_ckpt: model.gradient_checkpointing_enable() @@ -249,26 +341,29 @@ def training_function(kwargs: dict): ) # Instantiate scheduler - # Creates Dummy Scheduler if `scheduler` was specified in the config file + # Creates Dummy Scheduler if `scheduler` was specified in the config file or # else, creates `args.lr_scheduler_type` Scheduler # get train and valid dataset lengths + num_steps_per_epoch = math.ceil(train_ds_len / args.batch_size_per_device) + total_training_steps = ( + num_steps_per_epoch * num_epochs // gradient_accumulation_steps + ) + if ( accelerator.state.deepspeed_plugin is None or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config ): lr_scheduler = get_linear_schedule_with_warmup( optimizer=optimizer, - num_warmup_steps=100, - num_training_steps=( - (train_ds_len * num_epochs) // gradient_accumulation_steps - ), + num_warmup_steps=NUM_WARMUP_STEPS * args.num_devices, + num_training_steps=total_training_steps * args.num_devices, ) else: lr_scheduler = DummyScheduler( optimizer, - total_num_steps=(train_ds_len * num_epochs) // gradient_accumulation_steps, - warmup_num_steps=100, + warmup_num_steps=NUM_WARMUP_STEPS * args.num_devices, + total_num_steps=total_training_steps * args.num_devices, ) # Prepare everything @@ -284,7 +379,6 @@ def training_function(kwargs: dict): print("Number of batches on main process", train_ds_len // batch_size) for epoch in range(num_epochs): - fwd_time_sum, bwd_time_sum, optim_step_time_sum = 0, 0, 0 s_epoch = time.time() model.train() @@ -328,12 +422,13 @@ def training_function(kwargs: dict): f"loss: {loss.item()} step-time: {e_opt_step - s_fwd}" ) + aggregated_loss = torch.mean(accelerator.gather(loss[None])).item() + if config["as_test"]: break # as long as this is not the last step report here if step != (train_ds_len // batch_size - 1): - aggregated_loss = torch.mean(accelerator.gather(loss[None])).item() train.report( { "epoch": epoch, @@ -378,7 +473,7 @@ def training_function(kwargs: dict): metrics = { "epoch": epoch, "iteration": step, - "train_loss_batch": loss.item(), + "train_loss_batch": aggregated_loss, "avg_train_loss_epoch": loss_sum.item() / (step + 1), "eval_loss": eloss, "perplexity": perplex, @@ -459,6 +554,13 @@ def training_function(kwargs: dict): time.perf_counter() - checkpoint_save_start, ) + if perplex < args.stop_perplexity: + print(f"Perplexity reached {perplex} < {args.stop_perplexity}. Stopping.") + break + + if config["as_test"]: + break + def parse_args(): @@ -481,6 +583,14 @@ def parse_args(): help="Batch size to use per device.", ) + parser.add_argument( + "--stop-perplexity", + default=0, + type=float, + help="Target perplexity to reach after which to stop training. Default is 0. " + "If 0, training will not stop on perplexity.", + ) + parser.add_argument( "--eval-batch-size-per-device", type=int, @@ -495,7 +605,9 @@ def parse_args(): "--grad_accum", type=int, default=1, help="Gradient accumulation steps." ) parser.add_argument("--train_path", type=str, help="Path to training jsonl file") + parser.add_argument("--test_path", type=str, help="Path to testing jsonl file") + parser.add_argument( "--special_token_path", type=str, help="Path to token json file" ) @@ -505,6 +617,7 @@ def parse_args(): help="If passed, will not use gradient checkpointing.", ) parser.add_argument("--output_dir", type=str, help="Path to output directory.") + parser.add_argument( "--model_name", default="meta-llama/Llama-2-7b-chat-hf", type=str ) @@ -539,6 +652,15 @@ def parse_args(): default="./deepspeed_configs/zero_3_llama_2_7b.json", help="Deepspeed config json to use.", ) + + parser.add_argument( + "--lora", + action="store_true", + default=False, + help="If passed, will enable parameter efficient fine-tuning with LoRA (" + "https://arxiv.org/pdf/2106.09685.pdf).", + ) + args = parser.parse_args() return args @@ -566,6 +688,12 @@ def main(): } ) + # Add LoRA config if needed + if args.lora: + with open("./lora_configs/lora.json", "r") as json_file: + lora_config = json.load(json_file) + config["lora_config"] = lora_config + # Add deepspeed plugin to the config ds_plugin = DeepSpeedPlugin(hf_ds_config=config.get("ds_config")) config.update(ds_plugin=ds_plugin) @@ -602,6 +730,10 @@ def main(): f"{artifact_storage}/{user_name}/ft_llms_with_deepspeed/{args.model_name}" ) + trial_name = f"{args.model_name}".split("/")[-1] + if args.lora: + trial_name += "-lora" + trainer = TorchTrainer( training_function, train_loop_config={ diff --git a/doc/source/templates/04_finetuning_llms_with_deepspeed/lora_configs/lora.json b/doc/source/templates/04_finetuning_llms_with_deepspeed/lora_configs/lora.json new file mode 100644 index 0000000000000..f02ebe53c68e5 --- /dev/null +++ b/doc/source/templates/04_finetuning_llms_with_deepspeed/lora_configs/lora.json @@ -0,0 +1,11 @@ +{ + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": ["gate_proj", "up_proj", "down_proj"], + "task_type": "CAUSAL_LM", + "modules_to_save": [], + "bias": "none", + "fan_in_fan_out": false, + "init_lora_weights": true +} diff --git a/doc/source/templates/04_finetuning_llms_with_deepspeed/merge_lora_weights.py b/doc/source/templates/04_finetuning_llms_with_deepspeed/merge_lora_weights.py new file mode 100644 index 0000000000000..be06cf94b173f --- /dev/null +++ b/doc/source/templates/04_finetuning_llms_with_deepspeed/merge_lora_weights.py @@ -0,0 +1,157 @@ +""" +This script merges the weights of a LoRA checkpoint with the base model weights +to create a single model that can be used for model evaluation. +""" + +import torch +import argparse +import time +import peft +from pathlib import Path + +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + StoppingCriteriaList, +) + +from utils import download_model, get_mirror_link, get_checkpoint_and_refs_dir + +# In addition to merging the lora weights, you can also formulate a prompt for the +# model here to quickly test it after merging +TEST_EVAL = False +TEST_PROMPT = ( + "Natalia sold clips to 48 of her friends in April, and then " + "she sold half as many clips in May. How many clips did Natalia sell " + "altogether in April and May?" +) +STOP_TOKEN = "" + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of training script.") + + parser.add_argument( + "--output-path", + type=str, + help="Path to output directory. Defaults to the orginal checkpoint directory.", + required=True, + ) + + parser.add_argument("--model-name", required=True, type=str, help="7b, 13b or 70b.") + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to checkpoint containing the LoRA weights.", + ) + + args = parser.parse_args() + + return args + + +def test_eval(model, tokenizer): + """Query the model with a single prompt to sanity check it.""" + + print("Starting model evaluation...") + + model.eval() + model.to("cuda") + + print("Prompting model with promtp : ", TEST_PROMPT) + input_ids = tokenizer(TEST_PROMPT, return_tensors="pt")["input_ids"].to("cuda") + + stop_token_embeding = tokenizer( + STOP_TOKEN, return_tensors="pt", add_special_tokens=False + )["input_ids"].to("cuda") + + def custom_stopping_criteria(embeddings, *args, **kwargs) -> bool: + return stop_token_embeding in embeddings + + stopping_criteria = StoppingCriteriaList([custom_stopping_criteria]) + + with torch.no_grad(): + generation_output = model.generate( + input_ids=input_ids, + output_scores=True, + max_new_tokens=500, + stopping_criteria=stopping_criteria, + ) + + decoded = tokenizer.batch_decode(generation_output) + print("Outputs: ", decoded) + + +def main(): + args = parse_args() + + # Sanity checks + if not Path(args.checkpoint).exists(): + raise ValueError(f"Checkpoint {args.checkpoint} does not exist.") + + if not args.output_path: + args.output_path = Path(args.checkpoint) / "merged_model" + print(f"Output path not specified. Using {args.output_path}") + + Path(args.output_path).mkdir(parents=True, exist_ok=True) + + # Load orignal model + s = time.time() + model_id = f"meta-llama/Llama-2-{args.model_name}-hf" + s3_bucket = get_mirror_link(model_id) + ckpt_path, _ = get_checkpoint_and_refs_dir(model_id=model_id, bucket_uri=s3_bucket) + + print(f"Downloading original model {model_id} from {s3_bucket} to {ckpt_path} ...") + print("Loading tokenizer...") + + tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, legacy=True) + tokenizer.save_pretrained(Path(args.output_path)) + + print(f"Saved tokenizer to {args.output_path}") + + download_model( + model_id=model_id, + bucket_uri=s3_bucket, + s3_sync_args=["--no-sign-request"], + ) + + print(f"Downloading to {ckpt_path} finished after {time.time() - s} seconds.") + print(f"Loading original model from {ckpt_path} ...") + + s2 = time.time() + + model = AutoModelForCausalLM.from_pretrained( + ckpt_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + use_cache=False, + ) + model.resize_token_embeddings(len(tokenizer)) + + print(f"Done downloading and loading model after {time.time() - s2} seconds.") + print("Loading and merging peft weights...") + s3 = time.time() + + # Load LoRA weights + model: peft.PeftModel = peft.PeftModel.from_pretrained( + model=model, + model_id=args.checkpoint, + ) + + # Merge weights and save + model = model.merge_and_unload() + output_path = Path(args.output_path) + model.save_pretrained(output_path, safe_serialization=True) + model.config.save_pretrained(output_path) + + print(f"Saved merged model to {args.output_path} after {time.time() - s3} seconds.") + print(f"This script took {time.time() - s} seconds to execute.") + + if TEST_EVAL: + test_eval(model, tokenizer) + + +if __name__ == "__main__": + main() diff --git a/doc/source/templates/04_finetuning_llms_with_deepspeed/run_llama_ft.sh b/doc/source/templates/04_finetuning_llms_with_deepspeed/run_llama_ft.sh index 5b762a32fe4e6..8f7024b7d0f8a 100755 --- a/doc/source/templates/04_finetuning_llms_with_deepspeed/run_llama_ft.sh +++ b/doc/source/templates/04_finetuning_llms_with_deepspeed/run_llama_ft.sh @@ -35,7 +35,7 @@ fine_tune() { --test_path "${test_path}" \ --special_token_path "${token_path}" \ --num-checkpoints-to-keep 1 \ - --num-epochs 1 \ + --num-epochs 3 \ "${params[@]}"; then echo "Failed to fine-tune the model. Exiting..." exit 1 @@ -56,9 +56,14 @@ do key=${arg%%=*} value=${arg#*=} if [[ "$key" == "--size" ]]; then - SIZE=${value} - elif [ "$arg" = "--as-test" ]; then - params+=("--as-test") + SIZE=${value}; + elif [[ "$arg" == "--as-test" ]]; then + params+=("--as-test"); + elif [[ "$arg" == "--lora" ]]; then + params+=("--lora"); + # Lora usually requires a lower learning rate + params+=("--lr"); + params+=("1e-4"); fi done @@ -87,6 +92,7 @@ MODEL_ID="meta-llama/Llama-2-${SIZE}-hf" CONFIG_DIR="./deepspeed_configs/zero_3_llama_2_${SIZE}.json" check_and_create_dataset "${DATA_DIR}" + fine_tune "$BS" "$ND" "$MODEL_ID" "$BASE_DIR" "$CONFIG_DIR" "$TRAIN_PATH" "$TEST_PATH" "$TOKEN_PATH" "${params[@]}" echo "Process completed." diff --git a/doc/source/templates/testing/docker/04_finetuning_llms_with_deepspeed/Dockerfile b/doc/source/templates/testing/docker/04_finetuning_llms_with_deepspeed/Dockerfile index c313fb5a9ac70..299c30eafb726 100644 --- a/doc/source/templates/testing/docker/04_finetuning_llms_with_deepspeed/Dockerfile +++ b/doc/source/templates/testing/docker/04_finetuning_llms_with_deepspeed/Dockerfile @@ -1,5 +1,5 @@ -# Dockerfile used to create the docker image for `03_serving_stable_diffusion`. -FROM anyscale/ray:2.6.1-py39-cu117 +# Dockerfile used to create the docker image for `04_finetuning_llms_with_deepspeed`. +FROM anyscale/ray:2.7.1-py310-cu121 COPY requirements.txt ./ diff --git a/doc/source/templates/testing/docker/04_finetuning_llms_with_deepspeed/requirements.txt b/doc/source/templates/testing/docker/04_finetuning_llms_with_deepspeed/requirements.txt index c4cd96d21a138..0f30a1e954d6d 100644 --- a/doc/source/templates/testing/docker/04_finetuning_llms_with_deepspeed/requirements.txt +++ b/doc/source/templates/testing/docker/04_finetuning_llms_with_deepspeed/requirements.txt @@ -14,4 +14,5 @@ protobuf<3.21.0 torchmetrics lm_eval==0.3.0 tiktoken==0.1.2 -sentencepiece \ No newline at end of file +sentencepiece +peft @ git+https://github.com/huggingface/peft.git@08368a1fba16de09756f067637ff326c71598fb3 \ No newline at end of file diff --git a/release/ray_release/byod/byod_finetune_llvms.sh b/release/ray_release/byod/byod_finetune_llvms.sh index 5a9ed91b6593a..5052b4ebfc3bd 100755 --- a/release/ray_release/byod/byod_finetune_llvms.sh +++ b/release/ray_release/byod/byod_finetune_llvms.sh @@ -22,4 +22,5 @@ pip3 install -U \ tiktoken==0.1.2 \ sentencepiece==0.1.99 \ "urllib3<1.27" \ - git+https://github.com/huggingface/transformers.git@d0c1aeb + git+https://github.com/huggingface/transformers.git@d0c1aeb \ + git+https://github.com/huggingface/peft.git@08368a1fba16de09756f067637ff326c71598fb3 diff --git a/release/release_tests.yaml b/release/release_tests.yaml index 41a1a812525a1..dc49d45d7e952 100644 --- a/release/release_tests.yaml +++ b/release/release_tests.yaml @@ -1054,6 +1054,30 @@ cluster: cluster_compute: ../testing/compute_configs/04_finetuning_llms_with_deepspeed/gce_7b.yaml +- name: workspace_template_finetuning_llms_with_deepspeed_llama_2_7b_lora + group: Workspace templates + working_dir: workspace_templates/04_finetuning_llms_with_deepspeed + python: "3.9" + frequency: nightly-3x + team: ml + cluster: + byod: + type: cu121 + # This needs to be in sync with requirements under go/llm-forge. + post_build_script: byod_finetune_llvms.sh + cluster_compute: ../testing/compute_configs/04_finetuning_llms_with_deepspeed/aws_7b.yaml + + run: + timeout: 1000 + script: chmod +x ./run_llama_ft.sh && ./run_llama_ft.sh --size=7b --lora --as-test + + variations: + - __suffix__: aws + - __suffix__: gce + env: gce + frequency: manual + cluster: + cluster_compute: ../testing/compute_configs/04_finetuning_llms_with_deepspeed/gce_7b.yaml #######################