Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreaLombax committed Aug 8, 2022
1 parent 39d439e commit bb19237
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions modelv2/SegformerSkipDecodeHead.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ def forward(self, hidden_states: Tensor):
class SegformerSkipDecodeHead(SegformerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
print("Using SkipDecodeHead")
# linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size
mlps = []
for i in range(config.num_encoder_blocks):
mlp = SegformerMLP(config, input_dim=config.hidden_sizes[i])
mlps.append(mlp)
self.linear_c = nn.ModuleList(mlps)
#self.linear_c = nn.ModuleList(mlps.reverse())

# the following 3 layers implement the ConvModule of the original implementation
self.linear_fuse = nn.Conv2d(
Expand All @@ -35,6 +35,7 @@ def __init__(self, config):
bias=False,
)

# This layer is used to only fuse 2 hidden states, and not num_encoder_blocks hidden states (four)
self.linear_fuse_2_hidden_states = nn.Conv2d(
in_channels=config.decoder_hidden_size * 2,
out_channels=config.decoder_hidden_size,
Expand All @@ -54,12 +55,6 @@ def forward(self, encoder_hidden_states):
batch_size = encoder_hidden_states[-1].shape[0]

all_hidden_states = ()
#print(encoder_hidden_states[0].shape)
#print(encoder_hidden_states[1].shape)
#print(encoder_hidden_states[2].shape)
#print(encoder_hidden_states[3].shape)
#print(encoder_hidden_states[4].shape)
#input()
# MY VERSION

#reversed(encoder_hidden_states)
Expand All @@ -84,19 +79,19 @@ def forward(self, encoder_hidden_states):
)

if idx==3:
# If it is the last hidden_state (with size [8, 256, 16, 16])
# print(encoder_hidden_state.shape) # [8, 256, 16, 16]
# 1. First, multi-level features Fi from the MiT encoder goes through an MLP layer to unify the channel dimension
height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
encoder_hidden_state = mlp(encoder_hidden_state)
#print(encoder_hidden_state.shape)
# print(encoder_hidden_state.shape) # [8, 256, 256]
encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width)
# Partendo dall'ultimo... es. H/32xW/32
# 2. Features are upsampled to the previous encoder block size
encoder_hidden_state = nn.functional.interpolate(
encoder_hidden_state, size=encoder_hidden_states[idx-1].size()[2:], mode="bilinear", align_corners=False
)


# print(encoder_hidden_state.shape) #print(encoder_hidden_state.shape)
all_hidden_states += (encoder_hidden_state,)
else:
# 1. First, multi-level features Fi from the MiT encoder goes through an MLP layer to unify the channel dimension
Expand All @@ -107,6 +102,8 @@ def forward(self, encoder_hidden_states):
encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width)

all_hidden_states += (encoder_hidden_state,)
# Now we have 2 hidden_states of the same size. The previous one upsampled and the current one not-upsampled
# both have the same channel dimension (otherwise they can't be fused)
#print(all_hidden_states[0].shape)
#print(all_hidden_states[1].shape)
#fuse the concatenated features
Expand All @@ -115,16 +112,22 @@ def forward(self, encoder_hidden_states):
hidden_states = self.activation(hidden_states)
fused_hidden_states = self.dropout(hidden_states)

#print("fused: ", fused_hidden_states.shape)
# print("fused: ", fused_hidden_states.shape) # the shape is the same, e.g.
# [8, 256, 32, 32] fused with [8, 256, 32, 32]
# gives [8, 256, 32, 32]


if idx!=0:
# If it is not the first hidden_state (from the first encoder block)
#print(idx)
# 2. Features are upsampled to the previous encoder block size
# UPSAMPLE THE FUSED HIDDEN STATE
upsampled_hidden_states = nn.functional.interpolate(
fused_hidden_states, size=encoder_hidden_states[idx-1].size()[2:], mode="bilinear", align_corners=False
)
#print("upsampled: ", upsampled_hidden_states.shape)
all_hidden_states = ()
# We just upsample to unify on the next iteration
all_hidden_states += (upsampled_hidden_states,)

logits = self.classifier(fused_hidden_states)
Expand Down

0 comments on commit bb19237

Please sign in to comment.