diff --git a/transformer_models/ViT.py b/transformer_models/ViT.py index a1d8920..05c92c8 100644 --- a/transformer_models/ViT.py +++ b/transformer_models/ViT.py @@ -1,4 +1,6 @@ +from typing import Tuple import torch +from torch.functional import Tensor import torch.nn as nn import torch.nn.functional as F from .decoder import Decoder, DecoderLayer @@ -10,10 +12,86 @@ LearnedPositionalEncoding, ShiftingLearnedPositionalEncoding, ) +import continual as co __all__ = ["ViT_B16", "ViT_B32", "ViT_L16", "ViT_L32", "ViT_H14"] +def CoVisionTransformer( + args, + img_dim, + patch_dim, + out_dim, + embedding_dim, + num_heads, + num_layers, + hidden_dim, + dropout_rate=0.0, + attn_dropout_rate=0.0, + use_representation=True, + conv_patch_representation=False, + positional_encoding_type="learned", + with_camera=True, + with_motion=True, + num_channels=3072, +): + + assert embedding_dim % num_heads == 0 + assert img_dim % patch_dim == 0 + + num_patches = int(img_dim // patch_dim) + seq_length = num_patches # no class token + flatten_dim = patch_dim * patch_dim * num_channels + + linear_encoding = nn.Linear(flatten_dim, embedding_dim) + if positional_encoding_type == "learned": + position_encoding = LearnedPositionalEncoding( + seq_length, embedding_dim, seq_length + ) + elif positional_encoding_type == "fixed": + position_encoding = FixedPositionalEncoding( + embedding_dim, + ) + if positional_encoding_type == "shifting_learned": + position_encoding = ShiftingLearnedPositionalEncoding( + 2 * seq_length, embedding_dim, seq_length + ) + print("position encoding :", positional_encoding_type) + + pe_dropout = nn.Dropout(p=dropout_rate) + + encoder = CoTransformerModel( + embedding_dim, + num_layers, + num_heads, + hidden_dim, + dropout_rate, + attn_dropout_rate, + ) + pre_head_ln = nn.LayerNorm(embedding_dim) + mlp_head = nn.Linear(hidden_dim, out_dim) + + def concat_inputs(inputs: Tuple[Tensor, Tensor]) -> Tensor: + sequence_input_rgb, sequence_input_flow = inputs + if with_camera and with_motion: + x = torch.cat((sequence_input_rgb, sequence_input_flow), 2) + elif with_camera: + x = sequence_input_rgb + elif with_motion: + x = sequence_input_flow + return x + + return co.Sequential( + co.Lambda(concat_inputs), + linear_encoding, + position_encoding, + pe_dropout, + encoder, + pre_head_ln, + mlp_head, + ) + + class VisionTransformer_v3(nn.Module): def __init__( self, @@ -72,15 +150,7 @@ def __init__( self.pe_dropout = nn.Dropout(p=self.dropout_rate) - # self.encoder = TransformerModel( - # embedding_dim, - # num_layers, - # num_heads, - # hidden_dim, - # self.dropout_rate, - # self.attn_dropout_rate, - # ) - self.encoder = CoTransformerModel( + self.encoder = TransformerModel( embedding_dim, num_layers, num_heads, @@ -103,15 +173,6 @@ def __init__( self.mlp_head = nn.Linear(hidden_dim, out_dim) if self.conv_patch_representation: - # self.conv_x = nn.Conv2d( - # self.num_channels, - # self.embedding_dim, - # kernel_size=(self.patch_dim, self.patch_dim), - # stride=(self.patch_dim, self.patch_dim), - # padding=self._get_padding( - # 'VALID', (self.patch_dim, self.patch_dim), - # ), - # ) self.conv_x = nn.Conv1d( self.num_channels, self.embedding_dim, @@ -126,57 +187,11 @@ def __init__( self.conv_x = None self.to_cls_token = nn.Identity() - - # Decoder - # factor = 1 # 5 - # dropout = args.decoder_attn_dropout_rate - # # d_model = args.decoder_embedding_dim - # n_heads = args.decoder_num_heads - # d_layers = args.decoder_layers - # d_ff = ( - # args.decoder_embedding_dim_out - # ) # args.decoder_embedding_dim_out or 4*args.decoder_embedding_dim None - # activation = "gelu" # 'gelu' - # self.decoder = Decoder( - # [ - # DecoderLayer( - # AttentionLayer( - # FullAttention(True, factor, attention_dropout=dropout), # True - # d_model, - # n_heads, - # ), # ProbAttention FullAttention - # AttentionLayer( - # FullAttention( - # False, factor, attention_dropout=dropout - # ), # False - # d_model, - # n_heads, - # ), - # d_model, - # d_ff, - # dropout=dropout, - # activation=activation, - # ) - # for l in range(d_layers) - # ], - # norm_layer=torch.nn.LayerNorm(d_model), - # ) - # self.decoder_cls_token = nn.Parameter(torch.zeros(1, args.query_num, d_model)) - # if positional_encoding_type == "learned": - # self.decoder_position_encoding = LearnedPositionalEncoding( - # args.query_num, self.embedding_dim, args.query_num - # ) - # elif positional_encoding_type == "fixed": - # self.decoder_position_encoding = FixedPositionalEncoding( - # self.embedding_dim, - # ) - print("position decoding :", positional_encoding_type) self.classifier = nn.Linear(d_model, out_dim) self.after_dropout = nn.Dropout(p=self.dropout_rate) - # self.merge_fc = nn.Linear(d_model, 1) - # self.merge_sigmoid = nn.Sigmoid() - def forward(self, sequence_input_rgb, sequence_input_flow): + def forward(self, inputs: Tuple[Tensor, Tensor]): + sequence_input_rgb, sequence_input_flow = inputs if self.with_camera and self.with_motion: x = torch.cat((sequence_input_rgb, sequence_input_flow), 2) elif self.with_camera: @@ -185,35 +200,16 @@ def forward(self, sequence_input_rgb, sequence_input_flow): x = sequence_input_flow x = self.linear_encoding(x) - # cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) - # x = torch.cat((cls_tokens, x), dim=1) - # x = torch.cat((x, cls_tokens), dim=1) x = self.position_encoding(x) - x = self.pe_dropout(x) # not delete + x = self.pe_dropout(x) # apply transformer x = self.encoder(x) x = self.pre_head_ln(x) # [128, 33, 1024] - # x = self.after_dropout(x) # add - # decoder - # decoder_cls_token = self.decoder_cls_token.expand(x.shape[0], -1, -1) - # # decoder_cls_token = self.after_dropout(decoder_cls_token) # add - # # decoder_cls_token = self.decoder_position_encoding(decoder_cls_token) # [128, 8, 1024] - # dec = self.decoder(decoder_cls_token, x) # [128, 8, 1024] - # dec = self.after_dropout(dec) # add - # # merge_atte = self.merge_sigmoid(self.merge_fc(dec)) # [128, 8, 1] - # # dec_for_token = (merge_atte*dec).sum(dim=1) # [128, 1024] - # # dec_for_token = (merge_atte*dec).sum(dim=1)/(merge_atte.sum(dim=-2) + 0.0001) - # dec_for_token = dec.mean(dim=1) - # # dec_for_token = dec.max(dim=1)[0] - # dec_cls_out = self.classifier(dec) - # # set_trace() - # # x = self.to_cls_token(x[:, 0]) - # x = torch.cat((self.to_cls_token(x[:, -1]), dec_for_token), dim=1) x = self.mlp_head(x) # x = F.log_softmax(x, dim=-1) - return x[:, -1] # x , dec_cls_out + return x[:, -1] # x def _get_padding(self, padding_type, kernel_size): assert padding_type in ["SAME", "VALID"] @@ -233,7 +229,7 @@ def ViT_B16(dataset="imagenet"): out_dim = 10 patch_dim = 4 - return VisionTransformer( + return VisionTransformer_v3( img_dim=img_dim, patch_dim=patch_dim, out_dim=out_dim, @@ -260,7 +256,7 @@ def ViT_B32(dataset="imagenet"): out_dim = 10 patch_dim = 4 - return VisionTransformer( + return VisionTransformer_v3( img_dim=img_dim, patch_dim=patch_dim, out_dim=out_dim, @@ -287,7 +283,7 @@ def ViT_L16(dataset="imagenet"): out_dim = 10 patch_dim = 4 - return VisionTransformer( + return VisionTransformer_v3( img_dim=img_dim, patch_dim=patch_dim, out_dim=out_dim, @@ -314,7 +310,7 @@ def ViT_L32(dataset="imagenet"): out_dim = 10 patch_dim = 4 - return VisionTransformer( + return VisionTransformer_v3( img_dim=img_dim, patch_dim=patch_dim, out_dim=out_dim, @@ -341,7 +337,7 @@ def ViT_H14(dataset="imagenet"): out_dim = 10 patch_dim = 4 - return VisionTransformer( + return VisionTransformer_v3( img_dim=img_dim, patch_dim=patch_dim, out_dim=out_dim,