Skip to content

Commit

Permalink
Merge pull request #174 from wufei-png/wf/einsum
Browse files Browse the repository at this point in the history
replace einsum() with other ops
  • Loading branch information
wondervictor authored Mar 26, 2024
2 parents 3264b61 + 04954ff commit c960c05
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions yolo_world/models/layers/yolo_bricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def __init__(self,
with_scale: bool = False,
num_feats: int = 3,
num_heads: int = 8,
pool_size: int = 3):
pool_size: int = 3,
use_einsum: bool = True):
super().__init__()

self.text_channels = text_channels
Expand All @@ -169,7 +170,7 @@ def __init__(self,
self.num_feats = num_feats
self.head_channels = embed_channels // num_heads
self.pool_size = pool_size

self.use_einsum = use_einsum
if with_scale:
self.scale = nn.Parameter(torch.tensor([0.]), requires_grad=True)
else:
Expand Down Expand Up @@ -209,12 +210,21 @@ def forward(self, text_features, image_features):
q = q.reshape(B, -1, self.num_heads, self.head_channels)
k = k.reshape(B, -1, self.num_heads, self.head_channels)
v = v.reshape(B, -1, self.num_heads, self.head_channels)
if self.use_einsum:
attn_weight = torch.einsum('bnmc,bkmc->bmnk', q, k)
else:
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 3, 1)
attn_weight = torch.matmul(q, k)

attn_weight = torch.einsum('bnmc,bkmc->bmnk', q, k)
attn_weight = attn_weight / (self.head_channels**0.5)
attn_weight = F.softmax(attn_weight, dim=-1)

x = torch.einsum('bmnk,bkmc->bnmc', attn_weight, v)
if self.use_einsum:
x = torch.einsum('bmnk,bkmc->bnmc', attn_weight, v)
else:
v = v.permute(0, 2, 1, 3)
x = torch.matmul(attn_weight, v)
x = x.permute(0, 2, 1, 3)
x = self.proj(x.reshape(B, -1, self.embed_channels))
return x * self.scale + text_features

Expand Down

0 comments on commit c960c05

Please sign in to comment.