Skip to content

Allow passing 2D attention mask #27640

@UniverseFly

Description

@UniverseFly

Feature request

Allow passing a 2D attention mask in model.forward.

Motivation

With this feature, it would be much easier to avoid cross-context contamination during pretraining and supervised finetuning when packing the sequences together for more efficient training.

Here is an example usecase discussed in (huggingface/trl#805):

Your contribution

Upon investigation into the source code, I found the current logic of initializing attention masks is mostly a fixed code snippet encoded in each model:

        if getattr(self.config, "_flash_attn_2_enabled", False):
            # 2d mask is passed through the layers
            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
        else:
            # 4d mask is passed through the layers
            attention_mask = _prepare_4d_causal_attention_mask(
                attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
            )

To enable this behavior may require hacking into each model. I should be able to handle part of them and submit a draft PR. But before that, I want to know if this feature request is reasonable.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions