Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
abhi-glitchhg committed Mar 25, 2022
1 parent e4c9d36 commit a7637d9
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions vformer/attention/cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ..utils import ATTENTION_REGISTRY



@ATTENTION_REGISTRY.register()
class CrossAttention(nn.Module):
"""
Expand All @@ -30,14 +29,20 @@ def __init__(self, cls_dim, patch_dim, num_heads=8, head_dim=64):
inner_dim = num_heads * head_dim
self.num_heads = num_heads
self.scale = head_dim ** -0.5
self.fl = nn.Linear(cls_dim, patch_dim) if cls_dim != patch_dim else nn.Identity()
self.fl = (
nn.Linear(cls_dim, patch_dim) if cls_dim != patch_dim else nn.Identity()
)

self.gl = nn.Linear(patch_dim, cls_dim) if patch_dim != cls_dim else nn.Identity()
self.gl = (
nn.Linear(patch_dim, cls_dim) if patch_dim != cls_dim else nn.Identity()
)

self.to_k = nn.Linear(patch_dim, inner_dim)
self.to_v = nn.Linear(patch_dim, inner_dim)
self.to_q = nn.Linear(patch_dim, inner_dim)
self.cls_project = nn.Linear(inner_dim, patch_dim) if inner_dim != patch_dim else nn.Identity()
self.cls_project = (
nn.Linear(inner_dim, patch_dim) if inner_dim != patch_dim else nn.Identity()
)

self.attend = nn.Softmax(dim=-1)

Expand Down

0 comments on commit a7637d9

Please sign in to comment.