Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimenting with differential attention #2314

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,82 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class DiffAttention(nn.Module):
fused_attn: Final[bool]

def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RmsNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads // 2
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

self.lambda_init = 0.8
self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))

self.sub_norm = RmsNorm(2 * self.head_dim, eps=1e-5)

def _set_lambda_init(self, depth: int):
self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth)

def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
q, k, v = self.qkv(x).chunk(3, dim=2)
q = q.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
k = k.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
v = v.reshape(B, N, self.num_heads, 2 * self.head_dim).transpose(1, 2)
q, k = self.q_norm(q), self.k_norm(k)

if self.fused_attn:
q = q.reshape(B, self.num_heads, 2, N, self.head_dim)
k = k.reshape(B, self.num_heads, 2, N, self.head_dim)
q1, q2 = q.unbind(2)
k1, k2 = k.unbind(2)
attn1 = F.scaled_dot_product_attention(q1, k1, v)
attn2 = F.scaled_dot_product_attention(q2, k2, v)
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
x = attn1 - lambda_full * attn2
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
attn = attn.view(B, self.num_heads, 2, N, N)
attn = attn[:, :, 0] - lambda_full * attn[:, :, 1]
x = attn @ v

x = self.sub_norm(x)
x = x * (1 - self.lambda_init)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x


class LayerScale(nn.Module):
def __init__(
self,
Expand Down
Loading