Skip to content

Commit

Permalink
Generate: Deprecate returning legacy cache by default; Handle `use_ca…
Browse files Browse the repository at this point in the history
…che=False` (huggingface#32863)
  • Loading branch information
gante authored and zucchini-nlp committed Aug 30, 2024
1 parent e879b36 commit 210ba5b
Show file tree
Hide file tree
Showing 4 changed files with 311 additions and 256 deletions.
109 changes: 58 additions & 51 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,29 @@ class GenerationConfig(PushToHubMixin):
[this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
penalty_alpha (`float`, *optional*):
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
dola_layers (`str` or `List[int]`, *optional*):
The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
"low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
layers up to the last 20 layers.
If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
or [the paper](https://arxiv.org/abs/2309.03883) for more details.
> Parameters that control the cache
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
cache_implementation (`str`, *optional*, default to `None`):
Cache class that should be used when generating.
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
it will be converted to its repsective `CacheConfig` internally.
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
return_legacy_cache (`bool`, *optional*, default to `True`):
Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.
> Parameters for manipulation of the model output logits
Expand Down Expand Up @@ -307,29 +327,6 @@ class GenerationConfig(PushToHubMixin):
max_matching_ngram_size (`int`, *optional*, default to `None`):
The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.
> Generation parameters exclusive to [DoLa decoding](https://arxiv.org/abs/2309.03883)
dola_layers (`str` or `List[int]`, *optional*):
The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
"low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
layers up to the last 20 layers.
If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
or [the paper](https://arxiv.org/abs/2309.03883) for more details.
> Parameters specific to the caching mechanism:
cache_implementation (`str`, *optional*, default to `None`):
Cache class that should be used when generating.
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
it will be converted to its repsective `CacheConfig` internally.
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
return_legacy_cache (`bool`, *optional*, default to `True`):
Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.
> Wild card
generation_kwargs:
Expand All @@ -352,7 +349,19 @@ def __init__(self, **kwargs):
self.num_beams = kwargs.pop("num_beams", 1)
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
self.penalty_alpha = kwargs.pop("penalty_alpha", None)
self.dola_layers = kwargs.pop("dola_layers", None)

# Parameters that control the cache
self.use_cache = kwargs.pop("use_cache", True)
self.cache_implementation = kwargs.pop("cache_implementation", None)
self.cache_config = kwargs.pop("cache_config", None)
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
if self.cache_config is None:
self.cache_config = cache_config_class()
elif isinstance(self.cache_config, dict):
self.cache_config = cache_config_class.from_dict(self.cache_config)
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)

# Parameters for manipulation of the model output logits
self.temperature = kwargs.pop("temperature", 1.0)
Expand Down Expand Up @@ -411,20 +420,6 @@ def __init__(self, **kwargs):
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")

# DoLa generation
self.dola_layers = kwargs.pop("dola_layers", None)

# Cache implementation
self.cache_implementation = kwargs.pop("cache_implementation", None)
self.cache_config = kwargs.pop("cache_config", None)
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
if self.cache_config is None:
self.cache_config = cache_config_class()
elif isinstance(self.cache_config, dict):
self.cache_config = cache_config_class.from_dict(self.cache_config)
self.return_legacy_cache = kwargs.pop("return_legacy_cache", True)

# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
Expand Down Expand Up @@ -544,8 +539,9 @@ def validate(self, is_init=False):
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
if self.pad_token_id is not None and self.pad_token_id < 0:
warnings.warn(
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch generating, if there is padding. "
"Please set `pad_token_id` explicitly by `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation, and ensure your `input_ids` input does not have negative values."
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch "
"generating, if there is padding. Please set `pad_token_id` explicitly as "
"`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation"
)

# Validation of attribute relations:
Expand Down Expand Up @@ -675,6 +671,14 @@ def validate(self, is_init=False):
group_error_prefix
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
)
# DoLa generation
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
warnings.warn(
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
"DoLa decoding is `repetition_penalty>=1.2`.",
UserWarning,
)

# 4. check `num_return_sequences`
if self.num_return_sequences != 1:
Expand All @@ -690,7 +694,7 @@ def validate(self, is_init=False):
f"({self.num_beams})."
)

# 5. check `cache_config`
# 5. check cache-related arguments
if self.cache_config is not None:
cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation)
if cache_class is None:
Expand All @@ -702,6 +706,20 @@ def validate(self, is_init=False):
if not isinstance(self.cache_config, cache_class):
self.cache_config = cache_class.from_dict(self.cache_config)
self.cache_config.validate()
if self.use_cache is False:
# In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used
# passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error
# (otherwise a user might need to overwrite several parameters).
no_cache_warning = (
"You have set `use_cache` to `False`, but {cache_arg} is set to {cache_arg_value}. {cache_arg} will "
"have no effect."
)
for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"):
if getattr(self, arg_name) is not None:
logger.warning_once(
no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)),
UserWarning,
)

# 6. check watermarking arguments
if self.watermarking_config is not None:
Expand All @@ -727,17 +745,6 @@ def validate(self, is_init=False):
"`generate()` (or a pipeline) directly."
)

# 6. if dola_layers is set, check if repetition_penalty is set to >= 1.2
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
dola_decoding_wrong_parameter_msg = (
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of {repetition_penalty}, "
"which could induce unwanted repetition. The recommended value for DoLa decoding is `repetition_penalty>=1.2`."
)
warnings.warn(
dola_decoding_wrong_parameter_msg.format(repetition_penalty=self.repetition_penalty),
UserWarning,
)

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down
Loading

0 comments on commit 210ba5b

Please sign in to comment.