-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: Deprecate returning legacy cache by default; Handle use_cache=False
#32863
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
use_cache=False
@@ -130,9 +130,29 @@ class GenerationConfig(PushToHubMixin): | |||
[this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. | |||
penalty_alpha (`float`, *optional*): | |||
The values balance the model confidence and the degeneration penalty in contrastive search decoding. | |||
dola_layers (`str` or `List[int]`, *optional*): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved up to this documentation section (Parameters that control the generation strategy used
), which makes more sense
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md) | ||
or [the paper](https://arxiv.org/abs/2309.03883) for more details. | ||
|
||
> Parameters that control the cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new cache-related docs section in GenerationConfig
, moved all cache-related flags here
@@ -544,8 +539,9 @@ def validate(self, is_init=False): | |||
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.") | |||
if self.pad_token_id is not None and self.pad_token_id < 0: | |||
warnings.warn( | |||
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch generating, if there is padding. " | |||
"Please set `pad_token_id` explicitly by `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation, and ensure your `input_ids` input does not have negative values." | |||
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(>120 chars/line)
@@ -675,6 +671,14 @@ def validate(self, is_init=False): | |||
group_error_prefix | |||
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical." | |||
) | |||
# DoLa generation | |||
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(moved)
@@ -136,27 +136,23 @@ class GenerateDecoderOnlyOutput(ModelOutput): | |||
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |||
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter | |||
if all batches finished early due to the `eos_token_id`. | |||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): | |||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our docs we often mention that there are two ways to parameterize generate
(generation_config
or pass arg to generate
). I don't think we need to be verbose here.
Also, setting through config
is deprecated 😉
Returns the model cache, used to speed up decoding. Different models have a different cache format, check | ||
the model's documentation. Usually, a [`~cache_utils.Cache`] instance. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rewrote this one.
The old description was outdated (legacy cache), and we now know that different models have different caches, so we shouldn't be precise here. The model class docs can be more precise, let's redirect users there.
@@ -328,6 +312,7 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput): | |||
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None | |||
|
|||
|
|||
# TODO (joao): remove the equivalent classes and typing shortcuts below in v5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(These aliases made sense in the past, not anymore. They are, however, hard to deprecate!)
@@ -1497,6 +1482,127 @@ def _supports_default_dynamic_cache(self) -> bool: | |||
""" | |||
return self._supports_cache_class and "jamba" not in self.__class__.__name__.lower() | |||
|
|||
def _prepare_cache_for_generation( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New function, moving the cache logic from generate
. I've organized the logic in blocks, putting the cases where we DON'T prepare a new cache at the top.
It is doing essentially the same, except for the Quick escape route 2
, which is new. Added the warning in Quick escape route 3
.
if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"): | ||
if isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)): | ||
result.past_key_values = result.past_key_values.to_legacy_cache() | ||
# Convert to legacy cache format if requested |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is expanded to handle a deprecation cycle
@@ -194,6 +194,7 @@ def _greedy_generate( | |||
output_attentions=False, | |||
output_hidden_states=False, | |||
return_dict_in_generate=False, | |||
use_cache=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changes in this file: pass use_cache
to generate
, instead of relying on model.config.use_cache=False
and its side-effects
added a check to confirm that the cache is None when we pass use_cache=False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Let's make sure slow tests all pass as well here!
# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches, | ||
# which is only supported in dynamic caches atm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's create an issue and leave it up to the community in the mean time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a442522
to
e8492bf
Compare
What does this PR do?
Another step towards using
Cache
everywhere 💪This PR makes the following [
Cache
+generate
]-related changes:use_cache=False
(fixes Cache updating when use_cache = False #32843 )generate
tests now explicitly passuse_cache
, instead of setting it inmodel.config
🤢 We were relying on a LOT of side effects, and missing the incorrect case mentioned in Cache updating when use_cache = False #32843generation_config
Cache
instance by default ongenerate
generate
into a single function, and reorganize the logic by blocks