Skip to content

Commit

Permalink
[MISC] Consolidate cleanup() and refactor offline_inference_with_pref…
Browse files Browse the repository at this point in the history
…ix.py (vllm-project#9510)

Signed-off-by: Vinay Damodaran <vrdn@hey.com>
  • Loading branch information
comaniac authored and vrdn-23 committed Oct 23, 2024
1 parent 618735f commit b96a4a1
Show file tree
Hide file tree
Showing 20 changed files with 84 additions and 105 deletions.
19 changes: 12 additions & 7 deletions examples/offline_inference_with_prefix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory

# NOTE: This is just a running example. For benchmarking purpose,
# please see benchmarks/benchmark_prefix_caching.py
Expand Down Expand Up @@ -28,14 +29,9 @@
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0)

# Create an LLM.
regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.3)
# Create an LLM without prefix caching as a baseline.
regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.4)

# The second LLM needs to request a higher gpu_memory_utilization because
# the first LLM has already allocated a full 30% of the gpu memory.
prefix_cached_llm = LLM(model="facebook/opt-125m",
enable_prefix_caching=True,
gpu_memory_utilization=0.6)
print("Results without `enable_prefix_caching`")

# Generate texts from the prompts. The output is a list of RequestOutput objects
Expand All @@ -52,6 +48,15 @@

print("-" * 80)

# Destroy the LLM object and free up the GPU memory.
del regular_llm
cleanup_dist_env_and_memory()

# Create an LLM with prefix caching enabled.
prefix_cached_llm = LLM(model="facebook/opt-125m",
enable_prefix_caching=True,
gpu_memory_utilization=0.4)

# Warmup so that the shared prompt's KV cache is computed.
prefix_cached_llm.generate(generating_prompts[0], sampling_params)

Expand Down
4 changes: 2 additions & 2 deletions tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

from vllm import SamplingParams
from vllm.config import ParallelConfig
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
from vllm.outputs import RequestOutput as RealRequestOutput
from vllm.sampling_params import RequestOutputKind

from ..conftest import cleanup
from ..utils import wait_for_gpu_memory_to_clear


Expand Down Expand Up @@ -157,7 +157,7 @@ async def async_engine():
engine.shutdown_background_loop()
del engine
await asyncio.sleep(0.1)
cleanup()
cleanup_dist_env_and_memory()


@pytest.fixture()
Expand Down
23 changes: 5 additions & 18 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import contextlib
import gc
import json
import os
import sys
Expand Down Expand Up @@ -27,8 +25,7 @@
from vllm.assets.video import VideoAsset
from vllm.config import TaskOption, TokenizerPoolConfig
from vllm.connections import global_http_connection
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel,
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
initialize_model_parallel)
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
Expand Down Expand Up @@ -140,17 +137,7 @@ def dist_init():
)
initialize_model_parallel(1, 1)
yield
cleanup()


def cleanup():
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
if not is_cpu():
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()


@pytest.fixture()
Expand All @@ -167,7 +154,7 @@ def should_do_global_cleanup_after_test(request) -> bool:
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
yield
if should_do_global_cleanup_after_test:
cleanup()
cleanup_dist_env_and_memory()


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -606,7 +593,7 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
del self.model
cleanup()
cleanup_dist_env_and_memory()


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -861,7 +848,7 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
del self.model
cleanup()
cleanup_dist_env_and_memory()


@pytest.fixture(scope="session")
Expand Down
5 changes: 2 additions & 3 deletions tests/core/block/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import pytest

from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.utils import set_random_seed

from ....conftest import cleanup


@pytest.fixture
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
Expand Down Expand Up @@ -37,7 +36,7 @@ def generator_inner():

yield llm
del llm
cleanup()
cleanup_dist_env_and_memory()

for llm in generator_inner():
yield llm
Expand Down
5 changes: 2 additions & 3 deletions tests/entrypoints/llm/test_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import pytest

from vllm import LLM, EmbeddingRequestOutput, PoolingParams

from ...conftest import cleanup
from vllm.distributed import cleanup_dist_env_and_memory

MODEL_NAME = "intfloat/e5-mistral-7b-instruct"

Expand Down Expand Up @@ -41,7 +40,7 @@ def llm():

del llm

cleanup()
cleanup_dist_env_and_memory()


def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
Expand Down
5 changes: 2 additions & 3 deletions tests/entrypoints/llm/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import pytest

from vllm import LLM, RequestOutput, SamplingParams

from ...conftest import cleanup
from vllm.distributed import cleanup_dist_env_and_memory

MODEL_NAME = "facebook/opt-125m"

Expand Down Expand Up @@ -39,7 +38,7 @@ def llm():

del llm

cleanup()
cleanup_dist_env_and_memory()


def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
Expand Down
5 changes: 2 additions & 3 deletions tests/entrypoints/llm/test_generate_multiple_loras.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from huggingface_hub import snapshot_download

