-
Notifications
You must be signed in to change notification settings - Fork 31.5k
Better defaults for assisted generation #40976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2217,6 +2217,7 @@ def _extract_generation_mode_kwargs( | |
| "assistant_tokenizer": kwargs.pop("assistant_tokenizer", None), | ||
| "assistant_model": assistant_model, | ||
| "streamer": streamer, | ||
| "assistant_temperature": kwargs.pop("assistant_temperature", None), | ||
| } | ||
| generation_mode_kwargs["synced_gpus"] = ( | ||
|
Comment on lines
2218
to
2222
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to change generate signature!! it gets automatically forwarded. In fact, we could remove
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no signature change, but it's still an argument (that should be documented) In any case, I'd rather have it being controlled by
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1, it should lie within the assistant's generation config if possible. That would be cleaner |
||
| (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1 | ||
|
|
@@ -3457,6 +3458,7 @@ def _assisted_decoding( | |
| assistant_model: Optional["PreTrainedModel"] = None, | ||
| assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None, | ||
| tokenizer: Optional["PreTrainedTokenizerBase"] = None, | ||
| assistant_temperature: Optional[float] = None, | ||
| **model_kwargs, | ||
| ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: | ||
| r""" | ||
|
|
@@ -3491,6 +3493,9 @@ def _assisted_decoding( | |
| The tokenizer used for the assistant model. If not provided, the token space is assumed to be the same. | ||
| tokenizer (`PreTrainedTokenizerBase`, *optional*): | ||
| The tokenizer used for the main model. If not provided, the token space is assumed to be the same. | ||
| assistant_temperature (`float`, *optional*): | ||
| The temperature to use for the assistant model. If not provided and main generation temperature is below | ||
| 1.5, it will be set to 1.5 (to improve decoding speed). | ||
| model_kwargs: | ||
| Additional model specific keyword arguments will be forwarded to the `forward` function of the model. | ||
| If model is an encoder-decoder model the kwargs should include `encoder_outputs`. | ||
|
|
@@ -3511,6 +3516,20 @@ def _assisted_decoding( | |
| and any(getattr(l, "is_compileable", False) for l in model_kwargs["past_key_values"].layers) | ||
| ): | ||
| raise ValueError("assisted generate is not supported with Static cache classes`") | ||
| # Prefer a slightly higher temperature for the assistant when not explicitly provided | ||
| idx = next((i for i, p in enumerate(logits_processor) if isinstance(p, TemperatureLogitsWarper)), None) | ||
| temp_processor = logits_processor.pop(idx) if idx is not None else TemperatureLogitsWarper(temperature=1.0) | ||
|
|
||
| if assistant_temperature is None and temp_processor is not None and temp_processor.temperature < 1.5: | ||
| logger.warning_once( | ||
| f"The assistant's sampling temperature comes from main generation loop set to {temp_processor.temperature}, " | ||
| "but speculative decoding benefits from slightly hotter candidate generation, (see #40976) so we are setting it " | ||
| "to 1.5. This should improve decoding speed in most cases. Use `assistant_temperature` to override this value." | ||
| ) | ||
| assistant_temperature = 1.5 | ||
|
|
||
| if assistant_temperature is not None: | ||
| logits_processor.insert(0, TemperatureLogitsWarper(temperature=assistant_temperature)) | ||
|
Comment on lines
+3519
to
+3532
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doesn't this change the temperature for both models? 👀 (
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea that's a good question, we only up the base temperature no? We could also just modify the temperature in place if that's the case |
||
| # Get the candidate generator, given the parameterization | ||
| candidate_generator = self._get_candidate_generator( | ||
| generation_config=generation_config, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Length is controlled by main model's generation loop, so we should just discard those on the assistant right? @gante
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes 👍
(see comment on L175-176)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we essentially remove the thrown error? Not sure if this is really relevant to this PR, more of a shortener no?