Skip to content
This repository was archived by the owner on Aug 7, 2025. It is now read-only.

Commit a1602ba

Browse files
mresoagunapal
andauthored
GPT fast example (#2815)
* Initial commit for gpt_fast example Remove files and finish tests Add readme Complete readme * Add int8 quantization to example * Add missing json file * Enable streaming response * Remove print * Adapt unit test to list return value, fix lint error * Assert if batch_size is not 1 * Addressed review comments * Added GPU compatibility remark --------- Co-authored-by: Ankith Gunapal <agunapal@ischool.Berkeley.edu>
1 parent f3a2267 commit a1602ba

File tree

5 files changed

+509
-0
lines changed

5 files changed

+509
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
2+
## GPT-Fast
3+
4+
[GPT fast](https://github.com/pytorch-labs/gpt-fast) is a simple and efficient pytorch-native transformer text generation.
5+
6+
It features:
7+
* Very low latency
8+
* <1000 lines of python
9+
* No dependencies other than PyTorch and sentencepiece
10+
* int8/int4 quantization
11+
* Speculative decoding
12+
* Tensor parallelism
13+
* Supports Nvidia and AMD GPUs
14+
15+
More details about gpt-fast can be found in this [blog](https://pytorch.org/blog/accelerating-generative-ai-2/).
16+
The examples has been tested on A10, A100 as well as H100.
17+
18+
19+
#### Pre-requisites
20+
21+
`cd` to the example folder `examples/large_models/gpt_fast`
22+
23+
Install dependencies and upgrade torch to nightly build (currently required)
24+
```
25+
git clone https://github.com/pytorch-labs/gpt-fast/
26+
pip install sentencepiece huggingface_hub
27+
pip uninstall torchtext torchdata torch torchvision torchaudio -y
28+
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed
29+
```
30+
31+
### Step 1: Download and convert the weights
32+
33+
Currently supported models:
34+
```
35+
openlm-research/open_llama_7b
36+
meta-llama/Llama-2-7b-chat-hf
37+
meta-llama/Llama-2-13b-chat-hf
38+
meta-llama/Llama-2-70b-chat-hf
39+
codellama/CodeLlama-7b-Python-hf
40+
codellama/CodeLlama-34b-Python-hf
41+
```
42+
Prepare weights:
43+
```
44+
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
45+
cd gpt-fast
46+
huggingface-cli login
47+
./scripts/prepare.sh $MODEL_REPO
48+
cd ..
49+
```
50+
51+
### (Optional) Step 1.5: Quantize the model to int4
52+
53+
To speed up model loading and inference even further we can optionally quantize the model to int4 instead of int8. Please see the [blog post](https://pytorch.org/blog/accelerating-generative-ai-2/) for details on the potential accuracy loss.
54+
55+
```
56+
cd gpt-fast
57+
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
58+
cd ..
59+
```
60+
61+
The quantized model will show up as checkpoints/$MODEL_REPO/model_int4.pth. To enable it in the example you need to exchange the filename in the [`model_config.yaml`](./model_config.yaml) file.
62+
63+
64+
### Step 2: Generate model archive
65+
66+
```
67+
torch-model-archiver --model-name gpt_fast --version 1.0 --handler handler.py --config-file model_config.yaml --extra-files "gpt-fast/generate.py,gpt-fast/model.py,gpt-fast/quantize.py,gpt-fast/tp.py" --archive-format no-archive
68+
mv gpt-fast/checkpoints gpt_fast/
69+
```
70+
71+
### Step 3: Add the model archive to model store
72+
73+
```
74+
mkdir model_store
75+
mv gpt_fast model_store
76+
```
77+
78+
### Step 4: Start torchserve
79+
80+
```
81+
torchserve --start --ncs --model-store model_store --models gpt_fast
82+
```
83+
84+
### Step 5: Run inference
85+
86+
```
87+
curl "http://localhost:8080/predictions/gpt_fast" -T request.json
88+
# Returns: The capital of France, Paris, is a city of romance, fashion, and art. The city is home to the Eiffel Tower, the Louvre, and the Arc de Triomphe. Paris is also known for its cafes, restaurants
89+
```
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import json
2+
import logging
3+
import os
4+
import time
5+
from pathlib import Path
6+
7+
import torch
8+
from generate import _load_model, decode_one_token, encode_tokens, prefill
9+
from sentencepiece import SentencePieceProcessor
10+
11+
from ts.handler_utils.timer import timed
12+
from ts.protocol.otf_message_handler import send_intermediate_predict_response
13+
from ts.torch_handler.base_handler import BaseHandler
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class GptHandler(BaseHandler):
19+
def __init__(self):
20+
super().__init__()
21+
22+
self.model = None
23+
self.tokenizer = None
24+
self.context = None
25+
self.prefill = prefill
26+
self.decode_one_token = decode_one_token
27+
self.initialized = False
28+
self.device = torch.device("cpu")
29+
self.prompt_length = 0
30+
31+
def initialize(self, ctx):
32+
self.context = ctx
33+
properties = ctx.system_properties
34+
if torch.cuda.is_available():
35+
self.map_location = "cuda"
36+
self.device = torch.device(
37+
self.map_location + ":" + str(os.getenv("LOCAL_RANK", 0))
38+
)
39+
40+
checkpoint_path = Path(ctx.model_yaml_config["handler"]["converted_ckpt_dir"])
41+
assert checkpoint_path.is_file(), checkpoint_path
42+
43+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
44+
assert tokenizer_path.is_file(), tokenizer_path
45+
46+
logger.info("Loading model ...")
47+
t0 = time.time()
48+
self.model = _load_model(checkpoint_path, self.device, torch.bfloat16, False)
49+
torch.cuda.synchronize()
50+
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
51+
52+
self.tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
53+
54+
if ctx.model_yaml_config["handler"]["compile"]:
55+
self.decode_one_token = torch.compile(
56+
self.decode_one_token, mode="reduce-overhead", fullgraph=True
57+
)
58+
self.prefill = torch.compile(self.prefill, fullgraph=True, dynamic=True)
59+
60+
torch.manual_seed(42 * 42)
61+
62+
self.initialized = True
63+
64+
@timed
65+
def preprocess(self, requests):
66+
assert (
67+
len(requests) == 1
68+
), "GPT fast is currently only supported with batch_size=1"
69+
req_data = requests[0]
70+
71+
input_data = req_data.get("data") or req_data.get("body")
72+
73+
if isinstance(input_data, (bytes, bytearray)):
74+
input_data = input_data.decode("utf-8")
75+
76+
input_data = json.loads(input_data)
77+
78+
prompt = input_data["prompt"]
79+
80+
encoded = encode_tokens(self.tokenizer, prompt, bos=True, device=self.device)
81+
82+
self.prompt_length = encoded.size(0)
83+
84+
return {
85+
"encoded": encoded,
86+
"max_new_tokens": input_data.get("max_new_tokens", 50),
87+
}
88+
89+
@timed
90+
def inference(self, input_data):
91+
tokenizer = self.tokenizer
92+
period_id = tokenizer.encode(".")[0]
93+
94+
def call_me(x):
95+
nonlocal period_id, tokenizer
96+
text = self.tokenizer.decode([period_id] + x.tolist())[1:]
97+
send_intermediate_predict_response(
98+
[text],
99+
self.context.request_ids,
100+
"Intermediate Prediction success",
101+
200,
102+
self.context,
103+
)
104+
105+
y = self.generate(
106+
input_data["encoded"],
107+
input_data["max_new_tokens"],
108+
callback=call_me,
109+
temperature=0.8,
110+
top_k=1,
111+
)
112+
logger.info(f"Num tokens = {y.size(0) - self.prompt_length}")
113+
return y
114+
115+
def postprocess(self, y):
116+
return [""]
117+
118+
@torch.no_grad()
119+
def generate(
120+
self,
121+
prompt: torch.Tensor,
122+
max_new_tokens: int,
123+
*,
124+
callback=lambda x: x,
125+
**sampling_kwargs,
126+
) -> torch.Tensor:
127+
"""
128+
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
129+
"""
130+
# create an empty tensor of the expected final shape and fill in the current tokens
131+
T = prompt.size(0)
132+
T_new = T + max_new_tokens
133+
134+
max_seq_length = min(T_new, self.model.config.block_size)
135+
136+
device, dtype = prompt.device, prompt.dtype
137+
with torch.device(device):
138+
self.model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
139+
140+
# create an empty tensor of the expected final shape and fill in the current tokens
141+
empty = torch.empty(T_new, dtype=dtype, device=device)
142+
empty[:T] = prompt
143+
seq = empty
144+
input_pos = torch.arange(0, T, device=device)
145+
146+
next_token = self.prefill(
147+
self.model, prompt.view(1, -1), input_pos, **sampling_kwargs
148+
)
149+
period_id = self.tokenizer.encode(".")[0]
150+
text = self.tokenizer.decode([period_id] + next_token.tolist())[1:]
151+
send_intermediate_predict_response(
152+
[text],
153+
self.context.request_ids,
154+
"Intermediate Prediction success",
155+
200,
156+
self.context,
157+
)
158+
159+
seq[T] = next_token
160+
161+
input_pos = torch.tensor([T], device=device, dtype=torch.int)
162+
163+
generated_tokens, _ = self.decode_n_tokens(
164+
next_token.view(1, -1),
165+
input_pos,
166+
max_new_tokens - 1,
167+
callback=callback,
168+
**sampling_kwargs,
169+
)
170+
seq[T + 1 :] = torch.cat(generated_tokens)
171+
172+
return seq
173+
174+
def decode_n_tokens(
175+
self,
176+
cur_token: torch.Tensor,
177+
input_pos: torch.Tensor,
178+
num_new_tokens: int,
179+
callback=lambda _: _,
180+
**sampling_kwargs,
181+
):
182+
new_tokens, new_probs = [], []
183+
for i in range(num_new_tokens):
184+
with torch.backends.cuda.sdp_kernel(
185+
enable_flash=False, enable_mem_efficient=False, enable_math=True
186+
): # Actually better for Inductor to codegen attention here
187+
next_token, next_prob = self.decode_one_token(
188+
self.model, cur_token, input_pos, **sampling_kwargs
189+
)
190+
input_pos += 1
191+
new_tokens.append(next_token.clone())
192+
callback(new_tokens[-1])
193+
new_probs.append(next_prob.clone())
194+
cur_token = next_token.view(1, -1)
195+
return new_tokens, new_probs
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#frontend settings
2+
minWorkers: 1
3+
maxWorkers: 1
4+
maxBatchDelay: 200
5+
responseTimeout: 300
6+
deviceType: "gpu"
7+
continuousBatching: false
8+
handler:
9+
converted_ckpt_dir: "checkpoints/meta-llama/Llama-2-7b-hf/model.pth"
10+
max_new_tokens: 50
11+
compile: true
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"prompt": "The capital of France",
3+
"max_new_tokens": 50
4+
}

0 commit comments

Comments
 (0)