from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest

from ...conftest import cleanup

MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"

PROMPTS = [
Expand Down Expand Up @@ -39,7 +38,7 @@ def llm():

del llm

cleanup()
cleanup_dist_env_and_memory()


@pytest.fixture(scope="module")
Expand Down
5 changes: 2 additions & 3 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import jsonschema
import pytest

from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

from ...conftest import cleanup

MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"


Expand All @@ -23,7 +22,7 @@ def llm():
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup()
cleanup_dist_env_and_memory()


@pytest.mark.skip_global_cleanup
Expand Down
9 changes: 7 additions & 2 deletions tests/entrypoints/llm/test_lazy_outlines.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys

from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory


def test_lazy_outlines(sample_regex):
Expand All @@ -14,6 +15,7 @@ def test_lazy_outlines(sample_regex):
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM without guided decoding as a baseline.
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
gpu_memory_utilization=0.3)
Expand All @@ -26,8 +28,11 @@ def test_lazy_outlines(sample_regex):
# make sure outlines is not imported
assert 'outlines' not in sys.modules

# The second LLM needs to request a higher gpu_memory_utilization because
# the first LLM has already allocated a full 30% of the gpu memory.
# Destroy the LLM object and free up the GPU memory.
del llm
cleanup_dist_env_and_memory()

# Create an LLM with guided decoding enabled.
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
guided_decoding_backend="lm-format-enforcer",
Expand Down
5 changes: 2 additions & 3 deletions tests/entrypoints/offline_mode/test_offline_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import pytest

from vllm import LLM

from ...conftest import cleanup
from vllm.distributed import cleanup_dist_env_and_memory

MODEL_NAME = "facebook/opt-125m"

Expand All @@ -27,7 +26,7 @@ def llm():

del llm

cleanup()
cleanup_dist_env_and_memory()


@pytest.mark.skip_global_cleanup
Expand Down
26 changes: 6 additions & 20 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import contextlib
import gc
import tempfile
from collections import OrderedDict
from typing import Dict, List, TypedDict
from unittest.mock import MagicMock, patch

import pytest
import ray
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download

import vllm
from vllm.config import LoRAConfig
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel,
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
initialize_model_parallel)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -48,16 +44,6 @@ class ContextInfo(TypedDict):
}]


def cleanup():
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
torch.cuda.empty_cache()
ray.shutdown()


@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
"""Allow subdirectories to skip global cleanup by overriding this fixture.
Expand All @@ -72,7 +58,7 @@ def should_do_global_cleanup_after_test(request) -> bool:
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
yield
if should_do_global_cleanup_after_test:
cleanup()
cleanup_dist_env_and_memory(shutdown_ray=True)


@pytest.fixture
Expand All @@ -87,7 +73,7 @@ def dist_init():
)
initialize_model_parallel(1, 1)
yield
cleanup()
cleanup_dist_env_and_memory(shutdown_ray=True)


@pytest.fixture
Expand Down Expand Up @@ -238,7 +224,7 @@ def long_context_lora_files_32k():
def long_context_infos(long_context_lora_files_16k_1,
long_context_lora_files_16k_2,
long_context_lora_files_32k):
cleanup()
cleanup_dist_env_and_memory(shutdown_ray=True)
infos: Dict[int, ContextInfo] = {}
for lora_checkpoint_info in LONG_LORA_INFOS:
lora_id = lora_checkpoint_info["lora_id"]
Expand All @@ -259,7 +245,7 @@ def long_context_infos(long_context_lora_files_16k_1,

@pytest.fixture
def llama_2_7b_engine_extra_embeddings():
cleanup()
cleanup_dist_env_and_memory(shutdown_ray=True)
get_model_old = get_model

def get_model_patched(*, model_config, device_config, **kwargs):
Expand All @@ -272,7 +258,7 @@ def get_model_patched(*, model_config, device_config, **kwargs):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
yield engine.llm_engine
del engine
cleanup()
cleanup_dist_env_and_memory(shutdown_ray=True)


@pytest.fixture
Expand Down
9 changes: 4 additions & 5 deletions tests/lora/test_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import pytest

import vllm
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest

from .conftest import cleanup

MODEL_PATH = "baichuan-inc/Baichuan-7B"

PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
Expand Down Expand Up @@ -80,7 +79,7 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)

del llm_tp1
cleanup()
cleanup_dist_env_and_memory()

llm_tp2 = vllm.LLM(MODEL_PATH,
enable_lora=True,
Expand All @@ -93,7 +92,7 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)

del llm_tp2
cleanup()
cleanup_dist_env_and_memory()

assert output_tp1 == output_tp2

Expand All @@ -108,6 +107,6 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)

del llm_tp4
cleanup()
cleanup_dist_env_and_memory()

assert output_tp1 == output_tp4
Loading

0 comments on commit b96a4a1

Please sign in to comment.