Skip to content

Commit

Permalink
Add Adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
xq141839 authored Sep 27, 2023
1 parent a7079bf commit 1ce6d9d
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions modeling/image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from typing import Optional, Tuple, Type

from .common import LayerNorm2d, MLPBlock
from .common import LayerNorm2d, MLPBlock, Adapter


# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
Expand Down Expand Up @@ -149,6 +149,7 @@ def __init__(
"""
super().__init__()
self.norm1 = norm_layer(dim)

self.attn = Attention(
dim,
num_heads=num_heads,
Expand All @@ -163,6 +164,13 @@ def __init__(

self.window_size = window_size

#-----------------------------------------------

self.ft = Adapter(dim)
self.MLP_Adapter = Adapter(dim, skip_connect=False)

#-----------------------------------------------

def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
Expand All @@ -172,12 +180,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x, pad_hw = window_partition(x, self.window_size)

x = self.attn(x)
#-----------------------------------------------

x = self.ft(x)

#-----------------------------------------------

# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))

x = shortcut + x
x = x + self.mlp(self.norm2(x))

#-----------------------------------------------
xn = self.norm2(x)
x = x + self.mlp(xn) + 0.5 * self.MLP_Adapter(xn)

# x = x + self.mlp(self.norm2(x))
#-----------------------------------------------

return x

Expand Down

0 comments on commit 1ce6d9d

Please sign in to comment.