Skip to content

Commit

Permalink
Script to export 🤗 models (#4723)
Browse files Browse the repository at this point in the history
Summary:
bypass-github-export-checks

[Done] ~~Require PR [Make StaticCache configurable at model construct time](huggingface/transformers#32830) in order to export, lower and run the 🤗 model OOTB.~~
[Done] ~~Require huggingface/transformers#33303 or huggingface/transformers#33287 to be merged to 🤗 `transformers` to resolve the export issue introduced by huggingface/transformers#32543

-----------

Now we can take the integration point from 🤗 `transformers` to lower compatible models to ExecuTorch OOTB.
  - This PR creates a simple script with recipe of XNNPACK.
  - This PR also created a secret `EXECUTORCH_HT_TOKEN` to allow download checkpoints in the CI
  - This PR connects the 🤗 "Export to ExecuTorch" e2e workflow to ExecuTorch CI

### Instructions to run the demo:

1. Run the export_hf_model.py to lower gemma-2b to ExecuTorch:
```
python -m extension.export_util.export_hf_model -hfm "google/gemma-2b" # The model is exported statical dims with static KV cache
```
2. Run the tokenizer.py to generate the binary format for ExecuTorch runtime:
```
python -m extension.llm.tokenizer.tokenizer -t <path_to_downloaded_gemma_checkpoint_dir>/tokenizer.model -o tokenizer.bin
```
3. Build llm runner by following this guide [step 4](https://github.com/pytorch/executorch/tree/main/examples/models/llama2#step-4-run-on-your-computer-to-validate)

4. Run the lowered model
```
cmake-out/examples/models/llama2/llama_main --model_path=gemma.pte --tokenizer_path=tokenizer.bin --prompt="My name is"
```
OOTB output and perf
```
I 00:00:00.003110 executorch:cpuinfo_utils.cpp:62] Reading file /sys/devices/soc0/image_version
I 00:00:00.003360 executorch:cpuinfo_utils.cpp:78] Failed to open midr file /sys/devices/soc0/image_version
I 00:00:00.003380 executorch:cpuinfo_utils.cpp:158] Number of efficient cores 4
I 00:00:00.003384 executorch:main.cpp:65] Resetting threadpool with num threads = 6
I 00:00:00.014716 executorch:runner.cpp:51] Creating LLaMa runner: model_path=gemma.pte, tokenizer_path=tokenizer_gemma.bin
I 00:00:03.065359 executorch:runner.cpp:66] Reading metadata from model
I 00:00:03.065391 executorch:metadata_util.h:43] get_n_bos: 1
I 00:00:03.065396 executorch:metadata_util.h:43] get_n_eos: 1
I 00:00:03.065399 executorch:metadata_util.h:43] get_max_seq_len: 123
I 00:00:03.065402 executorch:metadata_util.h:43] use_kv_cache: 1
I 00:00:03.065404 executorch:metadata_util.h:41] The model does not contain use_sdpa_with_kv_cache method, using default value 0
I 00:00:03.065405 executorch:metadata_util.h:43] use_sdpa_with_kv_cache: 0
I 00:00:03.065407 executorch:metadata_util.h:41] The model does not contain append_eos_to_prompt method, using default value 0
I 00:00:03.065409 executorch:metadata_util.h:43] append_eos_to_prompt: 0
I 00:00:03.065411 executorch:metadata_util.h:41] The model does not contain enable_dynamic_shape method, using default value 0
I 00:00:03.065412 executorch:metadata_util.h:43] enable_dynamic_shape: 0
I 00:00:03.130388 executorch:metadata_util.h:43] get_vocab_size: 256000
I 00:00:03.130405 executorch:metadata_util.h:43] get_bos_id: 2
I 00:00:03.130408 executorch:metadata_util.h:43] get_eos_id: 1
My name is Melle. I am a 20 year old girl from Belgium. I am living in the southern part of Belgium. I am 165 cm tall and I weigh 45kg. I like to play sports like swimming, running and playing tennis. I am very interested in music and I like to listen to classical music. I like to sing and I can play the piano. I would like to go to the USA because I like to travel a lot. I am looking for a boy from the USA who is between 18 and 25 years old. I
PyTorchObserver {"prompt_tokens":4,"generated_tokens":118,"model_load_start_ms":1723685715497,"model_load_end_ms":1723685718612,"inference_start_ms":1723685718612,"inference_end_ms":1723685732965,"prompt_eval_end_ms":1723685719087,"first_token_ms":1723685719087,"aggregate_sampling_time_ms":182,"SCALING_FACTOR_UNITS_PER_SECOND":1000}
I 00:00:17.482472 executorch:stats.h:70] 	Prompt Tokens: 4    Generated Tokens: 118
I 00:00:17.482475 executorch:stats.h:76] 	Model Load Time:		3.115000 (seconds)
I 00:00:17.482481 executorch:stats.h:86] 	Total inference time:		14.353000 (seconds)		 Rate: 	8.221278 (tokens/second)
I 00:00:17.482483 executorch:stats.h:94] 		Prompt evaluation:	0.475000 (seconds)		 Rate: 	8.421053 (tokens/second)
I 00:00:17.482485 executorch:stats.h:105] 		Generated 118 tokens:	13.878000 (seconds)		 Rate: 	8.502666 (tokens/second)
I 00:00:17.482486 executorch:stats.h:113] 	Time to first generated token:	0.475000 (seconds)
I 00:00:17.482488 executorch:stats.h:120] 	Sampling time over 122 tokens:	0.182000 (seconds)
```

Pull Request resolved: #4723

Reviewed By: huydhn, kirklandsign

Differential Revision: D62543933

Pulled By: guangy10

fbshipit-source-id: 00401a39ba03d7383e4b284d25c8fc62a6695b34
  • Loading branch information
Guang Yang authored and facebook-github-bot committed Sep 14, 2024
1 parent 2001b3c commit 67be84b
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 0 deletions.
90 changes: 90 additions & 0 deletions .github/workflows/trunk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,93 @@ jobs:
PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/test_model.sh "${MODEL_NAME}" "${BUILD_TOOL}" "${BACKEND}"
echo "::endgroup::"
done
test-huggingface-transformers:
name: test-huggingface-transformers
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
secrets: inherit
strategy:
matrix:
hf_model_repo: [google/gemma-2b]
fail-fast: false
with:
secrets-env: EXECUTORCH_HF_TOKEN
runner: linux.12xlarge
docker-image: executorch-ubuntu-22.04-clang12
submodules: 'true'
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
timeout: 90
script: |
echo "::group::Set up ExecuTorch"
# The generic Linux job chooses to use base env, not the one setup by the image
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
conda activate "${CONDA_ENV}"
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh cmake
echo "Installing libexecutorch.a, libextension_module.so, libportable_ops_lib.a"
rm -rf cmake-out
cmake \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_BUILD_TYPE=Release \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
-DEXECUTORCH_BUILD_XNNPACK=ON \
-DPYTHON_EXECUTABLE=python \
-Bcmake-out .
cmake --build cmake-out -j9 --target install --config Release
echo "Build llama runner"
dir="examples/models/llama2"
cmake \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_BUILD_TYPE=Release \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
-DEXECUTORCH_BUILD_XNNPACK=ON \
-DPYTHON_EXECUTABLE=python \
-Bcmake-out/${dir} \
${dir}
cmake --build cmake-out/${dir} -j9 --config Release
echo "::endgroup::"
echo "::group::Set up HuggingFace Dependencies"
pip install -U "huggingface_hub[cli]"
huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN
pip install accelerate sentencepiece
# TODO(guangyang): Switch to use released transformers library after all required patches are included
pip install "git+https://github.com/huggingface/transformers.git@6cc4dfe3f1e8d421c6d6351388e06e9b123cbfe1"
pip list
echo "::endgroup::"
echo "::group::Export to ExecuTorch"
TOKENIZER_FILE=tokenizer.model
TOKENIZER_BIN_FILE=tokenizer.bin
ET_MODEL_NAME=et_model
# Fetch the file using a Python one-liner
DOWNLOADED_TOKENIZER_FILE_PATH=$(python -c "
from huggingface_hub import hf_hub_download
# Download the file from the Hugging Face Hub
downloaded_path = hf_hub_download(
repo_id='${{ matrix.hf_model_repo }}',
filename='${TOKENIZER_FILE}'
)
print(downloaded_path)
")
if [ -f "$DOWNLOADED_TOKENIZER_FILE_PATH" ]; then
echo "${TOKENIZER_FILE} downloaded successfully at: $DOWNLOADED_TOKENIZER_FILE_PATH"
python -m extension.llm.tokenizer.tokenizer -t $DOWNLOADED_TOKENIZER_FILE_PATH -o ./${TOKENIZER_BIN_FILE}
ls ./tokenizer.bin
else
echo "Failed to download ${TOKENIZER_FILE} from ${{ matrix.hf_model_repo }}."
exit 1
fi
python -m extension.export_util.export_hf_model -hfm=${{ matrix.hf_model_repo }} -o ${ET_MODEL_NAME}
cmake-out/examples/models/llama2/llama_main --model_path=${ET_MODEL_NAME}.pte --tokenizer_path=${TOKENIZER_BIN_FILE} --prompt="My name is"
echo "::endgroup::"
110 changes: 110 additions & 0 deletions extension/export_util/export_hf_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os

import torch
import torch.export._trace
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
from torch.nn.attention import SDPBackend
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.configuration_utils import GenerationConfig
from transformers.integrations.executorch import convert_and_export_with_cache
from transformers.modeling_utils import PreTrainedModel


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"-hfm",
"--hf_model_repo",
required=True,
default=None,
help="a valid huggingface model repo name",
)
parser.add_argument(
"-o",
"--output_name",
required=False,
default=None,
help="output name of the exported model",
)

args = parser.parse_args()

# Configs to HF model
device = "cpu"
dtype = torch.float32
batch_size = 1
max_length = 123
cache_implementation = "static"
attn_implementation = "sdpa"

# Load and configure a HF model
model = AutoModelForCausalLM.from_pretrained(
args.hf_model_repo,
attn_implementation=attn_implementation,
device_map=device,
torch_dtype=dtype,
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_length,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_length,
},
),
)
print(f"{model.config}")
print(f"{model.generation_config}")

tokenizer = AutoTokenizer.from_pretrained(args.hf_model_repo)
input_ids = tokenizer([""], return_tensors="pt").to(device)["input_ids"]
cache_position = torch.tensor([0], dtype=torch.long)

def _get_constant_methods(model: PreTrainedModel):
return {
"get_dtype": 5 if model.config.torch_dtype == torch.float16 else 6,
"get_bos_id": model.config.bos_token_id,
"get_eos_id": model.config.eos_token_id,
"get_head_dim": model.config.hidden_size / model.config.num_attention_heads,
"get_max_batch_size": model.generation_config.cache_config.batch_size,
"get_max_seq_len": model.generation_config.cache_config.max_cache_len,
"get_n_bos": 1,
"get_n_eos": 1,
"get_n_kv_heads": model.config.num_key_value_heads,
"get_n_layers": model.config.num_hidden_layers,
"get_vocab_size": model.config.vocab_size,
"use_kv_cache": model.generation_config.use_cache,
}

with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():

exported_prog = convert_and_export_with_cache(model, input_ids, cache_position)
prog = (
to_edge(
exported_prog,
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
),
constant_methods=_get_constant_methods(model),
)
.to_backend(XnnpackPartitioner())
.to_executorch(ExecutorchBackendConfig(extract_delegate_segments=True))
)
out_name = args.output_name if args.output_name else model.config.model_type
filename = os.path.join("./", f"{out_name}.pte")
with open(filename, "wb") as f:
prog.write_to_file(f)
print(f"Saved exported program to {filename}")


if __name__ == "__main__":
main()

0 comments on commit 67be84b

Please sign in to comment.