Skip to content

Commit 3add14d

Browse files
dpalmasanfacebook-github-bot
authored andcommitted
Tuning LLM from PTE (#5233)
Summary: Pull Request resolved: #5233 * Add example of finetuning using executorch Differential Revision: D61689035
1 parent b54206d commit 3add14d

File tree

10 files changed

+671
-14
lines changed

10 files changed

+671
-14
lines changed

examples/llm_pte_finetuning/TARGETS

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
2+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
3+
4+
oncall("papaya_oncall")
5+
6+
python_library(
7+
name = "model_loading_lib",
8+
srcs = [
9+
"model_loading_lib.py",
10+
],
11+
deps = [
12+
"fbcode//caffe2:torch",
13+
"fbcode//executorch/exir:lib",
14+
"fbcode//executorch/extension/pybindings:aten_lib", # @manual For PTE loader
15+
"fbcode//pytorch/torchtune:lib",
16+
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
17+
"fbsource//third-party/pypi/omegaconf:omegaconf",
18+
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
19+
],
20+
)
21+
22+
python_library(
23+
name = "training_lib",
24+
srcs = [
25+
"training_lib.py",
26+
],
27+
deps = [
28+
"fbcode//caffe2:torch",
29+
"fbcode//executorch/extension/pybindings:aten_lib", # @manual For PTE loader
30+
"fbcode//pytorch/torchtune:lib",
31+
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
32+
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
33+
"fbsource//third-party/pypi/tqdm:tqdm",
34+
],
35+
)
36+
37+
python_binary(
38+
name = "runner",
39+
srcs = [
40+
"runner.py",
41+
],
42+
main_function = "executorch.examples.llm_pte_finetuning.runner.main",
43+
deps = [
44+
"fbcode//caffe2:torch",
45+
"fbcode//executorch/exir:lib",
46+
"fbcode//pytorch/torchtune:lib",
47+
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
48+
"fbsource//third-party/pypi/omegaconf:omegaconf",
49+
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
50+
":training_lib",
51+
"//pytorch/vision:torchvision",
52+
],
53+
)
54+
55+
python_binary(
56+
name = "model_exporter",
57+
srcs = [
58+
"model_exporter.py",
59+
],
60+
main_function = "executorch.examples.llm_pte_finetuning.model_exporter.main",
61+
deps = [
62+
"fbcode//caffe2:torch",
63+
"fbcode//executorch/exir:lib",
64+
"fbcode//pytorch/torchtune:lib",
65+
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
66+
"fbsource//third-party/pypi/omegaconf:omegaconf",
67+
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
68+
"fbsource//third-party/pypi/tqdm:tqdm",
69+
":model_loading_lib", # @manual for model loading
70+
":training_lib", # @manual for model exporting
71+
],
72+
)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
9+
import torch
10+
from executorch.examples.llm_pte_finetuning.model_loading_lib import (
11+
export_model_lora_training,
12+
load_checkpoint,
13+
setup_model,
14+
)
15+
16+
from executorch.examples.llm_pte_finetuning.training_lib import (
17+
get_dataloader,
18+
TrainingModule,
19+
)
20+
21+
from omegaconf import OmegaConf
22+
from torch.nn import functional as F
23+
from torchtune import config
24+
25+
from torchtune.training import MODEL_KEY
26+
27+
parser = argparse.ArgumentParser(
28+
prog="ModelExporter",
29+
description="Export a LoRA model to ExecuTorch.",
30+
epilog="Model exported to be used for fine-tuning.",
31+
)
32+
33+
parser.add_argument("--cfg", type=str, help="Path to the config file.")
34+
parser.add_argument("--output_file", type=str, help="Path to the output ET model.")
35+
36+
37+
def main() -> None:
38+
args = parser.parse_args()
39+
config_file = args.cfg
40+
output_file = args.output_file
41+
cfg = OmegaConf.load(config_file)
42+
tokenizer = config.instantiate(
43+
cfg.tokenizer,
44+
)
45+
46+
loss_fn = config.instantiate(cfg.loss)
47+
48+
ds = config.instantiate(cfg.dataset, tokenizer)
49+
train_set, val_set = torch.utils.data.random_split(ds, [0.8, 0.2])
50+
train_dataloader = get_dataloader(cfg, train_set, tokenizer, loss_fn)
51+
52+
max_seq_len = cfg.tokenizer.max_seq_len
53+
54+
# Example inputs, needed for ET export.
55+
batch = next(iter(train_dataloader))
56+
tokens, labels = batch["tokens"], batch["labels"]
57+
token_size = tokens.shape[1]
58+
labels_size = labels.shape[1]
59+
60+
if token_size > max_seq_len:
61+
tokens = tokens[:, :max_seq_len]
62+
else:
63+
tokens = F.pad(tokens, (0, max_seq_len - token_size), value=0)
64+
65+
if labels_size > max_seq_len:
66+
labels = labels[:, :max_seq_len]
67+
else:
68+
labels = F.pad(labels, (0, max_seq_len - labels_size), value=0)
69+
70+
# Load pre-trained checkpoint.
71+
checkpoint_dict = load_checkpoint(cfg=cfg)
72+
model = setup_model(
73+
# pyre-ignore
74+
cfg=cfg,
75+
base_model_state_dict=checkpoint_dict[MODEL_KEY],
76+
)
77+
78+
training_module = TrainingModule(model, loss_fn)
79+
80+
# Export the model to ExecuTorch for training.
81+
export_model_lora_training(training_module, (tokens, labels), output_file)
82+
83+
84+
if __name__ == "__main__":
85+
main()
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Dict
8+
9+
import torch
10+
from executorch.exir import to_edge
11+
12+
from omegaconf import DictConfig
13+
from torch.export import export, ExportedProgram
14+
from torch.export.experimental import _export_forward_backward
15+
from torch.nn.attention import sdpa_kernel, SDPBackend
16+
from torchtune import config
17+
from torchtune.modules.peft import get_adapter_params, set_trainable_params
18+
from torchtune.training.precision import get_dtype, set_default_dtype
19+
from torchtune.utils._device import get_device
20+
21+
22+
def load_checkpoint(cfg) -> Dict[str, Any]:
23+
"""
24+
Extract the checkpoint state from file and validate. This includes the
25+
base model weights. If resume_from_checkpoint is True, this also includes
26+
the adapter weights and recipe state
27+
"""
28+
checkpointer = config.instantiate(
29+
cfg.checkpointer,
30+
resume_from_checkpoint=cfg.resume_from_checkpoint,
31+
)
32+
checkpoint_dict = checkpointer.load_checkpoint()
33+
return checkpoint_dict
34+
35+
36+
def setup_model(
37+
cfg: DictConfig,
38+
base_model_state_dict: Dict[str, Any],
39+
) -> torch.nn.Module:
40+
device = get_device(device=cfg.device)
41+
dtype = get_dtype(cfg.dtype, device=device)
42+
with set_default_dtype(dtype), device:
43+
model = config.instantiate(cfg.model)
44+
45+
adapter_params = get_adapter_params(model)
46+
set_trainable_params(model, adapter_params)
47+
model.load_state_dict(base_model_state_dict, strict=False)
48+
return model
49+
50+
51+
def export_model_lora_training(model, example_args, output_file) -> Any:
52+
"""
53+
Export model with LoRA model to executorch for training, only.
54+
"""
55+
56+
# 0. Mark the LoRA layers as trainable (requires_grad = True) in order
57+
# to just export the backwards pass for these layers later in the
58+
# export process.
59+
set_trainable_params(model, get_adapter_params(model))
60+
61+
print("Exporting model with LoRA for training")
62+
# 1. torch.export: Defines the program with the ATen operator set.
63+
64+
with sdpa_kernel([SDPBackend.MATH]):
65+
exported_graph: ExportedProgram = export(model, example_args, strict=False)
66+
print("Creating a joint forward-backwards graph for training")
67+
joint_graph = _export_forward_backward(exported_graph)
68+
69+
# 2. to_edge: Make optimizations for Edge devices.
70+
print("Lowering to edge dialect")
71+
edge_program = to_edge(joint_graph)
72+
73+
print(edge_program._edge_programs["forward"].graph_module)
74+
75+
# 3. to_executorch: Convert the graph to an ExecuTorch program.
76+
print("Exporting to executorch")
77+
executorch_program = edge_program.to_executorch()
78+
print(executorch_program.exported_program().graph_signature)
79+
print(f"Saving to {output_file}")
80+
with open(output_file, "wb") as file:
81+
file.write(executorch_program.buffer)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
tokenizer:
2+
_component_: torchtune.models.phi3.phi3_mini_tokenizer
3+
path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model
4+
max_seq_len: 1024
5+
6+
dataset:
7+
_component_: torchtune.datasets.instruct_dataset
8+
template: papaya.toolkit.experimental.llm_pte_finetuning.utils.DatabricksDolly
9+
source: iamtarun/python_code_instructions_18k_alpaca
10+
split: train
11+
column_map:
12+
instruction: instruction
13+
prompt: prompt
14+
input: input
15+
output: output
16+
seed: null
17+
shuffle: True
18+
batch_size: 1
19+
20+
loss:
21+
_component_: torch.nn.CrossEntropyLoss
22+
23+
model:
24+
_component_: torchtune.models.phi3.lora_phi3_mini
25+
lora_attn_modules: ['q_proj', 'v_proj']
26+
apply_lora_to_mlp: False
27+
apply_lora_to_output: False
28+
lora_rank: 8
29+
lora_alpha: 16
30+
31+
checkpointer:
32+
_component_: torchtune.training.FullModelHFCheckpointer
33+
checkpoint_dir: /tmp/Phi-3-mini-4k-instruct
34+
checkpoint_files: [
35+
model-00001-of-00002.safetensors,
36+
model-00002-of-00002.safetensors
37+
]
38+
recipe_checkpoint: null
39+
output_dir: /tmp/Phi-3-mini-4k-instruct/
40+
model_type: PHI3_MINI
41+
42+
resume_from_checkpoint: False
43+
save_adapter_weights_only: False
44+
45+
device: cpu
46+
dtype: fp32
47+
48+
enable_activation_checkpointing: True
49+
compile: False
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
tokenizer:
2+
_component_: torchtune.models.phi3.phi3_mini_tokenizer
3+
path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model
4+
max_seq_len: 512
5+
6+
dataset:
7+
_component_: torchtune.datasets.alpaca_cleaned_dataset
8+
seed: null
9+
shuffle: True
10+
batch_size: 1
11+
12+
loss:
13+
_component_: torch.nn.CrossEntropyLoss
14+
15+
model:
16+
_component_: torchtune.models.phi3.lora_phi3_mini
17+
lora_attn_modules: ['q_proj', 'v_proj']
18+
apply_lora_to_mlp: False
19+
apply_lora_to_output: False
20+
lora_rank: 8
21+
lora_alpha: 16
22+
23+
checkpointer:
24+
_component_: torchtune.training.FullModelHFCheckpointer
25+
checkpoint_dir: /tmp/Phi-3-mini-4k-instruct
26+
checkpoint_files: [
27+
model-00001-of-00002.safetensors,
28+
model-00002-of-00002.safetensors
29+
]
30+
recipe_checkpoint: null
31+
output_dir: /tmp/Phi-3-mini-4k-instruct/
32+
model_type: PHI3_MINI
33+
resume_from_checkpoint: False
34+
save_adapter_weights_only: False
35+
36+
device: cpu
37+
dtype: fp32
38+
39+
enable_activation_checkpointing: True
40+
compile: False
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
tokenizer:
2+
_component_: torchtune.models.qwen2.qwen2_tokenizer
3+
path: /tmp/Qwen2-0.5B-Instruct/vocab.json
4+
merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt
5+
max_seq_len: 512
6+
7+
dataset:
8+
_component_: torchtune.datasets.alpaca_cleaned_dataset
9+
seed: null
10+
shuffle: True
11+
batch_size: 1
12+
13+
loss:
14+
_component_: torch.nn.CrossEntropyLoss
15+
16+
model:
17+
_component_: torchtune.models.qwen2.lora_qwen2_0_5b
18+
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
19+
apply_lora_to_mlp: False
20+
lora_rank: 32
21+
lora_alpha: 64
22+
23+
checkpointer:
24+
_component_: torchtune.training.FullModelHFCheckpointer
25+
checkpoint_dir: /tmp/Qwen2-0.5B-Instruct
26+
checkpoint_files: [
27+
model.safetensors
28+
]
29+
recipe_checkpoint: null
30+
output_dir: /tmp/Qwen2-0.5B-Instruct
31+
model_type: QWEN2
32+
resume_from_checkpoint: False
33+
save_adapter_weights_only: False
34+
35+
device: cpu
36+
dtype: fp32
37+
38+
enable_activation_checkpointing: True
39+
compile: False

0 commit comments

Comments
 (0)