From 4413d8aba8eea8438f7e8c048321e578d72d1d51 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 18 Oct 2024 14:30:55 -0700 Subject: [PATCH] [MISC] Consolidate cleanup() and refactor offline_inference_with_prefix.py (#9510) Signed-off-by: Tyler Michael Smith --- examples/offline_inference_with_prefix.py | 19 +++++++++----- tests/async_engine/test_async_llm_engine.py | 4 +-- tests/conftest.py | 23 ++++------------ tests/core/block/e2e/conftest.py | 5 ++-- tests/entrypoints/llm/test_encode.py | 5 ++-- tests/entrypoints/llm/test_generate.py | 5 ++-- .../llm/test_generate_multiple_loras.py | 5 ++-- tests/entrypoints/llm/test_guided_generate.py | 5 ++-- tests/entrypoints/llm/test_lazy_outlines.py | 9 +++++-- .../offline_mode/test_offline_mode.py | 5 ++-- tests/lora/conftest.py | 26 +++++-------------- tests/lora/test_baichuan.py | 9 +++---- tests/lora/test_llama.py | 9 +++---- tests/lora/test_quant_model.py | 9 +++---- tests/metrics/test_metrics.py | 5 ++-- .../vision_language/test_intern_vit.py | 7 ++--- .../test_disable_sliding_window.py | 6 ++--- tests/spec_decode/e2e/conftest.py | 4 +-- tests/tensorizer_loader/conftest.py | 13 ++-------- vllm/distributed/parallel_state.py | 16 +++++++++++- 20 files changed, 84 insertions(+), 105 deletions(-) diff --git a/examples/offline_inference_with_prefix.py b/examples/offline_inference_with_prefix.py index f8a9727ea192f..67b755a155966 100644 --- a/examples/offline_inference_with_prefix.py +++ b/examples/offline_inference_with_prefix.py @@ -1,4 +1,5 @@ from vllm import LLM, SamplingParams +from vllm.distributed import cleanup_dist_env_and_memory # NOTE: This is just a running example. For benchmarking purpose, # please see benchmarks/benchmark_prefix_caching.py @@ -28,14 +29,9 @@ # Create a sampling params object. sampling_params = SamplingParams(temperature=0.0) -# Create an LLM. -regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.3) +# Create an LLM without prefix caching as a baseline. +regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.4) -# The second LLM needs to request a higher gpu_memory_utilization because -# the first LLM has already allocated a full 30% of the gpu memory. -prefix_cached_llm = LLM(model="facebook/opt-125m", - enable_prefix_caching=True, - gpu_memory_utilization=0.6) print("Results without `enable_prefix_caching`") # Generate texts from the prompts. The output is a list of RequestOutput objects @@ -52,6 +48,15 @@ print("-" * 80) +# Destroy the LLM object and free up the GPU memory. +del regular_llm +cleanup_dist_env_and_memory() + +# Create an LLM with prefix caching enabled. +prefix_cached_llm = LLM(model="facebook/opt-125m", + enable_prefix_caching=True, + gpu_memory_utilization=0.4) + # Warmup so that the shared prompt's KV cache is computed. prefix_cached_llm.generate(generating_prompts[0], sampling_params) diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 1903a7582dc89..8a04693ba676d 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -12,11 +12,11 @@ from vllm import SamplingParams from vllm.config import ParallelConfig +from vllm.distributed import cleanup_dist_env_and_memory from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine from vllm.outputs import RequestOutput as RealRequestOutput from vllm.sampling_params import RequestOutputKind -from ..conftest import cleanup from ..utils import wait_for_gpu_memory_to_clear @@ -157,7 +157,7 @@ async def async_engine(): engine.shutdown_background_loop() del engine await asyncio.sleep(0.1) - cleanup() + cleanup_dist_env_and_memory() @pytest.fixture() diff --git a/tests/conftest.py b/tests/conftest.py index ea7156c60e334..4c9180415da32 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,3 @@ -import contextlib -import gc import json import os import sys @@ -27,8 +25,7 @@ from vllm.assets.video import VideoAsset from vllm.config import TaskOption, TokenizerPoolConfig from vllm.connections import global_http_connection -from vllm.distributed import (destroy_distributed_environment, - destroy_model_parallel, +from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, initialize_model_parallel) from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, @@ -140,17 +137,7 @@ def dist_init(): ) initialize_model_parallel(1, 1) yield - cleanup() - - -def cleanup(): - destroy_model_parallel() - destroy_distributed_environment() - with contextlib.suppress(AssertionError): - torch.distributed.destroy_process_group() - gc.collect() - if not is_cpu(): - torch.cuda.empty_cache() + cleanup_dist_env_and_memory() @pytest.fixture() @@ -167,7 +154,7 @@ def should_do_global_cleanup_after_test(request) -> bool: def cleanup_fixture(should_do_global_cleanup_after_test: bool): yield if should_do_global_cleanup_after_test: - cleanup() + cleanup_dist_env_and_memory() @pytest.fixture(autouse=True) @@ -606,7 +593,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): del self.model - cleanup() + cleanup_dist_env_and_memory() @pytest.fixture(scope="session") @@ -861,7 +848,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): del self.model - cleanup() + cleanup_dist_env_and_memory() @pytest.fixture(scope="session") diff --git a/tests/core/block/e2e/conftest.py b/tests/core/block/e2e/conftest.py index e870597b7a011..70577ec052a2c 100644 --- a/tests/core/block/e2e/conftest.py +++ b/tests/core/block/e2e/conftest.py @@ -3,10 +3,9 @@ import pytest from vllm import LLM +from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.utils import set_random_seed -from ....conftest import cleanup - @pytest.fixture def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, @@ -37,7 +36,7 @@ def generator_inner(): yield llm del llm - cleanup() + cleanup_dist_env_and_memory() for llm in generator_inner(): yield llm diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index 1885f2e168d80..4c9f796e5ed71 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -4,8 +4,7 @@ import pytest from vllm import LLM, EmbeddingRequestOutput, PoolingParams - -from ...conftest import cleanup +from vllm.distributed import cleanup_dist_env_and_memory MODEL_NAME = "intfloat/e5-mistral-7b-instruct" @@ -41,7 +40,7 @@ def llm(): del llm - cleanup() + cleanup_dist_env_and_memory() def assert_outputs_equal(o1: List[EmbeddingRequestOutput], diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 5e32d7baabe4b..7d2b377752725 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -4,8 +4,7 @@ import pytest from vllm import LLM, RequestOutput, SamplingParams - -from ...conftest import cleanup +from vllm.distributed import cleanup_dist_env_and_memory MODEL_NAME = "facebook/opt-125m" @@ -39,7 +38,7 @@ def llm(): del llm - cleanup() + cleanup_dist_env_and_memory() def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): diff --git a/tests/entrypoints/llm/test_generate_multiple_loras.py b/tests/entrypoints/llm/test_generate_multiple_loras.py index 9f5727ecd0406..eb2113692e7b4 100644 --- a/tests/entrypoints/llm/test_generate_multiple_loras.py +++ b/tests/entrypoints/llm/test_generate_multiple_loras.py @@ -5,10 +5,9 @@ from huggingface_hub import snapshot_download from vllm import LLM +from vllm.distributed import cleanup_dist_env_and_memory from vllm.lora.request import LoRARequest -from ...conftest import cleanup - MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" PROMPTS = [ @@ -39,7 +38,7 @@ def llm(): del llm - cleanup() + cleanup_dist_env_and_memory() @pytest.fixture(scope="module") diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 2841dfc6bd9c2..67c79415f322a 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -5,12 +5,11 @@ import jsonschema import pytest +from vllm.distributed import cleanup_dist_env_and_memory from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput from vllm.sampling_params import GuidedDecodingParams, SamplingParams -from ...conftest import cleanup - MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @@ -23,7 +22,7 @@ def llm(): with llm.deprecate_legacy_api(): yield weakref.proxy(llm) del llm - cleanup() + cleanup_dist_env_and_memory() @pytest.mark.skip_global_cleanup diff --git a/tests/entrypoints/llm/test_lazy_outlines.py b/tests/entrypoints/llm/test_lazy_outlines.py index 010969ad4750d..cbfb0cc32c1ce 100644 --- a/tests/entrypoints/llm/test_lazy_outlines.py +++ b/tests/entrypoints/llm/test_lazy_outlines.py @@ -1,6 +1,7 @@ import sys from vllm import LLM, SamplingParams +from vllm.distributed import cleanup_dist_env_and_memory def test_lazy_outlines(sample_regex): @@ -14,6 +15,7 @@ def test_lazy_outlines(sample_regex): ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + # Create an LLM without guided decoding as a baseline. llm = LLM(model="facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.3) @@ -26,8 +28,11 @@ def test_lazy_outlines(sample_regex): # make sure outlines is not imported assert 'outlines' not in sys.modules - # The second LLM needs to request a higher gpu_memory_utilization because - # the first LLM has already allocated a full 30% of the gpu memory. + # Destroy the LLM object and free up the GPU memory. + del llm + cleanup_dist_env_and_memory() + + # Create an LLM with guided decoding enabled. llm = LLM(model="facebook/opt-125m", enforce_eager=True, guided_decoding_backend="lm-format-enforcer", diff --git a/tests/entrypoints/offline_mode/test_offline_mode.py b/tests/entrypoints/offline_mode/test_offline_mode.py index fe40af271c1cd..c89d315b664af 100644 --- a/tests/entrypoints/offline_mode/test_offline_mode.py +++ b/tests/entrypoints/offline_mode/test_offline_mode.py @@ -6,8 +6,7 @@ import pytest from vllm import LLM - -from ...conftest import cleanup +from vllm.distributed import cleanup_dist_env_and_memory MODEL_NAME = "facebook/opt-125m" @@ -27,7 +26,7 @@ def llm(): del llm - cleanup() + cleanup_dist_env_and_memory() @pytest.mark.skip_global_cleanup diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 405c0d0efad65..e40f0dd74602e 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -1,20 +1,16 @@ -import contextlib -import gc import tempfile from collections import OrderedDict from typing import Dict, List, TypedDict from unittest.mock import MagicMock, patch import pytest -import ray import torch import torch.nn as nn from huggingface_hub import snapshot_download import vllm from vllm.config import LoRAConfig -from vllm.distributed import (destroy_distributed_environment, - destroy_model_parallel, +from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, initialize_model_parallel) from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -48,16 +44,6 @@ class ContextInfo(TypedDict): }] -def cleanup(): - destroy_model_parallel() - destroy_distributed_environment() - with contextlib.suppress(AssertionError): - torch.distributed.destroy_process_group() - gc.collect() - torch.cuda.empty_cache() - ray.shutdown() - - @pytest.fixture() def should_do_global_cleanup_after_test(request) -> bool: """Allow subdirectories to skip global cleanup by overriding this fixture. @@ -72,7 +58,7 @@ def should_do_global_cleanup_after_test(request) -> bool: def cleanup_fixture(should_do_global_cleanup_after_test: bool): yield if should_do_global_cleanup_after_test: - cleanup() + cleanup_dist_env_and_memory(shutdown_ray=True) @pytest.fixture @@ -87,7 +73,7 @@ def dist_init(): ) initialize_model_parallel(1, 1) yield - cleanup() + cleanup_dist_env_and_memory(shutdown_ray=True) @pytest.fixture @@ -238,7 +224,7 @@ def long_context_lora_files_32k(): def long_context_infos(long_context_lora_files_16k_1, long_context_lora_files_16k_2, long_context_lora_files_32k): - cleanup() + cleanup_dist_env_and_memory(shutdown_ray=True) infos: Dict[int, ContextInfo] = {} for lora_checkpoint_info in LONG_LORA_INFOS: lora_id = lora_checkpoint_info["lora_id"] @@ -259,7 +245,7 @@ def long_context_infos(long_context_lora_files_16k_1, @pytest.fixture def llama_2_7b_engine_extra_embeddings(): - cleanup() + cleanup_dist_env_and_memory(shutdown_ray=True) get_model_old = get_model def get_model_patched(*, model_config, device_config, **kwargs): @@ -272,7 +258,7 @@ def get_model_patched(*, model_config, device_config, **kwargs): engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) yield engine.llm_engine del engine - cleanup() + cleanup_dist_env_and_memory(shutdown_ray=True) @pytest.fixture diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py index cbc3668997817..0ba2ce3617b67 100644 --- a/tests/lora/test_baichuan.py +++ b/tests/lora/test_baichuan.py @@ -3,10 +3,9 @@ import pytest import vllm +from vllm.distributed import cleanup_dist_env_and_memory from vllm.lora.request import LoRARequest -from .conftest import cleanup - MODEL_PATH = "baichuan-inc/Baichuan-7B" PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 @@ -80,7 +79,7 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files, output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1) del llm_tp1 - cleanup() + cleanup_dist_env_and_memory() llm_tp2 = vllm.LLM(MODEL_PATH, enable_lora=True, @@ -93,7 +92,7 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files, output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2) del llm_tp2 - cleanup() + cleanup_dist_env_and_memory() assert output_tp1 == output_tp2 @@ -108,6 +107,6 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files, output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2) del llm_tp4 - cleanup() + cleanup_dist_env_and_memory() assert output_tp1 == output_tp4 diff --git a/tests/lora/test_llama.py b/tests/lora/test_llama.py index ad8490353998f..e2a4f1ed0496a 100644 --- a/tests/lora/test_llama.py +++ b/tests/lora/test_llama.py @@ -4,10 +4,9 @@ import ray import vllm +from vllm.distributed import cleanup_dist_env_and_memory from vllm.lora.request import LoRARequest -from .conftest import cleanup - MODEL_PATH = "meta-llama/Llama-2-7b-hf" @@ -93,7 +92,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available): output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1) del llm_tp1 - cleanup() + cleanup_dist_env_and_memory() llm_tp2 = vllm.LLM(MODEL_PATH, enable_lora=True, @@ -103,7 +102,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available): output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1) del llm_tp2 - cleanup() + cleanup_dist_env_and_memory() assert output_tp1 == output_tp2 @@ -115,7 +114,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available): output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1) del llm_tp4 - cleanup() + cleanup_dist_env_and_memory() assert output_tp1 == output_tp4 diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 5636c96435024..d004c65929418 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -6,11 +6,10 @@ import pytest import vllm +from vllm.distributed import cleanup_dist_env_and_memory from vllm.lora.request import LoRARequest from vllm.utils import is_hip -from .conftest import cleanup - @dataclass class ModelWithQuantization: @@ -160,7 +159,7 @@ def expect_match(output, expected_output): print("removing lora") del llm - cleanup() + cleanup_dist_env_and_memory() @pytest.mark.parametrize("model", MODELS) @@ -181,7 +180,7 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1) del llm_tp1 - cleanup() + cleanup_dist_env_and_memory() llm_tp2 = vllm.LLM( model=model.model_path, @@ -194,6 +193,6 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1) del llm_tp2 - cleanup() + cleanup_dist_env_and_memory() assert output_tp1 == output_tp2 diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 8798ff078843a..92e6086e312f7 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -6,13 +6,12 @@ from prometheus_client import REGISTRY from vllm import EngineArgs, LLMEngine +from vllm.distributed import cleanup_dist_env_and_memory from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.metrics import RayPrometheusStatLogger from vllm.sampling_params import SamplingParams -from ..conftest import cleanup - MODELS = [ "facebook/opt-125m", ] @@ -307,7 +306,7 @@ def test_metric_spec_decode_interval( finally: del engine - cleanup() + cleanup_dist_env_and_memory() def assert_metrics(engine: LLMEngine, disable_log_stats: bool, diff --git a/tests/models/decoder_only/vision_language/test_intern_vit.py b/tests/models/decoder_only/vision_language/test_intern_vit.py index 3c3b95b38baac..98f313eb9b9af 100644 --- a/tests/models/decoder_only/vision_language/test_intern_vit.py +++ b/tests/models/decoder_only/vision_language/test_intern_vit.py @@ -6,7 +6,7 @@ from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoModel, CLIPImageProcessor -from ....conftest import _ImageAssets, cleanup +from ....conftest import _ImageAssets # we use snapshot_download to prevent conflicts between # dynamic_module and trust_remote_code for hf_runner @@ -45,12 +45,13 @@ def run_intern_vit_test( for pixel_value in pixel_values ] + from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.models.intern_vit import InternVisionModel vllm_model = InternVisionModel(config) vllm_model.load_weights(hf_model.state_dict().items()) del hf_model - cleanup() + cleanup_dist_env_and_memory() vllm_model = vllm_model.to("cuda", dtype) vllm_outputs_per_image = [ @@ -58,7 +59,7 @@ def run_intern_vit_test( for pixel_value in pixel_values ] del vllm_model - cleanup() + cleanup_dist_env_and_memory() cos_similar = nn.CosineSimilarity(dim=-1) for vllm_output, hf_output in zip(vllm_outputs_per_image, diff --git a/tests/prefix_caching/test_disable_sliding_window.py b/tests/prefix_caching/test_disable_sliding_window.py index eeac6ab43c05f..5a28943b7ecbc 100644 --- a/tests/prefix_caching/test_disable_sliding_window.py +++ b/tests/prefix_caching/test_disable_sliding_window.py @@ -4,8 +4,8 @@ """ import pytest -from tests.conftest import cleanup from vllm import LLM +from vllm.distributed import cleanup_dist_env_and_memory MODEL_LEN_LEN = [ # Example models with sliding window. @@ -31,7 +31,7 @@ def test_disable_sliding_window(model_len_len, ): model_config.max_model_len) del vllm_disabled_model - cleanup() + cleanup_dist_env_and_memory() vllm_enabled_model = LLM(model, disable_sliding_window=False) vllm_enabled_model.generate("Hi my name is") @@ -41,4 +41,4 @@ def test_disable_sliding_window(model_len_len, ): model_config.max_model_len) del vllm_enabled_model - cleanup() + cleanup_dist_env_and_memory() diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index b450ef97c89d4..b9cb3858c0068 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -4,10 +4,10 @@ import pytest from vllm import LLM, SamplingParams +from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.utils import set_random_seed from vllm.sequence import PromptLogprobs, SampleLogprobs -from ...conftest import cleanup from ...models.utils import (TokensTextLogprobs, TokensTextLogprobsPromptLogprobs, check_logprobs_close, check_outputs_equal) @@ -44,7 +44,7 @@ def generate(): yield llm del llm - cleanup() + cleanup_dist_env_and_memory() return generate diff --git a/tests/tensorizer_loader/conftest.py b/tests/tensorizer_loader/conftest.py index 07b9c6b3c6be6..2a45653622448 100644 --- a/tests/tensorizer_loader/conftest.py +++ b/tests/tensorizer_loader/conftest.py @@ -1,27 +1,18 @@ -import contextlib import functools import gc from typing import Callable, TypeVar import pytest -import ray import torch from typing_extensions import ParamSpec -from vllm.distributed import (destroy_distributed_environment, - destroy_model_parallel) +from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.model_loader.tensorizer import TensorizerConfig @pytest.fixture(autouse=True) def cleanup(): - destroy_model_parallel() - destroy_distributed_environment() - with contextlib.suppress(AssertionError): - torch.distributed.destroy_process_group() - ray.shutdown() - gc.collect() - torch.cuda.empty_cache() + cleanup_dist_env_and_memory(shutdown_ray=True) _P = ParamSpec("_P") diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6e1970bfed98a..8d4b673d2e6e4 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -20,6 +20,7 @@ steps. """ import contextlib +import gc import pickle import weakref from collections import namedtuple @@ -36,7 +37,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import supports_custom_op +from vllm.utils import is_cpu, supports_custom_op @dataclass @@ -1129,6 +1130,19 @@ def destroy_distributed_environment(): torch.distributed.destroy_process_group() +def cleanup_dist_env_and_memory(shutdown_ray: bool = False): + destroy_model_parallel() + destroy_distributed_environment() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + if shutdown_ray: + import ray # Lazy import Ray + ray.shutdown() + gc.collect() + if not is_cpu(): + torch.cuda.empty_cache() + + def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: """ This is a collective operation that returns if each rank is in the same node