Skip to content

Commit

Permalink
Merge branch 'vllm-project:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
sroy745 authored Jul 10, 2024
2 parents eb7a1c4 + 5ed3505 commit 7e2c87e
Show file tree
Hide file tree
Showing 113 changed files with 6,651 additions and 1,121 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 250 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.752
- name: "exact_match,flexible-extract"
value: 0.752
limit: 250
num_fewshot: 5
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1
model_name: "neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
tasks:
- name: "gsm8k"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.728
- name: "exact_match,flexible-extract"
value: 0.728
limit: 250
num_fewshot: 5
2 changes: 2 additions & 0 deletions .buildkite/lm-eval-harness/configs/models-small.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
Meta-Llama-3-8B-Instruct.yaml
Meta-Llama-3-8B-Instruct-FP8.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
done

lm_eval --model vllm \
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE \
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true \
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
--batch_size $BATCH_SIZE
3 changes: 2 additions & 1 deletion .buildkite/lm-eval-harness/test_lm_eval_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

def launch_lm_eval(eval_config):
model_args = f"pretrained={eval_config['model_name']}," \
f"tensor_parallel_size={TP_SIZE}"
f"tensor_parallel_size={TP_SIZE}," \
f"add_bos_token=true"

results = lm_eval.simple_evaluate(
model="vllm",
Expand Down
77 changes: 77 additions & 0 deletions .buildkite/run-multi-node-test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/bin/bash

set -euox pipefail

if [[ $# -lt 3 ]]; then
echo "Please provide the number of nodes and GPU per node."
exit 1
fi

NUM_NODES=$1
NUM_GPUS=$2
DOCKER_IMAGE=$3

shift 3
COMMANDS=("$@")
if [ ${#COMMANDS[@]} -ne $NUM_NODES ]; then
echo "The number of commands must be equal to the number of nodes."
echo "Number of nodes: $NUM_NODES"
echo "Number of commands: ${#COMMANDS[@]}"
exit 1
fi

echo "List of commands"
for command in "${COMMANDS[@]}"; do
echo $command
done

start_network() {
docker network create --subnet=192.168.10.0/24 docker-net
}

start_nodes() {
for node in $(seq 0 $(($NUM_NODES-1))); do
GPU_DEVICES='"device='
for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do
DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu))
GPU_DEVICES+=$(($DEVICE_NUM))
if [ $node_gpu -lt $(($NUM_GPUS - 1)) ]; then
GPU_DEVICES+=','
fi
done
GPU_DEVICES+='"'
# echo "Starting node$node with GPU devices: $GPU_DEVICES"
docker run -d --gpus "$GPU_DEVICES" --name node$node --network docker-net --ip 192.168.10.$((10 + $node)) --rm $DOCKER_IMAGE tail -f /dev/null
done
}

run_nodes() {
for node in $(seq 0 $(($NUM_NODES-1))); do
GPU_DEVICES='"device='
for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do
DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu))
GPU_DEVICES+=$(($DEVICE_NUM))
if [ $node_gpu -lt $(($NUM_GPUS - 1)) ]; then
GPU_DEVICES+=','
fi
done
GPU_DEVICES+='"'
echo "Running node$node with GPU devices: $GPU_DEVICES"
if [ $node -lt $(($NUM_NODES - 1)) ]; then
docker exec -d node$node /bin/bash -c "${COMMANDS[$node]}"
else
docker exec node$node /bin/bash -c "${COMMANDS[$node]}"
fi
done
}
cleanup() {
for node in $(seq 0 $(($NUM_NODES-1))); do
docker stop node$node
done
docker network rm docker-net
}
trap cleanup EXIT
start_network
start_nodes
run_nodes

3 changes: 3 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir

RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
#################### vLLM installation IMAGE ####################


Expand Down
10 changes: 7 additions & 3 deletions Dockerfile.tpu
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@ ARG NIGHTLY_DATE="20240601"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"

FROM $BASE_IMAGE

WORKDIR /workspace
COPY . /workspace/vllm

ENV VLLM_TARGET_DEVICE="tpu"
# Install aiohttp separately to avoid build errors.
RUN pip install aiohttp
# Install the TPU and Pallas dependencies.
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

# Build vLLM.
COPY . /workspace/vllm
ENV VLLM_TARGET_DEVICE="tpu"
RUN cd /workspace/vllm && python setup.py develop

# Re-install outlines to avoid dependency errors.
# The outlines version must follow requirements-common.txt.
RUN pip uninstall outlines -y
RUN pip install "outlines>=0.0.43"

CMD ["/bin/bash"]
65 changes: 60 additions & 5 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
--dataset-path <path to dataset> \
--request-rate <request_rate> \ # By default <request_rate> is inf
--num-prompts <num_prompts> # By default <num_prompts> is 1000
when using tgi backend, add
--endpoint /generate_stream
to the end of the command above.
Expand Down Expand Up @@ -77,7 +77,6 @@ def sample_sharegpt_requests(
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")

# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
Expand Down Expand Up @@ -185,6 +184,31 @@ def sample_sonnet_requests(
return sampled_requests


def sample_random_requests(
input_len: int, output_len: int, num_prompts: int, range_ratio: float,
tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, int, int]]:

input_lens = np.random.randint(
int(input_len * range_ratio),
input_len + 1,
size=num_prompts,
)
output_lens = np.random.randint(
int(output_len * range_ratio),
output_len + 1,
size=num_prompts,
)
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
input_requests = []
for i in range(args.num_prompts):
prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size
for j in range(input_lens[i])])
input_requests.append(
(prompt, int(input_lens[i]), int(output_lens[i])))

