Skip to content
This repository was archived by the owner on Aug 7, 2025. It is now read-only.

Commit 0dd2d87

Browse files
author
Naman Nandan
committed
Add llama2 text completion example with streaming response support
1 parent 4249e45 commit 0dd2d87

File tree

8 files changed

+136
-189
lines changed

8 files changed

+136
-189
lines changed

examples/large_models/inferentia2/llama/inf2_handler.py

Lines changed: 0 additions & 167 deletions
This file was deleted.

examples/large_models/inferentia2/llama/requirements.txt

Lines changed: 0 additions & 5 deletions
This file was deleted.

examples/large_models/inferentia2/llama/sample_text.txt

Lines changed: 0 additions & 1 deletion
This file was deleted.

examples/large_models/inferentia2/llama/Readme.md renamed to examples/large_models/inferentia2/llama2/Readme.md

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# Large model inference on Inferentia2
22

3-
This document briefs on serving large HuggingFace (HF) models on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) instances.
3+
This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) with streaming response support.
44

5-
Inferentia2 uses [Neuron SDK](https://aws.amazon.com/machine-learning/neuron/) which is build on top of PyTorch XLA stack. For large model inference [`transformers-neuronx`](https://github.com/aws-neuron/transformers-neuronx) package is used that takes care of model partitioning and running inference.
5+
Inferentia2 uses [Neuron SDK](https://aws.amazon.com/machine-learning/neuron/) which is built on top of PyTorch XLA stack. For large model inference [`transformers-neuronx`](https://github.com/aws-neuron/transformers-neuronx) package is used that takes care of model partitioning and running inference.
66

77
Let's take a look at the steps to prepare our model for inference on Inf2 instances.
88

9-
**Note** To run the model on an Inf2 instance, the model gets compiled as a preprocessing step. As part of the compilation process, to generate the model graph, a specific batch size is used. Following this, when running inference, we need to pass the same batch size that was used during compilation. This example uses batch size of 2 but make sure to change it and register the model according to your batch size.
9+
**Note** To run the model on an Inf2 instance, the model gets compiled as a preprocessing step. As part of the compilation process, to generate the model graph, a specific batch size is used. Following this, when running inference, we need to pass the same batch size that was used during compilation. This example uses batch size of 1 to demonstrate real-time inference with streaming response.
1010

1111
### Step 1: Inf2 instance
1212

@@ -30,7 +30,7 @@ source /opt/aws_neuron_venv_pytorch/bin/activate
3030
python -m pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com
3131

3232
# Update Neuron Compiler and Framework
33-
python -m pip install --upgrade neuronx-cc==2.* torch-neuronx torchvision
33+
python -m pip install --upgrade neuronx-cc==2.* torch-neuronx
3434

3535
pip install git+https://github.com/aws-neuron/transformers-neuronx.git transformers -U
3636

@@ -40,38 +40,42 @@ pip install git+https://github.com/aws-neuron/transformers-neuronx.git transform
4040

4141
### Step 2: Save the model split checkpoints compatible with `transformers-neuronx`
4242

43-
Navigate to `large_model/inferentia2/llama` directory.
43+
Navigate to `large_model/inferentia2/llama2` directory.
4444

4545
```bash
46-
python ../util/inf2_save_split_checkpoints.py --model_name decapoda-research/llama-7b-hf --save_path './decapoda_llama_7b_split'
46+
python ../util/inf2_save_split_checkpoints.py --model_name meta-llama/Llama-2-7b-hf --save_path './llama-2-7b-split'
4747

4848
```
4949

5050

5151
### Step 3: Generate Tar/ MAR file
5252

5353
```bash
54-
torch-model-archiver --model-name decapoda_llama_7b --version 1.0 --handler inf2_handler.py --extra-files ./decapoda_llama_7b_split -r requirements.txt --config-file model-config.yaml --archive-format no-archive
54+
torch-model-archiver --model-name llama-2-7b --version 1.0 --handler inf2_handler.py --extra-files ./llama-2-7b-split -r requirements.txt --config-file model-config.yaml --archive-format no-archive
5555

5656
```
5757

5858
### Step 4: Add the mar file to model store
5959

6060
```bash
6161
mkdir model_store
62-
mv decapoda_llama_7b model_store
62+
mv llama-2-7b model_store
6363
```
6464

6565
### Step 5: Start torchserve
6666

67-
Update config.properties and start torchserve
67+
```bash
68+
torchserve --ncs --start --model-store model_store
69+
```
70+
71+
### Step 6: Register model
6872

6973
```bash
70-
torchserve --ncs --start --model-store model_store --models decapoda_llama_7b
74+
curl -X POST "http://localhost:8081/models?url=llama-2-7b"
7175
```
7276

73-
### Step 6: Run inference
77+
### Step 7: Run inference
7478

7579
```bash
76-
curl -v "http://localhost:8080/predictions/decapoda_llama_7b" -T sample_text.txt
80+
python test_stream_response.py
7781
```
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import logging
2+
import os
3+
from abc import ABC
4+
from threading import Thread
5+
6+
import torch_neuronx
7+
from transformers import AutoConfig, LlamaTokenizer, TextIteratorStreamer
8+
from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter
9+
from transformers_neuronx.llama.model import LlamaForSampling
10+
11+
from ts.protocol.otf_message_handler import send_intermediate_predict_response
12+
from ts.torch_handler.base_handler import BaseHandler
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
class LLMHandler(BaseHandler, ABC):
18+
"""
19+
Transformers handler class for text completion streaming on Inferentia2
20+
"""
21+
22+
def __init__(self):
23+
super(LLMHandler, self).__init__()
24+
self.initialized = False
25+
26+
def initialize(self, ctx):
27+
self.manifest = ctx.manifest
28+
properties = ctx.system_properties
29+
model_dir = properties.get("model_dir")
30+
31+
# settings for model compiliation and loading
32+
model_name = ctx.model_yaml_config["handler"]["model_name"]
33+
tp_degree = ctx.model_yaml_config["handler"]["tp_degree"]
34+
self.max_length = ctx.model_yaml_config["handler"]["max_length"]
35+
36+
# allocate "tp_degree" number of neuron cores to the worker process
37+
os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree)
38+
try:
39+
num_neuron_cores_available = (
40+
torch_neuronx.xla_impl.data_parallel.device_count()
41+
)
42+
assert num_neuron_cores_available >= int(tp_degree)
43+
except (RuntimeError, AssertionError) as error:
44+
raise RuntimeError(
45+
"Required number of neuron cores for tp_degree "
46+
+ str(tp_degree)
47+
+ " are not available: "
48+
+ str(error)
49+
)
50+
51+
os.environ["NEURON_CC_FLAGS"] = "--model-type=transformer-inference"
52+
53+
self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
54+
self.model = LlamaForSampling.from_pretrained(
55+
model_dir, batch_size=1, tp_degree=tp_degree
56+
)
57+
logger.info("Starting to compile the model")
58+
self.model.to_neuron()
59+
logger.info("Model has been successfully compiled")
60+
model_config = AutoConfig.from_pretrained(model_dir)
61+
self.model = HuggingFaceGenerationModelAdapter(model_config, self.model)
62+
self.output_streamer = TextIteratorStreamer(self.tokenizer)
63+
64+
self.initialized = True
65+
66+
def preprocess(self, requests):
67+
input_texts = []
68+
for req in requests:
69+
data = req.get("data") or req.get("body")
70+
if isinstance(data, (bytes, bytearray)):
71+
data = data.decode("utf-8")
72+
input_texts.append(data)
73+
74+
return self.tokenizer(input_texts, return_tensors="pt")
75+
76+
def inference(self, tokenized_input):
77+
generation_kwargs = dict(
78+
tokenized_input,
79+
streamer=self.output_streamer,
80+
max_new_tokens=self.max_length,
81+
)
82+
self.model.reset_generation()
83+
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
84+
thread.start()
85+
86+
for new_text in self.output_streamer:
87+
send_intermediate_predict_response(
88+
[new_text],
89+
self.context.request_ids,
90+
"Intermediate Prediction success",
91+
200,
92+
self.context,
93+
)
94+
95+
thread.join()
96+
97+
return [""]
98+
99+
def postprocess(self, inference_output):
100+
return inference_output

examples/large_models/inferentia2/llama/model-config.yaml renamed to examples/large_models/inferentia2/llama2/model-config.yaml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,5 @@ responseTimeout: 600
55

66
handler:
77
max_length: 50
8-
manual_seed: 40
9-
batch_size: 2
108
tp_degree: 2
11-
amp: f16
12-
model_name: decapoda-research/llama-7b-hf
9+
model_name: meta-llama/Llama-2-7b-hf
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
torch-neuronx
2+
transformers-neuronx
3+
transformers
4+
tokenizers
5+
sentencepiece
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import requests
2+
3+
response = requests.post(
4+
"http://localhost:8080/predictions/llama-2-7b",
5+
data="Today the weather is really nice and I am planning on ",
6+
stream=True,
7+
)
8+
9+
for chunk in response.iter_content(chunk_size=None):
10+
if chunk:
11+
data = chunk.decode("utf-8")
12+
print(data, end="", flush=True)
13+
14+
print("")

0 commit comments

Comments
 (0)