@@ -591,6 +591,7 @@ def generate(
591
591
Dict [str , Any ]
592
592
] = None , # List of Image prompt tensors for multimodal models
593
593
start_pos : int = 0 ,
594
+ skip_cache_setup : bool = False ,
594
595
draft_model : Model ,
595
596
speculate_k : Optional [int ] = 8 ,
596
597
sequential_prefill = True ,
@@ -614,26 +615,27 @@ def generate(
614
615
max_new_tokens = min (max_new_tokens , max_seq_length - start_pos - prompt_length )
615
616
# set up caches only if first inference
616
617
if start_pos == 0 :
617
- model = model .to (device = device )
618
- with torch .device (device ):
619
- if (
620
- self .is_torchtune_model
621
- or self .model .config .model_type == ModelType .Flamingo
622
- ):
623
- # 6404 is one-gpu affordable max_seq_length for single image input
624
- model .setup_caches (
625
- batch_size = 1 ,
626
- dtype = self .dtype ,
627
- encoder_max_seq_len = 6404 ,
628
- decoder_max_seq_len = max_seq_length ,
629
- )
630
- else :
631
- model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
632
- if is_speculative and draft_model is not model :
633
- draft_model .setup_caches (
634
- max_batch_size = 1 ,
635
- max_seq_length = max_seq_length ,
636
- )
618
+ if not skip_cache_setup :
619
+ model = model .to (device = device )
620
+ with torch .device (device ):
621
+ if (
622
+ self .is_torchtune_model
623
+ or self .model .config .model_type == ModelType .Flamingo
624
+ ):
625
+ # 6404 is one-gpu affordable max_seq_length for single image input
626
+ model .setup_caches (
627
+ batch_size = 1 ,
628
+ dtype = self .dtype ,
629
+ encoder_max_seq_len = 6404 ,
630
+ decoder_max_seq_len = max_seq_length ,
631
+ )
632
+ else :
633
+ model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
634
+ if is_speculative and draft_model is not model :
635
+ draft_model .setup_caches (
636
+ max_batch_size = 1 ,
637
+ max_seq_length = max_seq_length ,
638
+ )
637
639
if model .config .model_type == ModelType .Flamingo :
638
640
model .reset_caches ()
639
641
@@ -1013,6 +1015,7 @@ def chat(
1013
1015
)
1014
1016
for i in range (num_samples ):
1015
1017
device_sync (device = self .builder_args .device )
1018
+ is_first_sample : bool = i == 0
1016
1019
if generator_args .chat_mode :
1017
1020
prompt = input ("User: " )
1018
1021
if prompt == "/bye" :
@@ -1038,7 +1041,7 @@ def chat(
1038
1041
]
1039
1042
)
1040
1043
self .system_prompt = None
1041
- elif i == 0 :
1044
+ elif is_first_sample :
1042
1045
encoded = self .chat_formatter .encode_dialog_prompt (
1043
1046
[{"role" : "user" , "content" : prompt }]
1044
1047
)
@@ -1107,6 +1110,7 @@ def callback(x, *, done_generating=False):
1107
1110
top_k = generator_args .top_k ,
1108
1111
sequential_prefill = generator_args .sequential_prefill ,
1109
1112
start_pos = start_pos ,
1113
+ skip_cache_setup = not is_first_sample ,
1110
1114
max_seq_length = max_seq_length ,
1111
1115
)
1112
1116
for token_tensor , metrics in generator_func :
@@ -1116,7 +1120,7 @@ def callback(x, *, done_generating=False):
1116
1120
if metrics is not None :
1117
1121
aggregate_metrics .update (metrics )
1118
1122
yield token_tensor , metrics
1119
- jit_compile = ( i == 0 ) and (
1123
+ jit_compile = is_first_sample and (
1120
1124
generator_args .compile or generator_args .compile_prefill
1121
1125
)
1122
1126
compilation_time = time .perf_counter () - t0
0 commit comments