Skip to content

Commit 300acb8

Browse files
authored
[Core][Bugfix] Use correct device to initialize GPU data during CUDA-graph-capture (#11233)
Signed-off-by: Yan Burman <yanburman@users.noreply.github.com> Signed-off-by: Ido Asraff <idoa@atero.ai>
1 parent d91457d commit 300acb8

File tree

5 files changed

+23
-15
lines changed

5 files changed

+23
-15
lines changed

tests/distributed/test_custom_all_reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
5050

5151
for sz in test_sizes:
5252
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
53-
with graph_capture() as graph_capture_context:
53+
with graph_capture(device=device) as graph_capture_context:
5454
# use integers so result matches NCCL exactly
5555
inp1 = torch.randint(1,
5656
16, (sz, ),

tests/distributed/test_pynccl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def multiple_allreduce_with_vllm_worker_fn():
107107
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
108108
ensure_model_parallel_initialized(2, 2)
109109
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
110-
with graph_capture():
110+
with graph_capture(device=device):
111111
# two tp groups can communicate independently
112112
if torch.distributed.get_rank() in [0, 1]:
113113
tensor = tensor_model_parallel_all_reduce(tensor)

vllm/distributed/parallel_state.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,7 @@ def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:
920920

921921

922922
@contextmanager
923-
def graph_capture():
923+
def graph_capture(device: torch.device):
924924
"""
925925
`graph_capture` is a context manager which should surround the code that
926926
is capturing the CUDA graph. Its main purpose is to ensure that the
@@ -934,8 +934,9 @@ def graph_capture():
934934
in order to explicitly distinguish the kernels to capture
935935
from other kernels possibly launched on background in the default stream.
936936
"""
937-
with get_tp_group().graph_capture() as context, get_pp_group(
938-
).graph_capture(context):
937+
context = GraphCaptureContext(torch.cuda.Stream(device=device))
938+
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
939+
context):
939940
yield context
940941

941942

vllm/v1/worker/gpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ def capture_model(self) -> None:
836836
# Trigger CUDA graph capture for specific shapes.
837837
# Capture the large shapes first so that the smaller shapes
838838
# can reuse the memory pool allocated for the large shapes.
839-
with graph_capture():
839+
with graph_capture(device=self.device):
840840
for num_tokens in reversed(self.cudagraph_batch_sizes):
841841
for _ in range(self.vllm_config.compilation_config.
842842
cudagraph_num_of_warmups):

vllm/worker/model_runner.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,10 +1426,15 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
14261426

14271427
# Prepare dummy inputs. These will be reused for all batch sizes.
14281428
max_batch_size = self.max_batchsize_to_capture
1429-
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
1430-
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
1429+
input_tokens = torch.zeros(max_batch_size,
1430+
dtype=torch.long,
1431+
device=self.device)
1432+
input_positions = torch.zeros(max_batch_size,
1433+
dtype=torch.long,
1434+
device=self.device)
14311435
if self.model_config.uses_mrope:
1432-
input_positions = torch.tile(input_positions, (3, 1))
1436+
input_positions = torch.tile(input_positions,
1437+
(3, 1)).cuda(device=self.device)
14331438
# Prepare dummy previous_hidden_states only if needed by the model.
14341439
# This is used by draft models such as EAGLE.
14351440
previous_hidden_states = None
@@ -1448,8 +1453,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
14481453
dtype=self.model_config.dtype,
14491454
device=self.device)
14501455

1451-
with self.attn_state.graph_capture(
1452-
max_batch_size), graph_capture() as graph_capture_context:
1456+
with self.attn_state.graph_capture(max_batch_size), graph_capture(
1457+
self.device) as graph_capture_context:
14531458
# NOTE: Capturing the largest batch size first may help reduce the
14541459
# memory usage of CUDA graph.
14551460
for virtual_engine in range(
@@ -1549,10 +1554,12 @@ def _update_inputs_to_capture_for_enc_dec_model(self,
15491554
"""
15501555
# During the decode phase encoder_input_ids and encoder_positions are
15511556
# unset. Do the same thing for graph capture.
1552-
capture_inputs["encoder_input_ids"] = torch.tensor(
1553-
[], dtype=torch.long).cuda()
1554-
capture_inputs["encoder_positions"] = torch.tensor(
1555-
[], dtype=torch.long).cuda()
1557+
capture_inputs["encoder_input_ids"] = torch.tensor([],
1558+
dtype=torch.long,
1559+
device=self.device)
1560+
capture_inputs["encoder_positions"] = torch.tensor([],
1561+
dtype=torch.long,
1562+
device=self.device)
15561563

15571564
@property
15581565
def vocab_size(self) -> int:

0 commit comments

Comments
 (0)