Skip to content

Commit b9d3460

Browse files
DarkLight1337sumitd2
authored andcommitted
[CI/Build] Add test decorator for minimum GPU memory (vllm-project#8925)
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
1 parent 0b90392 commit b9d3460

File tree

14 files changed

+117
-73
lines changed

14 files changed

+117
-73
lines changed

tests/lora/test_baichuan.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,11 @@ def test_baichuan_lora(baichuan_lora_files):
6363
assert output2[i] == expected_lora_output[i]
6464

6565

66-
@pytest.mark.skip("Requires multiple GPUs")
6766
@pytest.mark.parametrize("fully_sharded", [True, False])
68-
def test_baichuan_tensor_parallel_equality(baichuan_lora_files, fully_sharded):
69-
# Cannot use as it will initialize torch.cuda too early...
70-
# if torch.cuda.device_count() < 4:
71-
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
67+
def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
68+
num_gpus_available, fully_sharded):
69+
if num_gpus_available < 4:
70+
pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
7271

7372
llm_tp1 = vllm.LLM(MODEL_PATH,
7473
enable_lora=True,

tests/lora/test_quant_model.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ def format_prompt_tuples(prompt):
7171

7272
@pytest.mark.parametrize("model", MODELS)
7373
@pytest.mark.parametrize("tp_size", [1])
74-
def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
75-
# Cannot use as it will initialize torch.cuda too early...
76-
# if torch.cuda.device_count() < tp_size:
77-
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
74+
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
75+
tp_size):
76+
if num_gpus_available < tp_size:
77+
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
7878

7979
llm = vllm.LLM(
8080
model=model.model_path,
@@ -164,11 +164,10 @@ def expect_match(output, expected_output):
164164

165165

166166
@pytest.mark.parametrize("model", MODELS)
167-
@pytest.mark.skip("Requires multiple GPUs")
168-
def test_quant_model_tp_equality(tinyllama_lora_files, model):
169-
# Cannot use as it will initialize torch.cuda too early...
170-
# if torch.cuda.device_count() < 2:
171-
# pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
167+
def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
168+
model):
169+
if num_gpus_available < 2:
170+
pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
172171

173172
llm_tp1 = vllm.LLM(
174173
model=model.model_path,

tests/models/decoder_only/language/test_phimoe.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm.utils import is_cpu
99

10+
from ....utils import large_gpu_test
1011
from ...utils import check_logprobs_close
1112

1213
MODELS = [
@@ -69,20 +70,10 @@ def test_phimoe_routing_function():
6970
assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"])
7071

7172

72-
def get_gpu_memory():
73-
try:
74-
props = torch.cuda.get_device_properties(torch.cuda.current_device())
75-
gpu_memory = props.total_memory / (1024**3)
76-
return gpu_memory
77-
except Exception:
78-
return 0
79-
80-
8173
@pytest.mark.skipif(condition=is_cpu(),
8274
reason="This test takes a lot time to run on CPU, "
8375
"and vllm CI's disk space is not enough for this model.")
84-
@pytest.mark.skipif(condition=get_gpu_memory() < 100,
85-
reason="Skip this test if GPU memory is insufficient.")
76+
@large_gpu_test(min_gb=80)
8677
@pytest.mark.parametrize("model", MODELS)
8778
@pytest.mark.parametrize("dtype", ["bfloat16"])
8879
@pytest.mark.parametrize("max_tokens", [64])

tests/models/decoder_only/vision_language/test_llava_onevision.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput, VllmRunner,
1313
_VideoAssets)
14+
from ....utils import large_gpu_test
1415
from ...utils import check_logprobs_close
1516

1617
# Video test
@@ -164,9 +165,7 @@ def process(hf_inputs: BatchEncoding):
164165
)
165166

166167

