Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
zifeitong committed Jun 18, 2024
1 parent dc70f1e commit cd44c86
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 44 deletions.
9 changes: 9 additions & 0 deletions tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
from dataclasses import dataclass

import pytest
import torch

from vllm import SamplingParams
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine

from ..utils import wait_for_gpu_memory_to_clear


@dataclass
class RequestOutput:
Expand Down Expand Up @@ -98,6 +101,12 @@ async def test_new_requests_event():


def test_asyncio_run():
wait_for_gpu_memory_to_clear(
devices=list(range(torch.cuda.device_count())),
threshold_bytes=2 * 2**30,
timeout_s=60,
)

engine = AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(model="facebook/opt-125m"))

Expand Down
43 changes: 1 addition & 42 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
import asyncio
import time
from itertools import cycle
from typing import Dict, List, Optional, Tuple, Union

import pytest
import ray
import torch

from vllm.utils import is_hip

if (not is_hip()):
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
nvmlInit)

from vllm import LLM
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
Expand All @@ -26,6 +19,7 @@
from vllm.utils import Counter, random_uuid

from ...conftest import cleanup
from ...utils import wait_for_gpu_memory_to_clear


class AsyncLLM:
Expand Down Expand Up @@ -291,38 +285,3 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids


def wait_for_gpu_memory_to_clear(devices: List[int],
threshold_bytes: int,
timeout_s: float = 120) -> None:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
nvmlInit()
start_time = time.time()
while True:
output: Dict[int, str] = {}
output_raw: Dict[int, float] = {}
for device in devices:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
gb_used = mem_info.used / 2**30
output_raw[device] = gb_used
output[device] = f'{gb_used:.02f}'

print('gpu memory used (GB): ', end='')
for k, v in output.items():
print(f'{k}={v}; ', end='')
print('')

dur_s = time.time() - start_time
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
print(f'Done waiting for free GPU memory on devices {devices=} '
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
break

if dur_s >= timeout_s:
raise ValueError(f'Memory of devices {devices=} not free after '
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')

time.sleep(5)
43 changes: 41 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
import warnings
from contextlib import contextmanager
from typing import List
from typing import Dict, List

import openai
import ray
Expand All @@ -13,7 +13,11 @@
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import get_open_port
from vllm.utils import get_open_port, is_hip

if (not is_hip()):
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
nvmlInit)

# Path to root of repository so that utilities can be imported by ray workers
VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))
Expand Down Expand Up @@ -154,3 +158,38 @@ def error_on_warning():
warnings.simplefilter("error")

yield


def wait_for_gpu_memory_to_clear(devices: List[int],
threshold_bytes: int,
timeout_s: float = 120) -> None:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
nvmlInit()
start_time = time.time()
while True:
output: Dict[int, str] = {}
output_raw: Dict[int, float] = {}
for device in devices:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
gb_used = mem_info.used / 2**30
output_raw[device] = gb_used
output[device] = f'{gb_used:.02f}'

print('gpu memory used (GB): ', end='')
for k, v in output.items():
print(f'{k}={v}; ', end='')
print('')

dur_s = time.time() - start_time
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
print(f'Done waiting for free GPU memory on devices {devices=} '
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
break

if dur_s >= timeout_s:
raise ValueError(f'Memory of devices {devices=} not free after '
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')

time.sleep(5)

0 comments on commit cd44c86

Please sign in to comment.