Skip to content
This repository was archived by the owner on Aug 7, 2025. It is now read-only.

Commit 2cad147

Browse files
authored
Merge branch 'master' into nproc
2 parents 21b569a + f57240f commit 2cad147

File tree

11 files changed

+91
-12
lines changed

11 files changed

+91
-12
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ Refer to [torchserve docker](docker/README.md) for details.
7777

7878

7979
## 🏆 Highlighted Examples
80+
* [Serving Llama 2 with TorchServe](examples/LLM/llama2/README.md)
8081
* [Chatbot with Llama 2 on Mac 🦙💬](examples/LLM/llama2/chat_app)
8182
* [🤗 HuggingFace Transformers](examples/Huggingface_Transformers) with a [Better Transformer Integration/ Flash Attention & Xformer Memory Efficient ](examples/Huggingface_Transformers#Speed-up-inference-with-Better-Transformer)
8283
* [Model parallel inference](examples/Huggingface_Transformers#model-parallelism)

examples/LLM/llama2/README.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Llama 2: Next generation of Meta's Language Model
2+
![Llama 2](./images/llama.png)
3+
4+
TorchServe supports serving Llama 2 in a number of ways. The examples covered in this document range from someone new to TorchServe learning how to serve Llama 2 with an app, to an advanced user of TorchServe using micro batching and streaming response with Llama 2
5+
6+
## 🦙💬 Llama 2 Chatbot
7+
8+
### [Example Link](https://github.com/pytorch/serve/tree/master/examples/LLM/llama2/chat_app)
9+
10+
This example shows how to deploy a llama2 chat app using TorchServe.
11+
We use [streamlit](https://github.com/streamlit/streamlit) to create the app
12+
13+
This example is using [llama-cpp-python](https://github.com/abetlen/llama-cpp-python).
14+
15+
You can run this example on your laptop to understand how to use TorchServe, how to scale up/down TorchServe backend workers and play around with batch_size to see its effect on inference time
16+
17+
![Chatbot Architecture](./chat_app/screenshots/architecture.png)
18+
19+
## Llama 2 with HuggingFace
20+
21+
### [Example Link](https://github.com/pytorch/serve/tree/master/examples/large_models/Huggingface_accelerate/llama2)
22+
23+
This example shows how to serve Llama 2 - 70b model with limited resource using [HuggingFace](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf). It shows the following optimizations
24+
1) HuggingFace `accelerate`. This option can be activated with `low_cpu_mem_usage=True`.
25+
2) Quantization from [`bitsandbytes`](https://github.com/TimDettmers/bitsandbytes) using `load_in_8bit=True`
26+
The model is first created on the Meta device (with empty weights) and the state dict is then loaded inside it (shard by shard in the case of a sharded checkpoint).
27+
28+
## Llama 2 on Inferentia
29+
30+
### [Example Link](https://github.com/pytorch/serve/tree/master/examples/large_models/inferentia2/llama2)
31+
32+
### [PyTorch Blog](https://pytorch.org/blog/high-performance-llama/)
33+
34+
This example shows how to serve the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) for text completion with [micro batching](https://github.com/pytorch/serve/tree/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/examples/micro_batching) and [streaming response](https://github.com/pytorch/serve/blob/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/docs/inference_api.md#curl-example-1) support.
35+
36+
Inferentia2 uses [Neuron SDK](https://aws.amazon.com/machine-learning/neuron/) which is built on top of PyTorch XLA stack. For large model inference [`transformers-neuronx`](https://github.com/aws-neuron/transformers-neuronx) package is used that takes care of model partitioning and running inference.
37+
38+
![Inferentia 2 Software Stack](./images/software_stack_inf2.jpg)

examples/LLM/llama2/chat_app/client_app.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# App title
77
st.set_page_config(page_title="🦙💬 Llama 2 Chatbot")
88

9-
# Replicate Credentials
109
with st.sidebar:
1110
st.title("🦙💬 Llama 2 Chatbot")
1211

examples/LLM/llama2/images/llama.png

1.79 MB
Loading
39.4 KB
Loading

examples/pt2/README.md

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ python ts_scripts/install_dependencies.py --cuda=cu118
1313
pip install torchserve torch-model-archiver
1414
```
1515

16-
## Package your model
16+
## torch.compile
1717

1818
PyTorch 2.0 supports several compiler backends and you pick which one you want by passing in an optional file `model_config.yaml` during your model packaging
1919

@@ -34,10 +34,10 @@ The exact same approach works with any other model, what's going on is the below
3434
opt_mod = torch.compile(mod)
3535
# 2. Train the optimized module
3636
# ....
37-
# 3. Save the original module (weights are shared)
38-
torch.save(model, "model.pt")
37+
# 3. Save the opt module state dict
38+
torch.save(opt_model.state_dict(), "model.pt")
3939

40-
# 4. Load the non optimized model
40+
# 4. Reload the model
4141
mod = torch.load(model)
4242

4343
# 5. Compile the module and then run inferences with it
@@ -46,6 +46,47 @@ opt_mod = torch.compile(mod)
4646

4747
torchserve takes care of 4 and 5 for you while the remaining steps are your responsibility. You can do the exact same thing on the vast majority of TIMM or HuggingFace models.
4848

49-
## Next steps
49+
## torch.export.export
50+
51+
Export your model from a training script, keep in mind that an exported model cannot have graph breaks.
52+
53+
```python
54+
import io
55+
import torch
56+
57+
class MyModule(torch.nn.Module):
58+
def forward(self, x):
59+
return x + 10
60+
61+
ep = torch.export.export(MyModule(), (torch.randn(5),))
62+
63+
# Save to file
64+
# torch.export.save(ep, 'exported_program.pt2')
65+
extra_files = {'foo.txt': b'bar'.decode('utf-8')}
66+
torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
67+
68+
# Save to io.BytesIO buffer
69+
buffer = io.BytesIO()
70+
torch.export.save(ep, buffer)
71+
```
72+
73+
Serve your exported model from a custom handler
74+
75+
```python
76+
# from initialize()
77+
ep = torch.export.load('exported_program.pt2')
78+
79+
with open('exported_program.pt2', 'rb') as f:
80+
buffer = io.BytesIO(f.read())
81+
buffer.seek(0)
82+
ep = torch.export.load(buffer)
83+
84+
# Make sure everything looks good
85+
print(ep)
86+
print(extra_files['foo.txt'])
87+
88+
# from inference()
89+
print(ep(torch.randn(5)))
90+
```
91+
5092

51-
For now PyTorch 2.0 has mostly been focused on accelerating training so production grade applications should instead opt for TensorRT for accelerated inference performance which is also natively supported in torchserve. We just wanted to make it really easy for users to experiment with the PyTorch 2.0 stack. You can learn more here https://github.com/pytorch/serve/blob/master/docs/performance_guide.md

requirements/torch_cu121_linux.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121
2-
--extra-index-url https://download.pytorch.org/whl/test/cu121
2+
--extra-index-url https://download.pytorch.org/whl/cu121
33
-r torch_common.txt
44
torch==2.1.0+cu121; sys_platform == 'linux'
55
torchvision==0.16.0+cu121; sys_platform == 'linux'

requirements/torch_darwin.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#pip install torch torchvision torchaudio
2-
--extra-index-url https://download.pytorch.org/whl/test/cpu
2+
--extra-index-url https://download.pytorch.org/whl/cpu
33
-r torch_common.txt
44
torch==2.1.0; sys_platform == 'darwin'
55
torchvision==0.16.0; sys_platform == 'darwin'

requirements/torch_linux.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
2-
--extra-index-url https://download.pytorch.org/whl/test/cpu
2+
--extra-index-url https://download.pytorch.org/whl/cpu
33
-r torch_common.txt
44
torch==2.1.0+cpu; sys_platform == 'linux'
55
torchvision==0.16.0+cpu; sys_platform == 'linux'

requirements/torch_windows.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#pip install torch torchvision torchaudio
2-
--extra-index-url https://download.pytorch.org/whl/test/cpu
2+
--extra-index-url https://download.pytorch.org/whl/cpu
33
-r torch_common.txt
44
torch==2.1.0; sys_platform == 'win32'
55
torchvision==0.16.0; sys_platform == 'win32'

0 commit comments

Comments
 (0)