167-
@pytest.mark.skip(
168-
reason=
169-
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
168+
@large_gpu_test(min_gb=48)
170169
@pytest.mark.parametrize("model", models)
171170
@pytest.mark.parametrize(
172171
"size_factors",
@@ -210,9 +209,7 @@ def test_models(hf_runner, vllm_runner, video_assets, model, size_factors,
210209
)
211210

212211

213-
@pytest.mark.skip(
214-
reason=
215-
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
212+
@large_gpu_test(min_gb=48)
216213
@pytest.mark.parametrize("model", models)
217214
@pytest.mark.parametrize(
218215
"sizes",
@@ -306,9 +303,7 @@ def process(hf_inputs: BatchEncoding):
306303
)
307304

308305

309-
@pytest.mark.skip(
310-
reason=
311-
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
306+
@large_gpu_test(min_gb=48)
312307
@pytest.mark.parametrize("model", models)
313308
@pytest.mark.parametrize("dtype", ["half"])
314309
@pytest.mark.parametrize("max_tokens", [128])

tests/models/decoder_only/vision_language/test_pixtral.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from vllm.multimodal import MultiModalDataBuiltins
1818
from vllm.sequence import Logprob, SampleLogprobs
1919

20-
from ....utils import VLLM_PATH
20+
from ....utils import VLLM_PATH, large_gpu_test
2121
from ...utils import check_logprobs_close
2222

2323
if TYPE_CHECKING:
@@ -121,10 +121,7 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs:
121121
for tokens, text, logprobs in json_data]
122122

123123

124-
@pytest.mark.skip(
125-
reason=
126-
"Model is too big, test passed on A100 locally but will OOM on CI machine."
127-
)
124+
@large_gpu_test(min_gb=80)
128125
@pytest.mark.parametrize("model", MODELS)
129126
@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
130127
@pytest.mark.parametrize("dtype", ["bfloat16"])
@@ -157,10 +154,7 @@ def test_chat(
157154
name_1="output")
158155

159156

160-
@pytest.mark.skip(
161-
reason=
162-
"Model is too big, test passed on A100 locally but will OOM on CI machine."
163-
)
157+
@large_gpu_test(min_gb=80)
164158
@pytest.mark.parametrize("model", MODELS)
165159
@pytest.mark.parametrize("dtype", ["bfloat16"])
166160
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:

tests/models/encoder_decoder/vision_language/test_mllama.py

+20-22
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
1111
_ImageAssets)
12+
from ....utils import large_gpu_test
1213
from ...utils import check_logprobs_close
1314

1415
_LIMIT_IMAGE_PER_PROMPT = 1
@@ -227,29 +228,26 @@ def process(hf_inputs: BatchEncoding):
227228
)
228229

229230

230-
SIZES = [
231-
# Text only
232-
[],
233-
# Single-size
234-
[(512, 512)],
235-
# Single-size, batched
236-
[(512, 512), (512, 512), (512, 512)],
237-
# Multi-size, batched
238-
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
239-
(1024, 1024), (512, 1536), (512, 2028)],
240-
# Multi-size, batched, including text only
241-
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
242-
(1024, 1024), (512, 1536), (512, 2028), None],
243-
# mllama has 8 possible aspect ratios, carefully set the sizes
244-
# to cover all of them
245-
]
246-
247-
248-
@pytest.mark.skip(
249-
reason=
250-
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
231+
@large_gpu_test(min_gb=48)
251232
@pytest.mark.parametrize("model", models)
252-
@pytest.mark.parametrize("sizes", SIZES)
233+
@pytest.mark.parametrize(
234+
"sizes",
235+
[
236+
# Text only
237+
[],
238+
# Single-size
239+
[(512, 512)],
240+
# Single-size, batched
241+
[(512, 512), (512, 512), (512, 512)],
242+
# Multi-size, batched
243+
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
244+
(1024, 1024), (512, 1536), (512, 2028)],
245+
# Multi-size, batched, including text only
246+
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
247+
(1024, 1024), (512, 1536), (512, 2028), None],
248+
# mllama has 8 possible aspect ratios, carefully set the sizes
249+
# to cover all of them
250+
])
253251
@pytest.mark.parametrize("dtype", ["bfloat16"])
254252
@pytest.mark.parametrize("max_tokens", [128])
255253
@pytest.mark.parametrize("num_logprobs", [5])

tests/utils.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
from vllm.entrypoints.openai.cli_args import make_arg_parser
2525
from vllm.model_executor.model_loader.loader import get_model_loader
2626
from vllm.platforms import current_platform
27-
from vllm.utils import (FlexibleArgumentParser, cuda_device_count_stateless,
28-
get_open_port, is_hip)
27+
from vllm.utils import (FlexibleArgumentParser, GB_bytes,
28+
cuda_device_count_stateless, get_open_port, is_hip)
2929

