Skip to content

Commit

Permalink
🚨🚨 Setting default behavior of assisted decoding (#33657)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmamou authored Sep 25, 2024
1 parent 5f0c181 commit 52daf4e
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 8 deletions.
2 changes: 2 additions & 0 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def __init__(
self.generation_config.return_dict_in_generate = True
self.generation_config.output_scores = True
self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold
# this flag allow us set the confidence stopping criteria for assistant model generation.
self.generation_config.is_assistant = True

# avoid unnecessary warnings that min_length is larger than max_new_tokens
# remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`)
Expand Down
16 changes: 9 additions & 7 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,19 +338,20 @@ class GenerationConfig(PushToHubMixin):
(e.g. multilingual models with different target languages in one batch)
> Generation parameters exclusive to assistant generation
num_assistant_tokens (`int`, *optional*, defaults to 5):
is_assistant (`bool`, *optional*, defaults to `False`):
Whether the model is an assistant (draft) model.
num_assistant_tokens (`int`, *optional*, defaults to 20):
Defines the number of _speculative tokens_ that shall be generated by the assistant model before being
checked by the target model at each iteration. Higher values for `num_assistant_tokens` make the generation
more _speculative_ : If the assistant model is performant larger speed-ups can be reached, if the assistant
model requires lots of corrections, lower speed-ups are reached.
num_assistant_tokens_schedule (`str`, *optional*, defaults to `"heuristic"`):
num_assistant_tokens_schedule (`str`, *optional*, defaults to `"constant"`):
Defines the schedule at which max assistant tokens shall be changed during inference.
- `"heuristic"`: When all speculative tokens are correct, increase `num_assistant_tokens` by 2 else
reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model.
- `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call.
- `"constant"`: `num_assistant_tokens` stays unchanged during generation
assistant_confidence_threshold (`float`, *optional*):
assistant_confidence_threshold (`float`, *optional*, defaults to 0.4):
The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower
than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
(defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead
Expand Down Expand Up @@ -452,9 +453,10 @@ def __init__(self, **kwargs):
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)

# Assistant generation
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")
self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", None)
self.is_assistant = False
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 20)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "constant")
self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", 0.4)

# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,8 @@ def _get_stopping_criteria(
if generation_config._eos_token_tensor is not None:
criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
if (
generation_config.assistant_confidence_threshold is not None
generation_config.is_assistant
and generation_config.assistant_confidence_threshold is not None
and generation_config.assistant_confidence_threshold > 0
):
criteria.append(
Expand Down
11 changes: 11 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2069,6 +2069,7 @@ def test_assisted_decoding_with_num_logits_to_keep(self):
"assistant_model": assistant_model,
}

assistant_model.generation_config.assistant_confidence_threshold = None
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
with_all_logits = model.generate(
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0
Expand Down Expand Up @@ -3098,6 +3099,16 @@ def test_length_warning_assisted_generation(self):
)
self.assertEqual(len(warning_list), 0)

def test_default_assisted_generation(self):
# Initialize the GenerationConfig object
config = GenerationConfig()

# Check the default values
self.assertEqual(config.num_assistant_tokens, 20)
self.assertEqual(config.num_assistant_tokens_schedule, "constant")
self.assertEqual(config.assistant_confidence_threshold, 0.4)
self.assertEqual(config.is_assistant, False)

def test_generated_length_assisted_generation(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
Expand Down

0 comments on commit 52daf4e

Please sign in to comment.