From 2e27291ce4adbea9d2cb2f9bd6c43ec492e2bb5c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 13 May 2024 16:08:45 +0100 Subject: [PATCH] Generate: assistant should be greedy in assisted decoding (#30778) * assistant should be greedy * better comment * Update src/transformers/generation/candidate_generator.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/generation/candidate_generator.py | 6 ++++++ src/transformers/generation/configuration_utils.py | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index a958228d9be27f..52371d94dc56d1 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -150,6 +150,12 @@ def __init__( self.generation_config.return_dict_in_generate = True self.generation_config.output_scores = True + # Disable sampling -- this implementation of assisted generation/speculative decoding uses the assistant + # greedily to maximize matches. Disables sampling-related flags to prevent warnings + self.generation_config.do_sample = False + for attr in ("temperature", "top_p", "min_p", "typical_p", "top_k", "epsilon_cutoff", "eta_cutoff"): + setattr(self.generation_config, attr, None) + # avoid unnecessary warnings that min_length is larger than max_new_tokens # remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`) self.main_model_min_length = self.generation_config.min_length diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 85fcc055948c41..2bdf20c68613e9 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -496,6 +496,11 @@ def validate(self, is_init=False): greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p), UserWarning, ) + if self.min_p is not None: + warnings.warn( + greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p), + UserWarning, + ) if self.typical_p is not None and self.typical_p != 1.0: warnings.warn( greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),