3030
if current_platform.is_rocm():
3131
from amdsmi import (amdsmi_get_gpu_vram_usage,
@@ -455,6 +455,37 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
455455
return wrapper
456456

457457

458+
def large_gpu_test(*, min_gb: int):
459+
"""
460+
Decorate a test to be skipped if no GPU is available or it does not have
461+
sufficient memory.
462+
463+
Currently, the CI machine uses L4 GPU which has 24 GB VRAM.
464+
"""
465+
try:
466+
if current_platform.is_cpu():
467+
memory_gb = 0
468+
else:
469+
memory_gb = current_platform.get_device_total_memory() / GB_bytes
470+
except Exception as e:
471+
warnings.warn(
472+
f"An error occurred when finding the available memory: {e}",
473+
stacklevel=2,
474+
)
475+
476+
memory_gb = 0
477+
478+
test_skipif = pytest.mark.skipif(
479+
memory_gb < min_gb,
480+
reason=f"Need at least {memory_gb}GB GPU memory to run the test.",
481+
)
482+
483+
def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
484+
return test_skipif(fork_new_process_for_each_test(f))
485+
486+
return wrapper
487+
488+
458489
def multi_gpu_test(*, num_gpus: int):
459490
"""
460491
Decorate a test to be run only when multiple GPUs are available.

vllm/platforms/cpu.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import psutil
12
import torch
23

34
from .interface import Platform, PlatformEnum
@@ -10,6 +11,10 @@ class CpuPlatform(Platform):
1011
def get_device_name(cls, device_id: int = 0) -> str:
1112
return "cpu"
1213

14+
@classmethod
15+
def get_device_total_memory(cls, device_id: int = 0) -> int:
16+
return psutil.virtual_memory().total
17+
1318
@classmethod
1419
def inference_mode(cls):
1520
return torch.no_grad()

vllm/platforms/cuda.py

+12
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ def get_physical_device_name(device_id: int = 0) -> str:
5959
return pynvml.nvmlDeviceGetName(handle)
6060

6161

62+
@lru_cache(maxsize=8)
63+
@with_nvml_context
64+
def get_physical_device_total_memory(device_id: int = 0) -> int:
65+
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
66+
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
67+
68+
6269
@with_nvml_context
6370
def warn_if_different_devices():
6471
device_ids: int = pynvml.nvmlDeviceGetCount()
@@ -107,6 +114,11 @@ def get_device_name(cls, device_id: int = 0) -> str:
107114
physical_device_id = device_id_to_physical_device_id(device_id)
108115
return get_physical_device_name(physical_device_id)
109116

117+
@classmethod
118+
def get_device_total_memory(cls, device_id: int = 0) -> int:
119+
physical_device_id = device_id_to_physical_device_id(device_id)
120+
return get_physical_device_total_memory(physical_device_id)
121+
110122
@classmethod
111123
@with_nvml_context
112124
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:

vllm/platforms/interface.py

+6
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ def has_device_capability(
8585

8686
@classmethod
8787
def get_device_name(cls, device_id: int = 0) -> str:
88+
"""Get the name of a device."""
89+
raise NotImplementedError
90+
91+
@classmethod
92+
def get_device_total_memory(cls, device_id: int = 0) -> int:
93+
"""Get the total memory of a device in bytes."""
8894
raise NotImplementedError
8995

9096
@classmethod

vllm/platforms/rocm.py

+5
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,8 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
2929
@lru_cache(maxsize=8)
3030
def get_device_name(cls, device_id: int = 0) -> str:
3131
return torch.cuda.get_device_name(device_id)
32+
33+
@classmethod
34+
def get_device_total_memory(cls, device_id: int = 0) -> int:
35+
device_props = torch.cuda.get_device_properties(device_id)
36+
return device_props.total_memory

vllm/platforms/tpu.py

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ class TpuPlatform(Platform):
1010
def get_device_name(cls, device_id: int = 0) -> str:
1111
raise NotImplementedError
1212

13+
@classmethod
14+
def get_device_total_memory(cls, device_id: int = 0) -> int:
15+
raise NotImplementedError
16+
1317
@classmethod
1418
def inference_mode(cls):
1519
return torch.no_grad()

vllm/platforms/xpu.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@ class XPUPlatform(Platform):
88

99
@staticmethod
1010
def get_device_capability(device_id: int = 0) -> DeviceCapability:
11-
return DeviceCapability(major=int(
12-
torch.xpu.get_device_capability(device_id)['version'].split('.')
13-
[0]),
14-
minor=int(
15-
torch.xpu.get_device_capability(device_id)
16-
['version'].split('.')[1]))
11+
major, minor, *_ = torch.xpu.get_device_capability(
12+
device_id)['version'].split('.')
13+
return DeviceCapability(major=int(major), minor=int(minor))
1714

1815
@staticmethod
1916
def get_device_name(device_id: int = 0) -> str:
2017
return torch.xpu.get_device_name(device_id)
18+
19+
@classmethod
20+
def get_device_total_memory(cls, device_id: int = 0) -> int:
21+
device_props = torch.xpu.get_device_properties(device_id)
22+
return device_props.total_memory

vllm/utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@
119119
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
120120
STR_INVALID_VAL: str = "INVALID"
121121

122+
GB_bytes = 1_000_000_000
123+
"""The number of bytes in one gigabyte (GB)."""
124+
122125
GiB_bytes = 1 << 30
123126
"""The number of bytes in one gibibyte (GiB)."""
124127

0 commit comments

Comments
 (0)