Description
🐛 Bug
I am trying to use xformers to replace my native pytorch MHA implementation, sth like:
scale = 1 / query.shape[-1] ** 0.5
query = query * scale
attn = query @ key.transpose(-2, -1)
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
attn = F.dropout(attn, p)
return attn @ value
After switching to xformers, I am usiing xops.memory_efficient_attention(q, k ,v, attn_bias).
This works fine when I am using a lower triangular mask, either by passing in a LowerTriangularMask() or passing a torch.tensor with the same shape that built by my own.
However, when I am switching to use a arbitrary mask (supposing in pretraining stage, you opened reset_position_ids
and reset_attention_mask
flags, so you'll get a new start inside one sequence), I am getting NANs during evaluating (no grad forward) or training (with grad). Based on the log, the program is using the CUTLASS op.
Based on my observations, xformers saves 10-15% GPU memory and improves overall TFLOPs by 10-15%, so I really want to use it to replace with my native pytorch implementations. Could you help on this issue?
Environment
some depenedencies
trition==2.0
xformers==0.0.18
pytorch==2.0
Activity