Skip to content

Conversation

@manueldeprada
Copy link
Contributor

@manueldeprada manueldeprada commented Sep 15, 2025

Each decoding method has a common block of output handling boilerplate that worsens readability:

output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate

# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
    encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
    encoder_hidden_states = (
        model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
    )

...

while not finished:
    # Store scores, attentions and hidden_states when required
    if return_dict_in_generate:
        if output_scores:
            scores += (next_token_scores,)
        if output_logits:
            raw_logits += (next_token_logits,)
        if output_attentions:
            decoder_attentions += (
                (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
            )
            if self.config.is_encoder_decoder:
                cross_attentions += (outputs.cross_attentions,)

        if output_hidden_states:
            decoder_hidden_states += (
                (outputs.decoder_hidden_states,)
                if self.config.is_encoder_decoder
                else (outputs.hidden_states,)
            )

...
if return_dict_in_generate:
    if self.config.is_encoder_decoder:
        return XXXEncoderDecoderOutput(
            sequences=input_ids,
            scores=scores,
            logits=raw_logits,
            encoder_attentions=encoder_attentions,
            encoder_hidden_states=encoder_hidden_states,
            decoder_attentions=decoder_attentions,
            cross_attentions=cross_attentions,
            decoder_hidden_states=decoder_hidden_states,
            past_key_values=model_kwargs.get("past_key_values"),
        )
    else:
        return XXXDecoderOnlyOutput(
            sequences=input_ids,
            scores=scores,
            logits=raw_logits,
            attentions=decoder_attentions,
            hidden_states=decoder_hidden_states,
            past_key_values=model_kwargs.get("past_key_values"),
        )
else:
    return input_ids

This PR takes that boilerplate to reusable generate helpers

TODO: add generalization so that users can say output_x and x from forward gets forwarded.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@manueldeprada manueldeprada requested a review from gante September 17, 2025 19:36
@manueldeprada
Copy link
Contributor Author

@gante this PR is more of a RFC to see what you think than a full PR. If you agree this simplifies generate, I will put more work to make it clean for assisted gen and make tests happy!

@manueldeprada
Copy link
Contributor Author

related: #39834

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very much on board with this 👍 👍 👍

@manueldeprada
Copy link
Contributor Author

manueldeprada commented Nov 4, 2025

@gante This is ready for review, left some comments in the code. The _accumulate method in this PR does not hardcode arg names (attentions, hidden states, etc) but rather iterate on a general dict. This is less readable but prepares the ground for custom output_xxx like image_hidden_states in a future PR (as suggested in #39834).

I would suggest merging this PR as-is, and then I can make a second PR that enables custom output_xxx. For that second PR, we need to agree what is the best interface for users to specify which extra args from model output they want. A big caveat also to be discussed is that assisted generation needs to use _split_model_outputs which might not work for special model outputs, so we might just not support extra outputs in assisted generation. WDYT?

@manueldeprada manueldeprada marked this pull request as ready for review November 4, 2025 11:02
@manueldeprada manueldeprada requested a review from gante November 4, 2025 11:02
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since output_x comes from generation config, how do you suggest we enable extra generation outputs?

It could be a output_features=['attentions', 'hidden_states', 'scores'] etc

"will be skipped."
)

if can_compile:
Copy link
Contributor Author

@manueldeprada manueldeprada Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was missing from #40652 !! Just noticed while merging main here.

I added it so that it gets merged this week @gante otherwise I can push a separate fix!

@zucchini-nlp zucchini-nlp requested review from zucchini-nlp and removed request for gante November 10, 2025 13:07
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @manueldeprada , great job! I am happy to have a first step for better generation output handling.

Do you think we can make the dynamic output dict in this PR, since we already started the refactor? Would be super cool to get rid of near-duplicate code

Comment on lines +3720 to +3733
if not generation_config.return_dict_in_generate:
return {"return_dict_in_generate": False, "next_scores": None}
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits

next_scores = () if output_scores else None
next_logits = () if output_logits else None
decoder_attentions = () if output_attentions else None
cross_attentions = () if output_attentions and self.config.is_encoder_decoder else None
decoder_hidden_states = () if output_hidden_states else None

encoder_attentions = encoder_hidden_states = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we have to push further and make it output any value dynamically, as requested by users. Currently the PR splits out existing logic into its own fn but the existing code is very much repetitive

IMO we can check the model_outputs.keys() and dynamically update our generation output dict with the keys that are available in getattr(generation_config, f"output_{key}"). Since all models follow standard naming in output dict, it should have no edge cases

Comment on lines +3782 to +3798
if cur_len is not None:
for arg in splittable_args:
if generate_output.get(arg) is not None:
kwargs[arg] = _split_model_outputs(
kwargs[arg],
cur_len,
added_len,
is_prefill_pass=len(generate_output[arg]) == 0,
is_decoder_attention=(arg == "decoder_attentions"),
)
for arg in cropable_args:
if generate_output[arg] is not None:
kwargs[arg] = tuple(kwargs[arg][:, i, :] for i in range(added_len))
else:
for arg in all_args:
if generate_output.get(arg) is not None:
kwargs[arg] = (kwargs[arg],)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, this could be simplified no, if we set cur_len=1 as default. Then we can always try to split the output, it will catch up depending on length value

Comment on lines +3839 to +3841
if any(cache_key in model_kwargs for cache_key in ALL_CACHE_NAMES):
cache_key = next(cache_key for cache_key in ALL_CACHE_NAMES if cache_key in model_kwargs)
cache = model_kwargs[cache_key]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: with smth like caches_in_kwargs := [cache_key in model_kwargs for cache_key in ALL_CACHE_NAMES] we can avoid looping twice

Comment on lines +3844 to +3866
return encoder_decoder_cls(
sequences=sequences,
scores=generate_output["next_scores"],
logits=generate_output["next_logits"],
encoder_attentions=generate_output["encoder_attentions"],
encoder_hidden_states=generate_output["encoder_hidden_states"],
decoder_attentions=generate_output["decoder_attentions"],
cross_attentions=generate_output["cross_attentions"],
decoder_hidden_states=generate_output["decoder_hidden_states"],
past_key_values=cache,
**kwargs,
)
else:
return decoder_only_cls(
sequences=sequences,
scores=generate_output["next_scores"],
logits=generate_output["next_logits"],
attentions=generate_output["decoder_attentions"],
hidden_states=generate_output["decoder_hidden_states"],
past_key_values=cache,
**kwargs,
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe for future, would be great to not rely on expected set of keys and unpack everything in generate_output to the output dict. The GenerationOutput dict would have to be able to output anything for that.

Pseudo code like below

output_cls = encoder_decoder_cls if self.config.is_encoder_decoder else decoder_cls
return output_cls(sequences=sequences, past_key_values=cache, **generate_output)
    

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants