Skip to content

Commit f30c225

Browse files
Fix logic error in prepare_inputs_for_generation cache slicing condition (#41764)
Fix logic error in cache slicing condition Co-authored-by: Raushan Turganbay <raushan@huggingface.co>
1 parent 496c283 commit f30c225

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/transformers/generation/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ def prepare_inputs_for_generation(
608608
use_cache = kwargs.get("use_cache")
609609
if use_cache is None:
610610
use_cache = getattr(self.config, "use_cache", False)
611-
if past_key_values is None or use_cache:
611+
if past_key_values is not None or use_cache:
612612
# TODO (joao): handle the case where cache length == input_ids length. The function below results in an
613613
# exception because we get empty input_ids after slicing. In essence, we need to roll back the cache 1
614614
# token to recompute the logits for the first token to be generated (but not all caches support roll backs)

0 commit comments

Comments
 (0)