@@ -516,27 +516,19 @@ def execute_model(
516
516
raise ValueError (
517
517
"TPUModelRunner does not support multi-step execution." )
518
518
519
- def _execute_model (* args , clone : bool = False ) -> torch . Tensor :
519
+ def _execute_model (* args ) :
520
520
"""Move input args from CPU to device and execute the model."""
521
521
522
- def _copy_to_device (x : torch .Tensor ) -> torch .Tensor :
523
- if clone :
524
- # When x is a slice of a CPU tensor, XLA may copy the whole
525
- # original tensor to TPU instead of only copying x.
526
- # To avoid this, we copy x after cloning.
527
- x = x .clone ()
528
- return x .to (self .device )
529
-
530
522
new_args = []
531
523
for arg in args :
532
524
if isinstance (arg , torch .Tensor ):
533
- arg = _copy_to_device ( arg )
525
+ arg = arg . to ( self . device )
534
526
elif isinstance (arg , AttentionMetadata ):
535
- arg .slot_mapping = _copy_to_device ( arg .slot_mapping )
527
+ arg .slot_mapping = arg .slot_mapping . to ( self . device )
536
528
if getattr (arg , "block_tables" , None ) is not None :
537
- arg .block_tables = _copy_to_device ( arg .block_tables )
529
+ arg .block_tables = arg .block_tables . to ( self . device )
538
530
if getattr (arg , "context_lens" , None ) is not None :
539
- arg .context_lens = _copy_to_device ( arg .context_lens )
531
+ arg .context_lens = arg .context_lens . to ( self . device )
540
532
new_args .append (arg )
541
533
return self .model (* new_args )
542
534
@@ -563,13 +555,9 @@ def _copy_to_device(x: torch.Tensor) -> torch.Tensor:
563
555
output_token_ids = _execute_model (
564
556
model_input .token_ids [None , start_idx :end_idx ],
565
557
model_input .position_ids [None , start_idx :end_idx ],
566
- model_input .attn_metadata ,
567
- model_input .input_lens [i :i + 1 ],
568
- model_input .t [i :i + 1 ],
569
- model_input .p [i :i + 1 ],
570
- model_input .num_samples ,
571
- kv_caches ,
572
- clone = True )
558
+ model_input .attn_metadata , model_input .input_lens [i :i + 1 ],
559
+ model_input .t [i :i + 1 ], model_input .p [i :i + 1 ],
560
+ model_input .num_samples , kv_caches )
573
561
# Retrieve the outputs to CPU.
574
562
next_token_ids += output_token_ids .cpu ().tolist ()
575
563
start_idx = end_idx
0 commit comments