-
Notifications
You must be signed in to change notification settings - Fork 357
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #5233 * Add example of finetuning using executorch Reviewed By: JacobSzwejbka, dvorjackz Differential Revision: D61689035
- Loading branch information
1 parent
d80f78f
commit 69fa5f3
Showing
10 changed files
with
713 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") | ||
load("@fbcode_macros//build_defs:python_library.bzl", "python_library") | ||
|
||
oncall("papaya_oncall") | ||
|
||
python_library( | ||
name = "model_loading_lib", | ||
srcs = [ | ||
"model_loading_lib.py", | ||
], | ||
deps = [ | ||
"fbcode//caffe2:torch", | ||
"fbcode//executorch/examples/llm_pte_finetuning:training_lib", | ||
"fbcode//executorch/exir:lib", | ||
"fbcode//executorch/extension/pybindings:aten_lib", # @manual For PTE loader | ||
"fbcode//pytorch/torchtune:lib", | ||
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer | ||
"fbsource//third-party/pypi/omegaconf:omegaconf", | ||
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer | ||
], | ||
) | ||
|
||
python_library( | ||
name = "training_lib", | ||
srcs = [ | ||
"training_lib.py", | ||
], | ||
deps = [ | ||
"fbcode//caffe2:torch", | ||
"fbcode//executorch/extension/pybindings:aten_lib", # @manual For PTE loader | ||
"fbcode//pytorch/torchtune:lib", | ||
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer | ||
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer | ||
"fbsource//third-party/pypi/tqdm:tqdm", | ||
], | ||
) | ||
|
||
python_binary( | ||
name = "runner", | ||
srcs = [ | ||
"runner.py", | ||
], | ||
main_function = "executorch.examples.llm_pte_finetuning.runner.main", | ||
deps = [ | ||
"fbcode//caffe2:torch", | ||
"fbcode//executorch/examples/llm_pte_finetuning:training_lib", | ||
"fbcode//pytorch/torchtune:lib", | ||
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer | ||
"fbsource//third-party/pypi/omegaconf:omegaconf", | ||
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer | ||
"fbsource//third-party/pypi/tqdm:tqdm", | ||
], | ||
) | ||
|
||
python_binary( | ||
name = "model_exporter", | ||
srcs = [ | ||
"model_exporter.py", | ||
], | ||
main_function = "executorch.examples.llm_pte_finetuning.model_exporter.main", | ||
deps = [ | ||
"fbcode//caffe2:torch", | ||
"fbcode//executorch/examples/llm_pte_finetuning:model_loading_lib", # @manual for model loading | ||
"fbcode//executorch/examples/llm_pte_finetuning:training_lib", # @manual for model exporting | ||
"fbcode//pytorch/torchtune:lib", | ||
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer | ||
"fbsource//third-party/pypi/omegaconf:omegaconf", | ||
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
import argparse | ||
|
||
import torch | ||
from executorch.examples.llm_pte_finetuning.model_loading_lib import ( | ||
export_model_lora_training, | ||
load_checkpoint, | ||
setup_model, | ||
) | ||
|
||
from executorch.examples.llm_pte_finetuning.training_lib import ( | ||
get_dataloader, | ||
TrainingModule, | ||
) | ||
|
||
from omegaconf import OmegaConf | ||
from torch.nn import functional as F | ||
from torchtune import config | ||
|
||
from torchtune.training import MODEL_KEY | ||
|
||
parser = argparse.ArgumentParser( | ||
prog="ModelExporter", | ||
description="Export a LoRA model to ExecuTorch.", | ||
epilog="Model exported to be used for fine-tuning.", | ||
) | ||
|
||
parser.add_argument("--cfg", type=str, help="Path to the config file.") | ||
parser.add_argument("--output_file", type=str, help="Path to the output ET model.") | ||
|
||
|
||
def main() -> None: | ||
args = parser.parse_args() | ||
config_file = args.cfg | ||
output_file = args.output_file | ||
cfg = OmegaConf.load(config_file) | ||
tokenizer = config.instantiate( | ||
cfg.tokenizer, | ||
) | ||
|
||
loss_fn = config.instantiate(cfg.loss) | ||
|
||
ds = config.instantiate(cfg.dataset, tokenizer) | ||
train_set, val_set = torch.utils.data.random_split(ds, [0.8, 0.2]) | ||
train_dataloader = get_dataloader(cfg, train_set, tokenizer, loss_fn) | ||
|
||
max_seq_len = cfg.tokenizer.max_seq_len | ||
|
||
# Example inputs, needed for ET export. | ||
batch = next(iter(train_dataloader)) | ||
tokens, labels = batch["tokens"], batch["labels"] | ||
token_size = tokens.shape[1] | ||
labels_size = labels.shape[1] | ||
|
||
if token_size > max_seq_len: | ||
tokens = tokens[:, :max_seq_len] | ||
else: | ||
tokens = F.pad(tokens, (0, max_seq_len - token_size), value=0) | ||
|
||
if labels_size > max_seq_len: | ||
labels = labels[:, :max_seq_len] | ||
else: | ||
labels = F.pad(labels, (0, max_seq_len - labels_size), value=0) | ||
|
||
# Load pre-trained checkpoint. | ||
checkpoint_dict = load_checkpoint(cfg=cfg) | ||
model = setup_model( | ||
# pyre-ignore | ||
cfg=cfg, | ||
base_model_state_dict=checkpoint_dict[MODEL_KEY], | ||
) | ||
|
||
training_module = TrainingModule(model, loss_fn) | ||
|
||
# Export the model to ExecuTorch for training. | ||
export_model_lora_training(training_module, (tokens, labels), output_file) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from typing import Any, Dict, Tuple | ||
|
||
import torch | ||
from executorch.examples.llm_pte_finetuning.training_lib import TrainingModule | ||
from executorch.exir import to_edge | ||
|
||
from omegaconf import DictConfig | ||
from torch.export import export, ExportedProgram | ||
from torch.export.experimental import _export_forward_backward | ||
from torch.nn.attention import sdpa_kernel, SDPBackend | ||
from torchtune import config | ||
from torchtune.modules.peft import get_adapter_params, set_trainable_params | ||
from torchtune.training.precision import get_dtype, set_default_dtype | ||
from torchtune.utils._device import get_device | ||
|
||
|
||
def load_checkpoint(cfg: Any) -> Dict[str, Any]: # pyre-ignore[2] | ||
""" | ||
Extract the checkpoint state from file and validate. This includes the | ||
base model weights. If resume_from_checkpoint is True, this also includes | ||
the adapter weights and recipe state | ||
""" | ||
checkpointer = config.instantiate( | ||
cfg.checkpointer, | ||
resume_from_checkpoint=cfg.resume_from_checkpoint, | ||
) | ||
checkpoint_dict = checkpointer.load_checkpoint() | ||
return checkpoint_dict | ||
|
||
|
||
def setup_model( | ||
cfg: DictConfig, | ||
base_model_state_dict: Dict[str, Any], | ||
) -> torch.nn.Module: | ||
device = get_device(device=cfg.device) | ||
dtype = get_dtype(cfg.dtype, device=device) | ||
with set_default_dtype(dtype), device: | ||
model = config.instantiate(cfg.model) | ||
|
||
adapter_params = get_adapter_params(model) | ||
set_trainable_params(model, adapter_params) | ||
model.load_state_dict(base_model_state_dict, strict=False) | ||
return model | ||
|
||
|
||
def export_model_lora_training( | ||
model: TrainingModule, example_args: Tuple[Any, ...], output_file: str # pyre-ignore[2] | ||
) -> None: | ||
""" | ||
Export model with LoRA model to executorch for training, only. | ||
""" | ||
|
||
# 0. Mark the LoRA layers as trainable (requires_grad = True) in order | ||
# to just export the backwards pass for these layers later in the | ||
# export process. | ||
set_trainable_params(model, get_adapter_params(model)) | ||
|
||
print("Exporting model with LoRA for training") | ||
# 1. torch.export: Defines the program with the ATen operator set. | ||
|
||
with sdpa_kernel([SDPBackend.MATH]): | ||
exported_graph: ExportedProgram = export(model, example_args, strict=False) | ||
print("Creating a joint forward-backwards graph for training") | ||
joint_graph = _export_forward_backward(exported_graph) | ||
|
||
# 2. to_edge: Make optimizations for Edge devices. | ||
print("Lowering to edge dialect") | ||
edge_program = to_edge(joint_graph) | ||
|
||
print(edge_program._edge_programs["forward"].graph_module) | ||
|
||
# 3. to_executorch: Convert the graph to an ExecuTorch program. | ||
print("Exporting to executorch") | ||
executorch_program = edge_program.to_executorch() | ||
print(executorch_program.exported_program().graph_signature) | ||
print(f"Saving to {output_file}") | ||
with open(output_file, "wb") as file: | ||
file.write(executorch_program.buffer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
tokenizer: | ||
_component_: torchtune.models.phi3.phi3_mini_tokenizer | ||
path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model | ||
max_seq_len: 1024 | ||
|
||
dataset: | ||
_component_: torchtune.datasets.instruct_dataset | ||
template: papaya.toolkit.experimental.llm_pte_finetuning.utils.DatabricksDolly | ||
source: iamtarun/python_code_instructions_18k_alpaca | ||
split: train | ||
column_map: | ||
instruction: instruction | ||
prompt: prompt | ||
input: input | ||
output: output | ||
seed: null | ||
shuffle: True | ||
batch_size: 1 | ||
|
||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
|
||
model: | ||
_component_: torchtune.models.phi3.lora_phi3_mini | ||
lora_attn_modules: ['q_proj', 'v_proj'] | ||
apply_lora_to_mlp: False | ||
apply_lora_to_output: False | ||
lora_rank: 8 | ||
lora_alpha: 16 | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Phi-3-mini-4k-instruct | ||
checkpoint_files: [ | ||
model-00001-of-00002.safetensors, | ||
model-00002-of-00002.safetensors | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Phi-3-mini-4k-instruct/ | ||
model_type: PHI3_MINI | ||
|
||
resume_from_checkpoint: False | ||
save_adapter_weights_only: False | ||
|
||
device: cpu | ||
dtype: fp32 | ||
|
||
enable_activation_checkpointing: True | ||
compile: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
tokenizer: | ||
_component_: torchtune.models.phi3.phi3_mini_tokenizer | ||
path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model | ||
max_seq_len: 512 | ||
|
||
dataset: | ||
_component_: torchtune.datasets.alpaca_cleaned_dataset | ||
seed: null | ||
shuffle: True | ||
batch_size: 1 | ||
|
||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
|
||
model: | ||
_component_: torchtune.models.phi3.lora_phi3_mini | ||
lora_attn_modules: ['q_proj', 'v_proj'] | ||
apply_lora_to_mlp: False | ||
apply_lora_to_output: False | ||
lora_rank: 8 | ||
lora_alpha: 16 | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Phi-3-mini-4k-instruct | ||
checkpoint_files: [ | ||
model-00001-of-00002.safetensors, | ||
model-00002-of-00002.safetensors | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Phi-3-mini-4k-instruct/ | ||
model_type: PHI3_MINI | ||
resume_from_checkpoint: False | ||
save_adapter_weights_only: False | ||
|
||
device: cpu | ||
dtype: fp32 | ||
|
||
enable_activation_checkpointing: True | ||
compile: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
tokenizer: | ||
_component_: torchtune.models.qwen2.qwen2_tokenizer | ||
path: /tmp/Qwen2-0.5B-Instruct/vocab.json | ||
merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt | ||
max_seq_len: 512 | ||
|
||
dataset: | ||
_component_: torchtune.datasets.alpaca_cleaned_dataset | ||
seed: null | ||
shuffle: True | ||
batch_size: 1 | ||
|
||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
|
||
model: | ||
_component_: torchtune.models.qwen2.lora_qwen2_0_5b | ||
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] | ||
apply_lora_to_mlp: False | ||
lora_rank: 32 | ||
lora_alpha: 64 | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Qwen2-0.5B-Instruct | ||
checkpoint_files: [ | ||
model.safetensors | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Qwen2-0.5B-Instruct | ||
model_type: QWEN2 | ||
resume_from_checkpoint: False | ||
save_adapter_weights_only: False | ||
|
||
device: cpu | ||
dtype: fp32 | ||
|
||
enable_activation_checkpointing: True | ||
compile: False |
Oops, something went wrong.