Skip to content

Commit

Permalink
doc(training): update SFT tutorial's wording, separate shell script
Browse files Browse the repository at this point in the history
The shell script is to launch precompilation and fine-tuning.
  • Loading branch information
tengomucho committed Jan 29, 2025
1 parent bebd673 commit fd4a20e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 57 deletions.
77 changes: 20 additions & 57 deletions docs/source/training_tutorials/sft_lora_finetune_llm.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ PyTorch Neuron uses `torch_xla`. It evaluates operations lazily during the execu

When training models on AWS Trainium we first need to compile our model with our training arguments.

To ease this step, we added a [model cache repository](https://huggingface.co/aws-neuron/optimum-neuron-cache), which allows us to use precompiled models from the Hugging Face Hub to skip the compilation step. But be careful: every change in the model configuration might lead to a new compilation, which could result in some cache misses.
To ease this step, we added a [model cache repository](https://huggingface.co/aws-neuron/optimum-neuron-cache), which allows us to use precompiled models from the Hugging Face Hub to skip the compilation step. This is useful because it will allow to compile models much faster than what it would do when doing the actual training, because compilation can be parallelized. But be careful: every change in the model configuration might lead to a new compilation, which could result in some cache misses.

<Tip>

Expand All @@ -224,28 +224,29 @@ set -ex

export NEURON_FUSE_SOFTMAX=1
export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3
export MALLOC_ARENA_MAX=64
export MALLOC_ARENA_MAX=64 # limit the CPU allocation to avoid potential crashes
export NEURON_CC_FLAGS="--model-type=transformer --distribution-strategy=llm-training --enable-saturate-infinity --cache_dir=/home/ubuntu/cache_dir_neuron/"

PROCESSES_PER_NODE=8

NUM_EPOCHS=1
TP_DEGREE=2
PP_DEGREE=1
BS=1
GRADIENT_ACCUMULATION_STEPS=8
LOGGING_STEPS=1
MODEL_NAME="meta-llama/Meta-Llama-3-8B"
OUTPUT_DIR=output
OUTPUT_DIR=dolly_llama

if [ "$NEURON_EXTRACT_GRAPHS_ONLY" = "1" ]; then
MAX_STEPS=$((LOGGING_STEPS + 5))
MAX_STEPS=10
NUM_EPOCHS=1
else
MAX_STEPS=-1
NUM_EPOCHS=3
fi


XLA_USE_BF16=1 neuron_parallel_compile torchrun --nproc_per_node $PROCESSES_PER_NODE docs/source/training_tutorials/sft_lora_finetune_llm.py \
XLA_USE_BF16=1 torchrun --nproc_per_node $PROCESSES_PER_NODE docs/source/training_tutorials/sft_lora_finetune_llm.py \
--model_id $MODEL_NAME \
--num_train_epochs $NUM_EPOCHS \
--do_train \
Expand All @@ -267,13 +268,19 @@ XLA_USE_BF16=1 neuron_parallel_compile torchrun --nproc_per_node $PROCESSES_PER_
--overwrite_output_dir
```

<Tip>
For convenience, we saved this shell script to a file, [sft_lora_finetune_llm.sh](https://github.com/huggingface/optimum-neuron/blob/main/docs/source/training_tutorials/sft_lora_finetune_llm.sh). You can now pass it to the `neuron_parallel_compile` tool to trigger the compilation:

```bash
neuron_parallel_compile bash docs/source/training_tutorials/sft_lora_finetune_llm.sh
```

_Note: at the end of compilation, a `FileNotFoundError` message can appear. You can safely ignore it, as some compilation cache has been created._

Make sure to run this precompilation phase for around 10 training steps. It is usually enough to accumulate and compile all the graphs that will be needed during the actual training.
This precompilation phase runs for 10 training steps to ensure that the compiler has compiled all the necessary graphs. It is usually enough to accumulate and compile all the graphs that will be needed during the actual training.

</Tip>

_Note: Compiling without a cache can take a while. It will also create dummy files in the `dolly_llama_sharded` during compilation you will have to remove them afterwards. We also need to add `MALLOC_ARENA_MAX=64` to limit the CPU allocation to avoid potential crashes, don't remove it for now._
_Note: Compiling without a cache can take a while. It will also create dummy files in the `dolly_llama` during compilation you will have to remove them afterwards._

```bash
# remove dummy artifacts which are created by the precompilation command
Expand All @@ -286,57 +293,13 @@ After compilation is done we can start our actual training with a similar comman

We will use `torchrun` to launch our training script. `torchrun` is a tool that automatically distributes a PyTorch model across multiple accelerators. We can pass the number of accelerators as `nproc_per_node` arguments alongside our hyperparameters.

The difference to the compilation command is that we changed from `max_steps=10` to `num_train_epochs=3`.
The difference to the compilation command is that we changed variables `max_steps=10` and `num_train_epochs=3`.

Launch the training, with the following command.
Launch the training, with the same command used in the precompilation step, but without `neuron_parallel_compile`:

```bash
#!/bin/bash
set -ex

export NEURON_FUSE_SOFTMAX=1
export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3
export MALLOC_ARENA_MAX=64
export NEURON_CC_FLAGS="--model-type=transformer --distribution-strategy=llm-training --enable-saturate-infinity --cache_dir=/home/ubuntu/cache_dir_neuron/"
bash docs/source/training_tutorials/sft_lora_finetune_llm.sh

PROCESSES_PER_NODE=8

NUM_EPOCHS=1
TP_DEGREE=2
PP_DEGREE=1
BS=1
GRADIENT_ACCUMULATION_STEPS=8
LOGGING_STEPS=1
MODEL_NAME="meta-llama/Meta-Llama-3-8B"
OUTPUT_DIR=output

if [ "$NEURON_EXTRACT_GRAPHS_ONLY" = "1" ]; then
MAX_STEPS=$((LOGGING_STEPS + 5))
else
MAX_STEPS=-1
fi


XLA_USE_BF16=1 torchrun --nproc_per_node $PROCESSES_PER_NODE docs/source/training_tutorials/sft_lora_finetune_llm.py \
--model_id $MODEL_NAME \
--num_train_epochs $NUM_EPOCHS \
--do_train \
--learning_rate 5e-5 \
--warmup_ratio 0.03 \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BS \
--per_device_eval_batch_size $BS \
--gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
--gradient_checkpointing true \
--bf16 \
--zero_1 false \
--tensor_parallel_size $TP_DEGREE \
--pipeline_parallel_size $PP_DEGREE \
--logging_steps $LOGGING_STEPS \
--save_total_limit 1 \
--output_dir $OUTPUT_DIR \
--lr_scheduler_type "constant" \
--overwrite_output_dir
```

That's it, we successfully trained Llama-3 8B on AWS Trainium!
Expand Down Expand Up @@ -372,7 +335,7 @@ model = NeuronModelForCausalLM.from_pretrained(
**input_shapes)
```

_Note: Inference compilation can take ~25minutes. Luckily, you need to only run this onces. Since you can save the model afterwards. If you are going to run on Inferentia2 you need to recompile again. The compilation is parameter and hardware specific._
_Note: Inference compilation can take ~25minutes. Luckily, you need to only run this once. You need to run this compilation step also if you change the hardware where you run the inference, e.g. if you move from Trainium to Inferentia2. The compilation is parameter and hardware specific._

```python
# COMMENT IN if you want to save the compiled model
Expand Down
47 changes: 47 additions & 0 deletions docs/source/training_tutorials/sft_lora_finetune_llm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/bin/bash
set -ex

export NEURON_FUSE_SOFTMAX=1
export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3
export MALLOC_ARENA_MAX=64 # limit the CPU allocation to avoid potential crashes
export NEURON_CC_FLAGS="--model-type=transformer --distribution-strategy=llm-training --enable-saturate-infinity --cache_dir=/home/ubuntu/cache_dir_neuron/"

PROCESSES_PER_NODE=8

TP_DEGREE=2
PP_DEGREE=1
BS=1
GRADIENT_ACCUMULATION_STEPS=8
LOGGING_STEPS=1
MODEL_NAME="meta-llama/Meta-Llama-3-8B"
OUTPUT_DIR=dolly_llama_output

if [ "$NEURON_EXTRACT_GRAPHS_ONLY" = "1" ]; then
MAX_STEPS=10
NUM_EPOCHS=1
else
MAX_STEPS=-1
NUM_EPOCHS=3
fi


XLA_USE_BF16=1 torchrun --nproc_per_node $PROCESSES_PER_NODE docs/source/training_tutorials/sft_lora_finetune_llm.py \
--model_id $MODEL_NAME \
--num_train_epochs $NUM_EPOCHS \
--do_train \
--learning_rate 5e-5 \
--warmup_ratio 0.03 \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BS \
--per_device_eval_batch_size $BS \
--gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
--gradient_checkpointing true \
--bf16 \
--zero_1 false \
--tensor_parallel_size $TP_DEGREE \
--pipeline_parallel_size $PP_DEGREE \
--logging_steps $LOGGING_STEPS \
--save_total_limit 1 \
--output_dir $OUTPUT_DIR \
--lr_scheduler_type "constant" \
--overwrite_output_dir

0 comments on commit fd4a20e

Please sign in to comment.