Skip to content

[0.0.18, memory_efficient_attention with attn_bias]Getting NANs with arbitrary attn_bias mask with xformers==0.0.18  #722

Closed
@toothacher17

Description

@toothacher17

🐛 Bug

I am trying to use xformers to replace my native pytorch MHA implementation, sth like:

scale = 1 / query.shape[-1] ** 0.5
query = query * scale
attn = query @ key.transpose(-2, -1)
if attn_bias is not None:
    attn = attn + attn_bias
attn = attn.softmax(-1)
attn = F.dropout(attn, p)
return attn @ value

After switching to xformers, I am usiing xops.memory_efficient_attention(q, k ,v, attn_bias).

This works fine when I am using a lower triangular mask, either by passing in a LowerTriangularMask() or passing a torch.tensor with the same shape that built by my own.

However, when I am switching to use a arbitrary mask (supposing in pretraining stage, you opened reset_position_ids and reset_attention_mask flags, so you'll get a new start inside one sequence), I am getting NANs during evaluating (no grad forward) or training (with grad). Based on the log, the program is using the CUTLASS op.

Based on my observations, xformers saves 10-15% GPU memory and improves overall TFLOPs by 10-15%, so I really want to use it to replace with my native pytorch implementations. Could you help on this issue?

Environment

some depenedencies
trition==2.0
xformers==0.0.18
pytorch==2.0

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingongoing

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions