Skip to content

Commit

Permalink
Add CoVisionTransformer and clean up VisionTransformer_v3
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasHedegaard committed Dec 13, 2021
1 parent 168a55a commit 24209db
Showing 1 changed file with 88 additions and 92 deletions.
180 changes: 88 additions & 92 deletions transformer_models/ViT.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Tuple
import torch
from torch.functional import Tensor
import torch.nn as nn
import torch.nn.functional as F
from .decoder import Decoder, DecoderLayer
Expand All @@ -10,10 +12,86 @@
LearnedPositionalEncoding,
ShiftingLearnedPositionalEncoding,
)
import continual as co

__all__ = ["ViT_B16", "ViT_B32", "ViT_L16", "ViT_L32", "ViT_H14"]


def CoVisionTransformer(
args,
img_dim,
patch_dim,
out_dim,
embedding_dim,
num_heads,
num_layers,
hidden_dim,
dropout_rate=0.0,
attn_dropout_rate=0.0,
use_representation=True,
conv_patch_representation=False,
positional_encoding_type="learned",
with_camera=True,
with_motion=True,
num_channels=3072,
):

assert embedding_dim % num_heads == 0
assert img_dim % patch_dim == 0

num_patches = int(img_dim // patch_dim)
seq_length = num_patches # no class token
flatten_dim = patch_dim * patch_dim * num_channels

linear_encoding = nn.Linear(flatten_dim, embedding_dim)
if positional_encoding_type == "learned":
position_encoding = LearnedPositionalEncoding(
seq_length, embedding_dim, seq_length
)
elif positional_encoding_type == "fixed":
position_encoding = FixedPositionalEncoding(
embedding_dim,
)
if positional_encoding_type == "shifting_learned":
position_encoding = ShiftingLearnedPositionalEncoding(
2 * seq_length, embedding_dim, seq_length
)
print("position encoding :", positional_encoding_type)

pe_dropout = nn.Dropout(p=dropout_rate)

encoder = CoTransformerModel(
embedding_dim,
num_layers,
num_heads,
hidden_dim,
dropout_rate,
attn_dropout_rate,
)
pre_head_ln = nn.LayerNorm(embedding_dim)
mlp_head = nn.Linear(hidden_dim, out_dim)

def concat_inputs(inputs: Tuple[Tensor, Tensor]) -> Tensor:
sequence_input_rgb, sequence_input_flow = inputs
if with_camera and with_motion:
x = torch.cat((sequence_input_rgb, sequence_input_flow), 2)
elif with_camera:
x = sequence_input_rgb
elif with_motion:
x = sequence_input_flow
return x

return co.Sequential(
co.Lambda(concat_inputs),
linear_encoding,
position_encoding,
pe_dropout,
encoder,
pre_head_ln,
mlp_head,
)


class VisionTransformer_v3(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -72,15 +150,7 @@ def __init__(

self.pe_dropout = nn.Dropout(p=self.dropout_rate)

# self.encoder = TransformerModel(
# embedding_dim,
# num_layers,
# num_heads,
# hidden_dim,
# self.dropout_rate,
# self.attn_dropout_rate,
# )
self.encoder = CoTransformerModel(
self.encoder = TransformerModel(
embedding_dim,
num_layers,
num_heads,
Expand All @@ -103,15 +173,6 @@ def __init__(
self.mlp_head = nn.Linear(hidden_dim, out_dim)

if self.conv_patch_representation:
# self.conv_x = nn.Conv2d(
# self.num_channels,
# self.embedding_dim,
# kernel_size=(self.patch_dim, self.patch_dim),
# stride=(self.patch_dim, self.patch_dim),
# padding=self._get_padding(
# 'VALID', (self.patch_dim, self.patch_dim),
# ),
# )
self.conv_x = nn.Conv1d(
self.num_channels,
self.embedding_dim,
Expand All @@ -126,57 +187,11 @@ def __init__(
self.conv_x = None

self.to_cls_token = nn.Identity()

# Decoder
# factor = 1 # 5
# dropout = args.decoder_attn_dropout_rate
# # d_model = args.decoder_embedding_dim
# n_heads = args.decoder_num_heads
# d_layers = args.decoder_layers
# d_ff = (
# args.decoder_embedding_dim_out
# ) # args.decoder_embedding_dim_out or 4*args.decoder_embedding_dim None
# activation = "gelu" # 'gelu'
# self.decoder = Decoder(
# [
# DecoderLayer(
# AttentionLayer(
# FullAttention(True, factor, attention_dropout=dropout), # True
# d_model,
# n_heads,
# ), # ProbAttention FullAttention
# AttentionLayer(
# FullAttention(
# False, factor, attention_dropout=dropout
# ), # False
# d_model,
# n_heads,
# ),
# d_model,
# d_ff,
# dropout=dropout,
# activation=activation,
# )
# for l in range(d_layers)
# ],
# norm_layer=torch.nn.LayerNorm(d_model),
# )
# self.decoder_cls_token = nn.Parameter(torch.zeros(1, args.query_num, d_model))
# if positional_encoding_type == "learned":
# self.decoder_position_encoding = LearnedPositionalEncoding(
# args.query_num, self.embedding_dim, args.query_num
# )
# elif positional_encoding_type == "fixed":
# self.decoder_position_encoding = FixedPositionalEncoding(
# self.embedding_dim,
# )
print("position decoding :", positional_encoding_type)
self.classifier = nn.Linear(d_model, out_dim)
self.after_dropout = nn.Dropout(p=self.dropout_rate)
# self.merge_fc = nn.Linear(d_model, 1)
# self.merge_sigmoid = nn.Sigmoid()

def forward(self, sequence_input_rgb, sequence_input_flow):
def forward(self, inputs: Tuple[Tensor, Tensor]):
sequence_input_rgb, sequence_input_flow = inputs
if self.with_camera and self.with_motion:
x = torch.cat((sequence_input_rgb, sequence_input_flow), 2)
elif self.with_camera:
Expand All @@ -185,35 +200,16 @@ def forward(self, sequence_input_rgb, sequence_input_flow):
x = sequence_input_flow

x = self.linear_encoding(x)
# cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
# x = torch.cat((cls_tokens, x), dim=1)
# x = torch.cat((x, cls_tokens), dim=1)
x = self.position_encoding(x)
x = self.pe_dropout(x) # not delete
x = self.pe_dropout(x)

# apply transformer
x = self.encoder(x)
x = self.pre_head_ln(x) # [128, 33, 1024]
# x = self.after_dropout(x) # add
# decoder
# decoder_cls_token = self.decoder_cls_token.expand(x.shape[0], -1, -1)
# # decoder_cls_token = self.after_dropout(decoder_cls_token) # add
# # decoder_cls_token = self.decoder_position_encoding(decoder_cls_token) # [128, 8, 1024]
# dec = self.decoder(decoder_cls_token, x) # [128, 8, 1024]
# dec = self.after_dropout(dec) # add
# # merge_atte = self.merge_sigmoid(self.merge_fc(dec)) # [128, 8, 1]
# # dec_for_token = (merge_atte*dec).sum(dim=1) # [128, 1024]
# # dec_for_token = (merge_atte*dec).sum(dim=1)/(merge_atte.sum(dim=-2) + 0.0001)
# dec_for_token = dec.mean(dim=1)
# # dec_for_token = dec.max(dim=1)[0]
# dec_cls_out = self.classifier(dec)
# # set_trace()
# # x = self.to_cls_token(x[:, 0])
# x = torch.cat((self.to_cls_token(x[:, -1]), dec_for_token), dim=1)
x = self.mlp_head(x)
# x = F.log_softmax(x, dim=-1)

return x[:, -1] # x , dec_cls_out
return x[:, -1] # x

def _get_padding(self, padding_type, kernel_size):
assert padding_type in ["SAME", "VALID"]
Expand All @@ -233,7 +229,7 @@ def ViT_B16(dataset="imagenet"):
out_dim = 10
patch_dim = 4

return VisionTransformer(
return VisionTransformer_v3(
img_dim=img_dim,
patch_dim=patch_dim,
out_dim=out_dim,
Expand All @@ -260,7 +256,7 @@ def ViT_B32(dataset="imagenet"):
out_dim = 10
patch_dim = 4

return VisionTransformer(
return VisionTransformer_v3(
img_dim=img_dim,
patch_dim=patch_dim,
out_dim=out_dim,
Expand All @@ -287,7 +283,7 @@ def ViT_L16(dataset="imagenet"):
out_dim = 10
patch_dim = 4

return VisionTransformer(
return VisionTransformer_v3(
img_dim=img_dim,
patch_dim=patch_dim,
out_dim=out_dim,
Expand All @@ -314,7 +310,7 @@ def ViT_L32(dataset="imagenet"):
out_dim = 10
patch_dim = 4

return VisionTransformer(
return VisionTransformer_v3(
img_dim=img_dim,
patch_dim=patch_dim,
out_dim=out_dim,
Expand All @@ -341,7 +337,7 @@ def ViT_H14(dataset="imagenet"):
out_dim = 10
patch_dim = 4

return VisionTransformer(
return VisionTransformer_v3(
img_dim=img_dim,
patch_dim=patch_dim,
out_dim=out_dim,
Expand Down

0 comments on commit 24209db

Please sign in to comment.