Skip to content

Qwen2_5_VLForConditionalGeneration cfg forward twice error #39749

@guozhiyao

Description

@guozhiyao

System Info

transformers 4.49.0

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

accepted_tokens = torch.zeros(batch_size, 0, dtype=torch.long, device=input_ids.device)
    
    attention_mask = input_ids.ne(self.config.t5_pad_token_id)

    pos_kwargs = dict(
        inputs_embeds=inputs_emb, attention_mask=attention_mask,
        use_cache=True
    )
    pos_kwargs = self._get_initial_cache_position(None, pos_kwargs)

    attention_mask = negative_input_ids.ne(self.config.t5_pad_token_id)

    neg_kwargs = dict(
        inputs_embeds=negative_emb, attention_mask=attention_mask,
        use_cache=True
    )
    neg_kwargs = self._get_initial_cache_position(None, neg_kwargs)

    first_token = None
    for it in range(max_steps):

        pos_kwargs["input_ids"] = first_token
        pos_kwargs = self.prepare_inputs_for_generation(**pos_kwargs)
        output = self(**pos_kwargs)
        logits = output.logits
        pos_kwargs = self._update_model_kwargs_for_generation(output, pos_kwargs)


        cond_draft_logits = logits[:, -1:, :]

        if cfg_scale > 1.0 and negative_input_ids is not None:

            neg_kwargs["input_ids"] = first_token
            neg_kwargs = self.prepare_inputs_for_generation(**neg_kwargs)
            output = self(**neg_kwargs)
            uncond_logits = output.logits
            neg_kwargs = self._update_model_kwargs_for_generation(output, neg_kwargs)

            uncond_draft_logits = uncond_logits[:, -1:, :]

            draft_logits = uncond_draft_logits + cfg_scale * (cond_draft_logits - uncond_draft_logits)

        else:
            draft_logits = cond_draft_logits

        draft_logits /= temperature

        draft_probs = F.softmax(draft_logits, dim=-1, dtype=torch.float32)

        draft_tokens = torch.argmax(draft_probs, dim=-1)
        first_token = draft_tokens[:, :1]

        accepted_tokens = torch.cat([accepted_tokens, first_token], dim=1)

        neg_kwargs["inputs_embeds"] = None
        pos_kwargs["inputs_embeds"] = None

    # 10. Final Output --------------------------------------------------------
    output = accepted_tokens[:, :max_steps]

Expected behavior

I use the qwen2vl to do the cfg generation. But the qwen has self.rope_deltas, which will be modified by the second forwrad.

I modify the qwen code with

  1. add the _update_model_kwargs_for_generation to save the rope_deltas to model_kwargs.
    def _update_model_kwargs_for_generation(
            self,
            outputs: ModelOutput,
            model_kwargs: Dict[str, Any],
            is_encoder_decoder: bool = False,
            num_new_tokens: int = 1,
        ) -> Dict[str, Any]:

        model_kwargs["rope_deltas"] = self.rope_deltas
        return super()._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder, num_new_tokens)
  1. in prepare_inputs_for_generation, I will get the rope_deltas from model_kwargs.

self.rope_deltas = kwargs.get("rope_deltas", None)

And the bug is fixed.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions