Skip to content

Commit

Permalink
AWS Inferentia2 TGI server (#214)
Browse files Browse the repository at this point in the history
* feat(generate): reintroduce batch_size parameter in example

* feat(generate): default number of cores is now 1 on example

* feat(generate): pad inputs up to static batch size

* feat(tgi): add neuronx TGI inference server

The basic features of the Text Generation Inference product are supported:

- continuous batching,
- token streaming,
- greedy search and multinomial sampling using transformers.

The main differences with the standard service for CUDA and CPU backends
are that:

- the service uses a single internal static batch,
- new requests are inserted in the static batch during prefill,
- the static QV cache is rebuilt entirely during prefill (which makes it
  even more costly).

* feat(tgi): build Neuron TGI docker image

* chore(tgi): add TGI docker image Makefile target

* doc(tgi): add TGI README

* feat(tgi): update base TGI version

* review: fix type hints

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* review: fix typo

* review: restrict choices in example

* review: correct docstring

Co-authored-by: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com>

---------

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Co-authored-by: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 8, 2023
1 parent 6c2128d commit 3d297c9
Show file tree
Hide file tree
Showing 13 changed files with 1,071 additions and 12 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ PACKAGE_FILES = $(PACKAGE_PYTHON_FILES) \
$(PACKAGE_DIST) $(PACKAGE_WHEEL): $(PACKAGE_FILES)
python -m build

neuronx-tgi: $(PACKAGE_DIST)
docker build --rm -f text-generation-inference/Dockerfile --build-arg VERSION=$(VERSION) -t neuronx-tgi:$(VERSION) .

# Creates example scripts from Transformers
transformers_examples:
rm -f examples/**/*.py
Expand Down
16 changes: 13 additions & 3 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,19 @@ def generate(model, tokenizer, prompts, length, temperature):
default="One of my fondest memory is",
help="The prompts to use for generation, using | as separator.",
)
parser.add_argument(
"--batch_size",
type=int,
default=None,
help="The batch size (optional). If not specified it will be based on the number of prompts.",
)
parser.add_argument("--length", type=int, default=128, help="The number of tokens in the generated sequences.")
parser.add_argument(
"--num_cores", type=int, default=2, help="The number of cores on which the model should be split."
"--num_cores", type=int, default=1, help="The number of cores on which the model should be split."
)
parser.add_argument(
"--auto_cast_type", type=str, default="f32", choices=["f32", "f16", "bf16"], help="One of f32, f16, bf16."
)
parser.add_argument("--auto_cast_type", type=str, default="f32", help="One of f32, f16, bf16.")
parser.add_argument(
"--temperature",
type=float,
Expand All @@ -72,7 +80,9 @@ def generate(model, tokenizer, prompts, length, temperature):
if args.seed is not None:
set_seed(args.seed)
prompts = args.prompts.split("|")
batch_size = len(prompts)
batch_size = len(prompts) if args.batch_size is None else args.batch_size
if len(prompts) < batch_size:
prompts = prompts + [prompts[-1]] * (batch_size - len(prompts))
model = load_llm_optimum(args.model, batch_size, args.num_cores, args.auto_cast_type)
tokenizer = AutoTokenizer.from_pretrained(args.model)
outputs, latency = generate(model, tokenizer, prompts, args.length, args.temperature)
Expand Down
33 changes: 24 additions & 9 deletions optimum/neuron/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,29 +688,41 @@ def generate(

# Verify that the inputs are compatible with the model static input dimensions
batch_size, sequence_length = input_ids.shape
if batch_size != self.batch_size:
raise ValueError(
f"The specified batch_size ({batch_size}) does not match the model static batch size ({self.batch_size})"
)
if sequence_length > self.max_length:
raise ValueError(
f"The input sequence length ({sequence_length}) exceeds the model static sequence length ({self.max_length})"
)

padded_input_ids = input_ids
padded_attention_mask = attention_mask
if batch_size > self.batch_size:
raise ValueError(
f"The specified batch_size ({batch_size}) exceeds the model static batch size ({self.batch_size})"
)
elif batch_size < self.batch_size:
logger.warning("Inputs will be padded to match the model static batch size. This will increase latency.")
padding_shape = [self.batch_size - batch_size, sequence_length]
padding = torch.full(padding_shape, fill_value=self.config.eos_token_id, dtype=torch.int64)
padded_input_ids = torch.cat([input_ids, padding])
if attention_mask is not None:
padding = torch.zeros(padding_shape, dtype=torch.int64)
padded_attention_mask = torch.cat([attention_mask, padding])
# Drop the current generation context and clear the Key/Value cache
self.reset_generation()

return self.generate_tokens(
input_ids,
output_ids = self.generate_tokens(
padded_input_ids,
selector,
attention_mask=attention_mask,
batch_size,
attention_mask=padded_attention_mask,
**model_kwargs,
)
return output_ids[:batch_size, :]

def generate_tokens(
self,
input_ids: torch.LongTensor,
selector: TokenSelector,
batch_size: int,
attention_mask: Optional[torch.Tensor] = None,
**model_kwargs,
) -> torch.LongTensor:
Expand All @@ -722,6 +734,8 @@ def generate_tokens(
The sequence used as a prompt for the generation.
selector (`TokenSelector`):
The object implementing the generation logic based on transformers processors and stopping criterias.
batch_size (`int`):
The actual input batch size. Used to avoid generating tokens for padded inputs.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices.
model_kwargs:
Expand All @@ -732,7 +746,8 @@ def generate_tokens(
"""
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
unfinished_sequences = torch.zeros(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
unfinished_sequences[:batch_size] = 1

# auto-regressive generation
while True:
Expand Down
131 changes: 131 additions & 0 deletions text-generation-inference/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Fetch and extract the TGI sources
FROM alpine AS tgi
RUN mkdir -p /tgi
ADD https://github.com/huggingface/text-generation-inference/archive/refs/tags/v1.0.2.tar.gz /tgi/sources.tar.gz
RUN tar -C /tgi -xf /tgi/sources.tar.gz --strip-components=1

# Build cargo components (adapted from TGI original Dockerfile)
# Note that the build image is aligned on the same Linux version as the base image (Debian bookworm/ Ubuntu 22.04)
FROM lukemathwalker/cargo-chef:latest-rust-1.71-bookworm AS chef
WORKDIR /usr/src

ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse

FROM chef as planner
COPY --from=tgi /tgi/Cargo.toml Cargo.toml
COPY --from=tgi /tgi/rust-toolchain.toml rust-toolchain.toml
COPY --from=tgi /tgi/proto proto
COPY --from=tgi /tgi/benchmark benchmark
COPY --from=tgi /tgi/router router
COPY --from=tgi /tgi/launcher launcher
RUN cargo chef prepare --recipe-path recipe.json

FROM chef AS builder

RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP

COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json

COPY --from=tgi /tgi/Cargo.toml Cargo.toml
COPY --from=tgi /tgi/rust-toolchain.toml rust-toolchain.toml
COPY --from=tgi /tgi/proto proto
COPY --from=tgi /tgi/benchmark benchmark
COPY --from=tgi /tgi/router router
COPY --from=tgi /tgi/launcher launcher
RUN cargo build --release --workspace --exclude benchmark

# Python base image
FROM ubuntu:22.04 AS base

RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
python3-pip \
python3-setuptools \
python-is-python3 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
RUN pip3 --no-cache-dir install --upgrade pip

# Python server build image
FROM base AS pyserver

RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
make \
python3-venv \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean

RUN install -d /pyserver
WORKDIR /pyserver
COPY text-generation-inference/server server
COPY --from=tgi /tgi/proto proto
RUN pip3 install -r server/build-requirements.txt
RUN VERBOSE=1 BUILDDIR=/pyserver/build PROTODIR=/pyserver/proto make -C server gen-server

# Neuron base image (used for deployment)
FROM base AS neuron

# VERSION is a mandatory parameter
ARG VERSION
RUN test -n ${VERSION:?}

# Install system prerequisites
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
gnupg2 \
wget \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean

RUN echo "deb https://apt.repos.neuron.amazonaws.com jammy main" > /etc/apt/sources.list.d/neuron.list
RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add -

# Install neuronx 2.12.2 packages
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
aws-neuronx-dkms=2.11.9.0 \
aws-neuronx-collectives=2.15.16.0-db4e2d9a9 \
aws-neuronx-runtime-lib=2.15.14.0-279f319f2 \
aws-neuronx-tools=2.12.2.0 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean

ENV PATH="/opt/bin/:/opt/aws/neuron/bin:${PATH}"

RUN pip3 install \
torch-neuronx==1.13.1.1.9.1 \
transformers-neuronx==0.5.58 \
--extra-index-url=https://pip.repos.neuron.amazonaws.com

# Install HuggingFace packages
RUN pip3 install \
hf_transfer

# Install optimum-neuron
COPY dist/optimum-neuron-${VERSION}.tar.gz optimum-neuron.tar.gz
RUN pip3 install optimum-neuron.tar.gz

# TGI base env
ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80

# Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
# Install python server
COPY --from=pyserver /pyserver/build/dist dist
RUN pip install dist/text-generation-server*.tar.gz

# Final image
FROM neuron

ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]
116 changes: 116 additions & 0 deletions text-generation-inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Text-generation-inference docker image

This docker image integrates into a base image:

- the AWS Neuron SDK for Inferentia2,
- the [Text Generation Inference](https://github.com/huggingface/text-generation-inference) launcher and scheduling front-end,
- a neuron specific inference server for text-generation.

## Features

The basic features of the [Text Generation Inference](https://github.com/huggingface/text-generation-inference) product are supported:

- continuous batching,
- token streaming,
- greedy search and multinomial sampling using [transformers](https://huggingface.co/docs/transformers/generation_strategies#customize-text-generation).

The main differences with the standard service for CUDA and CPU backends are that:

- the service uses a single internal static batch,
- new requests are inserted in the static batch during prefill,
- the static KV cache is rebuilt entirely during prefill (which makes it even more costly).

## Build image

The image must be built from the top directory

```
make neuronx-tgi
```

## Deploy the service

The service is launched simply by running the neuronx-tgi container with two sets of parameters:

```
docker run <system_parameters> neuronx-tgi:<version> <service_parameters>
```

- system parameters are used to map ports, volumes and devices between the host and the service,
- service parameters are forwarded to the `text-generation-launcher`.

The snippet below shows how you can deploy a service from a hub neuron model:

```
docker run -p 8080:80 \
--device=/dev/neuron0 \
neuronx-tgi:<version> \
--model-id optimum/gpt2-neuronx-bs16 \
--max-concurrent-requests 16 \
--max-input-length 512 \
--max-total-tokens 1024 \
--max-batch-prefill-tokens 8192 \
--max-batch-total-tokens 16384
```

Alternatively, you can first compile the model locally, and deploy the service using a shared volume:

```
docker run -p 8080:80 \
-v $(pwd)/data:/data \
--device=/dev/neuron0 \
neuronx-tgi:0.0.11.dev0 \
--model-id /data/neuron_gpt2_bs16 \
--max-concurrent-requests 16 \
--max-input-length 512 \
--max-total-tokens 1024 \
--max-batch-prefill-tokens 8192 \
--max-batch-total-tokens 16384
```

### Choosing service parameters

Use the following command to list the available service parameters:

```
docker run neuronx-tgi --help
```

The configuration of an inference endpoint is always a compromise between throughput and latency: serving more requests in parallel will allow a higher throughput, but it will increase the latency.

The neuron models have static input dimensions `[batch_size, max_length]`.

It leads to a maximum number of tokens of `max_tokens = batch_size * max_length`.

This adds several restrictions to the following parameters:

- `--max-concurrent-requests` must be set to `batch size`,
- `--max-input-length` must be lower than `max_length`,
- `--max-total-tokens` must be set to `max_length` (it is per-request),
- `--max-batch-prefill-tokens` must be lower than `max_tokens`,
- `--max-batch-total-tokens` must be set to `max_tokens`.

### Choosing the correct batch size

As seen in the previous paragraph, neuron model static batch size has a direct influence on the endpoint latency and throughput.

For GPT2, a good compromise is a batch size of 16. If you need to absorb more load, then you can try a model compiled with a batch size of 128, but be aware
that the latency will increase a lot.

## Query the service

You can query the model using either the `/generate` or `/generate_stream` routes:

```
curl 127.0.0.1:8080/generate \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}' \
-H 'Content-Type: application/json'
```

```
curl 127.0.0.1:8080/generate_stream \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
-H 'Content-Type: application/json'
```
Loading

0 comments on commit 3d297c9

Please sign in to comment.