Skip to content

Commit

Permalink
[ci][distributed] fix device count call
Browse files Browse the repository at this point in the history
[ci][distributed] fix some cuda init that makes it necessary to use spawn (vllm-project#5991)
  • Loading branch information
youkaichao authored and jimpang committed Jul 8, 2024
1 parent 307e13b commit e42cac1
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 51 deletions.
12 changes: 1 addition & 11 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ steps:
num_gpus: 2
commands:
- bash ../.buildkite/download-images.sh
# FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
Expand All @@ -60,8 +57,7 @@ steps:
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
# FIXIT: find out why TP is failing with mp backend on phi3-v
# - TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
Expand All @@ -71,9 +67,6 @@ steps:
working_dir: "/vllm-workspace/tests"
num_gpus: 4
commands:
# FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s distributed/test_pynccl.py
# We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here.
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
Expand Down Expand Up @@ -225,9 +218,6 @@ steps:
gpu: a100
num_gpus: 4
commands:
# FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
# NOTE: don't test llama model here, it seems hf implementation is buggy
# see https://github.com/vllm-project/vllm/pull/5689 for details
- pytest -v -s distributed/test_custom_all_reduce.py
Expand Down
21 changes: 16 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,29 @@
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import (Any, Dict, List, Literal, Optional, Tuple, TypedDict,
TypeVar)
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple,
TypedDict, TypeVar)

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
AutoProcessor, AutoTokenizer, BatchEncoding)
AutoTokenizer, BatchEncoding)

from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
from vllm.multimodal import MultiModalData
from vllm.multimodal.image import ImageFeatureData, ImagePixelData

if TYPE_CHECKING:
from vllm.multimodal import MultiModalData
else:
# it will call torch.cuda.device_count()
MultiModalData = None
from vllm.sequence import SampleLogprobs
from vllm.utils import cuda_device_count_stateless, is_cpu

Expand Down Expand Up @@ -63,6 +67,10 @@ def for_hf(self) -> Image.Image:
return self.pil_image

def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData:
# don't put this import at the top level
# it will call torch.cuda.device_count()
from vllm.multimodal.image import ImageFeatureData # noqa: F401
from vllm.multimodal.image import ImagePixelData
image_input_type = vision_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType

Expand Down Expand Up @@ -216,6 +224,9 @@ def __init__(
)

try:
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor # noqa: F401
self.processor = AutoProcessor.from_pretrained(
model_name,
torch_dtype=torch_dtype,
Expand Down
15 changes: 10 additions & 5 deletions tests/distributed/test_basic_distributed_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import os

import pytest
import torch

from vllm.utils import cuda_device_count_stateless

from ..models.utils import check_outputs_equal

Expand All @@ -25,7 +26,7 @@
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"


@pytest.mark.skipif(torch.cuda.device_count() < 2,
@pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
Expand All @@ -40,16 +41,20 @@ def test_models(
) -> None:
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
Expand Down
14 changes: 10 additions & 4 deletions tests/distributed/test_chunked_prefill_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import os

import pytest
import torch

from vllm.utils import cuda_device_count_stateless

from ..models.utils import check_outputs_equal

Expand All @@ -24,7 +25,7 @@
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"


@pytest.mark.skipif(torch.cuda.device_count() < 2,
@pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
Expand All @@ -47,8 +48,10 @@ def test_models(
enable_chunked_prefill = True
max_num_batched_tokens = chunked_prefill_token_size

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

with vllm_runner(
model,
Expand All @@ -61,6 +64,9 @@ def test_models(
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
Expand Down
30 changes: 20 additions & 10 deletions tests/models/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,28 +88,38 @@ def run_test(
"""
model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets]
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]

with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
max_tokens,
images=hf_images)

vllm_image_prompts = [
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
for p in HF_IMAGE_PROMPTS
]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

with vllm_runner(model_id,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
**vlm_config.as_cli_args_dict()) as vllm_model:

# NOTE: `asset.for_vllm` will call `torch.cuda.device_count()`
# we must put it inside the vllm_runner context manager
# i.e. after creating vLLM instance.
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]

vllm_image_prompts = [
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
for p in HF_IMAGE_PROMPTS
]

vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
images=vllm_images)

with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
max_tokens,
images=hf_images)

check_outputs_equal(
hf_outputs,
[
Expand Down
42 changes: 26 additions & 16 deletions tests/models/test_phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,11 @@ def run_test(
"""
model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets]
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]

# use eager mode for hf runner, since phi3_v didn't work with flash_attn
hf_model_kwargs = {"_attn_implementation": "eager"}
with hf_runner(model_id, dtype=dtype,
model_kwargs=hf_model_kwargs) as hf_model:
hf_outputs = hf_model.generate_greedy(
HF_IMAGE_PROMPTS,
max_tokens,
images=hf_images,
eos_token_id=hf_model.processor.tokenizer.eos_token_id)

vllm_image_prompts = [
p.replace("<|image_1|>",
"<|image|>" * vlm_config.image_feature_size + "<s>")
for p in HF_IMAGE_PROMPTS
]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

with vllm_runner(model_id,
max_model_len=2048,
Expand All @@ -121,10 +109,32 @@ def run_test(
enforce_eager=True,
distributed_executor_backend=distributed_executor_backend,
**vlm_config.as_cli_args_dict()) as vllm_model:
# NOTE: `asset.for_vllm` will call `torch.cuda.device_count()`
# we must put it inside the vllm_runner context manager
# i.e. after creating vLLM instance.

vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]

vllm_image_prompts = [
p.replace("<|image_1|>",
"<|image|>" * vlm_config.image_feature_size + "<s>")
for p in HF_IMAGE_PROMPTS
]

vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
images=vllm_images)

# use eager mode for hf runner, since phi3_v didn't work with flash_attn
hf_model_kwargs = {"_attn_implementation": "eager"}
with hf_runner(model_id, dtype=dtype,
model_kwargs=hf_model_kwargs) as hf_model:
hf_outputs = hf_model.generate_greedy(
HF_IMAGE_PROMPTS,
max_tokens,
images=hf_images,
eos_token_id=hf_model.processor.tokenizer.eos_token_id)

check_outputs_equal(
hf_outputs,
[
Expand Down

0 comments on commit e42cac1

Please sign in to comment.