Skip to content

Commit 357e763

Browse files
authored
Merge pull request #2 from triton-inference-server/kaiyu/update
Update TensorRT-LLM backend code
2 parents d5c3ef6 + 7d171ca commit 357e763

File tree

12 files changed

+1123
-834
lines changed

12 files changed

+1123
-834
lines changed

README.md

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,78 @@
11
# TensorRT-LLM Backend
2-
The Triton backend for TensorRT-LLM.
2+
The Triton backend for [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM).
33

4-
## Usage
4+
## Introduction
5+
6+
This document describes how to serve models by TensorRT-LLM Triton backend. This backend is only an interface to call TensorRT-LLM in Triton. The heavy lifting, in terms of implementation, can be found in the TensorRT-LLM source code.
7+
8+
## Setup Environment
9+
10+
### Prepare the repository
11+
12+
Clone the repository, and update submodules recursively.
13+
```
14+
git clone git@github.com:triton-inference-server/tensorrtllm_backend.git
15+
git submodule update --init --recursive
16+
git lfs install
17+
git lfs pull
18+
```
19+
20+
### Build the Docker image.
21+
```
22+
cd tensorrtllm_backend
23+
docker build -f dockerfile/Dockerfile.trt_llm_backend -t tritonserver:w_trt_llm_backend .
24+
```
25+
26+
The rest of the documentation assumes that the Docker image has already been built.
27+
28+
### How to select the models
29+
There are two models under `all_models/`:
30+
- gpt: A Python implementation of the TensorRT-LLM Triton backend
31+
- inflight_batcher_llm: A C++ implementation of the TensorRT-LLM Triton backend
32+
33+
### Prepare TensorRT-LLM engines
34+
Follow the [guide](https://github.com/NVIDIA/TensorRT-LLM/blob/main/README.md) in TensorRT-LLM to prepare the engines for deployment.
35+
36+
For example, please find the details in the document of TensorRT-LLM GPT for instrutions to build GPT engines: [link](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/gpt#usage)
37+
38+
### How to set the model configuration
39+
40+
**TensorRT-LLM Triton Serving Configuration: config.pbtxt**
41+
42+
- This will be loaded by Triton servers
43+
- This mainly describes the server and TensorRT-LLM inference hyperparameters.
44+
45+
There are several components in each implemented backend, and there is a config.pbtxt for each component, take `all_models/inflight_batcher_llm` as an example:
46+
- preprocessing: Used for tokenizing.
47+
- tensorrt_llm: Inferencing.
48+
- postprocessing: Used for de-tokenizing.
49+
- ensemble: Connect preprocessing -> tensorrt_llm -> postprocessing
50+
51+
The following table shows the fields that need to be modified before deployment:
52+
53+
*all_models/inflight_batcher_llm/preprocessing/config.pbtxt*
54+
55+
| Name | Description
56+
| :----------------------: | :-----------------------------: |
57+
| `tokenizer_dir` | The path to the tokenizer for the model |
58+
| `tokenizer_type` | The type of the tokenizer for the model, t5, auto and llama are supported |
59+
60+
*all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt*
61+
62+
| Name | Description
63+
| :----------------------: | :-----------------------------: |
64+
| `decoupled` | Controls streaming. Decoupled mode must be set to true if using the streaming option from the client. |
65+
| `gpt_model_type` | "inflight_fused_batching" or "V1" (disable in-flight batching) |
66+
| `gpt_model_path` | Path to the TensorRT-LLM engines for deployment |
67+
68+
*all_models/inflight_batcher_llm/postprocessing/config.pbtxt*
69+
70+
| Name | Description
71+
| :----------------------: | :-----------------------------: |
72+
| `tokenizer_dir` | The path to the tokenizer for the model |
73+
| `tokenizer_type` | The type of the tokenizer for the model, t5, auto and llama are supported |
74+
75+
## Run Serving on Single Node
576

677
### Launch the backend *within Docker*
778

@@ -15,7 +86,7 @@ nvidia-docker run -it --rm -e LOCAL_USER_ID=`id -u ${USER}` --shm-size=2g -v <yo
1586
3. all_models/<model>/postprocessing/config.pbtxt
1687

1788
# 3. Launch triton server
18-
python3 scripts/launch_triton_server.py --world_size=1 \
89+
python3 scripts/launch_triton_server.py --world_size=<num_gpus> \
1990
--model_repo=all_models/<model>
2091
```
2192

@@ -56,20 +127,28 @@ ${TRITONSERVER} --model-repository=${MODEL_REPO} --disable-auto-complete-config
56127
sbatch tensorrt_llm_triton.sub
57128
```
58129

130+
When successfully deployed, the server produces logs similar to the following ones.
131+
```
132+
I0919 14:52:10.475738 293 grpc_server.cc:2451] Started GRPCInferenceService at 0.0.0.0:8001
133+
I0919 14:52:10.475968 293 http_server.cc:3558] Started HTTPService at 0.0.0.0:8000
134+
I0919 14:52:10.517138 293 http_server.cc:187] Started Metrics Service at 0.0.0.0:8002
135+
```
136+
59137
### Kill the backend
60138

61139
```bash
62140
pgrep tritonserver | xargs kill -9
63141
```
64142

65-
## Examples
143+
## C++ backend examples (support inflight batching)
144+
Please follow the guide in [`inflight_batcher_llm/README.md`](inflight_batcher_llm/README.md).
145+
146+
## Python backend examples (not support inflight batching)
66147

67-
### GPT/OPT/LLaMA/GPT-J...
148+
### GPT
68149
```bash
69150
cd tools/gpt/
70151

71-
# Download vocab and merge table for HF models
72-
# Take GPT as an example:
73152
rm -rf gpt2 && git clone https://huggingface.co/gpt2
74153
pushd gpt2 && rm pytorch_model.bin model.safetensors && \
75154
wget -q https://huggingface.co/gpt2/resolve/main/pytorch_model.bin && popd

all_models/gpt/preprocessing/1/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
import numpy as np
66
import torch
77
import triton_python_backend_utils as pb_utils
8+
from tensorrt_llm.runtime import to_word_list_format
89
from torch.nn.utils.rnn import pad_sequence
910
from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer
1011

11-
from tensorrt_llm.runtime import to_word_list_format
12-
1312

1413
class TritonPythonModel:
1514
"""Your Python model must use the same class name. Every Python model

all_models/gpt/tensorrt_llm/1/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
import torch
55
import triton_python_backend_utils as pb_utils
6+
from tensorrt_llm.runtime import GenerationSession, ModelConfig, SamplingConfig
67
from torch import from_numpy
78

89
import tensorrt_llm
9-
from tensorrt_llm.runtime import GenerationSession, ModelConfig, SamplingConfig
1010

1111

1212
def mpi_comm():

all_models/inflight_batcher_llm/preprocessing/1/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
import numpy as np
66
import torch
77
import triton_python_backend_utils as pb_utils
8+
from tensorrt_llm.runtime import to_word_list_format
89
from torch.nn.utils.rnn import pad_sequence
910
from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer
1011

11-
from tensorrt_llm.runtime import to_word_list_format
12-
1312

1413
class TritonPythonModel:
1514
"""Your Python model must use the same class name. Every Python model

all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,27 @@ parameters: {
162162
string_value: "${engine_dir}"
163163
}
164164
}
165+
parameters: {
166+
key: "max_tokens_in_paged_kv_cache"
167+
value: {
168+
string_value: "${max_tokens_in_paged_kv_cache}"
169+
}
170+
}
171+
parameters: {
172+
key: "batch_scheduler_policy"
173+
value: {
174+
string_value: "${batch_scheduler_policy}"
175+
}
176+
}
177+
parameters: {
178+
key: "kv_cache_free_gpu_mem_fraction"
179+
value: {
180+
string_value: "${kv_cache_free_gpu_mem_fraction}"
181+
}
182+
}
183+
parameters: {
184+
key: "max_num_sequences"
185+
value: {
186+
string_value: "${max_num_sequences}"
187+
}
188+
}

dockerfile/Dockerfile.trt_llm_backend

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver:23.07-py3
1+
ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver
2+
ARG BASE_TAG=23.08-py3
23

3-
FROM ${BASE_IMAGE} as base
4+
FROM ${BASE_IMAGE}:${BASE_TAG} as base
45

56
COPY requirements.txt /tmp/
67
RUN pip3 install -r /tmp/requirements.txt --extra-index-url https://pypi.ngc.nvidia.com
@@ -10,17 +11,37 @@ RUN pip3 install -r /tmp/requirements.txt --extra-index-url https://pypi.ngc.nvi
1011
RUN apt-get remove --purge -y tensorrt*
1112
RUN pip uninstall -y tensorrt
1213

13-
# Download and install TensorRT
14-
RUN wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/9.0.1/tars/TensorRT-9.0.1.4.Linux.x86_64-gnu.cuda-12.2.tar.gz -P /workspace
15-
RUN tar -xvf /workspace/TensorRT-9.0.1.4.Linux.x86_64-gnu.cuda-12.2.tar.gz -C /usr/local/ && mv /usr/local/TensorRT-9.0.1.4 /usr/local/tensorrt
16-
RUN pip install /usr/local/tensorrt/python/tensorrt-9.0.1*cp310-none-linux_x86_64.whl && rm -fr /workspace/TensorRT-9.0.1.4.Linux.x86_64-gnu.cuda-12.2.tar.gz
17-
ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib/:$LD_LIBRARY_PATH
18-
ENV TRT_ROOT=/usr/local/tensorrt
19-
2014
FROM base as dev
2115

22-
# Download and install polygraphy, only required if you need to run TRT-LLM python tests
23-
RUN pip install https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/9.0.1/tars/polygraphy-0.48.1-py2.py3-none-any.whl
16+
# Download & install internal TRT release
17+
ARG TENSOR_RT_VERSION="9.1.0.1"
18+
ARG CUDA_VERSION="12.2"
19+
ARG RELEASE_URL_TRT
20+
ARG TARGETARCH
21+
22+
RUN --mount=type=cache,target=/root/.cache \
23+
if [ -z "$RELEASE_URL_TRT"];then \
24+
ARCH=${TARGETARCH} && \
25+
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi && \
26+
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi && \
27+
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi &&\
28+
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04"; else OS1="Linux" && OS2="Linux";fi &&\
29+
RELEASE_URL_TRT=http://cuda-repo.nvidia.com/release-candidates/Libraries/TensorRT/v9.1/${TENSOR_RT_VERSION}-b6aa91dc/${CUDA_VERSION}-r535/${OS1}-${DIR_NAME}/tar/TensorRT-${TENSOR_RT_VERSION}.${OS2}.${ARCH}-gnu.cuda-${CUDA_VERSION}.tar.gz;\
30+
fi &&\
31+
wget --no-verbose ${RELEASE_URL_TRT} -O /workspace/TensorRT.tar && \
32+
tar -xf /workspace/TensorRT.tar -C /usr/local/ && \
33+
mv /usr/local/TensorRT-${TENSOR_RT_VERSION} /usr/local/tensorrt && \
34+
pip install /usr/local/tensorrt/python/tensorrt-*-cp310-*.whl && \
35+
rm -rf /workspace/TensorRT.tar
36+
37+
ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:${LD_LIBRARY_PATH}
38+
ENV TRT_ROOT=/usr/local/tensorrt
39+
40+
# Install latest Polygraphy
41+
ARG RELEASE_URL_PG=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/9.0.1/tars/polygraphy-0.48.1-py2.py3-none-any.whl
42+
RUN --mount=type=cache,target=/root/.cache \
43+
pip uninstall -y polygraphy && \
44+
pip install ${RELEASE_URL_PG}
2445

2546
# CMake
2647
RUN wget https://github.com/Kitware/CMake/releases/download/v3.18.1/cmake-3.18.1-Linux-x86_64.sh
@@ -35,13 +56,13 @@ FROM dev as trt_llm_builder
3556
WORKDIR /app
3657
COPY scripts scripts
3758
COPY tensorrt_llm tensorrt_llm
38-
RUN cd tensorrt_llm; python3 scripts/build_wheel.py --trt_root="${TRT_ROOT}" -i; cd ..
59+
RUN cd tensorrt_llm && python3 scripts/build_wheel.py --trt_root="${TRT_ROOT}" -i && cd ..
3960

4061
FROM trt_llm_builder as trt_llm_backend_builder
4162

4263
WORKDIR /app/
4364
COPY inflight_batcher_llm inflight_batcher_llm
44-
RUN cd inflight_batcher_llm; bash scripts/build.sh; cd ..
65+
RUN cd inflight_batcher_llm && bash scripts/build.sh && cd ..
4566

4667
FROM trt_llm_backend_builder as final
4768

0 commit comments

Comments
 (0)