diff --git a/transformer_models/Transformer.py b/transformer_models/Transformer.py index cf66e9a..7a7f1c2 100644 --- a/transformer_models/Transformer.py +++ b/transformer_models/Transformer.py @@ -1,5 +1,7 @@ from torch import nn from .Attention import SelfAttention +import continual as co +from continual_transformers import CoReMultiheadAttention, CoSiMultiheadAttention class Residual(nn.Module): @@ -80,6 +82,51 @@ def forward(self, x): return self.net(x) +def CoTransformerModel( + dim, + depth, + heads, + mlp_dim, + dropout_rate=0.1, + attn_dropout_rate=0.1, + sequence_len=64, +): + assert depth in {1, 2} + + layers = [] + + for d in range(depth): + CoMHA = ( + CoReMultiheadAttention if d == 0 and depth == 2 else CoSiMultiheadAttention + ) + layers.extend( + [ + co.Residual( + co.Sequential( + co.forward_stepping(nn.LayerNorm(dim)), + CoMHA( + dim, + heads, + attn_dropout_rate, + sequence_len=sequence_len, + forward_returns_attn_mask=False, + ), + nn.Dropout(p=dropout_rate), + ) + ), + co.Residual( + co.Sequential( + co.forward_stepping(nn.LayerNorm(dim)), + co.forward_stepping(FeedForward(dim, mlp_dim, dropout_rate)), + ) + ), + ] + ) + + net = co.Sequential(*layers) + return net + + def _register_ptflops(): try: from ptflops import flops_counter as fc diff --git a/transformer_models/ViT.py b/transformer_models/ViT.py index 4dec168..a1d8920 100644 --- a/transformer_models/ViT.py +++ b/transformer_models/ViT.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from .decoder import Decoder, DecoderLayer from .attn import FullAttention, ProbAttention, AttentionLayer -from .Transformer import TransformerModel +from .Transformer import TransformerModel, CoTransformerModel from ipdb import set_trace from .PositionalEncoding import ( FixedPositionalEncoding, @@ -72,7 +72,15 @@ def __init__( self.pe_dropout = nn.Dropout(p=self.dropout_rate) - self.encoder = TransformerModel( + # self.encoder = TransformerModel( + # embedding_dim, + # num_layers, + # num_heads, + # hidden_dim, + # self.dropout_rate, + # self.attn_dropout_rate, + # ) + self.encoder = CoTransformerModel( embedding_dim, num_layers, num_heads,