Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,9 @@ def __init__(
self.main_model_min_length = self.generation_config.min_length
self.generation_config.min_length = 0
self.generation_config.min_new_tokens = None
for processor in self.logits_processor:
if isinstance(processor, MinLengthLogitsProcessor):
raise ValueError(
"Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
"Please pass in `min_length` into `.generate()` instead"
)
self.logits_processor = [
processor for processor in self.logits_processor if not isinstance(processor, MinLengthLogitsProcessor)
]
Comment on lines +180 to +182
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Length is controlled by main model's generation loop, so we should just discard those on the assistant right? @gante

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes 👍

(see comment on L175-176)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we essentially remove the thrown error? Not sure if this is really relevant to this PR, more of a shortener no?


# We need to roll back the cache in assisted generation, only DynamicCache is supported
self.generation_config.cache_implementation = "dynamic_full"
Expand Down
19 changes: 19 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2217,6 +2217,7 @@ def _extract_generation_mode_kwargs(
"assistant_tokenizer": kwargs.pop("assistant_tokenizer", None),
"assistant_model": assistant_model,
"streamer": streamer,
"assistant_temperature": kwargs.pop("assistant_temperature", None),
}
generation_mode_kwargs["synced_gpus"] = (
Comment on lines 2218 to 2222
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to change generate signature!! it gets automatically forwarded.

In fact, we could remove assistant_model from the signature (👀 v5?) and all the decoding method-specific kwargs get automatically forwarded.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no signature change, but it's still an argument (that should be documented)

In any case, I'd rather have it being controlled by assistant_model.generation_config.temperature, in AssistedCandidateGenerator.__init__ -- if it's the default value (1.0), or == main model temperature, then override.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, it should lie within the assistant's generation config if possible. That would be cleaner

(is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
Expand Down Expand Up @@ -3457,6 +3458,7 @@ def _assisted_decoding(
assistant_model: Optional["PreTrainedModel"] = None,
assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
assistant_temperature: Optional[float] = None,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -3491,6 +3493,9 @@ def _assisted_decoding(
The tokenizer used for the assistant model. If not provided, the token space is assumed to be the same.
tokenizer (`PreTrainedTokenizerBase`, *optional*):
The tokenizer used for the main model. If not provided, the token space is assumed to be the same.
assistant_temperature (`float`, *optional*):
The temperature to use for the assistant model. If not provided and main generation temperature is below
1.5, it will be set to 1.5 (to improve decoding speed).
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand All @@ -3511,6 +3516,20 @@ def _assisted_decoding(
and any(getattr(l, "is_compileable", False) for l in model_kwargs["past_key_values"].layers)
):
raise ValueError("assisted generate is not supported with Static cache classes`")
# Prefer a slightly higher temperature for the assistant when not explicitly provided
idx = next((i for i, p in enumerate(logits_processor) if isinstance(p, TemperatureLogitsWarper)), None)
temp_processor = logits_processor.pop(idx) if idx is not None else TemperatureLogitsWarper(temperature=1.0)

if assistant_temperature is None and temp_processor is not None and temp_processor.temperature < 1.5:
logger.warning_once(
f"The assistant's sampling temperature comes from main generation loop set to {temp_processor.temperature}, "
"but speculative decoding benefits from slightly hotter candidate generation, (see #40976) so we are setting it "
"to 1.5. This should improve decoding speed in most cases. Use `assistant_temperature` to override this value."
)
assistant_temperature = 1.5

if assistant_temperature is not None:
logits_processor.insert(0, TemperatureLogitsWarper(temperature=assistant_temperature))
Comment on lines +3519 to +3532
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't this change the temperature for both models? 👀 (logits_processor also used in step 2.3)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea that's a good question, we only up the base temperature no? We could also just modify the temperature in place if that's the case

# Get the candidate generator, given the parameterization
candidate_generator = self._get_candidate_generator(
generation_config=generation_config,
Expand Down