return input_requests


async def get_request(
input_requests: List[Tuple[str, int, int]],
request_rate: float,
Expand All @@ -196,6 +220,7 @@ async def get_request(
if request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait.
continue

# Sample the request interval from the exponential distribution.
interval = np.random.exponential(1.0 / request_rate)
# The next request will be sent after the interval.
Expand All @@ -219,7 +244,7 @@ def calculate_metrics(
# We use the tokenizer to count the number of output tokens for all
# serving backends instead of looking at len(outputs[i].itl) since
# multiple output tokens may be bundled together
# Note: this may inflate the output token count slightly
# Note : this may inflate the output token count slightly
output_len = len(
tokenizer(outputs[i].generated_text,
add_special_tokens=False).input_ids)
Expand Down Expand Up @@ -456,6 +481,15 @@ def main(args: argparse.Namespace):
for prompt, prompt_formatted, prompt_len,
output_len in input_requests]

elif args.dataset_name == "random":
input_requests = sample_random_requests(
input_len=args.random_input_len,
output_len=args.random_output_len,
num_prompts=args.num_prompts,
range_ratio=args.random_range_ratio,
tokenizer=tokenizer,
)

else:
raise ValueError(f"Unknown dataset: {args.dataset_name}")

Expand Down Expand Up @@ -549,7 +583,7 @@ def main(args: argparse.Namespace):
"--dataset-name",
type=str,
default="sharegpt",
choices=["sharegpt", "sonnet"],
choices=["sharegpt", "sonnet", "random"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument("--dataset-path",
Expand All @@ -566,7 +600,7 @@ def main(args: argparse.Namespace):
"--tokenizer",
type=str,
help=
"Name or path of the tokenizer, if not using the default tokenizer.",
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
)
parser.add_argument(
"--best-of",
Expand Down Expand Up @@ -609,6 +643,27 @@ def main(args: argparse.Namespace):
help=
"Number of prefix tokens per request, used only for sonnet dataset.",
)
parser.add_argument(
"--random-input-len",
type=int,
default=1024,
help=
"Number of input tokens per request, used only for random sampling.",
)
parser.add_argument(
"--random-output-len",
type=int,
default=128,
help=
"Number of output tokens per request, used only for random sampling.",
)
parser.add_argument(
"--random-range-ratio",
type=float,
default=1.0,
help="Range of sampled ratio of input/output length, "
"used only for random sampling.",
)
parser.add_argument(
"--request-rate",
type=float,
Expand Down
4 changes: 4 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ Vision Language Models
- LLaVA-NeXT
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
-
* - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc.
-
* - :code:`Phi3VForCausalLM`
- Phi-3-Vision
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
Expand Down
17 changes: 15 additions & 2 deletions docs/source/serving/distributed_serving.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Distributed Inference and Serving
=================================

vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We manage the distributed runtime with either `Ray <https://github.com/ray-project/ray>`_ or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inferencing currently requires Ray.
vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We also support pipeline parallel as a beta feature for online serving. We manage the distributed runtime with either `Ray <https://github.com/ray-project/ray>`_ or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inferencing currently requires Ray.

Multiprocessing will be used by default when not running in a Ray placement group and if there are sufficient GPUs available on the same node for the configured :code:`tensor_parallel_size`, otherwise Ray will be used. This default can be overridden via the :code:`LLM` class :code:`distributed-executor-backend` argument or :code:`--distributed-executor-backend` API server argument. Set it to :code:`mp` for multiprocessing or :code:`ray` for Ray. It's not required for Ray to be installed for the multiprocessing case.

Expand All @@ -23,6 +23,19 @@ To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument wh
$ --model facebook/opt-13b \
$ --tensor-parallel-size 4
You can also additionally specify :code:`--pipeline-parallel-size` to enable pipeline parallelism. For example, to run API server on 8 GPUs with pipeline parallelism and tensor parallelism:

.. code-block:: console
$ python -m vllm.entrypoints.openai.api_server \
$ --model gpt2 \
$ --tensor-parallel-size 4 \
$ --pipeline-parallel-size 2 \
$ --distributed-executor-backend ray
.. note::
Pipeline parallel is a beta feature. It is only supported for online serving and the ray backend for now, as well as LLaMa and GPT2 style models.

To scale vLLM beyond a single machine, install and start a `Ray runtime <https://docs.ray.io/en/latest/ray-core/starting-ray.html>`_ via CLI before running vLLM:

.. code-block:: console
Expand All @@ -35,7 +48,7 @@ To scale vLLM beyond a single machine, install and start a `Ray runtime <https:/
$ # On worker nodes
$ ray start --address=<ray-head-address>
After that, you can run inference and serving on multiple machines by launching the vLLM process on the head node by setting :code:`tensor_parallel_size` to the number of GPUs to be the total number of GPUs across all machines.
After that, you can run inference and serving on multiple machines by launching the vLLM process on the head node by setting :code:`tensor_parallel_size` multiplied by :code:`pipeline_parallel_size` to the number of GPUs to be the total number of GPUs across all machines.

.. warning::
Please make sure you downloaded the model to all the nodes, or the model is downloaded to some distributed file system that is accessible by all nodes.
52 changes: 52 additions & 0 deletions examples/paligemma_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
import subprocess

from PIL import Image

from vllm import LLM

# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
# You can use `.buildkite/download-images.sh` to download them


def run_paligemma():
llm = LLM(model="google/paligemma-3b-mix-224")

prompt = "caption es"

image = Image.open("images/stop_sign.jpg")

outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {
"image": image
},
})

for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)


def main():
run_paligemma()


if __name__ == "__main__":
# Download from s3
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
local_directory = "images"

# Make sure the local directory exists or create it
os.makedirs(local_directory, exist_ok=True)

# Use AWS CLI to sync the directory, assume anonymous access
subprocess.check_call([
"aws",
"s3",
"sync",
s3_bucket_path,
local_directory,
"--no-sign-request",
])
main()
Loading

0 comments on commit 7e2c87e

Please sign in to comment.