Skip to content

Commit

Permalink
export phi-3-mini-wrapper (#4478)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #4478

imported-using-ghimport

Test Plan: Imported from OSS

Reviewed By: iseeyuan

Differential Revision: D60483506

Pulled By: helunwencser

fbshipit-source-id: f5f019035af66af6380186e4bc57a949e6cc5480
  • Loading branch information
helunwencser authored and facebook-github-bot committed Aug 1, 2024
1 parent a65700c commit a743a3b
Showing 1 changed file with 43 additions and 20 deletions.
63 changes: 43 additions & 20 deletions examples/models/phi-3-mini/export_phi-3-mini.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import argparse

import torch

from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
Expand All @@ -20,30 +23,43 @@
XNNPACKQuantizer,
)

from transformers import Phi3ForCausalLM
from transformers import AutoTokenizer, Phi3ForCausalLM

from .phi_3_mini import Phi3Mini


def main() -> None:
def main(args) -> None:
torch.manual_seed(0)

# pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `Phi3ForCausalLM`
model = Phi3ForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
model_name = "microsoft/Phi-3-mini-4k-instruct"

example_inputs = (torch.randint(0, 100, (1, 100), dtype=torch.long),)
dynamic_shape = {"input_ids": {1: torch.export.Dim("sequence_length", max=128)}}
with torch.no_grad():
model = Phi3Mini(
# pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `Phi3ForCausalLM`
model=Phi3ForCausalLM.from_pretrained(model_name),
max_batch_size=1,
max_seq_len=args.seq_len,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

xnnpack_quant_config = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
xnnpack_quantizer = XNNPACKQuantizer()
xnnpack_quantizer.set_global(xnnpack_quant_config)

with torch.nn.attention.sdpa_kernel(
[torch.nn.attention.SDPBackend.MATH]
), torch.no_grad():
model = capture_pre_autograd_graph(
model, example_inputs, dynamic_shapes=dynamic_shape
tokens = tokenizer.encode("Tell me a story", return_tensors="pt")
for input_pos in range(tokens.shape[-1]):
result = model.forward(
input_ids=tokens[:, input_pos : input_pos + 1],
)
current_token = torch.argmax(result, dim=-1).item()

example_inputs = (
torch.tensor([[current_token]], dtype=torch.long, requires_grad=False),
)

xnnpack_quant_config = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
xnnpack_quantizer = XNNPACKQuantizer()
xnnpack_quantizer.set_global(xnnpack_quant_config)

model = capture_pre_autograd_graph(model, example_inputs)
model = prepare_pt2e(model, xnnpack_quantizer)
model(*example_inputs)
model = convert_pt2e(model, fold_quantize=False)
Expand All @@ -53,19 +69,26 @@ def main() -> None:
model = torch.export._trace._export(
model,
example_inputs,
dynamic_shapes=dynamic_shape,
strict=False,
pre_dispatch=False,
)

edge_config = get_xnnpack_edge_compile_config()
edge_manager = to_edge(model, compile_config=edge_config)
edge_manager = edge_manager.to_backend(XnnpackPartitioner(has_dynamic_shapes=True))
edge_manager = edge_manager.to_backend(XnnpackPartitioner())
et_program = edge_manager.to_executorch()

with open("phi-3-mini.pte", "wb") as file:
file.write(et_program.buffer)


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser.add_argument(
"-s",
"--seq_len",
type=int,
default=128,
help="Maximum number of tokens including prompt to generate",
)
main(parser.parse_args())

0 comments on commit a743a3b

Please sign in to comment.