Fix logic error in prepare_inputs_for_generation cache slicing condition
#41764
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fix logic error in
prepare_inputs_for_generationcache slicing condition:Background
I think the PR
generatedelegates default cache initialization to the model #41505introduced a logic error where the condition for calling
_cache_dependant_input_preparationusesis Noneinstead ofis not None, causing crashes whenprepare_inputs_for_generationis called withpast_key_values=Noneanduse_cache=False.Bug
PR #41505 introduced this condition:
The condition
past_key_values is None or use_cachemeans:This triggers the function even when:
This combination is invalid for cache-dependent preparation and causes a crash when accessing
cache_position[-1](line 456).Note that during normal generation, it works fine because
use_cache=True, making the buggypast_key_values is Nonepart irrelevant.Fix
This PR changes the condition to:
The condition
past_key_values is not None or use_cachemeans:This is semantically correct and matches the intent described in the PR #41505 comment: #41505 (comment)
The
use_cachepart handles stateful models, whilepast_key_values is not Nonehandles normal cached models.Testing
This PR fixes the downstream failing test in TRL:
See the associated issue:
Related
This PR addresses a logic error introduced by:
generatedelegates default cache initialization to the model #41505This PR will fix CI fails with dev dependencies: TypeError: 'NoneType' object is not subscriptable trl#4272
CC:
generatedelegates default cache initialization to the model #41505generatedelegates default cache initialization to the model #41505; see 🚨 [v5]generatedelegates default cache initialization to the model #41505 (comment)