Skip to content

Commit

Permalink
Tuning LLM from PTE (pytorch#5233)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#5233

* Add example of finetuning using executorch

Reviewed By: JacobSzwejbka, dvorjackz

Differential Revision: D61689035
  • Loading branch information
dpalmasan authored and facebook-github-bot committed Sep 11, 2024
1 parent 338ef26 commit afb6c11
Show file tree
Hide file tree
Showing 10 changed files with 718 additions and 15 deletions.
70 changes: 70 additions & 0 deletions examples/llm_pte_finetuning/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

oncall("papaya_oncall")

python_library(
name = "model_loading_lib",
srcs = [
"model_loading_lib.py",
],
deps = [
"fbcode//caffe2:torch",
"fbcode//executorch/examples/llm_pte_finetuning:training_lib",
"fbcode//executorch/exir:lib",
"fbcode//executorch/extension/pybindings:aten_lib", # @manual For PTE loader
"fbcode//pytorch/torchtune:lib",
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
"fbsource//third-party/pypi/omegaconf:omegaconf",
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
],
)

python_library(
name = "training_lib",
srcs = [
"training_lib.py",
],
deps = [
"fbcode//caffe2:torch",
"fbcode//executorch/extension/pybindings:aten_lib", # @manual For PTE loader
"fbcode//pytorch/torchtune:lib",
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
"fbsource//third-party/pypi/tqdm:tqdm",
],
)

python_binary(
name = "runner",
srcs = [
"runner.py",
],
main_function = "executorch.examples.llm_pte_finetuning.runner.main",
deps = [
"fbcode//caffe2:torch",
"fbcode//executorch/examples/llm_pte_finetuning:training_lib",
"fbcode//pytorch/torchtune:lib",
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
"fbsource//third-party/pypi/omegaconf:omegaconf",
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
"fbsource//third-party/pypi/tqdm:tqdm",
],
)

python_binary(
name = "model_exporter",
srcs = [
"model_exporter.py",
],
main_function = "executorch.examples.llm_pte_finetuning.model_exporter.main",
deps = [
"fbcode//caffe2:torch",
"fbcode//executorch/examples/llm_pte_finetuning:model_loading_lib", # @manual for model loading
"fbcode//executorch/examples/llm_pte_finetuning:training_lib", # @manual for model exporting
"fbcode//pytorch/torchtune:lib",
"fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer
"fbsource//third-party/pypi/omegaconf:omegaconf",
"fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer
],
)
87 changes: 87 additions & 0 deletions examples/llm_pte_finetuning/model_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import argparse

import torch
from executorch.examples.llm_pte_finetuning.model_loading_lib import (
export_model_lora_training,
load_checkpoint,
setup_model,
)

from executorch.examples.llm_pte_finetuning.training_lib import (
get_dataloader,
TrainingModule,
)

from omegaconf import OmegaConf
from torch.nn import functional as F
from torchtune import config

from torchtune.training import MODEL_KEY

parser = argparse.ArgumentParser(
prog="ModelExporter",
description="Export a LoRA model to ExecuTorch.",
epilog="Model exported to be used for fine-tuning.",
)

parser.add_argument("--cfg", type=str, help="Path to the config file.")
parser.add_argument("--output_file", type=str, help="Path to the output ET model.")


def main() -> None:
args = parser.parse_args()
config_file = args.cfg
output_file = args.output_file
cfg = OmegaConf.load(config_file)
tokenizer = config.instantiate(
cfg.tokenizer,
)

loss_fn = config.instantiate(cfg.loss)

ds = config.instantiate(cfg.dataset, tokenizer)
train_set, val_set = torch.utils.data.random_split(ds, [0.8, 0.2])
train_dataloader = get_dataloader(cfg, train_set, tokenizer, loss_fn)

max_seq_len = cfg.tokenizer.max_seq_len

# Example inputs, needed for ET export.
batch = next(iter(train_dataloader))
tokens, labels = batch["tokens"], batch["labels"]
token_size = tokens.shape[1]
labels_size = labels.shape[1]

if token_size > max_seq_len:
tokens = tokens[:, :max_seq_len]
else:
tokens = F.pad(tokens, (0, max_seq_len - token_size), value=0)

if labels_size > max_seq_len:
labels = labels[:, :max_seq_len]
else:
labels = F.pad(labels, (0, max_seq_len - labels_size), value=0)

# Load pre-trained checkpoint.
checkpoint_dict = load_checkpoint(cfg=cfg)
model = setup_model(
# pyre-ignore
cfg=cfg,
base_model_state_dict=checkpoint_dict[MODEL_KEY],
)

training_module = TrainingModule(model, loss_fn)

# Export the model to ExecuTorch for training.
export_model_lora_training(training_module, (tokens, labels), output_file)


if __name__ == "__main__":
main()
88 changes: 88 additions & 0 deletions examples/llm_pte_finetuning/model_loading_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, Dict, Tuple

import torch
from executorch.examples.llm_pte_finetuning.training_lib import TrainingModule
from executorch.exir import to_edge

from omegaconf import DictConfig
from torch.export import export, ExportedProgram
from torch.export.experimental import _export_forward_backward
from torch.nn.attention import sdpa_kernel, SDPBackend
from torchtune import config
from torchtune.modules.peft import get_adapter_params, set_trainable_params
from torchtune.training.precision import get_dtype, set_default_dtype
from torchtune.utils._device import get_device


def load_checkpoint(cfg: Any) -> Dict[str, Any]: # pyre-ignore[2]
"""
Extract the checkpoint state from file and validate. This includes the
base model weights. If resume_from_checkpoint is True, this also includes
the adapter weights and recipe state
"""
checkpointer = config.instantiate(
cfg.checkpointer,
resume_from_checkpoint=cfg.resume_from_checkpoint,
)
checkpoint_dict = checkpointer.load_checkpoint()
return checkpoint_dict


def setup_model(
cfg: DictConfig,
base_model_state_dict: Dict[str, Any],
) -> torch.nn.Module:
device = get_device(device=cfg.device)
dtype = get_dtype(cfg.dtype, device=device)
with set_default_dtype(dtype), device:
model = config.instantiate(cfg.model)

adapter_params = get_adapter_params(model)
set_trainable_params(model, adapter_params)
model.load_state_dict(base_model_state_dict, strict=False)
return model


def export_model_lora_training(
model: TrainingModule,
example_args: Tuple[Any, ...], # pyre-ignore[2]
output_file: str,
) -> None:
"""
Export model with LoRA model to executorch for training, only.
"""

# 0. Mark the LoRA layers as trainable (requires_grad = True) in order
# to just export the backwards pass for these layers later in the
# export process.
set_trainable_params(model, get_adapter_params(model))

print("Exporting model with LoRA for training")
# 1. torch.export: Defines the program with the ATen operator set.

with sdpa_kernel([SDPBackend.MATH]):
exported_graph: ExportedProgram = export(model, example_args, strict=False)
print("Creating a joint forward-backwards graph for training")
joint_graph = _export_forward_backward(exported_graph)

# 2. to_edge: Make optimizations for Edge devices.
print("Lowering to edge dialect")
edge_program = to_edge(joint_graph)

print(edge_program._edge_programs["forward"].graph_module)

# 3. to_executorch: Convert the graph to an ExecuTorch program.
print("Exporting to executorch")
executorch_program = edge_program.to_executorch()
print(executorch_program.exported_program().graph_signature)
print(f"Saving to {output_file}")
with open(output_file, "wb") as file:
file.write(executorch_program.buffer)
49 changes: 49 additions & 0 deletions examples/llm_pte_finetuning/phi3_alpaca_code_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
tokenizer:
_component_: torchtune.models.phi3.phi3_mini_tokenizer
path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model
max_seq_len: 1024

dataset:
_component_: torchtune.datasets.instruct_dataset
template: papaya.toolkit.experimental.llm_pte_finetuning.utils.DatabricksDolly
source: iamtarun/python_code_instructions_18k_alpaca
split: train
column_map:
instruction: instruction
prompt: prompt
input: input
output: output
seed: null
shuffle: True
batch_size: 1

loss:
_component_: torch.nn.CrossEntropyLoss

model:
_component_: torchtune.models.phi3.lora_phi3_mini
lora_attn_modules: ['q_proj', 'v_proj']
apply_lora_to_mlp: False
apply_lora_to_output: False
lora_rank: 8
lora_alpha: 16

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Phi-3-mini-4k-instruct
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Phi-3-mini-4k-instruct/
model_type: PHI3_MINI

resume_from_checkpoint: False
save_adapter_weights_only: False

device: cpu
dtype: fp32

enable_activation_checkpointing: True
compile: False
40 changes: 40 additions & 0 deletions examples/llm_pte_finetuning/phi3_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
tokenizer:
_component_: torchtune.models.phi3.phi3_mini_tokenizer
path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model
max_seq_len: 512

dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
seed: null
shuffle: True
batch_size: 1

loss:
_component_: torch.nn.CrossEntropyLoss

model:
_component_: torchtune.models.phi3.lora_phi3_mini
lora_attn_modules: ['q_proj', 'v_proj']
apply_lora_to_mlp: False
apply_lora_to_output: False
lora_rank: 8
lora_alpha: 16

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Phi-3-mini-4k-instruct
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Phi-3-mini-4k-instruct/
model_type: PHI3_MINI
resume_from_checkpoint: False
save_adapter_weights_only: False

device: cpu
dtype: fp32

enable_activation_checkpointing: True
compile: False
39 changes: 39 additions & 0 deletions examples/llm_pte_finetuning/qwen_05b_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
tokenizer:
_component_: torchtune.models.qwen2.qwen2_tokenizer
path: /tmp/Qwen2-0.5B-Instruct/vocab.json
merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt
max_seq_len: 512

dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
seed: null
shuffle: True
batch_size: 1

loss:
_component_: torch.nn.CrossEntropyLoss

model:
_component_: torchtune.models.qwen2.lora_qwen2_0_5b
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: False
lora_rank: 32
lora_alpha: 64

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Qwen2-0.5B-Instruct
checkpoint_files: [
model.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Qwen2-0.5B-Instruct
model_type: QWEN2
resume_from_checkpoint: False
save_adapter_weights_only: False

device: cpu
dtype: fp32

enable_activation_checkpointing: True
compile: False
Loading

0 comments on commit afb6c11

Please sign in to comment.