-
Notifications
You must be signed in to change notification settings - Fork 887
Example for Llama2 on Inf2 #2458
Changes from all commits
2838ad0
73a0ba4
f5c0855
5c72745
0d06004
c5720c5
89e4ffb
483004b
83dc576
7920a96
a0c199c
3e0697b
2993154
98e2a94
f50ac63
9af1611
b0392cc
5ac4696
3da4a78
b747cd3
7c1b130
80eb640
e55fd86
36b0d96
5176dce
a983a01
cfaf385
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Large model inference on Inferentia2 | ||
|
||
This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) for text completion with [micro batching](https://github.com/pytorch/serve/tree/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/examples/micro_batching) and [streaming response](https://github.com/pytorch/serve/blob/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/docs/inference_api.md#curl-example-1) support. | ||
|
||
Inferentia2 uses [Neuron SDK](https://aws.amazon.com/machine-learning/neuron/) which is built on top of PyTorch XLA stack. For large model inference [`transformers-neuronx`](https://github.com/aws-neuron/transformers-neuronx) package is used that takes care of model partitioning and running inference. | ||
|
||
**Note**: To run the model on an Inf2 instance, the model gets compiled as a preprocessing step. As part of the compilation process, to generate the model graph, a specific batch size is used. Following this, when running inference, we need to pass input which matches the batch size that was used during compilation. Model compilation and input padding to match compiled model batch size is taken care of by the [custom handler](inf2_handler.py) in this example. | ||
|
||
The batch size and micro batch size configurations are present in [model-config.yaml](model-config.yaml). The batch size indicates the maximum number of requests torchserve will aggregate and send to the custom handler within the batch delay. | ||
The batch size is chosen to be a relatively large value, say 16 since micro batching enables running the preprocess(tokenization) and inference steps in parallel on the micro batches. The micro batch size is the batch size used for the Inf2 model compilation. | ||
Since compilation batch size can influence compile time and also constrained by the Inf2 instance type, this is chosen to be a relatively smaller value, say 4. | ||
|
||
This example also demonstrates the utilization of neuronx cache to store inf2 model compilation artifacts using the `NEURONX_CACHE` and `NEURONX_DUMP_TO` environment variables in the custom handler. | ||
When the model is loaded for the first time, the model is compiled for the configured micro batch size and the compilation artifacts are saved to the neuronx cache. | ||
On subsequent model load, the compilation artifacts in the neuronx cache serves as `Ahead of Time(AOT)` compilation artifacts and significantly reduces the model load time. | ||
For convenience, the compiled model artifacts for this example are made available on the Torchserve model zoo: `s3://torchserve/mar_files/llama-2-13b-neuronx-b4`\ | ||
Instructions on how to use the AOT compiled model artifacts is shown below. | ||
|
||
### Step 1: Inf2 instance | ||
|
||
Get an Inf2 instance(Note: This example was tested on instance type:`inf2.24xlarge`), ssh to it, make sure to use the following DLAMI as it comes with PyTorch and necessary packages for AWS Neuron SDK pre-installed. | ||
namannandan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
DLAMI Name: ` Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04) 20230720 Amazon Machine Image (AMI)` or higher. | ||
|
||
### Step 2: Package Installations | ||
|
||
Follow the steps below to complete package installations | ||
|
||
```bash | ||
sudo apt-get update | ||
sudo apt-get upgrade | ||
|
||
# Update Neuron Runtime | ||
sudo apt-get install aws-neuronx-collectives=2.* -y | ||
namannandan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sudo apt-get install aws-neuronx-runtime-lib=2.* -y | ||
|
||
# Activate Python venv | ||
source /opt/aws_neuron_venv_pytorch/bin/activate | ||
|
||
# Clone Torchserve git repository | ||
git clone https://github.com/pytorch/serve.git | ||
cd serve | ||
|
||
# Install dependencies | ||
python ts_scripts/install_dependencies.py --neuronx --environment=dev | ||
|
||
# Install torchserve and torch-model-archiver | ||
python ts_scripts/install_from_src.py | ||
|
||
# Navigate to `examples/large_models/inferentia2/llama2` directory | ||
cd examples/large_models/inferentia2/llama2/ | ||
|
||
# Install additional necessary packages | ||
python -m pip install -r requirements.txt | ||
``` | ||
|
||
### Step 3: Save the model artifacts compatible with `transformers-neuronx` | ||
In order to use the pre-compiled model artifacts, copy them from the model zoo using the command shown below and skip to **Step 5** | ||
```bash | ||
aws s3 cp s3://torchserve/mar_files/llama-2-13b-neuronx-b4/ llama-2-13b --recursive | ||
``` | ||
|
||
In order to download and compile the Llama2 model from scratch for support on Inf2:\ | ||
Request access to the Llama2 model\ | ||
https://huggingface.co/meta-llama/Llama-2-13b-hf | ||
|
||
Login to Huggingface | ||
```bash | ||
huggingface-cli login | ||
``` | ||
|
||
Run the `inf2_save_split_checkpoints.py` script | ||
```bash | ||
python ../util/inf2_save_split_checkpoints.py --model_name meta-llama/Llama-2-13b-hf --save_path './llama-2-13b-split' | ||
namannandan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
``` | ||
|
||
|
||
### Step 4: Package model artifacts | ||
|
||
```bash | ||
torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py -r requirements.txt --config-file model-config.yaml --archive-format no-archive | ||
mv llama-2-13b-split llama-2-13b | ||
``` | ||
|
||
### Step 5: Add the model artifacts to model store | ||
|
||
```bash | ||
mkdir model_store | ||
mv llama-2-13b model_store | ||
``` | ||
|
||
### Step 6: Start torchserve | ||
|
||
```bash | ||
namannandan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
torchserve --ncs --start --model-store model_store --ts-config config.properties | ||
``` | ||
|
||
### Step 7: Register model | ||
|
||
```bash | ||
curl -X POST "http://localhost:8081/models?url=llama-2-13b" | ||
``` | ||
|
||
### Step 8: Run inference | ||
|
||
```bash | ||
python test_stream_response.py | ||
``` | ||
|
||
### Step 9: Stop torchserve | ||
|
||
```bash | ||
torchserve --stop | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
install_py_dep_per_model=true |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be great to a have unit test for the handler. You can mock out inferentia and model related parts. This example shows how to mock the context etc
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import logging | ||
import os | ||
from abc import ABC | ||
from threading import Thread | ||
|
||
import torch_neuronx | ||
from transformers import AutoConfig, LlamaTokenizer | ||
from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter | ||
from transformers_neuronx.llama.model import LlamaForSampling | ||
|
||
from ts.handler_utils.hf_batch_streamer import TextIteratorStreamerBatch | ||
from ts.handler_utils.micro_batching import MicroBatching | ||
from ts.protocol.otf_message_handler import send_intermediate_predict_response | ||
from ts.torch_handler.base_handler import BaseHandler | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LLMHandler(BaseHandler, ABC): | ||
""" | ||
Transformers handler class for text completion streaming on Inferentia2 | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.initialized = False | ||
self.max_length = None | ||
self.tokenizer = None | ||
self.output_streamer = None | ||
# enable micro batching | ||
self.handle = MicroBatching(self) | ||
|
||
def initialize(self, ctx): | ||
self.manifest = ctx.manifest | ||
properties = ctx.system_properties | ||
model_dir = properties.get("model_dir") | ||
model_checkpoint_dir = ctx.model_yaml_config.get("handler", {}).get( | ||
"model_checkpoint_dir", "" | ||
) | ||
model_checkpoint_path = f"{model_dir}/{model_checkpoint_dir}" | ||
os.environ["NEURONX_CACHE"] = "on" | ||
os.environ["NEURONX_DUMP_TO"] = f"{model_dir}/neuron_cache" | ||
os.environ["NEURON_CC_FLAGS"] = "--model-type=transformer-inference" | ||
|
||
# micro batching initialization | ||
micro_batching_parallelism = ctx.model_yaml_config.get( | ||
"micro_batching", {} | ||
).get("parallelism", None) | ||
if micro_batching_parallelism: | ||
logger.info( | ||
f"Setting micro batching parallelism from model_config_yaml: {micro_batching_parallelism}" | ||
) | ||
self.handle.parallelism = micro_batching_parallelism | ||
|
||
micro_batch_size = ctx.model_yaml_config.get("micro_batching", {}).get( | ||
"micro_batch_size", 1 | ||
) | ||
logger.info(f"Setting micro batching size: {micro_batch_size}") | ||
self.handle.micro_batch_size = micro_batch_size | ||
|
||
# settings for model compiliation and loading | ||
amp = ctx.model_yaml_config.get("handler", {}).get("amp", "f32") | ||
tp_degree = ctx.model_yaml_config.get("handler", {}).get("tp_degree", 6) | ||
self.max_length = ctx.model_yaml_config.get("handler", {}).get("max_length", 50) | ||
|
||
# allocate "tp_degree" number of neuron cores to the worker process | ||
os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree) | ||
try: | ||
num_neuron_cores_available = ( | ||
torch_neuronx.xla_impl.data_parallel.device_count() | ||
) | ||
assert num_neuron_cores_available >= int(tp_degree) | ||
except (RuntimeError, AssertionError) as error: | ||
namannandan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.error( | ||
"Required number of neuron cores for tp_degree " | ||
+ str(tp_degree) | ||
+ " are not available: " | ||
+ str(error) | ||
) | ||
|
||
raise error | ||
|
||
self.tokenizer = LlamaTokenizer.from_pretrained(model_checkpoint_path) | ||
self.tokenizer.pad_token = self.tokenizer.eos_token | ||
self.model = LlamaForSampling.from_pretrained( | ||
model_checkpoint_path, | ||
batch_size=self.handle.micro_batch_size, | ||
amp=amp, | ||
tp_degree=tp_degree, | ||
) | ||
logger.info("Starting to compile the model") | ||
self.model.to_neuron() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @namannandan I am wondering if compilation can be done a head of time and we just load the compiled graphs here the way it was working for inf1? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tested the _save_compiled_artifacts . It is able to generate a neuron model. However, the transformers_neuronx still needs to recompile. I already let Neuron team know they need more work on the experimental feature _save_compiled_artifacts. |
||
logger.info("Model has been successfully compiled") | ||
model_config = AutoConfig.from_pretrained(model_checkpoint_path) | ||
self.model = HuggingFaceGenerationModelAdapter(model_config, self.model) | ||
self.output_streamer = TextIteratorStreamerBatch( | ||
self.tokenizer, | ||
batch_size=self.handle.micro_batch_size, | ||
skip_special_tokens=True, | ||
) | ||
|
||
self.initialized = True | ||
|
||
def preprocess(self, requests): | ||
input_text = [] | ||
for req in requests: | ||
data = req.get("data") or req.get("body") | ||
if isinstance(data, (bytes, bytearray)): | ||
data = data.decode("utf-8") | ||
logger.info(f"received req={data}") | ||
input_text.append(data.strip()) | ||
namannandan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Ensure the compiled model can handle the input received | ||
if len(input_text) > self.handle.micro_batch_size: | ||
raise ValueError( | ||
f"Model is compiled for batch size {self.handle.micro_batch_size} but received input of size {len(input_text)}" | ||
) | ||
|
||
# Pad input to match compiled model batch size | ||
input_text.extend([""] * (self.handle.micro_batch_size - len(input_text))) | ||
|
||
return self.tokenizer(input_text, return_tensors="pt", padding=True) | ||
namannandan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def inference(self, tokenized_input): | ||
generation_kwargs = dict( | ||
tokenized_input, | ||
streamer=self.output_streamer, | ||
max_new_tokens=self.max_length, | ||
) | ||
self.model.reset_generation() | ||
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | ||
thread.start() | ||
|
||
micro_batch_idx = self.handle.get_micro_batch_idx() | ||
micro_batch_req_id_map = self.get_micro_batch_req_id_map(micro_batch_idx) | ||
for new_text in self.output_streamer: | ||
logger.debug("send response stream") | ||
send_intermediate_predict_response( | ||
new_text[: len(micro_batch_req_id_map)], | ||
micro_batch_req_id_map, | ||
"Intermediate Prediction success", | ||
200, | ||
self.context, | ||
) | ||
|
||
thread.join() | ||
|
||
return [""] * len(micro_batch_req_id_map) | ||
|
||
def postprocess(self, inference_output): | ||
return inference_output | ||
|
||
def get_micro_batch_req_id_map(self, micro_batch_idx: int): | ||
start_idx = micro_batch_idx * self.handle.micro_batch_size | ||
micro_batch_req_id_map = { | ||
index: self.context.request_ids[batch_index] | ||
for index, batch_index in enumerate( | ||
range(start_idx, start_idx + self.handle.micro_batch_size) | ||
) | ||
if batch_index in self.context.request_ids | ||
} | ||
|
||
return micro_batch_req_id_map |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
minWorkers: 1 | ||
maxWorkers: 1 | ||
maxBatchDelay: 100 | ||
responseTimeout: 10800 | ||
batchSize: 16 | ||
|
||
handler: | ||
model_checkpoint_dir: "llama-2-13b-split" | ||
amp: "bf16" | ||
tp_degree: 6 | ||
max_length: 100 | ||
|
||
micro_batching: | ||
namannandan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
micro_batch_size: 4 | ||
parallelism: | ||
preprocess: 2 | ||
inference: 1 | ||
postprocess: 2 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
--extra-index-url https://pip.repos.neuron.amazonaws.com | ||
torch-neuronx==1.13.1.1.9.0 | ||
transformers-neuronx==0.5.58 | ||
transformers==4.31.0 | ||
namannandan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tokenizers==0.13.3 | ||
sentencepiece==0.1.99 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import requests | ||
|
||
response = requests.post( | ||
"http://localhost:8080/predictions/llama-2-13b", | ||
data="Today the weather is really nice and I am planning on ", | ||
stream=True, | ||
) | ||
|
||
for chunk in response.iter_content(chunk_size=None): | ||
if chunk: | ||
data = chunk.decode("utf-8") | ||
print(data, end="", flush=True) | ||
|
||
print("") |
Uh oh!
There was an error while loading. Please reload this page.