Skip to content

Commit

Permalink
Remove inefficient computation from AttentionPool2d Module (openai#271
Browse files Browse the repository at this point in the history
)

* fix inefficient attention computation

* remove erroneous formatting

* simplified flatten

Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
  • Loading branch information
jenkspt and jongwook authored Jul 21, 2022
1 parent 4d120f3 commit f69a9bc
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim:
self.num_heads = num_heads

def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x, key=x, value=x,
query=x[:1], key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
Expand All @@ -88,8 +88,7 @@ def forward(self, x):
training=self.training,
need_weights=False
)

return x[0]
return x.squeeze(0)


class ModifiedResNet(nn.Module):
Expand Down

0 comments on commit f69a9bc

Please sign in to comment.