Skip to content

Commit 6eae887

Browse files
authored
Update Caching logic to only trigger on the first inference sample (#1369)
* Only set up during the first sample * Cleaner
1 parent 93f713f commit 6eae887

File tree

1 file changed

+26
-22
lines changed

1 file changed

+26
-22
lines changed

torchchat/generate.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,7 @@ def generate(
591591
Dict[str, Any]
592592
] = None, # List of Image prompt tensors for multimodal models
593593
start_pos: int = 0,
594+
skip_cache_setup: bool = False,
594595
draft_model: Model,
595596
speculate_k: Optional[int] = 8,
596597
sequential_prefill=True,
@@ -614,26 +615,27 @@ def generate(
614615
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length)
615616
# set up caches only if first inference
616617
if start_pos == 0:
617-
model = model.to(device=device)
618-
with torch.device(device):
619-
if (
620-
self.is_torchtune_model
621-
or self.model.config.model_type == ModelType.Flamingo
622-
):
623-
# 6404 is one-gpu affordable max_seq_length for single image input
624-
model.setup_caches(
625-
batch_size=1,
626-
dtype=self.dtype,
627-
encoder_max_seq_len=6404,
628-
decoder_max_seq_len=max_seq_length,
629-
)
630-
else:
631-
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
632-
if is_speculative and draft_model is not model:
633-
draft_model.setup_caches(
634-
max_batch_size=1,
635-
max_seq_length=max_seq_length,
636-
)
618+
if not skip_cache_setup:
619+
model = model.to(device=device)
620+
with torch.device(device):
621+
if (
622+
self.is_torchtune_model
623+
or self.model.config.model_type == ModelType.Flamingo
624+
):
625+
# 6404 is one-gpu affordable max_seq_length for single image input
626+
model.setup_caches(
627+
batch_size=1,
628+
dtype=self.dtype,
629+
encoder_max_seq_len=6404,
630+
decoder_max_seq_len=max_seq_length,
631+
)
632+
else:
633+
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
634+
if is_speculative and draft_model is not model:
635+
draft_model.setup_caches(
636+
max_batch_size=1,
637+
max_seq_length=max_seq_length,
638+
)
637639
if model.config.model_type == ModelType.Flamingo:
638640
model.reset_caches()
639641

@@ -1013,6 +1015,7 @@ def chat(
10131015
)
10141016
for i in range(num_samples):
10151017
device_sync(device=self.builder_args.device)
1018+
is_first_sample: bool = i == 0
10161019
if generator_args.chat_mode:
10171020
prompt = input("User: ")
10181021
if prompt == "/bye":
@@ -1038,7 +1041,7 @@ def chat(
10381041
]
10391042
)
10401043
self.system_prompt = None
1041-
elif i == 0:
1044+
elif is_first_sample:
10421045
encoded = self.chat_formatter.encode_dialog_prompt(
10431046
[{"role": "user", "content": prompt}]
10441047
)
@@ -1107,6 +1110,7 @@ def callback(x, *, done_generating=False):
11071110
top_k=generator_args.top_k,
11081111
sequential_prefill=generator_args.sequential_prefill,
11091112
start_pos=start_pos,
1113+
skip_cache_setup=not is_first_sample,
11101114
max_seq_length=max_seq_length,
11111115
)
11121116
for token_tensor, metrics in generator_func:
@@ -1116,7 +1120,7 @@ def callback(x, *, done_generating=False):
11161120
if metrics is not None:
11171121
aggregate_metrics.update(metrics)
11181122
yield token_tensor, metrics
1119-
jit_compile = (i == 0) and (
1123+
jit_compile = is_first_sample and (
11201124
generator_args.compile or generator_args.compile_prefill
11211125
)
11221126
compilation_time = time.perf_counter() - t0

0 commit comments

Comments
 (0)