-
Notifications
You must be signed in to change notification settings - Fork 31.5k
Refactor output handling in generate for cleaner decoding methods #40887
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Refactor output handling in generate for cleaner decoding methods #40887
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This reverts commit e3aed39.
|
@gante this PR is more of a RFC to see what you think than a full PR. If you agree this simplifies generate, I will put more work to make it clean for assisted gen and make tests happy! |
|
related: #39834 |
gante
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very much on board with this 👍 👍 👍
|
@gante This is ready for review, left some comments in the code. The I would suggest merging this PR as-is, and then I can make a second PR that enables custom |
| output_attentions = generation_config.output_attentions | ||
| output_hidden_states = generation_config.output_hidden_states | ||
| output_scores = generation_config.output_scores | ||
| output_logits = generation_config.output_logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since output_x comes from generation config, how do you suggest we enable extra generation outputs?
It could be a output_features=['attentions', 'hidden_states', 'scores'] etc
| "will be skipped." | ||
| ) | ||
|
|
||
| if can_compile: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey @manueldeprada , great job! I am happy to have a first step for better generation output handling.
Do you think we can make the dynamic output dict in this PR, since we already started the refactor? Would be super cool to get rid of near-duplicate code
| if not generation_config.return_dict_in_generate: | ||
| return {"return_dict_in_generate": False, "next_scores": None} | ||
| output_attentions = generation_config.output_attentions | ||
| output_hidden_states = generation_config.output_hidden_states | ||
| output_scores = generation_config.output_scores | ||
| output_logits = generation_config.output_logits | ||
|
|
||
| next_scores = () if output_scores else None | ||
| next_logits = () if output_logits else None | ||
| decoder_attentions = () if output_attentions else None | ||
| cross_attentions = () if output_attentions and self.config.is_encoder_decoder else None | ||
| decoder_hidden_states = () if output_hidden_states else None | ||
|
|
||
| encoder_attentions = encoder_hidden_states = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we have to push further and make it output any value dynamically, as requested by users. Currently the PR splits out existing logic into its own fn but the existing code is very much repetitive
IMO we can check the model_outputs.keys() and dynamically update our generation output dict with the keys that are available in getattr(generation_config, f"output_{key}"). Since all models follow standard naming in output dict, it should have no edge cases
| if cur_len is not None: | ||
| for arg in splittable_args: | ||
| if generate_output.get(arg) is not None: | ||
| kwargs[arg] = _split_model_outputs( | ||
| kwargs[arg], | ||
| cur_len, | ||
| added_len, | ||
| is_prefill_pass=len(generate_output[arg]) == 0, | ||
| is_decoder_attention=(arg == "decoder_attentions"), | ||
| ) | ||
| for arg in cropable_args: | ||
| if generate_output[arg] is not None: | ||
| kwargs[arg] = tuple(kwargs[arg][:, i, :] for i in range(added_len)) | ||
| else: | ||
| for arg in all_args: | ||
| if generate_output.get(arg) is not None: | ||
| kwargs[arg] = (kwargs[arg],) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, this could be simplified no, if we set cur_len=1 as default. Then we can always try to split the output, it will catch up depending on length value
| if any(cache_key in model_kwargs for cache_key in ALL_CACHE_NAMES): | ||
| cache_key = next(cache_key for cache_key in ALL_CACHE_NAMES if cache_key in model_kwargs) | ||
| cache = model_kwargs[cache_key] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: with smth like caches_in_kwargs := [cache_key in model_kwargs for cache_key in ALL_CACHE_NAMES] we can avoid looping twice
| return encoder_decoder_cls( | ||
| sequences=sequences, | ||
| scores=generate_output["next_scores"], | ||
| logits=generate_output["next_logits"], | ||
| encoder_attentions=generate_output["encoder_attentions"], | ||
| encoder_hidden_states=generate_output["encoder_hidden_states"], | ||
| decoder_attentions=generate_output["decoder_attentions"], | ||
| cross_attentions=generate_output["cross_attentions"], | ||
| decoder_hidden_states=generate_output["decoder_hidden_states"], | ||
| past_key_values=cache, | ||
| **kwargs, | ||
| ) | ||
| else: | ||
| return decoder_only_cls( | ||
| sequences=sequences, | ||
| scores=generate_output["next_scores"], | ||
| logits=generate_output["next_logits"], | ||
| attentions=generate_output["decoder_attentions"], | ||
| hidden_states=generate_output["decoder_hidden_states"], | ||
| past_key_values=cache, | ||
| **kwargs, | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe for future, would be great to not rely on expected set of keys and unpack everything in generate_output to the output dict. The GenerationOutput dict would have to be able to output anything for that.
Pseudo code like below
output_cls = encoder_decoder_cls if self.config.is_encoder_decoder else decoder_cls
return output_cls(sequences=sequences, past_key_values=cache, **generate_output)
Each decoding method has a common block of output handling boilerplate that worsens readability:
This PR takes that boilerplate to reusable generate helpers
TODO: add generalization so that users can say output_x and x from forward gets forwarded.