Skip to content

[BUG] Wrong scale in Attention2d in fused_attn mode #2385

Closed
@laclouis5

Description

@laclouis5

Describe the bug
When the fused_attn is used, the scale of the attention is not specified in torch.nn.functional.scaled_dot_product_attention and the value defaults to q.size(-1) ** -0.5, which is different from the default from the Attention2d layer (num_heads ** -0.5).

This means that the results from the fused implementation and the vanilla one are different.

class Attention2d(nn.Module):
fused_attn: torch.jit.Final[bool]
""" multi-head attention for 2D NCHW tensors"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
num_heads: int = 32,
bias: bool = True,
expand_first: bool = False,
head_first: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.
):
super().__init__()
dim_out = dim_out or dim
dim_attn = dim_out if expand_first else dim
self.num_heads = num_heads
self.dim_head = dim_attn // num_heads
self.head_first = head_first
self.scale = num_heads ** -0.5
self.fused_attn = use_fused_attn()
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
B, C, H, W = x.shape
if self.head_first:
q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2)
else:
q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
if self.fused_attn:
x = torch.nn.functional.scaled_dot_product_attention(
q.transpose(-1, -2).contiguous(),
k.transpose(-1, -2).contiguous(),
v.transpose(-1, -2).contiguous(),
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.,
).transpose(-1, -2).reshape(B, -1, H, W)
else:
q = q * self.scale
attn = q.transpose(-2, -1) @ k
if attn_mask is not None:
# NOTE: assumes mask is float and in correct shape
attn = attn + attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
x = self.proj(x)
x = self.proj_drop(x)
return x

Expected behavior
Same results for the two implementations.

Desktop (please complete the following information):

  • OS: macOS
  • This repository version: 1.0.12
  • PyTorch version 2.5 (CPU)

Additional context
Add any other context about the problem here.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions