Skip to content

Commit aafa130

Browse files
committed
Update
[ghstack-poisoned]
2 parents 719b63d + 25a26e2 commit aafa130

File tree

20 files changed

+169
-156
lines changed

20 files changed

+169
-156
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,6 @@
7070
[submodule "third-party/pocketfft"]
7171
path = third-party/pocketfft
7272
url = https://github.com/mreineck/pocketfft
73+
[submodule "shim"]
74+
path = shim
75+
url = https://github.com/facebook/buck2-shims-meta

build/Test.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ function(et_cxx_test target_name)
3636
cmake_parse_arguments(ET_CXX_TEST "" "" "${multi_arg_names}" ${ARGN})
3737

3838
add_executable(${target_name} ${ET_CXX_TEST_SOURCES} ${EXECUTORCH_ROOT}/runtime/core/exec_aten/testing_util/tensor_util.cpp)
39+
find_package(GTest)
3940
# Includes gtest, gmock, executorch by default
4041
target_link_libraries(
4142
${target_name} GTest::gtest GTest::gtest_main GTest::gmock executorch

build/build_android_llm_demo.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ build_android_native_library() {
4444
-DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \
4545
-DANDROID_ABI="${ANDROID_ABI}" \
4646
-DANDROID_PLATFORM=android-26 \
47+
-DBUILD_TESTING=OFF \
4748
-DEXECUTORCH_ENABLE_LOGGING=ON \
4849
-DEXECUTORCH_LOG_LEVEL=Info \
4950
-DEXECUTORCH_BUILD_XNNPACK=ON \
@@ -73,6 +74,7 @@ build_android_native_library() {
7374
-DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \
7475
-DANDROID_ABI="${ANDROID_ABI}" \
7576
-DANDROID_PLATFORM=android-26 \
77+
-DBUILD_TESTING=OFF \
7678
-DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \
7779
-DEXECUTORCH_ENABLE_LOGGING=ON \
7880
-DEXECUTORCH_LOG_LEVEL=Info \

examples/llm_pte_finetuning/README.md

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,43 @@ In this tutorial, we show how to fine-tune an LLM using executorch.
66

77
You will need to have a model's checkpoint, in the Hugging Face format. For example:
88

9-
```
10-
git clone https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
9+
```console
10+
git clone git clone https://huggingface.co/Qwen/Qwen2-0.5B-Instruct
1111
```
1212

1313
You will need to install [torchtune](https://github.com/pytorch/torchtune) following [its installation instructions](https://github.com/pytorch/torchtune?tab=readme-ov-file#installation).
1414

15+
You might run into an issue with the `triton` package when installing `torchtune`. You can build `triton` locally following the [instructions in their repo](https://github.com/triton-lang/triton?tab=readme-ov-file#install-from-source).
16+
1517
## Config Files
1618

19+
The directory structure of the `llm_pte_finetuning` is:
20+
21+
```console
22+
examples/llm_pte_finetuning
23+
├── README.md
24+
├── TARGETS
25+
├── __init__.py
26+
│ ├── model_loading_lib.cpython-312.pyc
27+
│ └── training_lib.cpython-312.pyc
28+
├── model_exporter.py
29+
├── model_loading_lib.py
30+
├── phi3_alpaca_code_config.yaml
31+
├── phi3_config.yaml
32+
├── qwen_05b_config.yaml
33+
├── runner.py
34+
└── training_lib.py
35+
```
36+
37+
We already provide configs out of the box. The following sections explain how you can setup the config for your own model or dataset.
38+
1739
As mentioned in the previous section, we internally use `torchtune` APIs, and thus, we use config files that follow `torchtune`'s structure. Typically, in the following sections we go through a working example which can be found in the `phi3_config.yaml` config file.
1840

1941
### Tokenizer
2042

2143
We need to define the tokenizer. Let's suppose we would like to use [PHI3 Mini Instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) model from Microsoft. We need to define the tokenizer component:
2244

23-
```
45+
```yaml
2446
tokenizer:
2547
_component_: torchtune.models.phi3.phi3_mini_tokenizer
2648
path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model
@@ -33,7 +55,7 @@ This will load the tokenizer, and set the max sequence length to 1024. The class
3355

3456
In this example we use the [Alpaca-Cleaned dataset](https://huggingface.co/datasets/yahma/alpaca-cleaned). We need to define the following parameters:
3557

36-
```
58+
```yaml
3759
dataset:
3860
_component_: torchtune.datasets.alpaca_cleaned_dataset
3961
seed: null
@@ -47,7 +69,7 @@ Torchtune supports datasets using huggingface dataloaders, so custom datasets co
4769

4870
For the loss function, we use PyTorch losses. In this example we use the `CrossEntropyLoss`:
4971

50-
```
72+
```yaml
5173
loss:
5274
_component_: torch.nn.CrossEntropyLoss
5375
```
@@ -56,7 +78,7 @@ loss:
5678

5779
Model parameters can be set, in this example we replicate the configuration for phi3 mini instruct benchmarks:
5880

59-
```
81+
```yaml
6082
model:
6183
_component_: torchtune.models.phi3.lora_phi3_mini
6284
lora_attn_modules: ['q_proj', 'v_proj']
@@ -70,7 +92,7 @@ model:
7092

7193
Depending on how your model is defined, you will need to instantiate different components. In these examples we use checkpoints from HF (hugging face format), and thus we will need to instantiate a `FullModelHFCheckpointer` object. We need to pass the checkpoint directory, the files with the tensors, the output directory for training and the model type:
7294

73-
```
95+
```yaml
7496
checkpointer:
7597
_component_: torchtune.training.FullModelHFCheckpointer
7698
checkpoint_dir: /tmp/Phi-3-mini-4k-instruct
@@ -87,7 +109,7 @@ checkpointer:
87109

88110
Torchtune supports `cuda` and `bf16` tensors. However, for ExecuTorch training we only support `cpu` and `fp32`:
89111

90-
```
112+
```yaml
91113
device: cpu
92114
dtype: fp32
93115
```
@@ -101,28 +123,34 @@ The `model_exporter.py` exports the LLM checkpoint into an ExecuTorch checkpoint
101123
* `cfg`: Configuration file
102124
* `output_file`: The `.pte` output path
103125

104-
```
105-
python model_exporter.py --cfg=phi3_config.yaml --output_file=phi3_mini_lora.pte
126+
```console
127+
python model_exporter.py \
128+
--cfg=qwen_05b_config.yaml \
129+
--output_file=qwen2_0_5B.pte
106130
```
107131

108132
### Step 2: Run the fine-tuning job
109133

110134
To run the fine-tuning job:
111135

112-
```
113-
python runner.py --cfg=phi3_config.yaml --model_file=phi3_mini_lora.pte
136+
```console
137+
python runner.py \
138+
--cfg=qwen_05b_config.yaml \
139+
--model_file=qwen2_0_5B.pte \
140+
--num_training_steps=10 \
141+
--num_eval_steps=5
114142
```
115143

116144
You need to use **the same** config file from the previous step. The `model_file` arg is the `.pte` model from the previous step.
117145

118146
Example output:
119147

120-
```
121-
Evaluating the model before training...
122-
100%|██████████████████████████████████████████████████████████████████████████████████████| 3/3 [31:23<00:00, 627.98s/it]
123-
Eval loss: tensor(2.3778)
124-
100%|██████████████████████████████████████████████████████████████████████████████████████| 5/5 [52:29<00:00, 629.84s/it]
125-
Losses: [2.7152762413024902, 0.7890686988830566, 2.249271869659424, 1.4777560234069824, 0.8378427624702454]
126-
100%|██████████████████████████████████████████████████████████████████████████████████████| 3/3 [30:35<00:00, 611.90s/it]
127-
Eval loss: tensor(0.8464)
148+
```console
149+
Evaluating the model before training
150+
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:47<00:00, 9.45s/it]
151+
Eval loss: tensor(0.9441)
152+
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [01:30<00:00, 9.09s/it]
153+
Losses: [0.5646533966064453, 1.3464953899383545, 1.297974705696106, 1.2249481678009033, 0.6750457286834717, 0.7721152901649475, 1.0774847269058228, 0.7962403893470764, 0.8448256850242615, 0.8731598854064941]
154+
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:45<00:00, 9.18s/it]
155+
Eval loss: tensor(0.7679)
128156
```

examples/llm_pte_finetuning/TARGETS

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ python_library(
1212
"fbcode//caffe2:torch",
1313
"fbcode//executorch/examples/llm_pte_finetuning:training_lib",
1414
"fbcode//executorch/exir:lib",
15-
"fbcode//executorch/extension/pybindings:aten_lib", # @manual For PTE loader
15+
"fbcode//executorch/extension/pybindings:portable_lib", # @manual For PTE loader
1616
"fbcode//pytorch/torchtune:lib",
1717
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
1818
"fbsource//third-party/pypi/omegaconf:omegaconf",
@@ -27,11 +27,12 @@ python_library(
2727
],
2828
deps = [
2929
"fbcode//caffe2:torch",
30-
"fbcode//executorch/extension/pybindings:aten_lib", # @manual For PTE loader
30+
"fbcode//executorch/extension/pybindings:portable_lib", # @manual For PTE loader
3131
"fbcode//pytorch/torchtune:lib",
3232
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
3333
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
3434
"fbsource//third-party/pypi/tqdm:tqdm",
35+
"fbcode//executorch/backends/xnnpack/partition:xnnpack_partitioner",
3536
],
3637
)
3738

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2024 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from .model_loading_lib import export_model_lora_training, load_checkpoint, setup_model
9+
from .training_lib import eval_model, get_dataloader, TrainingModule, update_function
10+
11+
__all__ = [
12+
"eval_model",
13+
"get_dataloader",
14+
"update_function",
15+
"TrainingModule",
16+
"export_model_lora_training",
17+
"load_checkpoint",
18+
"setup_model",
19+
]

examples/llm_pte_finetuning/model_loading_lib.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from typing import Any, Dict, Tuple
1010

1111
import torch
12+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
1213
from executorch.examples.llm_pte_finetuning.training_lib import TrainingModule
13-
from executorch.exir import to_edge
14+
from executorch.exir import EdgeCompileConfig, to_edge
1415

1516
from omegaconf import DictConfig
1617
from torch.export import export, ExportedProgram
@@ -72,16 +73,70 @@ def export_model_lora_training(
7273
exported_graph: ExportedProgram = export(model, example_args, strict=False)
7374
print("Creating a joint forward-backwards graph for training")
7475
joint_graph = _export_forward_backward(exported_graph)
76+
ep = joint_graph
77+
78+
# Currently there is no implementation of empty_permuted for edge dialect.
79+
# We manually make a pass to rewrite the empty_permuted to empty and permute.
80+
for node in ep.graph.nodes:
81+
if (
82+
node.op == "call_function"
83+
and node.target == torch.ops.aten.empty_permuted.out
84+
):
85+
print("found empty_permute: ", node)
86+
empty_permuted_node = node
87+
with ep.graph.inserting_before(empty_permuted_node):
88+
empty_node = ep.graph.create_node(
89+
"call_function",
90+
torch.ops.aten.empty.memory_format,
91+
(node.args[0],),
92+
empty_permuted_node.kwargs,
93+
)
94+
permute_node = ep.graph.create_node(
95+
"call_function",
96+
torch.ops.aten.permute,
97+
(empty_node, node.args[1]),
98+
)
99+
for user in empty_permuted_node.users.copy():
100+
user.replace_input_with(empty_permuted_node, permute_node)
101+
if (
102+
node.op == "call_function"
103+
and node.target == torch.ops.aten.empty_permuted.default
104+
):
105+
print("found empty_permute default: ", node)
106+
empty_permuted_node = node
107+
with ep.graph.inserting_before(empty_permuted_node):
108+
empty_node = ep.graph.create_node(
109+
"call_function",
110+
torch.ops.aten.empty.memory_format,
111+
(node.args[0],),
112+
empty_permuted_node.kwargs,
113+
)
114+
permute_node = ep.graph.create_node(
115+
"call_function",
116+
torch.ops.aten.permute.default,
117+
(empty_node, node.args[1]),
118+
)
119+
for user in empty_permuted_node.users.copy():
120+
user.replace_input_with(empty_permuted_node, permute_node)
75121

76122
# 2. to_edge: Make optimizations for Edge devices.
77123
print("Lowering to edge dialect")
78-
edge_program = to_edge(joint_graph)
124+
edge_program = to_edge(
125+
joint_graph,
126+
compile_config=EdgeCompileConfig(
127+
_core_aten_ops_exception_list=[torch.ops.aten.empty_permuted.default]
128+
),
129+
)
79130

80131
print(edge_program._edge_programs["forward"].graph_module)
81132

82133
# 3. to_executorch: Convert the graph to an ExecuTorch program.
83134
print("Exporting to executorch")
135+
edge_program = edge_program.to_backend(
136+
XnnpackPartitioner(force_fp32_dynamic_linear=True)
137+
)
84138
executorch_program = edge_program.to_executorch()
139+
85140
print(executorch_program.exported_program().graph_signature)
86141
print(f"Saving to {output_file}")
87142
with open(output_file, "wb") as file:

examples/llm_pte_finetuning/qwen_05b_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ checkpointer:
2727
model.safetensors
2828
]
2929
recipe_checkpoint: null
30-
output_dir: /tmp/Qwen2-0.5B-Instruct
30+
output_dir: /tmp/qwen_0.5B_ft-output
3131
model_type: QWEN2
3232
resume_from_checkpoint: False
3333
save_adapter_weights_only: False

examples/llm_pte_finetuning/runner.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
update_function,
1616
)
1717

18-
from executorch.extension.pybindings.aten_lib import ( # @manual
18+
from executorch.extension.pybindings.portable_lib import ( # @manual
1919
_load_for_executorch_from_buffer,
2020
)
2121
from omegaconf import OmegaConf
@@ -30,6 +30,18 @@
3030
)
3131
parser.add_argument("--cfg", type=str, help="Path to the config file.")
3232
parser.add_argument("--model_file", type=str, help="Path to the ET model file.")
33+
parser.add_argument(
34+
"--num_training_steps",
35+
type=int,
36+
help="Number of training steps, assuming 1 epoch.",
37+
default=100,
38+
)
39+
parser.add_argument(
40+
"--num_eval_steps",
41+
type=int,
42+
help="Number of eval steps, assuming 1 epoch.",
43+
default=5,
44+
)
3345

3446

3547
def main() -> None:
@@ -47,10 +59,11 @@ def main() -> None:
4759
train_set, val_set = torch.utils.data.random_split(ds, [0.8, 0.2])
4860
train_dataloader = get_dataloader(cfg, train_set, tokenizer, loss_fn)
4961
val_dataloader = get_dataloader(cfg, val_set, tokenizer, loss_fn)
62+
num_training_steps = args.num_training_steps
63+
num_eval_steps = args.num_eval_steps
5064

5165
max_seq_len = cfg.tokenizer.max_seq_len
5266
# Num of steps to run training. Assume 1 epoch
53-
num_steps = 100
5467
with open(file, "rb") as f:
5568
model_bytes = f.read()
5669
et_mod = _load_for_executorch_from_buffer(model_bytes)
@@ -62,7 +75,7 @@ def main() -> None:
6275
dataloader=val_dataloader,
6376
loss_fn=loss_fn,
6477
max_seq_len=max_seq_len,
65-
num_eval_steps=10,
78+
num_eval_steps=num_eval_steps,
6679
)
6780
print("Eval loss: ", eval_loss)
6881

@@ -74,9 +87,9 @@ def main() -> None:
7487
learning_rate = 5e-3
7588
f.seek(0)
7689
losses = []
77-
for i, batch in tqdm(enumerate(train_dataloader), total=num_steps):
90+
for i, batch in tqdm(enumerate(train_dataloader), total=num_training_steps):
7891
# Run for a limited number of steps.
79-
if i >= num_steps:
92+
if i >= num_training_steps:
8093
break
8194
tokens, labels = batch["tokens"], batch["labels"]
8295
token_size = tokens.shape[1]
@@ -113,7 +126,7 @@ def main() -> None:
113126
dataloader=val_dataloader,
114127
loss_fn=loss_fn,
115128
max_seq_len=max_seq_len,
116-
num_eval_steps=10,
129+
num_eval_steps=num_eval_steps,
117130
)
118131
print("Eval loss: ", eval_loss)
119132

examples/llm_pte_finetuning/training_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Any
1111

1212
import torch
13-
from executorch.extension.pybindings.aten_lib import ExecuTorchModule # @manual
13+
from executorch.extension.pybindings.portable_lib import ExecuTorchModule # @manual
1414

1515
from torch.nn import functional as F
1616
from torch.utils.data import DataLoader, Dataset, DistributedSampler

0 commit comments

Comments
 (0)