Skip to content

Commit

Permalink
[Core] Pipeline parallel with Ray ADAG (vllm-project#6837)
Browse files Browse the repository at this point in the history
Support pipeline-parallelism with Ray accelerated DAG.

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
  • Loading branch information
ruisearch42 authored Aug 2, 2024
1 parent a8d604c commit 0530889
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 77 deletions.
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ WORKDIR /workspace

# install build and runtime dependencies
COPY requirements-common.txt requirements-common.txt
COPY requirements-adag.txt requirements-adag.txt
COPY requirements-cuda.txt requirements-cuda.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-cuda.txt
Expand Down Expand Up @@ -78,6 +79,7 @@ COPY setup.py setup.py
COPY cmake cmake
COPY CMakeLists.txt CMakeLists.txt
COPY requirements-common.txt requirements-common.txt
COPY requirements-adag.txt requirements-adag.txt
COPY requirements-cuda.txt requirements-cuda.txt
COPY pyproject.toml pyproject.toml
COPY vllm vllm
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
include LICENSE
include requirements-adag.txt
include requirements-common.txt
include requirements-cuda.txt
include requirements-rocm.txt
Expand Down
3 changes: 3 additions & 0 deletions requirements-adag.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Dependencies for Ray accelerated DAG
cupy-cuda12x
ray >= 2.32
3 changes: 3 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Needed for Ray accelerated DAG tests
-r requirements-adag.txt

# testing
pytest
tensorizer>=2.9.0
Expand Down
51 changes: 35 additions & 16 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,31 @@


@pytest.mark.parametrize(
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, DIST_BACKEND",
[
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
])
@fork_new_process_for_each_test
("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, "
"MODEL_NAME, DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL"), [
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
])
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
DIST_BACKEND):
DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL):
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
Expand Down Expand Up @@ -67,8 +76,18 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
if EAGER_MODE:
pp_args.append("--enforce-eager")
tp_args.append("--enforce-eager")
pp_env = None
if USE_RAY_ADAG:
assert DIST_BACKEND == "ray", (
"Ray ADAG is only supported with Ray distributed backend")
pp_env = {
"VLLM_USE_RAY_COMPILED_DAG": "1",
"VLLM_USE_RAY_SPMD_WORKER": "1",
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL":
str(int(USE_RAY_ADAG_NCCL)),
}

compare_two_settings(MODEL_NAME, pp_args, tp_args)
compare_two_settings(MODEL_NAME, pp_args, tp_args, pp_env)


@pytest.mark.parametrize("PP_SIZE, MODEL_NAME", [
Expand Down
31 changes: 25 additions & 6 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import openai
import ray
Expand Down Expand Up @@ -57,6 +57,7 @@ def __init__(
model: str,
cli_args: List[str],
*,
env_dict: Optional[Dict[str, str]] = None,
auto_port: bool = True,
) -> None:
if auto_port:
Expand All @@ -77,6 +78,8 @@ def __init__(
# the current process might initialize cuda,
# to be safe, we should use spawn method
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if env_dict is not None:
env.update(env_dict)
self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args,
env=env,
stdout=sys.stdout,
Expand All @@ -89,6 +92,11 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
self.proc.terminate()
try:
self.proc.wait(3)
except subprocess.TimeoutExpired:
# force kill if needed
self.proc.kill()

def _wait_for_server(self, *, url: str, timeout: float):
# run health check
Expand Down Expand Up @@ -127,19 +135,30 @@ def get_async_client(self):
)


def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
def compare_two_settings(model: str,
arg1: List[str],
arg2: List[str],
env1: Optional[Dict[str, str]] = None,
env2: Optional[Dict[str, str]] = None):
"""
Launch API server with two different sets of arguments and compare the
results of the API calls. The arguments are after the model name.
Launch API server with two different sets of arguments/environments
and compare the results of the API calls.
Args:
model: The model to test.
arg1: The first set of arguments to pass to the API server.
arg2: The second set of arguments to pass to the API server.
env1: The first set of environment variables to pass to the API server.
env2: The second set of environment variables to pass to the API server.
"""

tokenizer = AutoTokenizer.from_pretrained(model)

prompt = "Hello, my name is"
token_ids = tokenizer(prompt)["input_ids"]
results = []
for args in (arg1, arg2):
with RemoteOpenAIServer(model, args) as server:
for args, env in ((arg1, env1), (arg2, env2)):
with RemoteOpenAIServer(model, args, env_dict=env) as server:
client = server.get_client()

# test models list
Expand Down
12 changes: 10 additions & 2 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
VLLM_USE_RAY_SPMD_WORKER: bool = False
VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: bool = True
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
Expand Down Expand Up @@ -273,13 +274,20 @@ def get_default_config_root():
# execution on all workers.
# Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it.
"VLLM_USE_RAY_SPMD_WORKER":
lambda: bool(os.getenv("VLLM_USE_RAY_SPMD_WORKER", 0)),
lambda: bool(int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0"))),

# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
"VLLM_USE_RAY_COMPILED_DAG":
lambda: bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)),
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))),

# If the env var is set, it uses NCCL for communication in
# Ray's compiled DAG. This flag is ignored if
# VLLM_USE_RAY_COMPILED_DAG is not set.
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL":
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", "1"))
),

# Use dedicated multiprocess context for workers.
# Both spawn and fork work
Expand Down
Loading

0 comments on commit 0530889

Please sign in to comment.