Skip to content

Commit 43735bf

Browse files
authored
[TPU] Remove redundant input tensor cloning (#7660)
1 parent da11523 commit 43735bf

File tree

1 file changed

+8
-20
lines changed

1 file changed

+8
-20
lines changed

vllm/worker/tpu_model_runner.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -516,27 +516,19 @@ def execute_model(
516516
raise ValueError(
517517
"TPUModelRunner does not support multi-step execution.")
518518

519-
def _execute_model(*args, clone: bool = False) -> torch.Tensor:
519+
def _execute_model(*args):
520520
"""Move input args from CPU to device and execute the model."""
521521

522-
def _copy_to_device(x: torch.Tensor) -> torch.Tensor:
523-
if clone:
524-
# When x is a slice of a CPU tensor, XLA may copy the whole
525-
# original tensor to TPU instead of only copying x.
526-
# To avoid this, we copy x after cloning.
527-
x = x.clone()
528-
return x.to(self.device)
529-
530522
new_args = []
531523
for arg in args:
532524
if isinstance(arg, torch.Tensor):
533-
arg = _copy_to_device(arg)
525+
arg = arg.to(self.device)
534526
elif isinstance(arg, AttentionMetadata):
535-
arg.slot_mapping = _copy_to_device(arg.slot_mapping)
527+
arg.slot_mapping = arg.slot_mapping.to(self.device)
536528
if getattr(arg, "block_tables", None) is not None:
537-
arg.block_tables = _copy_to_device(arg.block_tables)
529+
arg.block_tables = arg.block_tables.to(self.device)
538530
if getattr(arg, "context_lens", None) is not None:
539-
arg.context_lens = _copy_to_device(arg.context_lens)
531+
arg.context_lens = arg.context_lens.to(self.device)
540532
new_args.append(arg)
541533
return self.model(*new_args)
542534

@@ -563,13 +555,9 @@ def _copy_to_device(x: torch.Tensor) -> torch.Tensor:
563555
output_token_ids = _execute_model(
564556
model_input.token_ids[None, start_idx:end_idx],
565557
model_input.position_ids[None, start_idx:end_idx],
566-
model_input.attn_metadata,
567-
model_input.input_lens[i:i + 1],
568-
model_input.t[i:i + 1],
569-
model_input.p[i:i + 1],
570-
model_input.num_samples,
571-
kv_caches,
572-
clone=True)
558+
model_input.attn_metadata, model_input.input_lens[i:i + 1],
559+
model_input.t[i:i + 1], model_input.p[i:i + 1],
560+
model_input.num_samples, kv_caches)
573561
# Retrieve the outputs to CPU.
574562
next_token_ids += output_token_ids.cpu().tolist()
575563
start_idx = end_idx

0 commit comments

Comments
 (0)