Skip to content

custom 4d attention_mask as transformers .forward() argument #27493

Closed
@poedator

Description

@poedator

Feature request

somewhere inside transformers models, 2d masks are converted into 4d. I want to be able to pass my own custom 4d mask to .forward().
Presently it causes error.
CODE EXAMPLE:

model_name = "openlm-research/open_llama_3b"
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, device_map=device)
# preparing KV cache
size0 = 5
max_token = 10000
x0 = torch.randint(max_token, (1, size0), device=device)
y0 = model.forward(x0, )

# forward with mask
size1 = 3
x1 = torch.randint(max_token, (1, size1), device=device)
mask_shape = (1, 1, size0, size1)  # bsz, head_dim=1, query_length, key_value_length
custom_mask = torch.randint(2, mask_shape, device=device)

model.forward(input_ids=x1, attention_mask=custom_mask, past_key_values=y0['past_key_values'])
# expected forward with this custom_mask

Error msg:

...
File .../transformers/src/transformers/modeling_attn_mask_utils.py:154, in AttentionMaskConverter._expand_mask(mask, dtype, tgt_len)
    149 @staticmethod
    150 def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    151     """
    152     Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    153     """
--> 154     bsz, src_len = mask.size()
    155     tgt_len = tgt_len if tgt_len is not None else src_len
    157     expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

ValueError: too many values to unpack (expected 2)

Motivation

need custom 4d mask for experiments with causal inference.

Your contribution

I am ready to get involved, with HF guidance.
tagging @patrickvonplaten who recently authored #27086

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions