Skip to content

Commit eac1b9f

Browse files
ywang96lulmer
authored andcommitted
[V1][Core] Fix memory issue with logits & sampling (vllm-project#13776)
Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
1 parent baad27f commit eac1b9f

File tree

3 files changed

+69
-29
lines changed

3 files changed

+69
-29
lines changed

tests/basic_correctness/test_cumem.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,16 @@ def test_end_to_end(model: str, use_v1: bool):
142142
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
143143
# now the memory usage is mostly cudagraph memory pool,
144144
# and it should be less than the model weights (1B model, 2GiB weights)
145-
assert used_bytes < 2 * GiB_bytes
145+
146+
# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
147+
# is captured but cannot be releasesd from PyTorch due to a known bug,
148+
# therefore high memory usage after `llm.sleep` is called is expected.
149+
# FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
150+
# in V1.
151+
if use_v1:
152+
assert used_bytes < 7 * GiB_bytes
153+
else:
154+
assert used_bytes < 2 * GiB_bytes
146155

147156
llm.wake_up()
148157
output2 = llm.generate(prompt, sampling_params)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,42 @@ def _dummy_run(
12381238
)
12391239
return hidden_states
12401240

1241+
@torch.inference_mode()
1242+
def _dummy_sampler_run(
1243+
self,
1244+
hidden_states: torch.Tensor,
1245+
) -> torch.Tensor:
1246+
1247+
logits = self.model.compute_logits(hidden_states, None)
1248+
num_reqs = logits.size(0)
1249+
1250+
dummy_tensors = lambda v: torch.full(
1251+
(num_reqs, ), v, device=self.device)
1252+
1253+
dummy_metadata = SamplingMetadata(
1254+
temperature=dummy_tensors(0.5),
1255+
all_greedy=False,
1256+
all_random=False,
1257+
top_p=dummy_tensors(0.9),
1258+
top_k=dummy_tensors(logits.size(1) - 1),
1259+
min_p=None,
1260+
generators={},
1261+
max_num_logprobs=None,
1262+
no_penalties=True,
1263+
prompt_token_ids=None,
1264+
frequency_penalties=dummy_tensors(0.1),
1265+
presence_penalties=dummy_tensors(0.1),
1266+
repetition_penalties=dummy_tensors(0.1),
1267+
output_token_ids=[[] for _ in range(num_reqs)],
1268+
min_tokens={},
1269+
logit_bias=[None for _ in range(num_reqs)],
1270+
allowed_token_ids_mask=None,
1271+
)
1272+
sampler_output = self.model.sample(logits=logits,
1273+
sampling_metadata=dummy_metadata)
1274+
1275+
return sampler_output
1276+
12411277
def profile_run(self) -> None:
12421278
# Profile with multimodal encoder & encoder cache.
12431279
# TODO: handle encoder-decoder models once we support them.
@@ -1353,37 +1389,11 @@ def profile_run(self) -> None:
13531389
hidden_states = self._dummy_run(self.max_num_tokens)
13541390
if get_pp_group().is_last_rank:
13551391
hidden_states = hidden_states[logit_indices]
1356-
logits = self.model.compute_logits(hidden_states, None)
1357-
dummy_tensors = lambda v: torch.full(
1358-
(num_reqs, ), v, device=self.device)
1359-
dummy_metadata = SamplingMetadata(
1360-
temperature=dummy_tensors(0.5),
1361-
all_greedy=False,
1362-
all_random=False,
1363-
top_p=dummy_tensors(0.9),
1364-
top_k=dummy_tensors(logits.size(1) - 1),
1365-
min_p=None,
1366-
generators={},
1367-
max_num_logprobs=None,
1368-
no_penalties=True,
1369-
prompt_token_ids=torch.ones_like(logits,
1370-
dtype=torch.int64),
1371-
frequency_penalties=dummy_tensors(0.1),
1372-
presence_penalties=dummy_tensors(0.1),
1373-
repetition_penalties=dummy_tensors(0.1),
1374-
output_token_ids=[[] for _ in range(num_reqs)],
1375-
min_tokens={},
1376-
logit_bias=[None for _ in range(num_reqs)],
1377-
allowed_token_ids_mask=None,
1378-
)
1379-
sampler_output = self.model.sample(
1380-
logits=logits, sampling_metadata=dummy_metadata)
1392+
sampler_output = self._dummy_sampler_run(hidden_states)
13811393
else:
1382-
logits = None
13831394
sampler_output = None
1384-
dummy_metadata = None
13851395
torch.cuda.synchronize()
1386-
del hidden_states, logits, sampler_output, dummy_metadata
1396+
del hidden_states, sampler_output
13871397
self.encoder_cache.clear()
13881398
gc.collect()
13891399

vllm/v1/worker/gpu_worker.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ def init_device(self):
119119
self.model_runner: GPUModelRunner = GPUModelRunner(
120120
self.vllm_config, self.device)
121121

122+
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
123+
# to hijack tensor allocation.
122124
def load_model(self) -> None:
123125
if self.vllm_config.model_config.enable_sleep_mode:
124126
allocator = CuMemAllocator.get_instance()
@@ -211,6 +213,25 @@ def compile_or_warm_up_model(self) -> None:
211213
self.model_runner._dummy_run(size)
212214
if not self.model_config.enforce_eager:
213215
self.model_runner.capture_model()
216+
217+
# Warm up sampler and preallocate memory buffer for logits and other
218+
# sampling related tensors of max possible shape to avoid memory
219+
# fragmentation issue.
220+
# NOTE: This is called after `capture_model` on purpose to prevent
221+
# memory buffers from being cleared by `torch.cuda.empty_cache`.
222+
try:
223+
self.model_runner._dummy_sampler_run(
224+
hidden_states=self.model_runner._dummy_run(
225+
num_tokens=self.scheduler_config.max_num_seqs))
226+
except RuntimeError as e:
227+
if 'out of memory' in str(e):
228+
raise RuntimeError(
229+
"CUDA out of memory occurred when warming up sampler. "
230+
"Please try lowering `gpu_memory_utilization` when "
231+
"initializing the engine.") from None
232+
else:
233+
raise e
234+
214235
# Reset the seed to ensure that the random state is not affected by
215236
# the model initialization and profiling.
216237
set_random_seed(self.model_config.seed)

0 commit comments

Comments
 (0)