From bb1923713caa47380cebc9b555b49af9eda9d3bb Mon Sep 17 00:00:00 2001 From: Andrea Lombardi Date: Mon, 8 Aug 2022 10:57:29 +0200 Subject: [PATCH] comments --- modelv2/SegformerSkipDecodeHead.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/modelv2/SegformerSkipDecodeHead.py b/modelv2/SegformerSkipDecodeHead.py index 1b7a68b..ed716ca 100644 --- a/modelv2/SegformerSkipDecodeHead.py +++ b/modelv2/SegformerSkipDecodeHead.py @@ -19,13 +19,13 @@ def forward(self, hidden_states: Tensor): class SegformerSkipDecodeHead(SegformerPreTrainedModel): def __init__(self, config): super().__init__(config) + print("Using SkipDecodeHead") # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size mlps = [] for i in range(config.num_encoder_blocks): mlp = SegformerMLP(config, input_dim=config.hidden_sizes[i]) mlps.append(mlp) self.linear_c = nn.ModuleList(mlps) - #self.linear_c = nn.ModuleList(mlps.reverse()) # the following 3 layers implement the ConvModule of the original implementation self.linear_fuse = nn.Conv2d( @@ -35,6 +35,7 @@ def __init__(self, config): bias=False, ) + # This layer is used to only fuse 2 hidden states, and not num_encoder_blocks hidden states (four) self.linear_fuse_2_hidden_states = nn.Conv2d( in_channels=config.decoder_hidden_size * 2, out_channels=config.decoder_hidden_size, @@ -54,12 +55,6 @@ def forward(self, encoder_hidden_states): batch_size = encoder_hidden_states[-1].shape[0] all_hidden_states = () - #print(encoder_hidden_states[0].shape) - #print(encoder_hidden_states[1].shape) - #print(encoder_hidden_states[2].shape) - #print(encoder_hidden_states[3].shape) - #print(encoder_hidden_states[4].shape) - #input() # MY VERSION #reversed(encoder_hidden_states) @@ -84,19 +79,19 @@ def forward(self, encoder_hidden_states): ) if idx==3: + # If it is the last hidden_state (with size [8, 256, 16, 16]) + # print(encoder_hidden_state.shape) # [8, 256, 16, 16] # 1. First, multi-level features Fi from the MiT encoder goes through an MLP layer to unify the channel dimension height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3] encoder_hidden_state = mlp(encoder_hidden_state) - #print(encoder_hidden_state.shape) + # print(encoder_hidden_state.shape) # [8, 256, 256] encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1) encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width) - # Partendo dall'ultimo... es. H/32xW/32 # 2. Features are upsampled to the previous encoder block size encoder_hidden_state = nn.functional.interpolate( encoder_hidden_state, size=encoder_hidden_states[idx-1].size()[2:], mode="bilinear", align_corners=False ) - - + # print(encoder_hidden_state.shape) #print(encoder_hidden_state.shape) all_hidden_states += (encoder_hidden_state,) else: # 1. First, multi-level features Fi from the MiT encoder goes through an MLP layer to unify the channel dimension @@ -107,6 +102,8 @@ def forward(self, encoder_hidden_states): encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width) all_hidden_states += (encoder_hidden_state,) + # Now we have 2 hidden_states of the same size. The previous one upsampled and the current one not-upsampled + # both have the same channel dimension (otherwise they can't be fused) #print(all_hidden_states[0].shape) #print(all_hidden_states[1].shape) #fuse the concatenated features @@ -115,16 +112,22 @@ def forward(self, encoder_hidden_states): hidden_states = self.activation(hidden_states) fused_hidden_states = self.dropout(hidden_states) - #print("fused: ", fused_hidden_states.shape) + # print("fused: ", fused_hidden_states.shape) # the shape is the same, e.g. + # [8, 256, 32, 32] fused with [8, 256, 32, 32] + # gives [8, 256, 32, 32] + if idx!=0: + # If it is not the first hidden_state (from the first encoder block) #print(idx) # 2. Features are upsampled to the previous encoder block size + # UPSAMPLE THE FUSED HIDDEN STATE upsampled_hidden_states = nn.functional.interpolate( fused_hidden_states, size=encoder_hidden_states[idx-1].size()[2:], mode="bilinear", align_corners=False ) #print("upsampled: ", upsampled_hidden_states.shape) all_hidden_states = () + # We just upsample to unify on the next iteration all_hidden_states += (upsampled_hidden_states,) logits = self.classifier(fused_hidden_states)