Skip to content

Revert "[V1][Core] Fix memory issue with logits & sampling" #14504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions tests/basic_correctness/test_cumem.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,7 @@ def test_end_to_end(model: str, use_v1: bool):
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
# now the memory usage is mostly cudagraph memory pool,
# and it should be less than the model weights (1B model, 2GiB weights)

# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
# is captured but cannot be releasesd from PyTorch due to a known bug,
# therefore high memory usage after `llm.sleep` is called is expected.
# FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
# in V1.
if use_v1:
assert used_bytes < 7 * GiB_bytes
else:
assert used_bytes < 2 * GiB_bytes
assert used_bytes < 2 * GiB_bytes

llm.wake_up()
output2 = llm.generate(prompt, sampling_params)
Expand Down
68 changes: 29 additions & 39 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,43 +1238,6 @@ def _dummy_run(
)
return hidden_states

@torch.inference_mode()
def _dummy_sampler_run(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:

logits = self.model.compute_logits(hidden_states, None)
num_reqs = logits.size(0)

dummy_tensors = lambda v: torch.full(
(num_reqs, ), v, device=self.device)

dummy_metadata = SamplingMetadata(
temperature=dummy_tensors(0.5),
all_greedy=False,
all_random=False,
top_p=dummy_tensors(0.9),
top_k=dummy_tensors(logits.size(1) - 1),
min_p=None,
generators={},
max_num_logprobs=None,
no_penalties=True,
prompt_token_ids=None,
frequency_penalties=dummy_tensors(0.1),
presence_penalties=dummy_tensors(0.1),
repetition_penalties=dummy_tensors(0.1),
output_token_ids=[[] for _ in range(num_reqs)],
min_tokens={},
logit_bias=[None for _ in range(num_reqs)],
allowed_token_ids_mask=None,
bad_words_token_ids={},
)
sampler_output = self.model.sample(logits=logits,
sampling_metadata=dummy_metadata)

return sampler_output

def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache.
# TODO: handle encoder-decoder models once we support them.
Expand Down Expand Up @@ -1390,11 +1353,38 @@ def profile_run(self) -> None:
hidden_states = self._dummy_run(self.max_num_tokens)
if get_pp_group().is_last_rank:
hidden_states = hidden_states[logit_indices]
sampler_output = self._dummy_sampler_run(hidden_states)
logits = self.model.compute_logits(hidden_states, None)
dummy_tensors = lambda v: torch.full(
(num_reqs, ), v, device=self.device)
dummy_metadata = SamplingMetadata(
temperature=dummy_tensors(0.5),
all_greedy=False,
all_random=False,
top_p=dummy_tensors(0.9),
top_k=dummy_tensors(logits.size(1) - 1),
min_p=None,
generators={},
max_num_logprobs=None,
no_penalties=True,
prompt_token_ids=torch.ones_like(logits,
dtype=torch.int64),
frequency_penalties=dummy_tensors(0.1),
presence_penalties=dummy_tensors(0.1),
repetition_penalties=dummy_tensors(0.1),
output_token_ids=[[] for _ in range(num_reqs)],
min_tokens={},
logit_bias=[None for _ in range(num_reqs)],
allowed_token_ids_mask=None,
bad_words_token_ids={},
)
sampler_output = self.model.sample(
logits=logits, sampling_metadata=dummy_metadata)
else:
logits = None
sampler_output = None
dummy_metadata = None
torch.cuda.synchronize()
del hidden_states, sampler_output
del hidden_states, logits, sampler_output, dummy_metadata
self.encoder_cache.clear()
gc.collect()

Expand Down
21 changes: 0 additions & 21 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def init_device(self):
self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device)

# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
Expand Down Expand Up @@ -213,25 +211,6 @@ def compile_or_warm_up_model(self) -> None:
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()

# Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
try:
self.model_runner._dummy_sampler_run(
hidden_states=self.model_runner._dummy_run(
num_tokens=self.scheduler_config.max_num_seqs))
except RuntimeError as e:
if 'out of memory' in str(e):
raise RuntimeError(
"CUDA out of memory occurred when warming up sampler. "
"Please try lowering `gpu_memory_utilization` when "
"initializing the engine.") from None
else:
raise e

# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
Expand Down